Skip to content

Commit

Permalink
[FlashAttention] Update flash-attention-mma 0.0.1 🎉 (#159)
Browse files Browse the repository at this point in the history
* Update flash_attn_mma_stage.cu

* Update flash_attn_mma_tiling.cu

* Update README.md

* Update README.md

* Update README.md
  • Loading branch information
DefTruth authored Dec 12, 2024
1 parent b1b923a commit 81404c1
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 18 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's d
|Collective Store (Warp Shfl)|Row Major (NN)|Col Major (TN)| SGEMM FP32/TF32|
|✔️|✔️|✔️|✔️|

I have also implemented **FlashAttention-2** using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp and Collective Store. Performance is continuously being optimized. Stay tuned for updates ~ Please refer to [flash-atttention-mma⚡️⚡️](./kernels/flash-attn) for more details.
I have also implemented **FlashAttention-2** using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp and Collective Store. Performance is continuously being optimized. Stay tuned for updates ~ Please refer to [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for more details.

![flash-attn-mma](https://github.com/user-attachments/assets/3e20fdaa-9b31-4dcd-91d5-204905842dce)

Expand Down
8 changes: 0 additions & 8 deletions kernels/flash-attn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,8 @@

## 📖 说明

包含以下内容:(性能持续优化中,敬请期待...)

- [X] flash_attn_cuda_kernel (F32)
- [x] flash_attn_mma_naive_kernel (ldmatrix + MMA)
- [X] flash_attn_mma_stage_kernel (ldmatrix + MMA, Stages, Tile MMA/Warp, Copy Async, Collective Store, SMEM Padding)

本仓库FlashAttention仅用于学习CUDA编程,考虑性能最优请使用FlashAttention官方版本:[flash-attention](https://github.com/Dao-AILab/flash-attention)

## 📖 Kernel 调用
- flash_attn_mma_stage_kernel:
```C++
template<
const int kHeadDim, // Headdim, 32,64,128
Expand Down
2 changes: 1 addition & 1 deletion kernels/flash-attn/mma/flash_attn_mma_stage.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using namespace nvcuda;
// The output is O, a 4D tensor with shape [batch_size, num_heads, seq_len, head_dim].

// The FlashAttention-2 algorithm is described in the following paper:
// https://arxiv.org/abs/2110.08210
// https://arxiv.org/pdf/2307.08691

// Q,K,V,O: [batch_size, num_heads, seq_len, head_dim], [B,H,N,d]
// each block processes Q_tile with shape [Br,d] and full K,V with shape [N,d]
Expand Down
10 changes: 2 additions & 8 deletions kernels/flash-attn/mma/flash_attn_mma_tiling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using namespace nvcuda;
// The output is O, a 4D tensor with shape [batch_size, num_heads, seq_len, head_dim].

// The FlashAttention-2 algorithm is described in the following paper:
// https://arxiv.org/abs/2110.08210
// https://arxiv.org/pdf/2307.08691

// Q,K,V,O: [batch_size, num_heads, seq_len, head_dim], [B,H,N,d]
// each block processes Q_tile with shape [Br,d] and full K,V with shape [N,d]
Expand Down Expand Up @@ -609,14 +609,8 @@ flash_attn_mma_stages_kernel(half* Q,
// 15 T28: {a2, a3} T29: {a2, a3} T30: {a2, a3} T31: {a2, a3} T28: {a6, a7} T29: {a6, a7} T30: {a6, a7} T31: {a6, a7}
#pragma unroll
for (int i = 0; i < kWarpTileSeqLenP; ++i) { // kWarpTileSeqLenQ=2
FA_MMA_CHECK_PRINT_REG(R_S[i][0][0], R_Q[i][0], "Check failed, R_S[%d][0][0], R_Q[%d][0], tile_V_Bc: %d, tid: %d, lane: %d", i, i, tile_V_Bc, tid, lane_id);
FA_MMA_CHECK_PRINT_REG(R_S[i][0][1], R_Q[i][1], "Check failed, R_S[%d][0][1], R_Q[%d][1], tile_V_Bc: %d, tid: %d, lane: %d", i, i, tile_V_Bc, tid, lane_id);
FA_MMA_CHECK_PRINT_REG(R_S[i][1][0], R_Q[i][2], "Check failed, R_S[%d][1][0], R_Q[%d][2], tile_V_Bc: %d, tid: %d, lane: %d", i, i, tile_V_Bc, tid, lane_id);
FA_MMA_CHECK_PRINT_REG(R_S[i][1][1], R_Q[i][3], "Check failed, R_S[%d][1][1], R_Q[%d][3], tile_V_Bc: %d, tid: %d, lane: %d", i, i, tile_V_Bc, tid, lane_id);
#pragma unroll
for (int j = 0; j < kWarpTileHeadDimV; ++j) { // kWarpTileHeadDimV=1,2,3,4,...
FA_MMA_PRINT_REG(R_V[j][0], "[Before] MMA P@V, R_V[%d][0], tile_V_Bc: %d, tid: %d, lane: %d", j, tile_V_Bc, tid, lane_id);
FA_MMA_PRINT_REG(R_V[j][1], "[Before] MMA P@V, R_V[%d][1], tile_V_Bc: %d, tid: %d, lane: %d", j, tile_V_Bc, tid, lane_id);
HMMA16816(R_O[i][j][0], R_O[i][j][1],
// FIXME(DefTruth): Still have some error while using R_S
// as registers for P(A) matrix directly. I will remove this
Expand Down Expand Up @@ -795,7 +789,7 @@ void launch_flash_attn_mma_stages(
constexpr int kWarpTileHeadDimV = (kHeadDim / (kMmaAtomN*kMmaTileHeadDimV));
constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*2*2=64
constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*4*2=64
constexpr int kPad = 0;
constexpr int kPad = 8;

// Calculate SRAM size needed per block, Q,K,V,S smem size
const int smem_max_size = ((Br * (kHeadDim + kPad)) +
Expand Down

0 comments on commit 81404c1

Please sign in to comment.