Understanding Flash Attention: Writing the Algorithm from Scratch in Triton

Why is Flash Attention so fast? Find out how Flash Attention works. Afterward, we'll polish our understanding by writing a GPU kernel of the algorithm in Triton.

Alex Dremov
Kernel jobs data management | Image by the author

Kernel jobs data management | Image by the author

Flash Attention is a revolutionary technique that dramatically accelerates the attention mechanism in transformer-based models, delivering processing speeds many times faster than naive methods. By cleverly tiling data and minimizing memory transfers, it tackles the notorious GPU memory bottleneck that large language models often struggle with.

In this post, we’ll dive into how Flash Attention leverages efficient I/O-awareness to reduce overhead, then take it a step further by crafting a block-sparse attention kernel in Triton.

πŸ’₯
I will provide a simple explanation of how Flash Attention works. We will then implement the explained algorithm in Triton!

What is Attention?

The attention mechanism (or scaled dot-product attention) is a core element of transformer models, which is a leading architecture for solving the problem of language modeling. All popular models, like GPT, LLaMA, and BERT, rely on attention.

The formula is pretty simple:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V,\\Q, K, V\;β€”\; \text{query, key, value tensors}$$

The rest is history.

Even though the formula looks simple, its computation involves multiplications of large tensors and a lot of data movement. Considering that this is a core part of the transformer architecture, optimizing the algorithm greatly improves the performance of the model in general.

In the naive implementation, attention requires \(O(n^2)\) additional memory and \(O(n^2)\) compute time complexity, where \(n\) is the sequence length. That's a lot!

Flash Attention

Core Idea

The main idea of Flash attention can be summarized in a simple quote from the original paper:

We argue that a missing principle is making attention algorithms IO-aware β€” accounting for reads and writes between levels of GPU memory.

That is, modern GPUs have several types of memory:

  • SRAM β€” fast, on-chip, small
  • HBM β€” slower than SRAM, large size. That's what we usually address as GPU memory.

Check out the memory hierarchy in the image below to see the differences in bandwidth and sizes of different memory types.

Image from FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness by Tri Dao et al.
πŸ’‘
To conduct computation, data must be transferred from HBM to SRAM, and this transfer is not overhead-free!

The Flash Attention algorithm proposes a method of computing attention in tiles, without explicitly materializing the attention scores tensor:

$$\text{AttentionScores}(Q, K) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)$$

πŸ’₯
Not materializing a matrix means that at any given time, the matrix does not exist in its full shape in memory.

It's easy to see that this matrix requires \(O(n^2)\) of memory to store. For large sequence lengths, that's a lot of data! So, if we manage to avoid explicitly materializing this matrix, we can save lots of memory.

However, this matrix is necessary for transformer training as it is a part of backpropagation and gradient calculation. The authors propose that it's better to recalculate this matrix during the backward pass (again without explicit materialization). Not only does this saves lots of memory, but it also provides huge speedups as we don't need to transfer this enormous matrix between different GPU memory types.

Overall, such an approach did not only speed up calculations by taking GPU I/O specifics into account, but also allowed processing huge sequence lengths as memory complexity drops to \(O(n)\).

Tiled Attention Calculation

The last thing to understand is how to compute attention in tiles. Basically, this means that we will calculate attention over the full sequence by processing incoming tokens in small portions.

Well, it's easy to calculate \(QK^T\) in tiles. Considering that attention dimension is not high, we can load full matrix rows and columns and conduct multiplication in tiles.

😑
Yes, if we want to have an enormous attention dimension, Flash Attention will not work without algorithm modifications.

As dimensions are usually quite small even for enormous models, this limitation is fair.
Tiled QK^T | Image by the author
Tiled QK^T | Image by the author

So, we have \(QK^T\) calculated in SRAM. All that's left is to apply softmax, multiply by \(V\), and that's it!

$$\text{Softmax}(z_i) = \frac{e^{z_{i}}}{\sum_{j=1}^T e^{z_{j}}} \; \; \text{for}\; i = 1, 2,\ldots, T$$

That's where the trick is.

The problem is that the softmax denominator requires aggregation over the sequence length to normalize scores, and we do not have access to the whole length as we load data in tiles.

To address it, we can implement a concatenated softmax algorithm. Using it, we can calculate softmax "in batch" mode: by adjusting computed values with the new incoming data.

Taking the algorithm from the original article, we can define rules to compute the softmax over data concatenation. Having two vectors \(x^{(1)}\) and \(x^{(2)}\), we need to calculate the softmax denominator \(l(x)\) over those vectors' concatenation: \(x = \left[x^{(1)}, x^{(2)}\right]\). If the vector's maximum is \(m(x)\), we can easily derive the softmax denominator of the concatenation:

