Skip to content

Commit

Permalink
[FA2] Release flash-attn-mma shared-kv/qkv🎉 (#162)
Browse files Browse the repository at this point in the history
* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update flash_attn_mma_share_kv.cu

* Create flash_attn_mma_share_qkv.cu

* Update flash_attn_mma_split_q.cu

* Update flash_attn_mma_split_q.cu

* Update flash_attn_mma_split_kv.cu

* Update flash_attn.cc

* Update flash_attn_mma.py

* Update flash_attn_mma.py

* Update flash_attn_mma.py

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update flash_attn_mma_share_qkv.cu
  • Loading branch information
DefTruth authored Dec 17, 2024
1 parent 6ff6490 commit 0c6785f
Show file tree
Hide file tree
Showing 8 changed files with 1,280 additions and 315 deletions.
101 changes: 62 additions & 39 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,49 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's d
|Collective Store (Warp Shfl)|Row Major (NN)|Col Major (TN)| SGEMM FP32/TF32|
|✔️|✔️|✔️|✔️|

I have also implemented **FlashAttention-2** using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp and Collective Store. Performance is continuously being optimized. Stay tuned for updates ~ Please refer to [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for more details.

I have also implemented **FlashAttention-2** using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp, Fully Sahred QKV SMEM, Prefetch Q s2r, Collective Store, etc. Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2 on some Devices, for example, NVIDIA RTX 3080 Laptop.

![flash-attn-mma](https://github.com/user-attachments/assets/6f66796d-44d5-4ec1-b224-af997bd152b2)

- Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop)
```bash
python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 # NVIDIA RTX 3080 Laptop
------------------------------------------------------------------------------------------------------------------------
B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 1617, Warmup: 1, Iters: 10
------------------------------------------------------------------------------------------------------------------------
B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10
mma(split-kv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:5.586338ms, TFLOPS:25.08
mma(split-kv+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:5.326223ms, TFLOPS:26.31
mma(split-q+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:3.834152ms, TFLOPS:36.54
mma(split-q+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:4.328346ms, TFLOPS:32.37
mma(split-q+share-kv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.636528ms, TFLOPS:53.15
mma(split-q+share-qkv+stage1): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.594471ms, TFLOPS:54.01
mma(split-q+share-qkv+stage2): ['0.01960754 ', '0.01452637 ', '-0.02592468 '], time:2.574611ms, TFLOPS:54.42
(flash): ['0.01963806 ', '0.0145874 ', '-0.02593994 '], time:3.764462ms, TFLOPS:37.22
-----------------------------------------------------------------------------------------------------------------------
```

However, for large-scale attention computations, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ Please refer to [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for more details.

|CUDA Cores|Sliced K (Loop over N/D)|Tile Block (Br, Bc, Bd)|MMA (m16n8k16)|
|Tensor Cores|Loop over Seqlen/Headdim |Tile Block (Br, Bc)|MMA (m16n8k16)|
|:---:|:---:|:---:|:---:|
|✔️|✔️|✔️|✔️|
|Pack LDST (128 bits)|SMEM Padding|Copy Async |Tile MMAs (More Threads)
|Pack LDST (128 bits)|SMEM Padding|Copy Async|Tile MMA (More Threads)
|✔️|✔️|✔️|✔️|
|Tile Warps (More Values)|Multi Stages (1/2)| Collective Store (Shfl)| **Split KV/Q** |
|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Shfl)|**Split KV/Q**|
|✔️|✔️|✔️|✔️|
|**Shared KV** SMEM|Fully **Shared QKV** SMEM|**Prefetch Q** s2r|SMEM/Block Swizzle|
|✔️|✔️|✔️|?|

The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps), is slower than `Split Q` policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).
The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps), is slower than `Split Q` policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).

