Designing Hardware-Aware Algorithms: FlashAttention
If you weren’t already aware, the T in chatGPT stands for Transformer, the linchpin architecture for developing state-of-the-art AI. Initially developed for machine translation, the Transformer is a neural network architecture. It introduced self-attention in the paper: Attention is all you need. Through layers of interconnected nodes that map out an internal mathematical representation identifying relationships and relevance, an input sequence is transformed into an output sequence.
If Attention is All You Need, Let’s Make it Better…
The advent of the Transformer architecture ushered in a new era of AI research. The research is focused on increasing the efficiency of its core mechanism, attention. Attention’s scalability is compromised by its time and memory complexity that scales quadratically or O(n^2) with sequence length, n. This is rather troublesome as efficiently modeling long sequences is incredibly important for capturing the long-range dependencies. These are required to model lengthy texts, codebases, high-resolution images, etc. To handle this, many researchers have been working on hardware-aware and memory-efficient algorithms, such as FlashAttention.
Introduction
The goal of this article is to highlight the concepts that made FlashAttention (2022) successful in achieving wall-clock speedup over the standard attention mechanism. The techniques leveraged in its second (2023) and third (2024) iterations will be covered in subsequent blog posts.
Prerequisites
Familiarity with the following will help with understanding the topics presented in this article:
- The transformer and the attention mechanism
- Matrix multiplication
- Softmax operation
- Forward propagation/ Backward propagation (AKA forward pass, backward pass)
- The GPU memory hierarchy
- GPU performance optimization
- CUDA programming concepts (thread blocks, warps, kernels)
- Floating point formats (FP16, BF16, FP8)
Designing Hardware-Aware And Memory-Efficient Algorithms
Modern accelerators, such as Hopper and Ampere GPUs, have an abundance of floating-point operations per second. Or FLOPS, a metric telling of a device’s theoretical computational power. However, these same accelerators are limited by memory bandwidth. This is the rate at which data can be transferred between the GPU’s memory and its processing units. With this in mind, designing hardware-aware and memory-efficient algorithms for GPUs would require strategic consideration of how to best leverage the memory hierarchy and use as much of the theoretical maximum FLOPS possible.
FlashAttention is an excellent example of a hardware-aware and memory-efficient algorithm that enables longer context in Transformers by optimizing the attention mechanism for the hardware it’s computed on.
FlashAttention (2022)
The FlashAttention is introduced as an “IO-aware exact attention algorithm that uses tiling to reduce the number of reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM.”
GPU Memory: HBM & SRAM
The terminology surrounding GPU memory types can be confusing, with numerous terms often describing identical or overlapping concepts. FlashAttention involves two memory types HBM and SRAM.
| Memory | AKA | Key characteristics | 
|---|---|---|
| HBM (High Bandwidth Memory) | GPU memory, global memory | Slow, larger memory capacity | 
| SRAM (Static Random-Access Memory) | L1 cache, shared memory | Fast, smaller memory capacity, on-chip | 
GPU Compute Model
Diagram from Aleksa Gordić’s YouTube video featuring FlashAttention author Tri Dao: Streaming multiprocessors (2) are in blue and contain compute units and SRAM. Global memory accesses to and from HBM are slow and should be minimized if possible.
Computing Attention

The Attention Line-up
Here’s a refresher of the variables involved in calculating the self-attention layer of the transformer.
- Query (Q): The query vector is the current input or element for which attention will be computed. The vector is part of a query matrix of size Nxd where N is the sequence length on the order of 1K-8K and d is the head dimension of length 64-128.
- Key (K): The key matrix is of the same dimensions as the query matrix. The key vectors are multiplied by the query vectors to calculate the similarity score.
- Similarity Score (S): The similarity score is a measure of how similar the query is to each element in the sequence. By multiplying the query matrix with the transposed key matrix, a NxN matrix of similarity scores is produced.
- Attention Probability (P in algorithm, A in diagram): The Attention Probability is a probability distribution computed by applying the softmax operation to the similarity scores, S. The softmax function normalizes the similarity scores, ensuring they are positive and sum up to 1.
- Value (V): The value vectors of the Nxd value matrix contains information about each element in a sequence and is multiplied by the attention probabilities to produce an Nxd output. 
Standard Attention Algorithm
Attention algorithm as depicted in FlashAttention paper. In step 1, Q and K matrices are loaded into HBM to compute S. In step 2, S is read from HBM to have softmax applied to it, which is then written as P to HBM. This step takes the longest.
From Aleksa Gordić’s YouTube video featuring FlashAttention author Tri Dao: The diagram explains how reading and writing the intermediate matrices (S and A/P) is the main bottleneck when computing attention. Note that A in this diagram is the same thing as P in the algorithm above.
FlashAttention is IO-aware
Now that we’ve established that the standard attention implementation lacks IO-awareness with its redundant reads and writes from slow GPU memory (HBM), let’s discuss the hurdles FlashAttention had to overcome to achieve IO-awareness.
Kernel Fusion
FlashAttention boosts performance by fusing the attention computation into a single CUDA kernel. While kernel fusion may seem straightforward, the FlashAttention algorithm had to be carefully designed to ensure that the on-chip memory does not exceed hardware limits.
Tiling
Tiling is a technique that involves partitioning data into smaller blocks, or “tiles”, that can fit into on-chip memory. Memory bandwidth requirements are reduced with tiling-assisted kernel fusion since data is transferred from global memory to the streaming multiprocessors only once per tile.
Tiling is particularly effective for associative operations like matrix multiplication. This property allows the computation to be reordered without affecting the final result, enabling efficient processing of smaller tiles. The softmax operation in self-attention, however, is not associative, meaning the order of the computations do matter.
Making Softmax Associative
Leveraging the online softmax trick to make softmax associative is arguably the key innovation of FlashAttention.
To incrementally perform softmax reduction, the attention computation is restructured as indicated by the figure. The inputs Q, K, V are split into blocks. Instead of materializing the intermediate matrices (S,A/P) in HBM, they are computed in SRAM. The output is rescaled to the correct denominator (normalization factor) before adding them up at the end to give us the same result as the standard attention implementation.
Recomputation in the Backward Pass
Redundant read/write operations are omitted by not storing the intermediate S and A/P matrices and instead recomputing them in the backward pass. This is done by storing the output O and softmax normalization statistics (m, l) to recompute the intermediate S and A/P matrices in the backward pass from the Q, K, V blocks in SRAM.
 Conclusion
Conclusion
By cleverly reordering the attention computation with classical techniques like tiling and recomputation to exploit the asymmetric GPU memory hierarchy, FlashAttention sped up the attention mechanism and reduced memory usage from quadratic to linear in sequence length. This algorithm does an excellent job of demonstrating both the art and effectiveness of designing hardware-aware algorithms.






 Conclusion
Conclusion