$$m(x) = m\left(\left[x^{(1)}, x^{(2)}\right]\right) = m(m(x^{(1)}), m(x^{(2)})),$$

$$l(x) = l\left(\left[x^{(1)}, x^{(2)}\right]\right) = e^{m(x^{(1)}) - m(x)}l(x^{(1)}) + e^{m(x^{(2)}) - m(x)}l(x^{(2)}).$$

The last equivalence can be easily verified as \(l(x)=\sum_{j=1}^{T} e^{x_{j}}.\)

So, now we have what we want β€” we can calculate softmax per-tile and then, by doing re-normalization from the formula above, compute the global softmax. The last thing to do is to incorporate the tile of the \(V\) tensor and keep doing the same re-normalization (as matrix multiplication is a linear operation).

And all of this without loading the full sequence into memory or materializing \(QK^T\)!

πŸ’₯
Notice that we calculate \(\text{Softmax}\left(QK^T\right)\) in tiles only, without needing to have the whole matrix at any moment.

Also, in the actual algorithm for numerical stability, we will compute not \(\text{Softmax}(x)\) but \(\text{Softmax}(x - \max(x))\). We can do that as softmax is invariant to constant shifts.

Triton Implementation

Now, we can easily implement the outlined algorithm in Triton, which is a tool that allows us to write efficient GPU kernels with the ease of Python.

πŸ’‘
To learn more about Triton, check out their official guides.
Tutorials β€” Triton documentation
Subscribe and don't miss posts!

Outlining the Algorithm

The first step is to decide how we will assign jobs and what data each job will load. By the algorithm of tiled softmax, each job must have access to \(K, V\) over the whole sequence length. So, each job will iterate over \(K, V\) in tiles. We don't have any algorithmic restriction on the number of \(Q\) tiles processed. Therefore, each job will load just one \(Q\) tile and work with it only β€” this way we will maximize job parallelism.

Jobs data management | Image by the author
Kernel jobs data management | Image by the author

In summary, each job will load a single \(Q\) tile, iterate over all tiles in \(K\) and \(V\), and store one tile of result corresponding to the \(Q\) tile.

The Kernel

What's left is to write the actual code. Let's focus on the core part first, and only then we'll add Triton-specific boilerplates.

Below is a Triton pseudocode with every line explained.

def self_attn_fwd(...):
    # loading sample len
    seq_len = ...

    # running qk^T max (initialized by -inf)
    m_i = tl.zeros([TILE_Q_SIZE], dtype=tl.float32) - float("inf")

    # current softmax denominator
    l_i = tl.zeros([TILE_Q_SIZE], dtype=tl.float32)

    # result tile 
    # we will accumulate here (softmax numerator) @ V
    # then, we will divide it by softmax denominator in the very end
    acc = tl.zeros([TILE_Q_SIZE, HEAD_DIM], dtype=tl.float32)

    # notice: we accumulate all values above
    # in fp32 for higher precision

    # account for variable length of samples in batch
    q_tile_indices = q_token_idx + tl.arange(0, TILE_Q_SIZE)
    q_lens_mask = (
        q_tile_indices[:, None] < seq_len
    )

    # loading q tile into SRAM, shape (TILE_Q_SIZE, HEAD_DIM)
    q_tile = ... 

    # softmax scale, multiplying by log_2(e) 
    # to use faster exp2(...) instead of exp(...)
    softmax_scale: tl.constexpr = tl.cast(SM_SCALE * log_2(e), q_tile.dtype)

    # indices of tokens inside kv tile 
    tile_k_arange = tl.arange(0, TILE_K_SIZE)

    # iterate over all tiles in k, v
    for kv_tile_idx in tl.range(
        0, tl.cdiv(seq_len, TILE_K_SIZE), num_stages=PIPELINING
    ):
        # index of the first token in the kv tile
        kv_token_idx = kv_tile_idx * TILE_K_SIZE

        kt_tile = ... # load into SRAM K^T tile no. kv_tile_idx
        v_tile = ... # load into SRAM V tile no. kv_tile_idx

        # compute tile of QK^T
        qk = tl.dot(
            q_tile * softmax_scale,
            kt_tile,
            input_precision=INPUT_PRECISION,
            out_dtype=tl.float32
        )

        # masking out kv tokens after the sequence length
        kv_indices = kv_token_idx + tile_k_arange
        mask = q_lens_mask & (
            kv_indices[None, :] < seq_len
        )

        # set masked out values to -inf
        # for softmax to ignore them
        qk = tl.where(mask, qk, tl.cast(-float("inf"), qk.dtype))

        # calculating new maximum over seq len
        # m(x) = m(m(x1), m(x2))
        m_ij = tl.maximum(m_i, tl.max(qk, 1))

        # e^(x2 - m(x))
        p = tl.math.exp2(qk - m_ij[:, None])

        # current tile softmax denominator
        l_ij = tl.sum(p, 1)

        # from softmax formula: e^(m(x1) - m(x))
        alpha = tl.math.exp2(m_i - m_ij)

        # updating denominator using the formula
        # l(x) = e^(m(x1) - m(x)) * l(x1) + e^(0)l(x2)
        # notice: e^(0) as we subtract m(x) from x2 above
        l_i = l_i * alpha + l_ij
        
        # update previous acc to address maximum change
        # as e^(xi - m(x1)) * alpha = e^(xi - m(x))
        acc = acc * alpha[:, None]

        # multiply p by v and adding to acc
        acc += tl.dot(
            p.to(v_tile.dtype),
            v_tile,
            input_precision=INPUT_PRECISION,
            out_dtype=tl.float32,
        )

        # storing new maximum
        m_i = m_ij

    # finally incorporate softmax denominator
    acc = acc / l_i[:, None]

    # set fully masked token values to 0 to avoid garbage values
    # in the result
    acc = tl.where(q_lens_mask, acc, 0.0)

    # save the result
    tl.save(acc, ...) 