![flash-attn](https://github.com/user-attachments/assets/11490fbc-2a4a-4630-abe8-91a9d1251cba)

<!--
![flash-attn](https://github.com/user-attachments/assets/11490fbc-2a4a-4630-abe8-91a9d1251cba)
-->
- 📚 Split KV (Basic, FlashAttention-1)
<div id="mma-split-kv"></div>

```C++
// Split QKV across MMA(Warps) using naive matmul MMA&Warp tiling policy.
Expand All @@ -69,22 +94,6 @@ The `Split KV` and `Split Q` implementations have been carried out in [flash-att
// | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --|
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
// | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --|
template<
const int kHeadDim, // Headdim, 32,64,128
const int kMmaAtomM, // MMA Atom M, 16
const int kMmaAtomN, // MMA Atom N, 8
const int kMmaAtomK, // MMA Atom K, 16
const int kMmaTileSeqLenQ, // 2, more MMA(warp), M=16*2=32, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
const int kMmaTileSeqLenK, // 4, more MMA(warp), N=8*4= 32, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
const int kMmaTileSeqLenP, // 2, more MMA(warp), M=16*2=32, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
const int kMmaTileHeadDimV, // 4, more MMA(warp), N=8*4= 32, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
const int kWarpTileSeqLenQ, // 2, more values, M, Br=32*2=64, matmul M
const int kWarpTileSeqLenK, // 2, more values, N, Bc=32*2=64, matmul N
const int kWarpTileSeqLenP, // 2, more values, M, Br=32*2=64, matmul M
const int kWarpTileHeadDimV, // 2, more values, N, d=32*(1|2|3|4|...)=32|64|96|128|...
const int kStage, // only support 1 or 2 now.
const int kPad // 0,8
>
__global__ void
flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
half* K, // [B, H, D, N] K^T transposed
Expand All @@ -94,6 +103,7 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
```
- 📚 Split Q (Faster, FlashAttention-2)
<div id="mma-split-q"></div>
```C++
// Split Q across MMA(Warps) and keep access KV for all MMA(Warps),
Expand All @@ -104,29 +114,40 @@ flash_attn_mma_stages_split_kv_kernel(half* Q, // [B, H, N, D]
// | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
// | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
// | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
template<
const int kHeadDim, // Headdim, 32,64,128
const int kMmaAtomM, // MMA Atom M, 16
const int kMmaAtomN, // MMA Atom N, 8
const int kMmaAtomK, // MMA Atom K, 16
const int kMmaTileSeqLenQ, // 4, more MMA(warp), M=16*4=64, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
const int kMmaTileSeqLenK, // 1, more MMA(warp), N=8*1 =8, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
const int kMmaTileSeqLenP, // 4, more MMA(warp), M=16*4=64, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
const int kMmaTileHeadDimV, // 1, more MMA(warp), N=8*1 =8, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
const int kWarpTileSeqLenQ, // 1, more values, M, Br=64*1=64, matmul M
const int kWarpTileSeqLenK, // 8, more values, N, Bc=8*8 =64, matmul N
const int kWarpTileSeqLenP, // 1, more values, M, Br=64*1=64, matmul M
const int kWarpTileHeadDimV, // 8, more values, N, d=8*(1|2|3|4|...)=8|...|32|64|96|128|...
const int kStage, // only support 1 or 2 now.
const int kPad // 0,8
>
__global__ void
flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]
half* K, // [B, H, D, N] K^T transposed
half* V, // [B, H, N, D]
half* O, // [B, H, N, D]
int QKV_seqlen);
```

- 📚 Split Q + Shared KV SMEM (Faster+)
<div id="mma-share-kv"></div>

```C++
// K, V shared the same shared memory, improve block occupancy.
__global__ void
flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
half* K,
half* V,
half* O,
int QKV_seqlen);
```
- 📚 Split Q + Fully Shared QKV SMEM (Faster++)
<div id="mma-share-qkv"></div>
```C++
// Q, K, V fully shared the same shared memory and prefetch Q s2r, improve block occupancy.
__global__ void
flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q,
half* K,
half* V,
half* O,
int QKV_seqlen);
```

## ©️Citations🎉🎉

```BibTeX
Expand All @@ -144,11 +165,13 @@ flash_attn_mma_stages_split_q_kernel(half* Q, // [B, H, N, D]

<div id="cuda-kernel"></div>

|📖 CUDA Kernel| 📖 Elem dtype| 📖 Acc dtype| 📖 Docs | 📖 Level |
|📖 CUDA Kernel| 📖 Elem DType| 📖 Acc DType| 📖 Docs | 📖 Level |
|:---|:---|:---|:---|:---|
| ✔️ [nsys/ncu(timeline/ptx/sass)](./kernels/nvidia-nsight/)|/|/|[link](./kernels/nvidia-nsight/)|⭐️|
| ✔️ [flash_attn_mma_stages_split_kv*](./kernels/flash-attn/mma/flash_attn_mma_split_kv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
| ✔️ [flash_attn_mma_stages_split_q*](./kernels/flash-attn/mma/flash_attn_mma_split_q.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
| ✔️ [flash_attn_mma_stages...shared_kv*](./kernels/flash-attn/mma/flash_attn_mma_share_kv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
| ✔️ [flash_attn_mma_stages...shared_qkv*](./kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
| ✔️ [sgemm_naive_f32](./kernels/sgemm/sgemm.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️|
| ✔️ [sgemm_sliced_k_f32](./kernels/sgemm/sgemm.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_t_8x8_sliced_k_f32x4](./kernels/sgemm/sgemm.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
Expand Down
Loading

0 comments on commit 0c6785f

Please sign in to comment.