Flash Attention

softmax

naive

\(\sigma(z_i) = \frac{e^{z_{i}}}{\sum_{j=1}^K e^{z_{j}}} \ \ \ for\ i=1,2,\dots,K\)

1

safe softmax

It involves subtracting the maximum value from all the output values before applying the softmax equation. This step helps to prevent the exponent from becoming extremely large, which could lead to computational challenges or overflow errors.

>>> np.exp(0.5)
np.float64(1.6487212707001282)
>>> np.exp(1)
np.float64(2.718281828459045)
>>> np.exp(10)
np.float64(22026.465794806718)
np.exp(x - max_x) / sum(np.exp(x - max_x))

2

safe softmax with online normalizer calculation

Essentially, the algorithm keeps the maximum value \(m\) and the normalization term \(d\) as it iterates over elements of the input array. At each iteration it needs to adjust the normalizer \(d\) to the new maximum \(m_j\) and only then add new value to the normalizer.

online softmax

# loop 1: get the maximum value of x and the accumulated exponential values
max_x = -np.inf
accum_exp = 0.
for t in x:
    max_x_new = t if t > max_x else max_x
    accum_exp = np.exp(max_x - max_x_new) * accum_exp + np.exp(t - max_x_new)
    max_x = max_x_new

# loop 2: get the softmax output by dividing the exponential of `x-max(x)` with `accum_exp`
out = [0. for _ in range(len(x))]
for i, t in enumerate(x):
    out[i] = np.exp(t - max_x) / accum_exp

tiling

>>> a = np.array([0.1, 0.5, 0.4, 0.2, 0.3, 0.3])
>>> sum([np.exp(aa - 0.5) for aa in a])
4.9534
>>> l1 = np.exp(0.1 - 0.5) + np.exp(0.5 - 0.5) + np.exp(0.4 - 0.5)
2.5751
>>> l2 = np.exp(0.2 - 0.3) + np.exp(0.3 - 0.3) + np.exp(0.3 - 0.3)
2.9048
>>> np.exp(0.5 - 0.5) * l1 + np.exp(0.3 - 0.5) * l2
4.9534

Flash Attention (Tiling)

C++, Torch Profiling

fanano = load(name='fanano', sources=['fanano.cpp', 'fanano.cu'], verbose=True)

빌드시 undefined symbol 오류는 argument type이 잘못되어 있는 경우이므로 주의

프로파일링 시 레거시인 torch.autograd.profiler.profiletorch.profiler.profile는 결과가 다르므로 주의, 특히 CUDA 실행 시간이 많이 차이가 난다.

Last Modified: 2024/10/29 18:50:38

is a collection of Papers I have written.
© 2000 - Sang Park Except where otherwise noted, content on this site is licensed under a CC BY 4.0.
This site design was brought from Distill.