See? Easy!

What's important is that you can see how simple it is to write such a thing as soon as we understand the idea of tiled softmax. Apart from that, there's nothing complicated from the algorithm perspective.

πŸ’₯
This kernel can be made even faster by implementing triton optimizations. However, this is out of the scope of this article.

This pseudocode is pretty close to the actual code. You may find it in my GitHub by following the link. All that I added is just data management and PyTorch wrappers.

kernels/src/self_attention/kernel.py at main Β· alexdremov/kernels
Collection of useful kernels. Contribute to alexdremov/kernels development by creating an account on GitHub.
❗
Don't hesitate to ask if something isn't clear. I'm here in the comments 😁.

The code above was extensively tested to match PyTorch's scaled_dot_product_attention. You can also check out the tests to see how to use the written kernel.

Benchmarking

While we wrote the kernel in Triton to improve the algorithm understanding, it's interesting to compare the performance with a naive implementation and PyTorch's scaled_dot_product_attention.

Benchmarking implementations for different sequence lengths | Image by the author

As expected, the Flash Attention algorithm completely outperforms the naive implementation performance-wise. Also, I've marked with a dashed line the range of lengths for which the naive implementation causes a CUDA out-of-memory error.

We see that our Triton implementation is slightly worse than PyTorch SDPA. But the difference is not too large Considering the fact that PyTorch SDPA is a well-optimized CUDA kernel, that's a nice result.

Benchmarking code is also available in the repository.

kernels/benchmark/benchmark_self_attention.py at main Β· alexdremov/kernels
Collection of useful kernels. Contribute to alexdremov/kernels development by creating an account on GitHub.

Conclusions

In the post, I covered the motivation of the Flash Attention algorithm as well as its algorithm details. Finally, we were able to implement it from scratch in Triton, reproducing the speedups from the paper.

I hope this post improved your understanding of Flash Attention. Feel free to leave a comment below if you have any questions.

References

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
Transformers are slow and memory-hungry on long sequences, since the time and memory complexity of self-attention are quadratic in sequence length. Approximate attention methods have attempted to address this problem by trading off model quality to reduce the compute complexity, but often do not achieve wall-clock speedup. We argue that a missing principle is making attention algorithms IO-aware -- accounting for reads and writes between levels of GPU memory. We propose FlashAttention, an IO-aware exact attention algorithm that uses tiling to reduce the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM. We analyze the IO complexity of FlashAttention, showing that it requires fewer HBM accesses than standard attention, and is optimal for a range of SRAM sizes. We also extend FlashAttention to block-sparse attention, yielding an approximate attention algorithm that is faster than any existing approximate attention method. FlashAttention trains Transformers faster than existing baselines: 15% end-to-end wall-clock speedup on BERT-large (seq. length 512) compared to the MLPerf 1.1 training speed record, 3$\times$ speedup on GPT-2 (seq. length 1K), and 2.4$\times$ speedup on long-range arena (seq. length 1K-4K). FlashAttention and block-sparse FlashAttention enable longer context in Transformers, yielding higher quality models (0.7 better perplexity on GPT-2 and 6.4 points of lift on long-document classification) and entirely new capabilities: the first Transformers to achieve better-than-chance performance on the Path-X challenge (seq. length 16K, 61.4% accuracy) and Path-256 (seq. length 64K, 63.1% accuracy).
Tutorials β€” Triton documentation
GitHub - alexdremov/kernels: Collection of useful kernels
Collection of useful kernels. Contribute to alexdremov/kernels development by creating an account on GitHub.

Share

Subscribe to Alex Dremov

Get the email newsletter and receive valuable tips to bump up your professional skills