From b1b923a7ec713f65c15d87fdceb9f64fc7109c9b Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Thu, 12 Dec 2024 11:28:33 +0800 Subject: [PATCH] =?UTF-8?q?[FlashAttention]=20Release=20flash-atttention-m?= =?UTF-8?q?ma=200.0.1=20=F0=9F=8E=89=20(#158)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update makefile * Update .gitignore * Update hgemm_mma_stage.cu * Create flash_attn_mma.py * Delete kernels/flash-attn/flash_attn.py * Update README.md * Update hgemm_mma_stage.cu * Update hgemm_mma_stage_tn_cute.cu * Update README.md * Update README.md * Update README.md * Update hgemm_mma_stage.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Create flexiable_flash_attn_mma.cu * Create flash_qattn_mma.cu * Create flexiable_flash_qattn_mma.cu * Delete kernels/flash-attn/mma/flash_attn_mma_fp8.cu * Delete kernels/flash-attn/cutlass/flash_attn_cute_fp8.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn.cc * Update flash_attn_cuda.cu * Update flash_attn_mma_old.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.py * Update flash_attn_mma.py * add more tests * add more tests * add more tests * add more tests * add more tests * add more tests * add more tests * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Create custom_mma_utils.h * Update custom_mma_utils.h * Update flash_attn_mma.cu * Update flash_attn_mma.cu * Update custom_mma_utils.h * Update flash_attn_mma.py * Update flash_attn_mma.py * Update flash_attn_mma.cu * Update custom_mma_utils.h * Update flash_attn_mma.cu * Update flash_attn_mma.py * Delete kernels/flash-attn/mma/custom_mma_utils.h * Delete kernels/flash-attn/mma/flexiable_flash_qattn_mma.cu * Delete kernels/flash-attn/mma/flexiable_flash_attn_mma.cu * Delete kernels/flash-attn/mma/flash_qattn_mma.cu * Delete kernels/flash-attn/mma/flash_attn_mma_old.cu * Delete kernels/flash-attn/mma/flash_attn_mma_bak.cu * Delete kernels/flash-attn/mma/flash_attn_mma.cu * Create flash_attn_mma_naive.cu * Create flash_attn_mma_stage.cu * Create flash_attn_mma_tiling.cu * Update utils.h * Update flash_attn_cuda.cu * Update flash_attn_mma.py * Update flash_attn_mma.py * Update flash_attn_mma.py * Update flash_attn_mma_stage.cu * Update flash_attn_mma.py * Update flash_attn_mma.py * Update flash_attn_mma.py * Update flash_attn_mma_stage.cu * Update flash_attn_mma_tiling.cu * Update README.md * Update flash_attn_mma_naive.cu * Update README.md * Update flash_attn_mma.py * Update flash_attn_mma.py * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update flash_attn_mma_stage.cu * Update README.md * Update README.md * Update flash_attn_mma_stage.cu * Update README.md * Update README.md * Update README.md --- .github/workflows/.gitignore | 3 + .gitignore | 3 + .gitmodules | 4 +- LICENSE | 2 - README.md | 17 +- kernels/flash-attn/README.md | 457 ++++++--- .../flash-attn/cutlass/flash_attn_cute_fp8.cu | 0 kernels/flash-attn/flash_attn.py | 92 -- kernels/flash-attn/flash_attn_mma.py | 221 +++++ kernels/flash-attn/mma/flash_attn_mma.cu | 81 -- kernels/flash-attn/mma/flash_attn_mma_fp8.cu | 0 ...ttn_mma_old.cu => flash_attn_mma_naive.cu} | 113 ++- .../flash-attn/mma/flash_attn_mma_stage.cu | 895 +++++++++++++++++ .../flash-attn/mma/flash_attn_mma_tiling.cu | 901 ++++++++++++++++++ kernels/flash-attn/naive/flash_attn_cuda.cu | 6 +- kernels/flash-attn/pybind/flash_attn.cc | 10 +- kernels/flash-attn/utils/utils.h | 208 ++++ kernels/hgemm/.gitignore | 8 +- kernels/hgemm/README.md | 7 +- .../hgemm/cutlass/hgemm_mma_stage_tn_cute.cu | 17 +- kernels/hgemm/makefile | 2 +- kernels/hgemm/mma/hgemm_mma_stage.cu | 10 +- others/.gitignore | 3 + slides/.gitignore | 3 + third-party/.gitignore | 18 +- 25 files changed, 2724 insertions(+), 357 deletions(-) delete mode 100644 kernels/flash-attn/cutlass/flash_attn_cute_fp8.cu delete mode 100644 kernels/flash-attn/flash_attn.py create mode 100644 kernels/flash-attn/flash_attn_mma.py delete mode 100644 kernels/flash-attn/mma/flash_attn_mma.cu delete mode 100644 kernels/flash-attn/mma/flash_attn_mma_fp8.cu rename kernels/flash-attn/mma/{flash_attn_mma_old.cu => flash_attn_mma_naive.cu} (76%) create mode 100644 kernels/flash-attn/mma/flash_attn_mma_stage.cu create mode 100644 kernels/flash-attn/mma/flash_attn_mma_tiling.cu diff --git a/.github/workflows/.gitignore b/.github/workflows/.gitignore index 171541b1..65f67af8 100644 --- a/.github/workflows/.gitignore +++ b/.github/workflows/.gitignore @@ -19,3 +19,6 @@ __pycache__ *.bin outupt bin +*.log +*.txt +*.tex \ No newline at end of file diff --git a/.gitignore b/.gitignore index 171541b1..65f67af8 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,6 @@ __pycache__ *.bin outupt bin +*.log +*.txt +*.tex \ No newline at end of file diff --git a/.gitmodules b/.gitmodules index 9df792ca..bd48fab6 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,4 @@ [submodule "third-party/cutlass"] path = third-party/cutlass url = https://github.com/NVIDIA/cutlass.git - tag = v3.5.1 - - + tag = v3.5.1 \ No newline at end of file diff --git a/LICENSE b/LICENSE index c154efd7..f288702d 100644 --- a/LICENSE +++ b/LICENSE @@ -672,5 +672,3 @@ may consider it more useful to permit linking proprietary applications with the library. If this is what you want to do, use the GNU Lesser General Public License instead of this License. But first, please read . - - diff --git a/README.md b/README.md index e9396eb2..68b26dff 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,18 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's d |Collective Store (Warp Shfl)|Row Major (NN)|Col Major (TN)| SGEMM FP32/TF32| |✔️|✔️|✔️|✔️| +I have also implemented **FlashAttention-2** using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp and Collective Store. Performance is continuously being optimized. Stay tuned for updates ~ Please refer to [flash-atttention-mma⚡️⚡️](./kernels/flash-attn) for more details. + +![flash-attn-mma](https://github.com/user-attachments/assets/3e20fdaa-9b31-4dcd-91d5-204905842dce) + +|CUDA Cores|Sliced K (Loop over N/D)|Tile Block (Br, Bc, Bd)|MMA (m16n8k16)| +|:---:|:---:|:---:|:---:| +|✔️|✔️|✔️|✔️| +|Pack LDST (128 bits)|SMEM Padding|Copy Async |Tile MMA (More Threads) +|✔️|✔️|✔️|✔️| +|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Shfl)|Row Major (NN)| +|✔️|✔️|✔️|✔️| + ## ©️Citations🎉🎉 ```BibTeX @@ -198,8 +210,9 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's d | ✔️ [hgemv_k32_f16](./kernels/hgemv/hgemv.cu)|f16|f16|[link](./kernels/hgemv/)|⭐️⭐️⭐️| | ✔️ [hgemv_k128_f16x4](./kernels/hgemv/hgemv.cu)|f16|f16|[link](./kernels/hgemv/)|⭐️⭐️⭐️| | ✔️ [hgemv_k16_f16](./kernels/hgemv/hgemv.cu)|f16|f16|[link](./kernels/hgemv/)|⭐️⭐️⭐️| -| ✔️ [flash_attn_f32](./kernels/flash-attn/flash_attn.cu)|f32|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️| -| ✔️ [flash_attn_mma_m16n8k16*](./kernels/flash-attn/flash_attn_mma.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️| +| ✔️ [flash_attn_cuda](./kernels/flash-attn/naive/flash_attn_cuda.cu)|f32|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️| +| ✔️ [flash_attn_mma_naive*](./kernels/flash-attn/mma/flash_attn_mma_naive.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️| +| ✔️ [flash_attn_mma_stage*](./kernels/flash-attn/mma/flash_attn_mma_stage.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️| | ✔️ [nms_f32](./kernels/nms/nms.cu)|f32|/|[link](./kernels/nms)|⭐️⭐️| | ✔️ [notes v1(deprecated)](./kernels/notes-v1.cu)|f32|f32|/|⭐️| diff --git a/kernels/flash-attn/README.md b/kernels/flash-attn/README.md index 6a000e4d..692b3fdd 100644 --- a/kernels/flash-attn/README.md +++ b/kernels/flash-attn/README.md @@ -1,191 +1,432 @@ -# FlashAttention +## ⚡️⚡️FlashAttention-2 MMA: Write FlashAttention using Tensor Cores with pure MMA PTX -## 0x00 说明 +|CUDA Cores|Sliced K (Loop over N/D)|Tile Block (Br, Bc, Bd)|MMA (m16n8k16)| +|:---:|:---:|:---:|:---:| +|✔️|✔️|✔️|✔️| +|Pack LDST (pack 128 bits)|SMEM Padding|Copy Async (cp.async.cg/ca)|Tile MMA (More Threads) +|✔️|✔️|✔️|✔️| +|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Warp Shuffle & Reg Reuse)|Row Major (NN)| +|✔️|✔️|✔️|✔️| -包含以下内容: +## 📖 说明 -- [X] flash_attn_1_fwd_f32_kernel -- [x] flash_attn_2_fwd_f16_mma_m16n8k16_kernel (ldmatrix + MMA) -- [X] PyTorch bindings +包含以下内容:(性能持续优化中,敬请期待...) + +- [X] flash_attn_cuda_kernel (F32) +- [x] flash_attn_mma_naive_kernel (ldmatrix + MMA) +- [X] flash_attn_mma_stage_kernel (ldmatrix + MMA, Stages, Tile MMA/Warp, Copy Async, Collective Store, SMEM Padding) 本仓库FlashAttention仅用于学习CUDA编程,考虑性能最优请使用FlashAttention官方版本:[flash-attention](https://github.com/Dao-AILab/flash-attention) -### 运行测试 +## 📖 Kernel 调用 +- flash_attn_mma_stage_kernel: +```C++ +template< + const int kHeadDim, // Headdim, 32,64,128 + const int kMmaAtomM, // MMA Atom M, 16 + const int kMmaAtomN, // MMA Atom N, 8 + const int kMmaAtomK, // MMA Atom K, 16 + const int kMmaTileSeqLenQ, // 2, more MMA(warp), M=16*2=32, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] + const int kMmaTileSeqLenK, // 4, more MMA(warp), N=8*4= 32, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] + const int kMmaTileSeqLenP, // 2, more MMA(warp), M=16*2=32, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] + const int kMmaTileHeadDimV, // 4, more MMA(warp), N=8*4= 32, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] + const int kWarpTileSeqLenQ, // 2, more values, M, Br=32*2=64, matmul M + const int kWarpTileSeqLenK, // 2, more values, N, Bc=32*2=64, matmul N + const int kWarpTileSeqLenP, // 2, more values, M, Br=32*2=64, matmul M + const int kWarpTileHeadDimV, // 2, more values, N, d=32*(1|2|3|4|...)=32|64|96|128|... + const int kStage, // Multi-Stages, only support 1/2. + const int kPad // 0,8 + > +__global__ void __launch_bounds__( + WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK) +flash_attn_mma_stages_kernel(half* Q, + half* K, + half* V, + half* O, + int QKV_seqlen) { +} +``` + +## 📖 目前性能 + +目前,在小规模Attention(SeqLen<=4096)的情形,本仓库实现的flash-atttenion-mma基本持平或略优于FA官方的性能,在大规模Attention计算,仍然有较大的性能差距。性能持续优化中,敬请期待~ 性能测试示例如下:(NVIDIA L20) + +- B=2, H=2, N=4096, D=64 + ```bash -# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ... -export TORCH_CUDA_ARCH_LIST=Ada -python3 flash_attn.py +python3 flash_attn_mma.py --naive --B 2 --H 2 --D 64 --N 4096 +---------------------------------------------------------------------------------------------------- + B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 8942, Warmup: 2, Iters: 10 +---------------------------------------------------------------------------------------------------- + B=2, H=2, N=4096, D=64, Warmup: 2, Iters: 10 + naive(unfused): ['-0.03945923 ', '0.01776123 ', '0.02627563 '], time:1.318264ms + mma(naive): ['-0.03945923 ', '0.01774597 ', '0.02626038 '], time:9.853077ms + mma(stage1): ['-0.03945923 ', '0.01776123 ', '0.02624512 '], time:0.336719ms + mma(stage2): ['-0.03945923 ', '0.01776123 ', '0.02624512 '], time:0.304818ms + (flash): ['-0.03945923 ', '0.01776123 ', '0.02626038 '], time:0.328016ms +---------------------------------------------------------------------------------------------------- ``` -日志如下: + +- B=2, H=2, N=4096, D=128 + ```bash +python3 flash_attn_mma.py --naive --B 2 --H 2 --D 128 --N 4096 ---------------------------------------------------------------------------------------------------- - B: batch_size, H: n_head, N: seq_len, D: head_dim + B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 2806, Warmup: 2, Iters: 10 ---------------------------------------------------------------------------------------------------- - B=8, H=8, N=256, D=64 - out_FA1f32: ['0.01037013 ', '-0.09995531 ', '0.09193697 '], time:9.288564ms - out_f32_th(naive): ['0.01037012 ', '-0.09995528 ', '0.09193695 '], time:0.086453ms + B=2, H=2, N=4096, D=128, Warmup: 2, Iters: 10 + naive(unfused): ['0.00286484 ', '-0.00598907 ', '-0.02156067 '], time:1.377940ms + mma(naive): ['0.00284004 ', '-0.00598526 ', '-0.02157593 '], time:19.166064ms + mma(stage1): ['0.00284004 ', '-0.00598526 ', '-0.02156067 '], time:0.678110ms + mma(stage2): ['0.00284004 ', '-0.00598526 ', '-0.02156067 '], time:0.659609ms + (flash): ['0.0028553 ', '-0.00598145 ', '-0.02156067 '], time:0.548506ms ---------------------------------------------------------------------------------------------------- - out_FA2MMAf16: ['0.01031494 ', '-0.09997559 ', '0.09197998 '], time:0.047593ms - out_f16_th(naive): ['0.01040649 ', '-0.10003662 ', '0.09197998 '], time:0.053408ms +``` + +- B=2, H=2, N=1024, D=128 + +```bash +python3 flash_attn_mma.py --naive --B 2 --H 2 --D 128 --N 1024 ---------------------------------------------------------------------------------------------------- + B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 4166, Warmup: 2, Iters: 10 ---------------------------------------------------------------------------------------------------- - B=8, H=8, N=256, D=128 + B=2, H=2, N=1024, D=128, Warmup: 2, Iters: 10 + naive(unfused): ['-0.02110291 ', '0.04946899 ', '-0.04928589 '], time:0.145769ms + mma(naive): ['-0.02116394 ', '0.04946899 ', '-0.04946899 '], time:1.236653ms + mma(stage1): ['-0.02114868 ', '0.04943848 ', '-0.04943848 '], time:0.070930ms + mma(stage2): ['-0.02114868 ', '0.04943848 ', '-0.04943848 '], time:0.069165ms + (flash): ['-0.02113342 ', '0.04949951 ', '-0.04931641 '], time:0.151205ms ---------------------------------------------------------------------------------------------------- - out_FA2MMAf16: ['0.15332031 ', '0.15917969 ', '0.07592773 '], time:0.091217ms - out_f16_th(naive): ['0.15368652 ', '0.15905762 ', '0.07580566 '], time:0.052757ms +``` + +- B=2, H=2, N=8192, D=64 +```bash +python3 flash_attn_mma.py --naive --B 2 --H 2 --D 64 --N 8192 ---------------------------------------------------------------------------------------------------- + B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 434, Warmup: 2, Iters: 10 ---------------------------------------------------------------------------------------------------- - B=8, H=8, N=512, D=64 - out_FA1f32: ['0.01696955 ', '-0.05399467 ', '-0.03177956 '], time:37.062004ms - out_f32_th(naive): ['0.01696953 ', '-0.05399465 ', '-0.03177955 '], time:0.471001ms + B=2, H=2, N=8192, D=64, Warmup: 2, Iters: 10 + naive(unfused): ['-0.00259781 ', '-0.00584412 ', '-0.00161552 '], time:5.139947ms + mma(naive): ['-0.00258827 ', '-0.00583267 ', '-0.00162792 '], time:39.265347ms + mma(stage1): ['-0.00261307 ', '-0.00583267 ', '-0.00162888 '], time:1.131415ms + mma(stage2): ['-0.00261307 ', '-0.00583267 ', '-0.00162888 '], time:1.082253ms + (flash): ['-0.00259209 ', '-0.00584793 ', '-0.00160122 '], time:0.786042ms ---------------------------------------------------------------------------------------------------- - out_FA2MMAf16: ['0.01699829 ', '-0.0539856 ', '-0.0317688 '], time:0.168507ms - out_f16_th(naive): ['0.01699829 ', '-0.0539856 ', '-0.03173828 '], time:0.132778ms +``` + +## 📖 运行测试 +```bash +# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ... +pip install flash-attn +export TORCH_CUDA_ARCH_LIST=Ada +python3 flash_attn_mma.py +``` + +- NVIDIA L20 +```bash +python3 flash_attn_mma.py --naive --N 4096 --B 2 --H 2 --D 128 ---------------------------------------------------------------------------------------------------- + B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 762, Warmup: 2, Iters: 10 ---------------------------------------------------------------------------------------------------- - B=8, H=8, N=512, D=128 + B=2, H=2, N=4096, D=128, Warmup: 2, Iters: 10 + naive(unfused): ['0.06402588 ', '0.01030731 ', '0.02693176 '], time:1.380467ms + mma(naive): ['0.06408691 ', '0.01036835 ', '0.0269165 '], time:19.160128ms + mma(stage1): ['0.06390381 ', '0.01038361 ', '0.02685547 '], time:0.681663ms + mma(stage2): ['0.06390381 ', '0.01038361 ', '0.02685547 '], time:0.661945ms + (flash): ['0.06402588 ', '0.01029968 ', '0.02694702 '], time:0.550222ms ---------------------------------------------------------------------------------------------------- - out_FA2MMAf16: ['0.06872559 ', '-0.07714844 ', '0.04348755 '], time:0.326455ms - out_f16_th(naive): ['0.06872559 ', '-0.07720947 ', '0.04345703 '], time:0.152197ms +``` + +- 更多日志如下: +```bash ---------------------------------------------------------------------------------------------------- + B: batch_size, H: n_head, N: seq_len, D: head_dim, seed: 4791, Warmup: 2, Iters: 10 ---------------------------------------------------------------------------------------------------- - B=8, H=8, N=1024, D=64 - out_FA1f32: ['-0.04256601 ', '0.0555016 ', '0.05054659 '], time:148.082373ms - out_f32_th(naive): ['-0.04256602 ', '0.05550159 ', '0.05054657 '], time:2.673364ms + B=1, H=1, N=1024, D=64, Warmup: 2, Iters: 10 + naive(unfused): ['-0.02313232 ', '-0.0690918 ', '0.01024628 '], time:0.096679ms + mma(naive): ['-0.02304077 ', '-0.0690918 ', '0.01034546 '], time:0.627875ms + mma(stage1): ['-0.02307129 ', '-0.06915283 ', '0.01036835 '], time:0.049162ms + mma(stage2): ['-0.02307129 ', '-0.06915283 ', '0.01036835 '], time:0.044727ms + (flash): ['-0.02305603 ', '-0.06915283 ', '0.01024628 '], time:0.114155ms ---------------------------------------------------------------------------------------------------- - out_FA2MMAf16: ['-0.0425415 ', '0.05551147 ', '0.05053711 '], time:0.633800ms - out_f16_th(naive): ['-0.0425415 ', '0.05545044 ', '0.05053711 '], time:1.276960ms ---------------------------------------------------------------------------------------------------- + B=1, H=1, N=1024, D=128, Warmup: 2, Iters: 10 + naive(unfused): ['0.02102661 ', '-0.05181885 ', '-0.05075073 '], time:0.098372ms + mma(naive): ['0.02102661 ', '-0.05178833 ', '-0.05081177 '], time:1.219583ms + mma(stage1): ['0.02101135 ', '-0.05175781 ', '-0.0508728 '], time:0.069022ms + mma(stage2): ['0.02101135 ', '-0.05175781 ', '-0.0508728 '], time:0.065708ms + (flash): ['0.02110291 ', '-0.05181885 ', '-0.05075073 '], time:0.122333ms ---------------------------------------------------------------------------------------------------- - B=8, H=8, N=1024, D=128 ---------------------------------------------------------------------------------------------------- - out_FA2MMAf16: ['-0.00053024 ', '0.04940796 ', '-0.01649475 '], time:1.235073ms - out_f16_th(naive): ['-0.00051165 ', '0.04946899 ', '-0.01644897 '], time:1.371036ms + B=1, H=1, N=2048, D=64, Warmup: 2, Iters: 10 + naive(unfused): ['-0.04943848 ', '0.03289795 ', '0.03292847 '], time:0.145054ms + mma(naive): ['-0.04937744 ', '0.03283691 ', '0.03292847 '], time:2.422976ms + mma(stage1): ['-0.04946899 ', '0.03283691 ', '0.03292847 '], time:0.084805ms + mma(stage2): ['-0.04946899 ', '0.03283691 ', '0.03292847 '], time:0.075865ms + (flash): ['-0.04940796 ', '0.03283691 ', '0.03295898 '], time:0.123882ms ---------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------- - B=8, H=16, N=256, D=64 - out_FA1f32: ['0.06706338 ', '-0.01847678 ', '-0.02532079 '], time:9.592953ms - out_f32_th(naive): ['0.0670634 ', '-0.01847675 ', '-0.02532081 '], time:0.150659ms + B=1, H=1, N=2048, D=128, Warmup: 2, Iters: 10 + naive(unfused): ['0.04684448 ', '0.04428101 ', '-0.0496521 '], time:0.155616ms + mma(naive): ['0.04684448 ', '0.04425049 ', '-0.04962158 '], time:4.777288ms + mma(stage1): ['0.046875 ', '0.04421997 ', '-0.0496521 '], time:0.124383ms + mma(stage2): ['0.046875 ', '0.04421997 ', '-0.0496521 '], time:0.116611ms + (flash): ['0.04684448 ', '0.04428101 ', '-0.0496521 '], time:0.151396ms ---------------------------------------------------------------------------------------------------- - out_FA2MMAf16: ['0.06719971 ', '-0.01847839 ', '-0.02529907 '], time:0.060866ms - out_f16_th(naive): ['0.06713867 ', '-0.01846313 ', '-0.0252533 '], time:0.063777ms ---------------------------------------------------------------------------------------------------- + B=1, H=4, N=1024, D=64, Warmup: 2, Iters: 10 + naive(unfused): ['0.02505493 ', '-0.03884888 ', '0.03839111 '], time:0.114441ms + mma(naive): ['0.02503967 ', '-0.03878784 ', '0.03833008 '], time:0.641346ms + mma(stage1): ['0.02502441 ', '-0.03881836 ', '0.03833008 '], time:0.051594ms + mma(stage2): ['0.02502441 ', '-0.03881836 ', '0.03833008 '], time:0.046849ms + (flash): ['0.02508545 ', '-0.03890991 ', '0.03842163 '], time:0.159287ms ---------------------------------------------------------------------------------------------------- - B=8, H=16, N=256, D=128 ---------------------------------------------------------------------------------------------------- - out_FA2MMAf16: ['-0.05142212 ', '0.03041077 ', '-0.08868408 '], time:0.132723ms - out_f16_th(naive): ['-0.05151367 ', '0.03018188 ', '-0.08911133 '], time:0.079043ms + B=1, H=4, N=1024, D=128, Warmup: 2, Iters: 10 + naive(unfused): ['0.05755615 ', '-0.0489502 ', '-0.065979 '], time:0.128007ms + mma(naive): ['0.05743408 ', '-0.04898071 ', '-0.065979 '], time:1.237154ms + mma(stage1): ['0.05749512 ', '-0.0489502 ', '-0.065979 '], time:0.071168ms + mma(stage2): ['0.05749512 ', '-0.0489502 ', '-0.065979 '], time:0.068855ms + (flash): ['0.05752563 ', '-0.04898071 ', '-0.065979 '], time:0.171375ms ---------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------- - B=8, H=16, N=512, D=64 - out_FA1f32: ['-0.03446965 ', '0.05762016 ', '0.07836776 '], time:38.253429ms - out_f32_th(naive): ['-0.03446964 ', '0.05762014 ', '0.07836778 '], time:1.357274ms + B=1, H=4, N=2048, D=64, Warmup: 2, Iters: 10 + naive(unfused): ['-0.02835083 ', '0.00565338 ', '0.05978394 '], time:0.321698ms + mma(naive): ['-0.02848816 ', '0.00557709 ', '0.05969238 '], time:2.502728ms + mma(stage1): ['-0.02851868 ', '0.00556564 ', '0.0597229 '], time:0.111842ms + mma(stage2): ['-0.02851868 ', '0.00556564 ', '0.0597229 '], time:0.107741ms + (flash): ['-0.02844238 ', '0.00562286 ', '0.0597229 '], time:0.164914ms ---------------------------------------------------------------------------------------------------- - out_FA2MMAf16: ['-0.03445435 ', '0.05758667 ', '0.07836914 '], time:0.218937ms - out_f16_th(naive): ['-0.03445435 ', '0.05758667 ', '0.07830811 '], time:0.500908ms ---------------------------------------------------------------------------------------------------- + B=1, H=4, N=2048, D=128, Warmup: 2, Iters: 10 + naive(unfused): ['0.0413208 ', '-0.01803589 ', '0.00205803 '], time:0.341320ms + mma(naive): ['0.04135132 ', '-0.01794434 ', '0.00203705 '], time:4.843640ms + mma(stage1): ['0.04138184 ', '-0.01802063 ', '0.00209618 '], time:0.237679ms + mma(stage2): ['0.04138184 ', '-0.01802063 ', '0.00209618 '], time:0.232315ms + (flash): ['0.04135132 ', '-0.01802063 ', '0.00205231 '], time:0.219083ms ---------------------------------------------------------------------------------------------------- - B=8, H=16, N=512, D=128 ---------------------------------------------------------------------------------------------------- - out_FA2MMAf16: ['-0.00230026 ', '-0.05194092 ', '0.0164032 '], time:0.493281ms - out_f16_th(naive): ['-0.00205803 ', '-0.05209351 ', '0.01664734 '], time:0.568807ms + B=1, H=8, N=1024, D=64, Warmup: 2, Iters: 10 + naive(unfused): ['0.06848145 ', '-0.03616333 ', '-0.02568054 '], time:0.139976ms + mma(naive): ['0.06842041 ', '-0.03619385 ', '-0.02557373 '], time:0.642300ms + mma(stage1): ['0.06848145 ', '-0.03619385 ', '-0.02557373 '], time:0.064182ms + mma(stage2): ['0.06848145 ', '-0.03619385 ', '-0.02557373 '], time:0.062895ms + (flash): ['0.06848145 ', '-0.03613281 ', '-0.02572632 '], time:0.141335ms ---------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------- - B=8, H=16, N=1024, D=64 - out_FA1f32: ['0.02074369 ', '-0.01090947 ', '-0.01393144 '], time:152.446897ms - out_f32_th(naive): ['0.02074368 ', '-0.01090949 ', '-0.01393143 '], time:5.296123ms + B=1, H=8, N=1024, D=128, Warmup: 2, Iters: 10 + naive(unfused): ['-0.02468872 ', '0.01733398 ', '-0.06427002 '], time:0.161743ms + mma(naive): ['-0.02461243 ', '0.0173645 ', '-0.06433105 '], time:1.234579ms + mma(stage1): ['-0.02459717 ', '0.01733398 ', '-0.06439209 '], time:0.131607ms + mma(stage2): ['-0.02459717 ', '0.01733398 ', '-0.06439209 '], time:0.125146ms + (flash): ['-0.02471924 ', '0.01733398 ', '-0.06420898 '], time:0.172949ms ---------------------------------------------------------------------------------------------------- - out_FA2MMAf16: ['0.02073669 ', '-0.01097107 ', '-0.01395416 '], time:0.834603ms - out_f16_th(naive): ['0.02073669 ', '-0.01092529 ', '-0.01390839 '], time:2.576745ms ---------------------------------------------------------------------------------------------------- + B=1, H=8, N=2048, D=64, Warmup: 2, Iters: 10 + naive(unfused): ['-0.05279541 ', '0.0152359 ', '0.01768494 '], time:0.682831ms + mma(naive): ['-0.05279541 ', '0.01522827 ', '0.01771545 '], time:2.491093ms + mma(stage1): ['-0.05282593 ', '0.01521301 ', '0.01768494 '], time:0.177312ms + mma(stage2): ['-0.05282593 ', '0.01521301 ', '0.01768494 '], time:0.162506ms + (flash): ['-0.05279541 ', '0.0152359 ', '0.01768494 '], time:0.208259ms ---------------------------------------------------------------------------------------------------- - B=8, H=16, N=1024, D=128 ---------------------------------------------------------------------------------------------------- - out_FA2MMAf16: ['0.08306885 ', '0.03659058 ', '0.04852295 '], time:1.907628ms - out_f16_th(naive): ['0.08319092 ', '0.03668213 ', '0.04858398 '], time:2.696407ms + B=1, H=8, N=2048, D=128, Warmup: 2, Iters: 10 + naive(unfused): ['-0.00988007 ', '-0.07843018 ', '-0.04537964 '], time:0.758100ms + mma(naive): ['-0.00988007 ', '-0.07836914 ', '-0.04541016 '], time:4.834414ms + mma(stage1): ['-0.00990295 ', '-0.07830811 ', '-0.04544067 '], time:0.354052ms + mma(stage2): ['-0.00990295 ', '-0.07830811 ', '-0.04544067 '], time:0.341821ms + (flash): ['-0.00987244 ', '-0.07843018 ', '-0.04541016 '], time:0.322461ms ---------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------- - B=16, H=8, N=256, D=64 - out_FA1f32: ['0.09634054 ', '-0.02606717 ', '0.13369624 '], time:9.618666ms - out_f32_th(naive): ['0.09634058 ', '-0.02606717 ', '0.13369617 '], time:0.147052ms + B=2, H=1, N=1024, D=64, Warmup: 2, Iters: 10 + naive(unfused): ['0.0115509 ', '-0.06903076 ', '-0.07427979 '], time:0.104880ms + mma(naive): ['0.01156616 ', '-0.06903076 ', '-0.07427979 '], time:0.631714ms + mma(stage1): ['0.01158142 ', '-0.06896973 ', '-0.07427979 '], time:0.050020ms + mma(stage2): ['0.01158142 ', '-0.06896973 ', '-0.07427979 '], time:0.045276ms + (flash): ['0.01151276 ', '-0.06903076 ', '-0.07427979 '], time:0.116420ms ---------------------------------------------------------------------------------------------------- - out_FA2MMAf16: ['0.09649658 ', '-0.02606201 ', '0.13366699 '], time:0.060964ms - out_f16_th(naive): ['0.09631348 ', '-0.02613831 ', '0.13366699 '], time:0.063334ms ---------------------------------------------------------------------------------------------------- + B=2, H=1, N=1024, D=128, Warmup: 2, Iters: 10 + naive(unfused): ['0.01117706 ', '0.01002502 ', '-0.05667114 '], time:0.106740ms + mma(naive): ['0.01112366 ', '0.01013184 ', '-0.05661011 '], time:1.222134ms + mma(stage1): ['0.01113892 ', '0.01016998 ', '-0.05661011 '], time:0.069237ms + mma(stage2): ['0.01113892 ', '0.01016998 ', '-0.05661011 '], time:0.065756ms + (flash): ['0.01117706 ', '0.01009369 ', '-0.05667114 '], time:0.127745ms ---------------------------------------------------------------------------------------------------- - B=16, H=8, N=256, D=128 ---------------------------------------------------------------------------------------------------- - out_FA2MMAf16: ['-0.0680542 ', '0.18212891 ', '0.09741211 '], time:0.132513ms - out_f16_th(naive): ['-0.0680542 ', '0.18212891 ', '0.09747314 '], time:0.079212ms + B=2, H=1, N=2048, D=64, Warmup: 2, Iters: 10 + naive(unfused): ['-0.01078033 ', '0.0486145 ', '-0.02133179 '], time:0.206518ms + mma(naive): ['-0.01079559 ', '0.0486145 ', '-0.02133179 '], time:2.455759ms + mma(stage1): ['-0.01078033 ', '0.04867554 ', '-0.02133179 '], time:0.086260ms + mma(stage2): ['-0.01078033 ', '0.04867554 ', '-0.02133179 '], time:0.076795ms + (flash): ['-0.01076508 ', '0.0486145 ', '-0.02130127 '], time:0.140572ms ---------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------- - B=16, H=8, N=512, D=64 - out_FA1f32: ['0.06110233 ', '-0.03080001 ', '0.06487844 '], time:38.171313ms - out_f32_th(naive): ['0.06110234 ', '-0.0308 ', '0.06487839 '], time:1.358862ms + B=2, H=1, N=2048, D=128, Warmup: 2, Iters: 10 + naive(unfused): ['-0.02957153 ', '-0.05664062 ', '0.00559998 '], time:0.215340ms + mma(naive): ['-0.02957153 ', '-0.05667114 ', '0.00561523 '], time:4.799604ms + mma(stage1): ['-0.02952576 ', '-0.05661011 ', '0.00555038 '], time:0.126314ms + mma(stage2): ['-0.02952576 ', '-0.05661011 ', '0.00555038 '], time:0.123835ms + (flash): ['-0.02960205 ', '-0.05667114 ', '0.0055809 '], time:0.169039ms ---------------------------------------------------------------------------------------------------- - out_FA2MMAf16: ['0.06112671 ', '-0.03077698 ', '0.06488037 '], time:0.218849ms - out_f16_th(naive): ['0.06109619 ', '-0.03079224 ', '0.06488037 '], time:0.497117ms ---------------------------------------------------------------------------------------------------- + B=2, H=4, N=1024, D=64, Warmup: 2, Iters: 10 + naive(unfused): ['-0.05426025 ', '-0.04974365 ', '0.03216553 '], time:0.140238ms + mma(naive): ['-0.05422974 ', '-0.04974365 ', '0.03216553 '], time:0.640368ms + mma(stage1): ['-0.05422974 ', '-0.04971313 ', '0.03225708 '], time:0.063801ms + mma(stage2): ['-0.05422974 ', '-0.04971313 ', '0.03225708 '], time:0.062037ms + (flash): ['-0.05426025 ', '-0.04980469 ', '0.03216553 '], time:0.138855ms ---------------------------------------------------------------------------------------------------- - B=16, H=8, N=512, D=128 ---------------------------------------------------------------------------------------------------- - out_FA2MMAf16: ['-0.00991058 ', '-0.18884277 ', '-0.04980469 '], time:0.493472ms - out_f16_th(naive): ['-0.0098877 ', '-0.18884277 ', '-0.04980469 '], time:0.573759ms + B=2, H=4, N=1024, D=128, Warmup: 2, Iters: 10 + naive(unfused): ['0.01756287 ', '0.04769897 ', '0.03887939 '], time:0.163817ms + mma(naive): ['0.0176239 ', '0.04779053 ', '0.03903198 '], time:1.236463ms + mma(stage1): ['0.01760864 ', '0.04772949 ', '0.0390625 '], time:0.131607ms + mma(stage2): ['0.01760864 ', '0.04772949 ', '0.0390625 '], time:0.124693ms + (flash): ['0.01759338 ', '0.04776001 ', '0.03890991 '], time:0.167727ms ---------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------- - B=16, H=8, N=1024, D=64 - out_FA1f32: ['-0.01831236 ', '-0.07696866 ', '-0.04614653 '], time:152.500360ms - out_f32_th(naive): ['-0.01831233 ', '-0.07696865 ', '-0.04614652 '], time:5.295737ms + B=2, H=4, N=2048, D=64, Warmup: 2, Iters: 10 + naive(unfused): ['0.02493286 ', '-0.01035309 ', '-0.01535797 '], time:0.681067ms + mma(naive): ['0.02497864 ', '-0.01041412 ', '-0.01535797 '], time:2.491713ms + mma(stage1): ['0.02500916 ', '-0.01036072 ', '-0.01535797 '], time:0.177479ms + mma(stage2): ['0.02500916 ', '-0.01036072 ', '-0.01535797 '], time:0.161791ms + (flash): ['0.02494812 ', '-0.01035309 ', '-0.0153656 '], time:0.207233ms ---------------------------------------------------------------------------------------------------- - out_FA2MMAf16: ['-0.01831055 ', '-0.07696533 ', '-0.04614258 '], time:0.834262ms - out_f16_th(naive): ['-0.01826477 ', '-0.0769043 ', '-0.04614258 '], time:2.576706ms ---------------------------------------------------------------------------------------------------- + B=2, H=4, N=2048, D=128, Warmup: 2, Iters: 10 + naive(unfused): ['-0.05224609 ', '0.05612183 ', '-0.00789642 '], time:0.755811ms + mma(naive): ['-0.05227661 ', '0.05609131 ', '-0.00796509 '], time:4.834819ms + mma(stage1): ['-0.05230713 ', '0.05609131 ', '-0.00799561 '], time:0.354147ms + mma(stage2): ['-0.05230713 ', '0.05609131 ', '-0.00799561 '], time:0.342011ms + (flash): ['-0.05227661 ', '0.05612183 ', '-0.00790405 '], time:0.317454ms ---------------------------------------------------------------------------------------------------- - B=16, H=8, N=1024, D=128 ---------------------------------------------------------------------------------------------------- - out_FA2MMAf16: ['0.04501343 ', '0.07751465 ', '-0.01131439 '], time:1.907537ms - out_f16_th(naive): ['0.04501343 ', '0.07745361 ', '-0.01132965 '], time:2.697947ms + B=2, H=8, N=1024, D=64, Warmup: 2, Iters: 10 + naive(unfused): ['-0.10241699 ', '0.0279541 ', '0.0413208 '], time:0.255227ms + mma(naive): ['-0.10229492 ', '0.02796936 ', '0.0413208 '], time:0.641489ms + mma(stage1): ['-0.10235596 ', '0.02799988 ', '0.04135132 '], time:0.097919ms + mma(stage2): ['-0.10235596 ', '0.02799988 ', '0.04135132 '], time:0.090098ms + (flash): ['-0.10235596 ', '0.02792358 ', '0.04129028 '], time:0.148273ms ---------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------- - B=16, H=16, N=256, D=64 - out_FA1f32: ['0.05493443 ', '0.03093347 ', '-0.05244123 '], time:12.086096ms - out_f32_th(naive): ['0.05493441 ', '0.03093351 ', '-0.05244119 '], time:0.518868ms + B=2, H=8, N=1024, D=128, Warmup: 2, Iters: 10 + naive(unfused): ['0.00484085 ', '-0.09161377 ', '-0.05404663 '], time:0.302172ms + mma(naive): ['0.00475311 ', '-0.0914917 ', '-0.0539856 '], time:1.234841ms + mma(stage1): ['0.004776 ', '-0.0914917 ', '-0.05404663 '], time:0.194931ms + mma(stage2): ['0.004776 ', '-0.0914917 ', '-0.05404663 '], time:0.180697ms + (flash): ['0.00481796 ', '-0.09161377 ', '-0.05404663 '], time:0.204754ms ---------------------------------------------------------------------------------------------------- - out_FA2MMAf16: ['0.05496216 ', '0.03089905 ', '-0.05227661 '], time:0.083928ms - out_f16_th(naive): ['0.05487061 ', '0.03102112 ', '-0.05239868 '], time:0.133991ms ---------------------------------------------------------------------------------------------------- + B=2, H=8, N=2048, D=64, Warmup: 2, Iters: 10 + naive(unfused): ['-0.01339722 ', '-0.00278282 ', '0.05957031 '], time:1.561284ms + mma(naive): ['-0.01334381 ', '-0.0027523 ', '0.05957031 '], time:2.487493ms + mma(stage1): ['-0.01335907 ', '-0.00276947 ', '0.05953979 '], time:0.299239ms + mma(stage2): ['-0.01335907 ', '-0.00276947 ', '0.05953979 '], time:0.286341ms + (flash): ['-0.01339722 ', '-0.00278664 ', '0.05960083 '], time:0.265479ms ---------------------------------------------------------------------------------------------------- - B=16, H=16, N=256, D=128 ---------------------------------------------------------------------------------------------------- - out_FA2MMAf16: ['-0.03808594 ', '-0.19189453 ', '0.00264549 '], time:0.192747ms - out_f16_th(naive): ['-0.03778076 ', '-0.19189453 ', '0.00281334 '], time:0.178058ms + B=2, H=8, N=2048, D=128, Warmup: 2, Iters: 10 + naive(unfused): ['-0.01472473 ', '0.01069641 ', '0.00494003 '], time:1.626682ms + mma(naive): ['-0.01460266 ', '0.01072693 ', '0.00496674 '], time:4.847050ms + mma(stage1): ['-0.01470184 ', '0.0107193 ', '0.00500488 '], time:0.680041ms + mma(stage2): ['-0.01470184 ', '0.0107193 ', '0.00500488 '], time:0.667119ms + (flash): ['-0.01469421 ', '0.01070404 ', '0.00495911 '], time:0.430512ms +---------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- + B=4, H=1, N=1024, D=64, Warmup: 2, Iters: 10 + naive(unfused): ['-0.01221466 ', '-0.0252533 ', '0.02658081 '], time:0.113845ms + mma(naive): ['-0.01217651 ', '-0.02531433 ', '0.02661133 '], time:0.644326ms + mma(stage1): ['-0.01215363 ', '-0.02522278 ', '0.02659607 '], time:0.050974ms + mma(stage2): ['-0.01215363 ', '-0.02522278 ', '0.02659607 '], time:0.046182ms + (flash): ['-0.0121994 ', '-0.0252533 ', '0.02659607 '], time:0.124216ms ---------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------- - B=16, H=16, N=512, D=64 - out_FA1f32: ['0.02739076 ', '0.01203587 ', '0.09457675 '], time:48.142586ms - out_f32_th(naive): ['0.02739077 ', '0.01203588 ', '0.09457672 '], time:2.749476ms + B=4, H=1, N=1024, D=128, Warmup: 2, Iters: 10 + naive(unfused): ['-0.15124512 ', '-0.00973511 ', '-0.04898071 '], time:0.126386ms + mma(naive): ['-0.15124512 ', '-0.00974274 ', '-0.0489502 '], time:1.234722ms + mma(stage1): ['-0.15112305 ', '-0.00977325 ', '-0.0489502 '], time:0.070763ms + mma(stage2): ['-0.15112305 ', '-0.00977325 ', '-0.0489502 '], time:0.068855ms + (flash): ['-0.15136719 ', '-0.009758 ', '-0.04898071 '], time:0.144768ms +---------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- + B=4, H=1, N=2048, D=64, Warmup: 2, Iters: 10 + naive(unfused): ['0.00852203 ', '0.03509521 ', '-0.01040649 '], time:0.322938ms + mma(naive): ['0.00856018 ', '0.0350647 ', '-0.01039124 '], time:2.482176ms + mma(stage1): ['0.00857544 ', '0.03512573 ', '-0.01041412 '], time:0.111032ms + mma(stage2): ['0.00857544 ', '0.03512573 ', '-0.01041412 '], time:0.107479ms + (flash): ['0.00852966 ', '0.03509521 ', '-0.01039124 '], time:0.162792ms ---------------------------------------------------------------------------------------------------- - out_FA2MMAf16: ['0.02740479 ', '0.01203918 ', '0.09454346 '], time:0.291946ms - out_f16_th(naive): ['0.02740479 ', '0.01203156 ', '0.09460449 '], time:1.350477ms +---------------------------------------------------------------------------------------------------- + B=4, H=1, N=2048, D=128, Warmup: 2, Iters: 10 + naive(unfused): ['0.02624512 ', '-0.0218811 ', '0.00286674 '], time:0.342917ms + mma(naive): ['0.02622986 ', '-0.02198792 ', '0.00295258 '], time:4.843545ms + mma(stage1): ['0.02619934 ', '-0.02203369 ', '0.00297737 '], time:0.238228ms + mma(stage2): ['0.02619934 ', '-0.02203369 ', '0.00297737 '], time:0.232315ms + (flash): ['0.02624512 ', '-0.02191162 ', '0.00291061 '], time:0.219464ms ---------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------- - B=16, H=16, N=512, D=128 + B=4, H=4, N=1024, D=64, Warmup: 2, Iters: 10 + naive(unfused): ['-0.02259827 ', '-0.04086304 ', '-0.01893616 '], time:0.237060ms + mma(naive): ['-0.02270508 ', '-0.04071045 ', '-0.01905823 '], time:0.642824ms + mma(stage1): ['-0.02276611 ', '-0.04071045 ', '-0.01899719 '], time:0.098085ms + mma(stage2): ['-0.02276611 ', '-0.04071045 ', '-0.01899719 '], time:0.090194ms + (flash): ['-0.0226593 ', '-0.040802 ', '-0.01896667 '], time:0.148749ms +---------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- + B=4, H=4, N=1024, D=128, Warmup: 2, Iters: 10 + naive(unfused): ['-0.03939819 ', '-0.020401 ', '0.06835938 '], time:0.305700ms + mma(naive): ['-0.03936768 ', '-0.02033997 ', '0.06835938 '], time:1.237082ms + mma(stage1): ['-0.03936768 ', '-0.02037048 ', '0.0682373 '], time:0.191927ms + mma(stage2): ['-0.03936768 ', '-0.02037048 ', '0.0682373 '], time:0.181007ms + (flash): ['-0.03942871 ', '-0.02038574 ', '0.06835938 '], time:0.202513ms ---------------------------------------------------------------------------------------------------- - out_FA2MMAf16: ['-0.06494141 ', '-0.06427002 ', '-0.04528809 '], time:0.690589ms - out_f16_th(naive): ['-0.06500244 ', '-0.06427002 ', '-0.04519653 '], time:1.470513ms +---------------------------------------------------------------------------------------------------- + B=4, H=4, N=2048, D=64, Warmup: 2, Iters: 10 + naive(unfused): ['-0.00169182 ', '-0.03265381 ', '-0.01131439 '], time:1.560163ms + mma(naive): ['-0.00167274 ', '-0.03265381 ', '-0.01129913 '], time:2.554107ms + mma(stage1): ['-0.00167847 ', '-0.03265381 ', '-0.01128387 '], time:0.299549ms + mma(stage2): ['-0.00167847 ', '-0.03265381 ', '-0.01128387 '], time:0.286245ms + (flash): ['-0.00166225 ', '-0.03265381 ', '-0.0112915 '], time:0.265408ms ---------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------- - B=16, H=16, N=1024, D=64 - out_FA1f32: ['-0.02254915 ', '0.00821745 ', '0.09361463 '], time:196.162612ms - out_f32_th(naive): ['-0.02254917 ', '0.00821746 ', '0.09361461 '], time:10.451190ms + B=4, H=4, N=2048, D=128, Warmup: 2, Iters: 10 + naive(unfused): ['0.00659943 ', '0.01834106 ', '-0.01256561 '], time:1.623297ms + mma(naive): ['0.00662994 ', '0.01834106 ', '-0.01259613 '], time:4.849243ms + mma(stage1): ['0.00666428 ', '0.01835632 ', '-0.0125885 '], time:0.682616ms + mma(stage2): ['0.00666428 ', '0.01835632 ', '-0.0125885 '], time:0.666404ms + (flash): ['0.0066185 ', '0.01834106 ', '-0.0125885 '], time:0.430489ms +---------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- + B=4, H=8, N=1024, D=64, Warmup: 2, Iters: 10 + naive(unfused): ['0.1619873 ', '-0.02633667 ', '-0.07122803 '], time:0.614214ms + mma(naive): ['0.16186523 ', '-0.02633667 ', '-0.07128906 '], time:0.641060ms + mma(stage1): ['0.16186523 ', '-0.02638245 ', '-0.07122803 '], time:0.160074ms + mma(stage2): ['0.16186523 ', '-0.02638245 ', '-0.07122803 '], time:0.153947ms + (flash): ['0.1619873 ', '-0.02633667 ', '-0.07128906 '], time:0.179791ms ---------------------------------------------------------------------------------------------------- - out_FA2MMAf16: ['-0.02252197 ', '0.00821686 ', '0.09368896 '], time:1.106799ms - out_f16_th(naive): ['-0.02255249 ', '0.00818634 ', '0.09368896 '], time:5.125363ms +---------------------------------------------------------------------------------------------------- + B=4, H=8, N=1024, D=128, Warmup: 2, Iters: 10 + naive(unfused): ['-0.01779175 ', '0.00307846 ', '0.03710938 '], time:0.656843ms + mma(naive): ['-0.01780701 ', '0.00312424 ', '0.0369873 '], time:1.240921ms + mma(stage1): ['-0.01782227 ', '0.00312233 ', '0.0369873 '], time:0.357890ms + mma(stage2): ['-0.01782227 ', '0.00312233 ', '0.0369873 '], time:0.345254ms + (flash): ['-0.01786804 ', '0.00312233 ', '0.03707886 '], time:0.260258ms ---------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------- - B=16, H=16, N=1024, D=128 + B=4, H=8, N=2048, D=64, Warmup: 2, Iters: 10 + naive(unfused): ['-0.01933289 ', '0.00508499 ', '0.00285912 '], time:3.031611ms + mma(naive): ['-0.01930237 ', '0.00502014 ', '0.0028553 '], time:2.496433ms + mma(stage1): ['-0.01933289 ', '0.00498962 ', '0.00286484 '], time:0.574112ms + mma(stage2): ['-0.01933289 ', '0.00498962 ', '0.00286484 '], time:0.547338ms + (flash): ['-0.0193634 ', '0.00505829 ', '0.00289345 '], time:0.430322ms +---------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------- - out_FA2MMAf16: ['-0.07330322 ', '-0.06152344 ', '0.00090456 '], time:3.174434ms - out_f16_th(naive): ['-0.07336426 ', '-0.06149292 ', '0.00105381 '], time:5.335908ms + B=4, H=8, N=2048, D=128, Warmup: 2, Iters: 10 + naive(unfused): ['-0.05450439 ', '-0.03857422 ', '0.00600052 '], time:3.121519ms + mma(naive): ['-0.05447388 ', '-0.03863525 ', '0.00599289 '], time:4.852891ms + mma(stage1): ['-0.05450439 ', '-0.03857422 ', '0.00598526 '], time:1.331329ms + mma(stage2): ['-0.05450439 ', '-0.03857422 ', '0.00598526 '], time:1.311874ms + (flash): ['-0.05456543 ', '-0.03857422 ', '0.0059967 '], time:0.760818ms ---------------------------------------------------------------------------------------------------- ``` diff --git a/kernels/flash-attn/cutlass/flash_attn_cute_fp8.cu b/kernels/flash-attn/cutlass/flash_attn_cute_fp8.cu deleted file mode 100644 index e69de29b..00000000 diff --git a/kernels/flash-attn/flash_attn.py b/kernels/flash-attn/flash_attn.py deleted file mode 100644 index 6afe2eba..00000000 --- a/kernels/flash-attn/flash_attn.py +++ /dev/null @@ -1,92 +0,0 @@ -import math -import time -import torch -from torch.nn import functional as F -from torch.utils.cpp_extension import load -from functools import partial -from typing import Optional - -torch.set_grad_enabled(False) -# Load the CUDA kernel as a python module -lib = load(name='flash_attn_lib', - sources=[ - './naive/flash_attn_cuda.cu', - './mma/flash_attn_mma_old.cu', - 'pybind/flash_attn.cc'], - extra_cuda_cflags=[ - "-O3", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math" - ], - extra_cflags=['-std=c++17']) - -# un-fused naive attn -def naive_attn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): - att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1)))) - att = F.softmax(att, dim=-1) - y = att @ v - return y - - -def run_benchmark(perf_func: callable, - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - tag: str, out: Optional[torch.Tensor] = None, - warmup: int = 5, iters: int = 10, - show_all: bool = False): - if out is not None: - out.fill_(0) - if out is not None: - for i in range(warmup): - perf_func(q, k, v, out) - else: - for i in range(warmup): - _ = perf_func(q, k, v) - - torch.cuda.synchronize() - start = time.time() - # iters - if out is not None: - for i in range(iters): - perf_func(q, k, v, out) - else: - for i in range(iters): - out = perf_func(q, k, v) - torch.cuda.synchronize() - end = time.time() - total_time = (end - start) * 1000 # ms - mean_time = total_time / iters - out_info = f"out_{tag}" - out_val = out.flatten().detach().cpu().numpy().tolist()[:3] - out_val = [round(v, 8) for v in out_val] - out_val = [f"{v:<12}" for v in out_val] - print(f"{out_info:>20}: {out_val}, time:{mean_time:.6f}ms") - if show_all: print(out[0, 0, 0, :]) - return out.clone(), mean_time - -Bs = [8, 16] -Hs = [8, 16] -Ns = [1024, 2048, 4096] -Ds = [64, 128] # only support [64, 128] now -# batch_size, n_head, seq_len, head_dim (B,nh,N,d) -BHNDs = [(B, H, N, D) for B in Bs for H in Hs for N in Ns for D in Ds] - -print("-" * 100) -print(" "* 25 + "B: batch_size, H: n_head, N: seq_len, D: head_dim") -for (B, H, N, D) in BHNDs: - print("-" * 100) - print(" " * 40 + f"B={B}, H={H}, N={N}, D={D}") - q = torch.randn(B, H, N, D, device="cuda", dtype=torch.half).contiguous() - k = torch.randn(B, H, N, D, device="cuda", dtype=torch.half).contiguous() - v = torch.randn(B, H, N, D, device="cuda", dtype=torch.half).contiguous() - o = torch.randn(B, H, N, D, device="cuda", dtype=torch.half).contiguous() - torch.cuda.synchronize() - - # using fp16 Tesor Core MMA instruction - run_benchmark(lib.flash_attn_2_fwd_f16_mma_m16n8k16, q, k, v, "FA2MMAf16", o) - run_benchmark(naive_attn, q, k, v, "f16_th(naive)") - print("-" * 100) diff --git a/kernels/flash-attn/flash_attn_mma.py b/kernels/flash-attn/flash_attn_mma.py new file mode 100644 index 00000000..c92f4597 --- /dev/null +++ b/kernels/flash-attn/flash_attn_mma.py @@ -0,0 +1,221 @@ +import os +import math +import time +import torch +from torch.nn import functional as F +from torch.utils.cpp_extension import load +from typing import Optional +from flash_attn import flash_attn_func +import argparse +import random +import numpy as np + +torch.set_grad_enabled(False) +torch.set_printoptions(precision=6, threshold=8, edgeitems=3, + linewidth=120, sci_mode=False) + + +def set_rand_seed(seed:int=1): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_project_dir(): + return os.path.dirname(os.path.dirname( + os.path.dirname(os.path.abspath(__file__)))) + + +project_dir = get_project_dir() + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--no-rand-q", '--no-rq', action="store_true") + parser.add_argument("--no-rand-k", '--no-rk', action="store_true") + parser.add_argument("--no-rand-v", '--no-rv', action="store_true") + parser.add_argument("--no-rand-qkv", '--no-rqkv', action="store_true") + parser.add_argument("--naive", action="store_true") + parser.add_argument("--sdpa", action="store_true") + parser.add_argument("--check", action="store_true") + parser.add_argument("--show-all", '--show', action="store_true") + parser.add_argument("--B", type=int, default=None) + parser.add_argument("--H", type=int, default=None) + parser.add_argument("--N", type=int, default=None) + parser.add_argument("--D", type=int, default=None) + parser.add_argument("--seed", type=int, default=None) + parser.add_argument("--debug", action="store_true") + parser.add_argument("--warmup", type=int, default=2) + parser.add_argument("--iters", type=int, default=10) + parser.add_argument("--range-k", '--gk', action="store_true") + return parser.parse_args() + + +args = get_args() +print(args) + + +# Load the CUDA kernel as a python module +lib = load(name='flash_attn_lib', + sources=[ + './naive/flash_attn_cuda.cu', + './mma/flash_attn_mma_naive.cu', + './mma/flash_attn_mma_stage.cu', + './pybind/flash_attn.cc'], + extra_cuda_cflags=[ + "-O3", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + f"-I {project_dir}/kernels/flash-attn/utils", + "-DFLASH_ATTN_MMA_DEBUG" if args.debug else "" + ], + extra_cflags=['-std=c++17']) + + +def run_benchmark(perf_func: callable, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tag: str, + out: Optional[torch.Tensor] = None, + s: Optional[torch.Tensor] = None, # BUDEG + stages: int = -1, + warmup: int = args.warmup, + iters: int = args.iters, + show_all: bool = args.show_all): + if out is not None: + out.fill_(0) + if s is not None: + s.fill_(0) + if out is not None: + for i in range(warmup): + if stages >= 1: + if s is not None: + perf_func(q, k, v, out, s, stages) + else: + perf_func(q, k, v, out, stages) + else: + perf_func(q, k, v, out) + else: + for i in range(warmup): + _ = perf_func(q, k, v) + + torch.cuda.synchronize() + start = time.time() + # iters + if out is not None: + for i in range(iters): + if stages >= 1: + if s is not None: + perf_func(q, k, v, out, s, stages) + else: + perf_func(q, k, v, out, stages) + else: + perf_func(q, k, v, out) + else: + for i in range(iters): + out = perf_func(q, k, v) + torch.cuda.synchronize() + end = time.time() + total_time = (end - start) * 1000 # ms + mean_time = total_time / iters + out_info = f"{tag}" + out_val_first = out.flatten()[:3].detach().cpu().numpy().tolist() + out_val_last = out.flatten()[-3:].detach().cpu().numpy().tolist() + out_val_first = [round(v, 8) for v in out_val_first] + out_val_last = [round(v, 8) for v in out_val_last] + out_val = out_val_first[:2] + out_val.append(out_val_last[-1]) + out_val = [f"{v:<12}" for v in out_val] + print(f"{out_info:>20}: {out_val}, time:{mean_time:.6f}ms") + if show_all: + print(out) + time.sleep(0.05) + return out.clone(), mean_time + + +def get_qkvo(B, H, N, D): + if not (args.no_rand_q or args.no_rand_qkv): + q = torch.randn((B, H, N, D), dtype=torch.half, device="cuda") + else: + q = torch.ones(B, H, N, D, device="cuda", dtype=torch.half).contiguous() + if not (args.no_rand_k or args.no_rand_qkv): + k = torch.randn((B, H, N, D), dtype=torch.half, device="cuda") + else: + k = torch.ones(B, H, N, D, device="cuda", dtype=torch.half).contiguous() + if args.range_k: + for i in range(N): + k[:, :, i, :] = (i + 1) / N + k = k.cuda().half().contiguous() + if not (args.no_rand_v or args.no_rand_qkv): + v = torch.randn((B, H, N, D), dtype=torch.half, device="cuda") + else: + v = torch.ones(B, H, N, D, device="cuda", dtype=torch.half).contiguous() + + o = torch.zeros(B, H, N, D, device="cuda", dtype=torch.half).contiguous() + + return q, k, v, o + + +# un-fused naive attn +def naive_attn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1)))) + att = F.softmax(att, dim=-1) + y = att @ v + return y + + +Bs = [1, 2, 4] if not args.B else [args.B] +Hs = [1, 4, 8] if not args.H else [args.H] +Ns = [1024, 2048] if not args.N else [args.N] +Ds = [64, 128] if not args.D else [args.D] +# batch_size, n_head, seq_len, head_dim (B,H,N,D) +BHNDs = [(B, H, N, D) for B in Bs for H in Hs for N in Ns for D in Ds] + +seed = args.seed if args.seed else random.choice(range(10000)) +set_rand_seed(seed) +print("-" * 100) +print(" "* 10 + f"B: batch_size, H: n_head, N: seq_len, D: head_dim, " + f"seed: {seed}, Warmup: {args.warmup}, Iters: {args.iters}") + +for (B, H, N, D) in BHNDs: + print("-" * 100) + print(" " * 25 + f"B={B}, H={H}, N={N}, D={D}, Warmup: {args.warmup}, Iters: {args.iters}") + q, k, v, o = get_qkvo(B, H, N, D) + tk = k.transpose(-2, -1).contiguous() + fq = q.transpose(1, 2).contiguous() + fk = k.transpose(1, 2).contiguous() + fv = v.transpose(1, 2).contiguous() + torch.cuda.synchronize() + + if args.naive: + out_naive, _ = run_benchmark(naive_attn, q, k, v, "naive(unfused)") + + # using fp16 Tesor Core MMA instruction + out_mma_naive, _ = run_benchmark(lib.flash_attn_mma_naive, q, k, v, "mma(naive)", o) + out_mma_stage1, _ = run_benchmark(lib.flash_attn_mma_stages, q, tk, v, "mma(stage1)", o, stages=1) + out_mma_stage2, _ = run_benchmark(lib.flash_attn_mma_stages, q, tk, v, "mma(stage2)", o, stages=2) + out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)") + + if args.sdpa: + out_sdpa, _ = run_benchmark(F.scaled_dot_product_attention, q, k, v, "(sdpa)") + print("-" * 100) + + torch.cuda.synchronize() + if args.check: + out_flash = out_flash.transpose(1, 2) + for i in range(int(N/8)): + if i < 4: + print("-" * 100) + print(f"out_flash[:, :, {(i*8)}:{(i+1)*8}, :]:\n") + print(out_flash[:, :, (i*8):(i+1)*8, :].float()) + print(f"out_mma_stage1[:, :, {(i*8)}:{(i+1)*8}, :]:\n") + print(out_mma_stage1[:, :, (i*8):(i+1)*8, :].float()) + print("-" * 100) + print(f"{torch.allclose(out_flash.float(), out_mma_naive.float(), atol=1e-2)}") diff --git a/kernels/flash-attn/mma/flash_attn_mma.cu b/kernels/flash-attn/mma/flash_attn_mma.cu deleted file mode 100644 index 45a22c94..00000000 --- a/kernels/flash-attn/mma/flash_attn_mma.cu +++ /dev/null @@ -1,81 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -using namespace nvcuda; - -#define WARP_SIZE 32 -#define DEVICE_INLINE __device__ inline -#define HOST_DEVICE_INLINE __device__ __host__ inline -#define INT4(value) (reinterpret_cast(&(value))[0]) -#define FLOAT4(value) (reinterpret_cast(&(value))[0]) -#define HALF2(value) (reinterpret_cast(&(value))[0]) -#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) -#define LDST32BITS(value) (reinterpret_cast(&(value))[0]) -#define LDST64BITS(value) (reinterpret_cast(&(value))[0]) -#define LDST128BITS(value) (reinterpret_cast(&(value))[0]) -// gmem -> smem -#define CP_ASYNC_COMMIT_GROUP() asm volatile("cp.async.commit_group;\n" ::) -#define CP_ASYNC_WAIT_ALL() asm volatile("cp.async.wait_all;\n" ::) -#define CP_ASYNC_WAIT_GROUP(n) asm volatile("cp.async.wait_group %0;\n" ::"n"(n)) -// ca(cache all, L1 + L2): support 4, 8, 16 bytes, cg(cache global, L2): only support 16 bytes. -#define CP_ASYNC_CA(dst, src, bytes) asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) -#define CP_ASYNC_CG(dst, src, bytes) asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) -// smem -> gmem: requires sm_90 or higher. -#define CP_ASYNC_BULK_COMMIT_GROUP() asm volatile("cp.async.bulk.commit_group;\n" ::) -#define CP_ASYNC_BULK_WAIT_ALL() asm volatile("cp.async.bulk.wait_all;\n" ::) -#define CP_ASYNC_BULK_WAIT_GROUP(n) asm volatile("cp.async.bulk.wait_group %0;\n" ::"n"(n)) -#define CP_ASYNC_BULK(dst, src, bytes) asm volatile("cp.async.bulk.global.shared::cta.bulk_group.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) -// ldmatrix -#define LDMATRIX_X1(R, addr) asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) -#define LDMATRIX_X2(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) -#define LDMATRIX_X4(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) -#define LDMATRIX_X1_T(R, addr) asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) -#define LDMATRIX_X2_T(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) -#define LDMATRIX_X4_T(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) -// stmatrix: requires sm_90 or higher. -#define STMATRIX_X1(addr, R) asm volatile("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" :: "r"(addr), "r"(R)) -#define STMATRIX_X2(addr, R0, R1) asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" :: "r"(addr), "r"(R0), "r"(R1)) -#define STMATRIX_X4(addr, R0, R1, R2, R3) asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" :: "r"(addr), "r"(R0), "r"(R1), "r"(R2), "r"(R3)) -#define STMATRIX_X1_T(addr, R) asm volatile("stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n" :: "r"(addr), "r"(R)) -#define STMATRIX_X2_T(addr, R0, R1) asm volatile("stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n" :: "r"(addr), "r"(R0), "r"(R1)) -#define STMATRIX_X4_T(addr, R0, R1, R2, R3) asm volatile("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" :: "r"(addr), "r"(R0), "r"(R1), "r"(R2), "r"(R3)) -// mma m16n8k16 -#define HMMA16816(RD0, RD1, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1) asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" : "=r"(RD0), "=r"(RD1) : "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1)) - -HOST_DEVICE_INLINE -int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 1) : (a / b); } - -// Write FlashAttention-2 from scratch using Tensor Cores with MMA PTX instruction. -// The input is Q,K,V, 4D tensor with shape [batch_size, num_heads, seq_len, head_dim]. -// The output is O, a 4D tensor with shape [batch_size, num_heads, seq_len, head_dim]. - -// The FlashAttention-2 algorithm is described in the following paper: -// https://arxiv.org/abs/2110.08210 - -// m16n8k16_mma2x4_warp4x4 -// Q,K,V,O: [batch_size, num_heads, seq_len, head_dim], [Nxd] -// each block processes Q_tile with shape [Br,d] and full K,V with shape [N,d] -// Br or Bc = 64,128,256, etc. -// grid(batch, head_num, N/Br), block(256=8*mma) -template -__global__ void flash_attn_mma_kernel( - half* Q, half* K, half* V, half* O, int N) { - // step 0: S_tile[Br,N] = Q_tile[Br,d] * K[N,d], slice-k manner matmul - // across K's N dim, each K_tile/V_tile inner loop has shape [Bc,d]. - // step 1: P_tile[Br,N] = softmax(S_tile[Br,N]), row wise. - // step 2: O_tile[Br,d] = P_tile[Br,N] * V[N,d], matmul. - const int Tr = div_ceil(N, Br); // Tr Q_tile[Br,d] - const int Tc = div_ceil(N, Bc); // Tc K/V_tile[Bc,d] - const float scale = 1.0 / sqrt((float)d); - - - -} \ No newline at end of file diff --git a/kernels/flash-attn/mma/flash_attn_mma_fp8.cu b/kernels/flash-attn/mma/flash_attn_mma_fp8.cu deleted file mode 100644 index e69de29b..00000000 diff --git a/kernels/flash-attn/mma/flash_attn_mma_old.cu b/kernels/flash-attn/mma/flash_attn_mma_naive.cu similarity index 76% rename from kernels/flash-attn/mma/flash_attn_mma_old.cu rename to kernels/flash-attn/mma/flash_attn_mma_naive.cu index f6be038a..3d458ca5 100644 --- a/kernels/flash-attn/mma/flash_attn_mma_old.cu +++ b/kernels/flash-attn/mma/flash_attn_mma_naive.cu @@ -8,28 +8,11 @@ #include #include #include +#include "utils.h" -#define WARP_SIZE 32 -#define INT4(value) (reinterpret_cast(&(value))[0]) -#define FLOAT4(value) (reinterpret_cast(&(value))[0]) -#define HALF2(value) (reinterpret_cast(&(value))[0]) -#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) -#define LDST128BITS(value) (reinterpret_cast(&(value))[0]) - -// Load matrix to REGISTER -#define LDMATRIX_X4(R0, R1, R2, R3, addr) \ - asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" \ - : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) \ - : "r"(addr)) - -// half mma 16x8x16 (only support "ARCH >= SM_80") -#define HMMA16816(RD0, RD1, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1) \ - asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" \ - : "=r"(RD0), "=r"(RD1) \ - : "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1)) template -__global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel( +__global__ void flash_attn_mma_naive_kernel( half* Q, half* K, half* V, const int N, const int Tc, const int Tr, const float scale, half* O) { @@ -39,6 +22,7 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel( // warp and lane Id int warpId = threadIdx.x / 32; int laneId = threadIdx.x % 32; + int tid = threadIdx.x; // Offset into Q, K, V, O - different for each batch and head int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d); // gridDim.y = nh @@ -54,7 +38,7 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel( half reg[32]; for (int i = 0; i < Tr; i++) { - // Read Q from global memory to shared memory + // Read Q from global memory to shared memory [Br,d] for (int x = threadIdx.x * 8; x < tile_size; x += 1024) { int dim_x = x % d; int dim_y = x / d; @@ -85,26 +69,55 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel( uint32_t RB[4]; uint32_t RD[4]; - // Read K from global memory to shared memory + // Read K from global memory to shared memory, 8*128=1024 for (int x = threadIdx.x * 8; x < tile_size; x += 1024) { - int dim_x = x % d; - int dim_y = x / d; - - int new_dim_x = dim_x % 16; - int new_dim_y = (dim_y / 16 * (d / 16) * 16) + (dim_x / 16 * 16) + (dim_y % 16); - - LDST128BITS(Kj[new_dim_y * 16 + new_dim_x]) = LDST128BITS(K[qkv_offset + (j * tile_size) + x]); + int dim_x = x % d; // d=64, 0~63, col + int dim_y = x / d; // x=(0~127=64x2)*8,d=64 or 128 + // shared memory: Br*d=64x64, reshape [256,16], 变换后的row按照16递增 + // 变换后的col,则为0和8,表示两个MMA需要的8x8矩阵,按照K=16, N=16=2x8来布局。 + // 对于一个M16N16K16,当col>15后,属于新的MMA,因此按照K=16在行数递增 + // 满足ldmatrix.x4的加载要求的布局,加载4个8x8,也就是一个16x16的矩阵。 + // [Naive] Load K, g->s, tid: 0, x:0, (row,col):(0,0)->(0,0) + // [Naive] Load K, g->s, tid: 1, x:8, (row,col):(0,8)->(0,8) + // [Naive] Load K, g->s, tid: 8, x:64, (row,col):(1,0)->(1,0) + // [Naive] Load K, g->s, tid: 9, x:72, (row,col):(1,8)->(1,8) + // [Naive] Load K, g->s, tid: 10, x:80, (row,col):(1,16)->(17,0) + // [Naive] Load K, g->s, tid: 0, x:1024, (row,col):(16,0)->(64,0) + // x是8的倍数,因此这里结果为 0,8 + int new_dim_x = dim_x % 16; + int new_dim_y = ((dim_y / 16) * (d / 16) * 16) + (dim_x / 16 * 16) + (dim_y % 16); + LDST128BITS(Kj[new_dim_y * 16 + new_dim_x]) = LDST128BITS( + K[qkv_offset + (j * tile_size) + x]); } __syncthreads(); - // Q @ K^T + // Q[Br,d] @ K^T[Br,d], tile_d=16, matmul K=16 + // 先 loop over K(d) for (int k = 0; k < d / 16; k++) { // Bc x d to Bc / 4 x d (4 is warp size) - uint32_t Qi_lane_addr = __cvta_generic_to_shared(&Qi[(warpId * 16 * d) + (laneId % 16) * 16 + (laneId / 16) * 8 + (k * 16 * 16)]); + uint32_t Qi_lane_addr = __cvta_generic_to_shared( + &Qi[(warpId * 16 * d) + (laneId % 16) * 16 + (laneId / 16) * 8 + (k * 16 * 16)]); LDMATRIX_X4(RA[0], RA[1], RA[2], RA[3], Qi_lane_addr); - + + // tile_Bc=16=2x8, matmul N=8, 256 x 16 + // 先 loop over N(Bc) for (int len = 0; len < Bc; len += 16) { - uint32_t Kj_lane_addr = __cvta_generic_to_shared(&Kj[(len * d) + (laneId % 16) * 16 + (laneId / 16) * 8 + (k * 16 * 16)]); + // (len * d): 每个迭代处理[16,d]的块 + // (laneId % 16) * 16: 16表示col=16,smem是按照col=16布局的,(laneId % 16), + // 0~15则表示K16N16块中的行数. + // (laneId / 16) * 8: (0~1)*8, 0/8, 表示K16N16块中的列数 + // (k * 16 * 16): 表示K16N16的块大小,每次加载16x16大小的块 + // T0|(0, 0),...,(0,7) | T16|(0, 8),...,(0,15) | + // T1|(1, 0),...,(1,7) | T17|(1, 8),...,(1,15) | + // T2|(2, 0),...,(2,7) | T18|(2, 8),...,(2,15) | + // |(., 0),...,(.,7) | |(., 8),...,(.,15) | + // |(7, 0),...,(7,7) | |(7, 8),...,(7,15) | + // |(8, 0),...,(8,7) | |(8, 8),...,(8,15) | + // |(9, 0),...,(9,7) | |(9, 8),...,(9,15) | + // |(., 0),...,(.,7) | |(., 8),...,(.,15) | + // T15|(15,0),...,(15,7)| T31|(15,8),...,(15,15)| + uint32_t Kj_lane_addr = __cvta_generic_to_shared( + &Kj[(len * d) + (laneId % 16) * 16 + (laneId / 16) * 8 + (k * 16 * 16)]); // be careful "not 0 1 2 3" LDMATRIX_X4(RB[0], RB[2], RB[1], RB[3], Kj_lane_addr); @@ -120,9 +133,10 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel( RB[2], RB[3], RC[(len / 16) * 2 + 1][0], RC[(len / 16) * 2 + 1][1]); } - } + } // end for loop over d. + __syncthreads(); - + // Read V from global memory to shared memory for (int x = threadIdx.x * 8; x < tile_size; x += 1024) { LDST128BITS(reg[0]) = LDST128BITS(V[qkv_offset + (j * tile_size) + x]); @@ -147,7 +161,7 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel( LDST128BITS(reg[8]) = LDST128BITS(RC[2][0]); LDST128BITS(reg[16]) = LDST128BITS(RC[4][0]); LDST128BITS(reg[24]) = LDST128BITS(RC[6][0]); - + // thread level reduce max #pragma unroll for (int xi = 0; xi < Bc / 16; xi++) { @@ -171,6 +185,7 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel( thread_max[tc_yi] = max(thread_max[tc_yi], __shfl_xor_sync(0xffffffff, thread_max[tc_yi], s, 4)); } } + // thread level reduce sum #pragma unroll @@ -203,24 +218,32 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel( LDST128BITS(RC[4][0]) = LDST128BITS(reg[16]); LDST128BITS(RC[6][0]) = LDST128BITS(reg[24]); - // P @ V - for (int k = 0; k < d / 16; k++) { + // P[Br=M,Bc=K] @ V[Bc=K,d=N] + // 先 loop over N(d), RD[4] + for (int k = 0; k < d / 16; k++) { // 64/16=4 RD[0] = RD[1] = RD[2] = RD[3] = 0; + // 再 loop over K(Bc) for (int len = 0; len < Bc; len += 16) { uint32_t Vj_lane_addr = __cvta_generic_to_shared(&Vj[(k * 16 * Bc) + (len * 16) + (laneId % 16) * 16 + (laneId / 16) * 8]); LDMATRIX_X4(RB[0], RB[2], RB[1], RB[3], Vj_lane_addr); + // RC[8][2] {0,1|2,3|4,5|6,7}[0|1] + // len = 0, RC[0][0] [0][1] [1][0] [1][1] + // len = 16, RC[2][0] [2][1] [3][0] [3][1] + // len = 32, RC[4][0] [4][1] [5][0] [5][1] + // len = 48, RC[6][0] [6][1] [7][0] [7][1] + // RD[4]在K维度累加了4次 HMMA16816(RD[0], RD[1], RC[len / 16 * 2 + 0][0], RC[len / 16 * 2 + 0][1], RC[len / 16 * 2 + 1][0], RC[len / 16 * 2 + 1][1], RB[0], RB[1], RD[0], RD[1]); - + // RC[0][0] [0][1] [1][0] [1][1] HMMA16816(RD[2], RD[3], RC[len / 16 * 2 + 0][0], RC[len / 16 * 2 + 0][1], RC[len / 16 * 2 + 1][0], RC[len / 16 * 2 + 1][1], RB[2], RB[3], RD[2], RD[3]); - } - + } // end for Bc + LDST128BITS(reg[0]) = LDST128BITS(RD[0]); #pragma unroll for(int tc_yi = 0; tc_yi < 2; tc_yi++) { @@ -243,7 +266,7 @@ __global__ void flash_attn_2_fwd_f16_mma_m16n8k16_kernel( exp_max * __half2float(reg[tc_xi * 4 + tc_yi * 2 + 1])); } } - } + } // end for d // update m, l for(int tc_yi = 0; tc_yi < 2; tc_yi++) { @@ -292,7 +315,7 @@ if (((T2).size(0) != (T1).size(0)) || \ throw std::runtime_error("Tensor size mismatch!"); \ } -void flash_attn_2_fwd_f16_mma_m16n8k16( +void flash_attn_mma_naive( torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O) { // TODO: determine Bc, Br dynamically CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) @@ -322,10 +345,8 @@ void flash_attn_2_fwd_f16_mma_m16n8k16( dim3 grid(B, nh); // batch_size x num_heads dim3 block(128); // 4 Warps per block - // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (d == 64) { - flash_attn_2_fwd_f16_mma_m16n8k16_kernel<<< + flash_attn_mma_naive_kernel<<< grid, block, sram_size>>>( reinterpret_cast(Q.data_ptr()), reinterpret_cast(K.data_ptr()), @@ -335,7 +356,7 @@ void flash_attn_2_fwd_f16_mma_m16n8k16( ); } if (d == 128) { - flash_attn_2_fwd_f16_mma_m16n8k16_kernel<<< + flash_attn_mma_naive_kernel<<< grid, block, sram_size>>>( reinterpret_cast(Q.data_ptr()), reinterpret_cast(K.data_ptr()), diff --git a/kernels/flash-attn/mma/flash_attn_mma_stage.cu b/kernels/flash-attn/mma/flash_attn_mma_stage.cu new file mode 100644 index 00000000..17f2936d --- /dev/null +++ b/kernels/flash-attn/mma/flash_attn_mma_stage.cu @@ -0,0 +1,895 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +using namespace nvcuda; +#include "utils.h" + +// Write FlashAttention-2 from scratch using Tensor Cores with MMA PTX instruction. +// The input is Q,K,V, 4D tensor with shape [batch_size, num_heads, seq_len, head_dim]. +// The output is O, a 4D tensor with shape [batch_size, num_heads, seq_len, head_dim]. + +// The FlashAttention-2 algorithm is described in the following paper: +// https://arxiv.org/abs/2110.08210 + +// Q,K,V,O: [batch_size, num_heads, seq_len, head_dim], [B,H,N,d] +// each block processes Q_tile with shape [Br,d] and full K,V with shape [N,d] +// Currently, we only support Br = Bc = 64. +template< + const int kHeadDim, // Headdim, 32,64,128 + const int kMmaAtomM, // MMA Atom M, 16 + const int kMmaAtomN, // MMA Atom N, 8 + const int kMmaAtomK, // MMA Atom K, 16 + const int kMmaTileSeqLenQ, // 2, more MMA(warp), M=16*2=32, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] + const int kMmaTileSeqLenK, // 4, more MMA(warp), N=8*4= 32, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] + const int kMmaTileSeqLenP, // 2, more MMA(warp), M=16*2=32, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] + const int kMmaTileHeadDimV, // 4, more MMA(warp), N=8*4= 32, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] + const int kWarpTileSeqLenQ, // 2, more values, M, Br=32*2=64, matmul M + const int kWarpTileSeqLenK, // 2, more values, N, Bc=32*2=64, matmul N + const int kWarpTileSeqLenP, // 2, more values, M, Br=32*2=64, matmul M + const int kWarpTileHeadDimV, // 2, more values, N, d=32*(1|2|3|4|...)=32|64|96|128|... + const int kStage, + const int kPad + > +__global__ void __launch_bounds__( + WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK) +flash_attn_mma_stages_kernel(half* Q, + half* K, + half* V, + half* O, + int QKV_seqlen) { + // Matmul Layout: Q[Br,d]@K^T[d,Bc] NN, P[Br,Bc]@V[Bc,d] NN, all row major. + static_assert(kMmaAtomM == 16 && kMmaAtomN == 8 && kMmaAtomK == 16); // m16n8k16 + static_assert(kMmaTileSeqLenQ == 2 && kMmaTileSeqLenK == 4); // Q@K^T + static_assert(kMmaTileSeqLenP == 2 && kMmaTileHeadDimV == 4); // P@V + static_assert(kWarpTileSeqLenQ == 2 && kWarpTileSeqLenK == 2); // Q@K^T + // e.g, kWarpTileHeadDimV: 1->d 32, 2->d 64, 3->d 96, 4-> d 128, ..., etc. + static_assert(kWarpTileSeqLenP == 2 && kWarpTileHeadDimV == ( + kHeadDim / (kMmaAtomN * kMmaTileHeadDimV))); // P@V + static_assert(kStage > 0 && kStage < 3); // 1,2 + static_assert(kPad >= 0 && kPad % 8 == 0); // 0,8,16 + constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*2*2=64 + constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*4*2=64 + constexpr int KNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*2*4=256, num threads + // Now, N must be mutliples of Bc(32/64) for KV tiling across seqlen. + const int Tc = div_ceil(QKV_seqlen, Bc); // Tc K^T_tile[d,Bc] + const float scale = 1.0f / sqrt((float) kHeadDim); + + // Launch: grid(batch, head_num, N/Br=Tr), block(256=8*mma or 128=4*mma) + const int QKV_batch_id = blockIdx.x; // Batch size, bx + const int QKV_head_id = blockIdx.y; // Head num, by + const int Q_tile_id = blockIdx.z; // Q tile_id, range [0, Tr), bz. + const int O_tile_id = Q_tile_id; // O tile_id, same as Q. + const int tid = threadIdx.x; // within block + const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block + const int lane_id = tid % WARP_SIZE; // 0~31 + const int warp_QP = warp_id % 2; // 0,1 + const int warp_KV = warp_id / 2; // 0,1,2,3 + // The layout of 8 MMA(2x4) [before] kWarpTileSeqLenQxkWarpTileSeqLenK(2x2) -> 16x2,8x4=32x32: + // | [32,32] | warp_KV 0 | warp_KV 1 | warp_KV 2 | warp_KV 3 | + // | warp_QP 0 |-- MMA 0 --|-- MMA 2 --|-- MMA 4 --|-- MMA 6 --| + // | warp_QP 1 |-- MMA 1 --|-- MMA 3 --|-- MMA 5 --|-- MMA 7 --| + // The layout of 8 MMA(2x4) [after] kWarpTileSeqLenQxkWarpTileSeqLenK(2x2) -> 32x2,32x2=64x64: + // | [64,64] | warp_KV 0 | warp_KV 1 | warp_KV 2 | warp_KV 3 | + // | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --| + // | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --| + // | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --| + // | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --| + // gridDim.y = head_num, gridDim.z = N/Br = Tr. + const int Q_gmem_offset = ((QKV_batch_id * gridDim.y * QKV_seqlen * kHeadDim) + + (QKV_head_id * QKV_seqlen * kHeadDim)); // Q [seqlen,d] + const int K_gmem_offset = ((QKV_batch_id * gridDim.y * kHeadDim * QKV_seqlen) + + (QKV_head_id * kHeadDim * QKV_seqlen)); // transpose K, [d,seqlen] + const int V_gmem_offset = Q_gmem_offset; // V [seqlen,d] + const int O_gmem_offset = Q_gmem_offset; // O [seqlen,d] + + // Mapping Q gmem -> tid -> smem, Q[Br,d]=[64,64 or 128], 256 threads. + int load_smem_Q_Br = (tid / (KNumThreads / Br)); // Br 64, tid / 4, row 0~64 + int load_smem_Q_d = (tid % (KNumThreads / Br)) * (kHeadDim / (KNumThreads / Br)); // (tid % 4) * 16, 0,16,32,48 + // Mapping K gmem -> tid -> smem, K^T[d,Bc]=[64 or 128,64], 256 threads. + int load_smem_K_d = (tid / (KNumThreads / kHeadDim)); // d 64, tid / 4, row 0~64 + int load_smem_K_Bc = (tid % (KNumThreads / kHeadDim)) * (Bc / (KNumThreads / kHeadDim)); // (tid % 4) * 16, 0,16,32,48 + // Mapping V gmem -> tid -> smem, V[Bc,d]=[64,64 or 128], 256 threads. + int load_smem_V_Bc = (tid / (KNumThreads / Bc)); // Bc 64, tid / 4, row 0~64 + int load_smem_V_d = (tid % (KNumThreads / Bc)) * (kHeadDim / (KNumThreads / Bc)); // (tid % 4) * 16, 0,16,32,48 + // global Q row of current head for tile [Br,d] per block. + int load_gmem_Q_Br = Q_tile_id * Br + load_smem_Q_Br; + if (load_gmem_Q_Br >= QKV_seqlen) return; + // KV tile gmem load index starts from 0 and increments with + // each iteration as we loop over seqlen. + int load_gmem_K_Bc_offset = 0; + int load_gmem_V_Bc_offset = 0; + + // Shared memory for Q,K,V,O, d=64->24M, d=128=48M + extern __shared__ half smem[]; + constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M + constexpr int K_tile_size = kHeadDim * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M, KV may shared 8M + constexpr int V_tile_size = Bc * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M, KV may shared 8M + constexpr int S_tile_size = Br * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M, KV may shared 8M + // K multi-stages: currently, only apply multi stages for K across seq_len. + half* Q_tile_smem = smem; // 8M/16M + half* K_tile_smem = Q_tile_smem + Q_tile_size; // 8M/16M + half* V_tile_smem = K_tile_smem + kStage * K_tile_size; + half* S_tile_smem = V_tile_smem + V_tile_size; // for temp S=Q@K^T + // TODO: KV may shared same smem to reduce smem usage for headdim>=256 + // half* V_tile_smem = K_tile_smem; // KV may shared same smem 8M/16M + // stage 2, no shared KV smem, Br=Bc=64, d=64: 8M+(8M)*2+8M =32M, shared KV smem: 24M + // stage 2, no shared KV smem, Br=Bc=64, d=128: 16M+(16M)*2+16M=64M, shared KV smem: 48M + // stage 2, no shared KV smem, Br=Bc=64, d=256: 32M+(32M)*2+32M=128M, shared KV smem: 96M + // stage 1, no shared KV smem, Br=Bc=64, d=256: 32M+(32M)*1+32M=96M, shared KV smem: 64M + + uint32_t smem_Q_base_ptr = __cvta_generic_to_shared(Q_tile_smem); + uint32_t smem_K_base_ptr = __cvta_generic_to_shared(K_tile_smem); + uint32_t smem_V_base_ptr = __cvta_generic_to_shared(V_tile_smem); + uint32_t smem_S_base_ptr = __cvta_generic_to_shared(S_tile_smem); + + // --------------------- Registers/SMEM for thread block ------------------------- + // block m_old, l_old, store in lane, use float to keep precision. + float lane_block_row_max_old[kWarpTileSeqLenQ][2]; + float lane_block_row_sum_old[kWarpTileSeqLenQ][2]; + fill_2D_regs(lane_block_row_max_old, -INFINITY); + fill_2D_regs(lane_block_row_sum_old, 0.0f); + // m[Br], l[Br], for the output of P[Br,Bc]=Q[Br,d]@K^T[d,Bc], + // 64x(4)x4=1024 bytes, 1M+1M=2M. TODO: 64x4=256, may use each + // thread to store a max/sum value instead of using shared memory + // and mapping based on thread ID and row number in Br. + __shared__ float block_row_max_new_smem[Br][kMmaTileSeqLenK]; + __shared__ float block_row_sum_new_smem[Br][kMmaTileSeqLenK]; + + // ---------------------- Registers for S=Q@K^T/O=P@V ---------------------------- + // registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc] and O=P[Br,Bc]@V[Bc,d]=[Br,d]. + uint32_t R_Q[kWarpTileSeqLenQ][ 4]; + uint32_t R_K[kWarpTileSeqLenK][ 2]; + uint32_t R_V[kWarpTileHeadDimV][2]; + // registers for current tile_K_seqlen within, [64,64] = S_tile[Br,Bc] + // = Q_tile[Br,d] * K[Bc,d], each thread hold 2x32 bits regs. + uint32_t R_S[kWarpTileSeqLenQ][kWarpTileSeqLenK][ 2]; // [2][2][2] + // registers for tile_K_seqlen O=PV[Br,d]=P@V, [2][2/4][2], 8 or 16 regs. + // TODO: may reuse R_D as R_O? kWarpTileSeqLenP=kWarpTileSeqLenQ. + uint32_t R_O[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [2][2/4][2] + // registers final Output [D]=final rescale(R_O), [2][2/4][2], 8 or 16 regs. + uint32_t R_D[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [2][2/4][2] + fill_3D_regs(R_S, 0); + fill_3D_regs(R_D, 0); + fill_3D_regs(R_O, 0); + + // load Q from gmem -> smem, only load once. + { + int load_gmem_Q_d = load_smem_Q_d; + int load_gmem_Q_addr = (Q_gmem_offset + load_gmem_Q_Br * kHeadDim + load_gmem_Q_d); + uint32_t load_smem_Q_ptr = (smem_Q_base_ptr + ( + load_smem_Q_Br * (kHeadDim + kPad) + load_smem_Q_d) * sizeof(half)); + #pragma unroll + for (int i = 0; i < (kHeadDim / (KNumThreads / Br)); i += 8) { + CP_ASYNC_CG(load_smem_Q_ptr + i * 2, &Q[load_gmem_Q_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } + + // load K from gmem -> smem, (kStage - 1) K^T tiles, [d,Bc] + if constexpr (kStage > 1) { + #pragma unroll + for (int stage = 0; stage < (kStage - 1); ++stage) { + // update the offset of n according to stages + load_gmem_K_Bc_offset = stage * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_K_d = load_smem_K_d; // K^T [d,Bc] from [d,seqlen] + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen + int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + uint32_t load_smem_K_ptr = ( + smem_K_base_ptr + (stage * K_tile_size + + load_smem_K_d * (Bc + kPad) + + load_smem_K_Bc) * sizeof(half)); + #pragma unroll + for (int i = 0; i < (Bc / (KNumThreads / kHeadDim)); i += 8) { + CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } + } + + // wait Q and at least (kStage - 1) for K ready. + if constexpr (kStage > 1) { + CP_ASYNC_WAIT_GROUP(kStage - 2); // s2->0, s3->1, s4->2 + __syncthreads(); + } + + // : for K^T[d,seqlen] with K^T_tile[d,Bc] + // tile_K_seqlen: compute S_tile[Br,Bc] = Q@K^T = Q_tile[Br,d] * K^T[d,Bc] + #pragma unroll 1 + for (int tile_K_seqlen = 0; tile_K_seqlen < Tc; ++tile_K_seqlen) { + // TODO: process last tile_K_seqlen ? pad to multiple of 8. + // s2 tn 0->0, 1->1, 2->0; s3 tn 0->0, 1->1, 2->2, 3->0; + int smem_sel = (tile_K_seqlen) % kStage; + // s2 tn 0->1, 1->0, 2->1; s3 tn 0->2, 1->0, 2->1, 3->2; + int smem_sel_next = (tile_K_seqlen + (kStage - 1)) % kStage; + + // multi stages pipeling gmem -> smem + // NOTE: kStage must be > 1 for pipeling. For s1, smem_sel + // and smem_sel_next will always equal 0, thus, we can not + // prefetch KV from gmem to smem before tile_K_seqlen MMA done. + + if constexpr (kStage > 1) { + // First, prefetch curr V tile_K_seqlen [Bc,d] (no stages) + { + load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc; + int load_gmem_V_d = load_smem_V_d; + int load_gmem_V_addr = ( + V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d); + uint32_t load_smem_V_ptr = ( + smem_V_base_ptr + (load_smem_V_Bc * (kHeadDim + kPad) + + load_smem_V_d) * sizeof(half) + ); + #pragma unroll + for (int i = 0; i < (kHeadDim / (KNumThreads / Bc)); i += 8) { + CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } + + // Then, prefetch next stage K (tile_K_seqlen + 1) [d,Bc] + if ((tile_K_seqlen + 1) < Tc) { + load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; + int load_gmem_K_addr = ( + K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + uint32_t load_smem_K_ptr = ( + smem_K_base_ptr + (smem_sel_next * K_tile_size + + load_smem_K_d * (Bc + kPad) + + load_smem_K_Bc) * sizeof(half) + ); + #pragma unroll + for (int i = 0; i < (Bc / (KNumThreads / kHeadDim)); i += 8) { + CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } else { + // wait all memory issues ready for last tile. (may not need) + CP_ASYNC_WAIT_GROUP(0); + __syncthreads(); + } + } else { + // If no stages, kStage = 1, we have to load current K tile + // from gmem to smem and have to wait it ready for Q@K^T MMA. + + // First, prefetch curr K tile_K_seqlen [d,Bc] (no stages) + { + load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen + int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + uint32_t load_smem_K_ptr = ( + smem_K_base_ptr + (smem_sel * K_tile_size + + load_smem_K_d * (Bc + kPad) + + load_smem_K_Bc) * sizeof(half)); + #pragma unroll + for (int i = 0; i < (Bc / (KNumThreads / kHeadDim)); i += 8) { + CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } + + // Then, prefetch curr K tile_K_seqlen [d,Bc] (no stages) + { + load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc; + int load_gmem_V_d = load_smem_V_d; + int load_gmem_V_addr = ( + V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d); + uint32_t load_smem_V_ptr = ( + smem_V_base_ptr + (load_smem_V_Bc * (kHeadDim + kPad) + + load_smem_V_d) * sizeof(half) + ); + #pragma unroll + for (int i = 0; i < (kHeadDim / (KNumThreads / Bc)); i += 8) { + CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } + + // Wait K tile ready and let V tile copy async. + CP_ASYNC_WAIT_GROUP(1); + __syncthreads(); + } + + // : tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc] + // Matmul with NN layout, Q row major, K row major. + // S_tile[Br,Bc]=Q_tile[Br,d]@K[d,Bc] + fill_3D_regs(R_S, 0); + #pragma unroll + for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) { + // smem -> reg, load m16k16 smem Q, offset d according tile_K_d. + // ldmatrix.x4 for Q_tile_smem. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K] + int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM; + int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16; // 0~15 + int lane_smem_Q_d = tile_K_d * kMmaAtomK + (lane_id / 16) * 8; // 0,8 + uint32_t lane_smem_Q_ptr = ( + smem_Q_base_ptr + (lane_smem_Q_Br * (kHeadDim + kPad) + + lane_smem_Q_d) * sizeof(half) + ); + LDMATRIX_X4(R_Q[i][0], R_Q[i][1], R_Q[i][2], R_Q[i][3], + lane_smem_Q_ptr); // now, R_Q + } + + // smem -> reg, load k16n8 from smem K, offset d according tile_K_d. + // ldmatrix.x2.trans for K_tile_smem, [kMmaAtomK,Bc] from [d,Bc]=[K,N] + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN; // (N) + int lane_smem_K_d = tile_K_d * kMmaAtomK + lane_id % 16; // 0~15 (K); + int lane_smem_K_Bc = warp_smem_K_Bc; // 0(N) + uint32_t lane_smem_K_ptr = ( + smem_K_base_ptr + (smem_sel * K_tile_size + + lane_smem_K_d * (Bc + kPad) + + lane_smem_K_Bc) * sizeof(half) + ); + LDMATRIX_X2_T(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K + } // end for kWarpTileSeqLenK + + // MMA compute + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + HMMA16816(R_S[i][j][0], R_S[i][j][1], + R_Q[i][0], R_Q[i][1], R_Q[i][2], R_Q[i][3], + R_K[j][0], R_K[j][1], + R_S[i][j][0], R_S[i][j][1]); + } + } + } // end loop over d, S=Q@K^T + __syncthreads(); + + // TODO: May reuse K smem for V, for example, stages 2, stage + // 0 K smem can be reuse as V smem 0 because we do not need + // K values on stage 0 K smem anymore. + + // Now, we got a computed tile of S[Br,N], tile with shape [Br,Bc]. + // Assume [Br, Bc] = [64, 64] = 64x64 = 4096 values. Each thread holds + // a portion of this [Br, Bc] block, specifically, R_S = R_S[2][2][2]. + // This means that each Warp (MMA) repeats 2 times in the N direction + // for both Q and K, resulting in 2x2 = 4 sets of MMA results. Each set + // of results is stored in 2 32-bit registers, with each register holding + // 2 half-precision values. In other words, each thread stores (4x2)x2 = 16 + // half-precision values. With a total of 256 threads, the total number of + // half-precision values is 256x16 = 4096, which exactly matches the total + // [Br, Bc] = [64, 64] values. + + // The layout of 8 MMA m16n8k16 (2x4) [after] kWarpTileQPxkWarpTileKV(2x2) -> 32x2,32x2=64x64: + // | [64,64] | warp_KV 0 | warp_KV 1 | warp_KV 2 | warp_KV 3 | + // | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --| row max + // | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --| row max + // | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 3 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --| row max + // | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 3 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --| row max + + // WIP: online safe softmax, warp/block reduce max/sum, row wise + // warp 0/2/4/6, [0][2] row 0~15, col 0/8/16/32, max, [1][2] row 16~31, col 0/8/16/32, max + // warp 1/3/5/7, [0][2] row 32~47, col 0/8/16/32, max, [1][2] row 48~61, col 0/8/16/32, max + float lane_row_max_new[kWarpTileSeqLenQ][2]; + float lane_row_sum_new[kWarpTileSeqLenQ][2]; + fill_2D_regs(lane_row_max_new, -INFINITY); + fill_2D_regs(lane_row_sum_new, 0.0f); + + // Row max for [Br,Bc] tile, Thread -> Warp -> Block. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { + // Thread level reduce max across kWarpTileSeqLenK dim, namely Bc. + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + // reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type + // The layout of the fragments held by different threads for C. (m16n8k16) + // Row\Col 0 1 2 3 4 5 6 7 + // 0 T0: {c0, c1} T1: {c0, c1} T2: {c0, c1} T3: {c0, c1} + // 1 T4: {c0, c1} T5: {c0, c1} T6: {c0, c1} T7: {c0, c1} + // 2 ... + // ... + // 7 T28: {c0, c1} T29: {c0, c1} T30: {c0, c1} T31: {c0, c1} + // 8 T0: {c2, c3} T1: {c2, c3} T2: {c2, c3} T3: {c2, c3} + // 9 T4: {c2, c3} T5: {c2, c3} T6: {c2, c3} T7: {c2, c3} + // 10 ... + // ... + // 15 T28: {c2, c3} T29: {c2, c3} T30: {c2, c3} T31: {c2, c3} + float2 t_reg_S_0 = __half22float2(HALF2(R_S[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_S_1 = __half22float2(HALF2(R_S[i][j][1])); // 8~15 {c2, c3} + // This should be the row max after S = (Q @ K^T) / sqrt(d) + float tmp_max_0 = max(t_reg_S_0.x, t_reg_S_0.y) * scale; + float tmp_max_1 = max(t_reg_S_1.x, t_reg_S_1.y) * scale; + lane_row_max_new[i][0] = max(lane_row_max_new[i][0], tmp_max_0); + lane_row_max_new[i][1] = max(lane_row_max_new[i][1], tmp_max_1); + } // end for kWarpTileSeqLenK + + // Warp level reduce max, warp_size = 4 + // Each thread contains the maximum of 2 rows of Br, + // and only the values of T0, T4, ..., T28 are used. + // Br, row_id = warp_QP<0|1> * 32 + i<0|1> * 16 + 0 * 8 + (lane / 4) <0~7> + lane_row_max_new[i][0] = warp_reduce_max(lane_row_max_new[i][0]); + // Br, row_id = warp_QP<0|1> * 32 + i<0|1> * 16 + 1 * 8 + (lane / 4) <8~15> + lane_row_max_new[i][1] = warp_reduce_max(lane_row_max_new[i][1]); + + if (lane_id % 4 == 0) { // only need T0,T4,...,T28 + block_row_max_new_smem[ // Br, row_id, 0~7, 16~23, 32~39, 48~55 + warp_QP * 32 + i * 16 + 0 * 8 + (lane_id / 4)][warp_KV] = lane_row_max_new[i][0]; + block_row_max_new_smem[ // Br, row_id, 8~15, 24~31, 40~47, 56~63 + warp_QP * 32 + i * 16 + 1 * 8 + (lane_id / 4)][warp_KV] = lane_row_max_new[i][1]; + } + } // end for kWarpTileSeqLenQ + __syncthreads(); + + // Block level reduce max, row wise, 64x4=256 + float wrp_row_max_new = ( + block_row_max_new_smem[tid / kMmaTileSeqLenK][tid % kMmaTileSeqLenK]); // [0~63][0~4] + float blk_row_max_new = warp_reduce_max(wrp_row_max_new); + block_row_max_new_smem[tid / kMmaTileSeqLenK][tid % kMmaTileSeqLenK] = ( + blk_row_max_new); + __syncthreads(); + + // Exp sum and mul scale_factor for [Br,Bc] tile, Thread -> Warp -> Block. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { + // Use latest global row max without update. + // Br 0, row_id, 0~7, 16~23, 32~39, 48~55; + float block_row_max_new_0 = block_row_max_new_smem[ + warp_QP * 32 + i * 16 + 0 * 8 + (lane_id / 4)][0]; + // Br 1, row_id, 8~15, 24~31, 40~47, 56~63; + float block_row_max_new_1 = block_row_max_new_smem[ + warp_QP * 32 + i * 16 + 1 * 8 + (lane_id / 4)][0]; + // m_new = max(m_old, m_new) + block_row_max_new_0 = (tile_K_seqlen > 0 ? + max(lane_block_row_max_old[i][0], block_row_max_new_0) : + block_row_max_new_0); + block_row_max_new_1 = (tile_K_seqlen > 0 ? + max(lane_block_row_max_old[i][1], block_row_max_new_1) : + block_row_max_new_1); + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + float2 t_reg_S_0 = __half22float2(HALF2(R_S[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_S_1 = __half22float2(HALF2(R_S[i][j][1])); // 8~15 {c2, c3} + // P = Exp(S - m_new) + t_reg_S_0.x = __expf(t_reg_S_0.x * scale - block_row_max_new_0); + t_reg_S_0.y = __expf(t_reg_S_0.y * scale - block_row_max_new_0); + t_reg_S_1.x = __expf(t_reg_S_1.x * scale - block_row_max_new_1); + t_reg_S_1.y = __expf(t_reg_S_1.y * scale - block_row_max_new_1); + lane_row_sum_new[i][0] += (t_reg_S_0.x + t_reg_S_0.y); + lane_row_sum_new[i][1] += (t_reg_S_1.x + t_reg_S_1.y); + // Update R_S for P[Br,Bc] = Exp(S-m), point wise. + HALF2(R_S[i][j][0]) = __float22half2_rn(t_reg_S_0); + HALF2(R_S[i][j][1]) = __float22half2_rn(t_reg_S_1); + } // end for kWarpTileSeqLenK + + // Warp level reduce sum, warp_size = 4 + lane_row_sum_new[i][0] = warp_reduce_sum(lane_row_sum_new[i][0]); + lane_row_sum_new[i][1] = warp_reduce_sum(lane_row_sum_new[i][1]); + + if (lane_id % 4 == 0) { // only need T0,T4,...,T28 + block_row_sum_new_smem[ // Br, row_id, 0~7, 16~23, 32~39, 48~55 + warp_QP * 32 + i * 16 + 0 * 8 + (lane_id / 4)][warp_KV] = lane_row_sum_new[i][0]; + block_row_sum_new_smem[ // Br, row_id, 8~15, 24~31, 40~47, 56~63 + warp_QP * 32 + i * 16 + 1 * 8 + (lane_id / 4)][warp_KV] = lane_row_sum_new[i][1]; + } + } // end for kWarpTileSeqLenQ + __syncthreads(); + + // Block level reduce sum, row wise, 64x4=256 + float wrp_row_sum_new = ( + block_row_sum_new_smem[tid / kMmaTileSeqLenK][tid % kMmaTileSeqLenK]); // [0~63][0~4] + float blk_row_sum_new = warp_reduce_sum(wrp_row_sum_new); + block_row_sum_new_smem[tid / kMmaTileSeqLenK][tid % kMmaTileSeqLenK] = ( + blk_row_sum_new); + __syncthreads(); + + // Retile warp for [Br,d], kWarpTileHeadDimV: 1=32/(4*8); 2=64/(4*8); 4=128/(4*8). + // Compute P[Br,Bc] @ V[Bc,d] = [Br,d] = [64, 64/128], partion Attention. + + // If headdim=<32>, then, kWarpTileHeadDimV = 1, the layout of 8 MMA m16n8k16 (2x4) after + // kWarpTileSeqLenPxkWarpTileHeadDimV(2x1) tiling to (32x2,32x1)=(64x32), will look like: + // | [64,32] | warp_KV 0 | warp_KV 1 | warp_KV 2 | warp_KV 3 | + // | warp_QP 0 |-- MMA 0 --|-- MMA 2 --|-- MMA 4 --|-- MMA 6 --| + // | warp_QP 0 |-- MMA 0 --|-- MMA 2 --|-- MMA 4 --|-- MMA 6 --| + // | warp_QP 1 |-- MMA 1 --|-- MMA 3 --|-- MMA 5 --|-- MMA 7 --| + // | warp_QP 1 |-- MMA 1 --|-- MMA 3 --|-- MMA 5 --|-- MMA 7 --| + + // If headdim=<64>, then, kWarpTileHeadDimV = 2, the layout of 8 MMA m16n8k16 (2x4) after + // kWarpTileSeqLenPxkWarpTileHeadDimV(2x2) tiling to (32x2,32x2)=(64x64), will look like: + // | [64,64] | warp_KV 0 | warp_KV 1 | warp_KV 2 | warp_KV 3 | + // | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --| + // | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --| + // | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 3 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --| + // | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 3 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --| + + // If headdim=<128>, then, kWarpTileHeadDimV = 4, the layout of 8 MMA m16n8k16 (2x4) after + // kWarpTileSeqLenPxkWarpTileHeadDimV(2x2x2) tiling to (32x2,32x2x2)=(64x64x2), will look like: + // | [64,64x2] | warp_KV 0 | warp_KV 1 | warp_KV 2 | warp_KV 3 | + // | warp_QP 0 |-- MMA 0,MMA 0,MMA 0,MMA 0 --|-- MMA 2,MMA 2,MMA 2,MMA 2 --|-- MMA 4,MMA 4,MMA 4,MMA 4 --|-- MMA 6,MMA 6,MMA 6,MMA 6 --| + // | warp_QP 0 |-- MMA 0,MMA 0,MMA 0,MMA 0 --|-- MMA 2,MMA 2,MMA 2,MMA 2 --|-- MMA 4,MMA 4,MMA 4,MMA 4 --|-- MMA 6,MMA 6,MMA 6,MMA 6 --| + // | warp_QP 1 |-- MMA 1,MMA 1,MMA 1,MMA 1 --|-- MMA 3,MMA 3,MMA 3,MMA 3 --|-- MMA 5,MMA 5,MMA 5,MMA 5 --|-- MMA 7,MMA 7,MMA 7,MMA 7 --| + // | warp_QP 1 |-- MMA 1,MMA 1,MMA 1,MMA 1 --|-- MMA 3,MMA 3,MMA 3,MMA 3 --|-- MMA 5,MMA 5,MMA 5,MMA 5 --|-- MMA 7,MMA 7,MMA 7,MMA 7 --| + + // Write R_P(R_S) to P_smem [Br,Bc] + // store S(for DEBUG now) [Br,Bc] of [seqlen,seqlen] [64,64] + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + R_Q[0][0] = R_S[i][j][0]; R_Q[1][0] = R_S[i][j][1]; // warp_size 4 + R_Q[0][1] = __shfl_sync((0xffffffff), R_S[i][j][0], lane_id + 1, 4); + R_Q[0][2] = __shfl_sync((0xffffffff), R_S[i][j][0], lane_id + 2, 4); + R_Q[0][3] = __shfl_sync((0xffffffff), R_S[i][j][0], lane_id + 3, 4); + R_Q[1][1] = __shfl_sync((0xffffffff), R_S[i][j][1], lane_id + 1, 4); + R_Q[1][2] = __shfl_sync((0xffffffff), R_S[i][j][1], lane_id + 2, 4); + R_Q[1][3] = __shfl_sync((0xffffffff), R_S[i][j][1], lane_id + 3, 4); + + // st.global.v4 128 bits. + if (lane_id % 4 == 0) { + // (0/1)*32 + (0/1)*16=(0,16,32,48), + 0~7 -> 0~56 + int store_warp_regs_S_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM; + int store_lane_smem_S_Br = store_warp_regs_S_Br + lane_id / 4; // 0~7 + // (0~3)*16 + (0/1)*8=(0,8,16,24,...,48,56) + int store_warp_regs_S_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN; + int store_lane_smem_S_Bc = store_warp_regs_S_Bc; // (0~3)*16+(0/8) + int store_smem_S_addr_0 = ((store_lane_smem_S_Br + 0) * (Bc + kPad) + store_lane_smem_S_Bc); + int store_smem_S_addr_1 = ((store_lane_smem_S_Br + 8) * (Bc + kPad) + store_lane_smem_S_Bc); + LDST128BITS(S_tile_smem[store_smem_S_addr_0]) = LDST128BITS(R_Q[0][0]); + LDST128BITS(S_tile_smem[store_smem_S_addr_1]) = LDST128BITS(R_Q[1][0]); + } + } // end for kWarpTileHeadDimV + } // end for kWarpTileSeqLenQ + __syncthreads(); + + // Compute P[Br,Bc] @ V[Bc,d] = [Br,d] = [64, 64/128], partion Attention. + // Here, we have to wait V ready before compute O = P @ V + if constexpr (kStage == 2) { + // NOTE: For kStage > 1, we have send V mem issues before K + CP_ASYNC_WAIT_GROUP(1); // s1->-1, s2->0, s3->1, s4->2 + } else { + CP_ASYNC_WAIT_GROUP(0); + } + __syncthreads(); + + // : P[Br,Bc]@V[Bc,d]=[Br,d]=[64,64/128], partion Attention. + // Matmul with NN layout: P[Br,Bc] row major, V[Bc,d] row major. + // Make sure to clear the states in R_O before MMA for P@V for each step. + fill_3D_regs(R_O, 0); + #pragma unroll + for (int tile_V_Bc = 0; tile_V_Bc < (Bc / kMmaAtomK); ++tile_V_Bc) { + // smem -> reg, load m16k16 smem Q, offset d according tile_K_d. + // ldmatrix.x4 for S_tile_smem. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenP; ++i) { // S[Br,Bc]=[M,K] + int warp_smem_S_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenP) + i * kMmaAtomM; + int lane_smem_S_Br = warp_smem_S_Br + lane_id % 16; // 0~15 + int lane_smem_S_Bc = tile_V_Bc * kMmaAtomK + (lane_id / 16) * 8; // 0,8 + uint32_t lane_smem_S_ptr = ( + smem_S_base_ptr + (lane_smem_S_Br * (Bc + kPad) + + lane_smem_S_Bc) * sizeof(half) + ); + LDMATRIX_X4(R_Q[i][0], R_Q[i][1], R_Q[i][2], R_Q[i][3], + lane_smem_S_ptr); // now, R_P + } + + // Load k16n8 V from smem -> regs, R_KV, ldmatrix.x2.trans. + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { + int warp_smem_V_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN; // d, matmaul N + int lane_smem_V_Bc = tile_V_Bc * kMmaAtomK + lane_id % 16; // 0~15; Bc, matmul K + int lane_smem_V_d = warp_smem_V_d; // 0 + uint32_t lane_smem_V_ptr = ( + smem_V_base_ptr + (lane_smem_V_Bc * (kHeadDim + kPad) + + lane_smem_V_d) * sizeof(half) + ); + LDMATRIX_X2_T(R_V[j][0], R_V[j][1], lane_smem_V_ptr); // R_V + } + + // NOTE: Values for P[Br,Bc] already in R_S registers, can we use these + // registers for P(A) matrix directly ? How to do that ? + // according to the A matrix layout for MMA m16n8k16 instruction. + // reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type + // The layout of the fragments held by different threads for A matrix with .f16. + // R\C 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + // 0 T0: {a0, a1} T1: {a0, a1} T2: {a0, a1} T3: {a0, a1} T0: {a4, a5} T1: {a4, a5} T2: {a4, a5} T3: {a4, a5} + // 1 T4: {a0, a1} T5: {a0, a1} T6: {a0, a1} T7: {a0, a1} T4: {a4, a5} T5: {a4, a5} T6: {a4, a5} T7: {a4, a5} + // 2 (dashed arrow pointing right) + // ... + // 7 T28: {a0, a1} T29: {a0, a1} T30: {a0, a1} T31: {a0, a1} T28: {a4, a5} T29: {a4, a5} T30: {a4, a5} T31: {a4, a5} + // 8 T0: {a2, a3} T1: {a2, a3} T2: {a2, a3} T3: {a2, a3} T0: {a6, a7} T1: {a6, a7} T2: {a6, a7} T3: {a6, a7} + // 9 T4: {a2, a3} T5: {a2, a3} T6: {a2, a3} T7: {a2, a3} T4: {a6, a7} T5: {a6, a7} T6: {a6, a7} T7: {a6, a7} + // 10 (dashed arrow pointing right) + // ... + // 15 T28: {a2, a3} T29: {a2, a3} T30: {a2, a3} T31: {a2, a3} T28: {a6, a7} T29: {a6, a7} T30: {a6, a7} T31: {a6, a7} + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenP; ++i) { // kWarpTileSeqLenQ=2 + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { // kWarpTileHeadDimV=1,2,3,4,... + HMMA16816(R_O[i][j][0], R_O[i][j][1], + // FIXME(DefTruth): Still have some error while using R_S + // as registers for P(A) matrix directly. I will remove this + // comments when I have fixed it. + // R_S[i][0][0], R_S[i][0][1], R_S[i][1][0], R_S[i][1][1], + R_Q[i][0], R_Q[i][1], R_Q[i][2], R_Q[i][3], + R_V[j][0], R_V[j][1], + R_O[i][j][0], R_O[i][j][1]); + } + } + } // end for V Bc. + __syncthreads(); + + // Rescale O -> Update row sum Exp -> then, Update row max. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenP; ++i) { // kWarpTileSeqLenQ=kWarpTileSeqLenP + // m = max(m_old, m_new), l = exp(m_old - m) * l_old + l_new (FA2 paper) + // Br 0, row_id, 0~7, 16~23, 32~39, 48~55; Br 1, row_id, 8~15, 24~31, 40~47, 56~63 + float block_row_max_new_0 = block_row_max_new_smem[ + warp_QP * 32 + i * 16 + 0 * 8 + (lane_id / 4)][0]; + float block_row_max_new_1 = block_row_max_new_smem[ + warp_QP * 32 + i * 16 + 1 * 8 + (lane_id / 4)][0]; + float block_row_sum_new_0 = block_row_sum_new_smem[ + warp_QP * 32 + i * 16 + 0 * 8 + (lane_id / 4)][0]; + float block_row_sum_new_1 = block_row_sum_new_smem[ + warp_QP * 32 + i * 16 + 1 * 8 + (lane_id / 4)][0]; + float block_row_max_old_0 = lane_block_row_max_old[i][0]; + float block_row_max_old_1 = lane_block_row_max_old[i][1]; + block_row_max_new_0 = (tile_K_seqlen > 0 ? max(block_row_max_old_0, + block_row_max_new_0) : + block_row_max_new_0); + block_row_max_new_1 = (tile_K_seqlen > 0 ? max(block_row_max_old_1, + block_row_max_new_1) : + block_row_max_new_1); + block_row_max_old_0 = (tile_K_seqlen > 0 ? block_row_max_old_0 : + block_row_max_new_0); + block_row_max_old_1 = (tile_K_seqlen > 0 ? block_row_max_old_1 : + block_row_max_new_1); + + // rescale factor for O and l, exp(m_old - m) + float rescale_o_factor_0 = __expf(block_row_max_old_0 - block_row_max_new_0); + float rescale_o_factor_1 = __expf(block_row_max_old_1 - block_row_max_new_1); + // 0. Rescale O: Online rescaling O each tile_K_seqlen step, need m_new, m_old. + // m = max(m_old, m_new), O_new[Br,d] = exp(m_old - m) * O_old + P@V + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { + float2 t_reg_O_0 = __half22float2(HALF2(R_O[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_O_1 = __half22float2(HALF2(R_O[i][j][1])); // 8~15 {c2, c3} + float2 t_reg_D_0 = __half22float2(HALF2(R_D[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_D_1 = __half22float2(HALF2(R_D[i][j][1])); // 8~15 {c2, c3} + // Note that the formula in the FA2 paper is incorrect; here, + // the inverse of the exp function should not be taken, as it + // would result in an error during rescaling, namely, you have + // use exp(m_old - m_new), not 1/(m_old - m_new). + // O_new[Br,d] = exp(m_old - m_new) * O_old + P@V + t_reg_D_0.x = rescale_o_factor_0 * t_reg_D_0.x + t_reg_O_0.x; + t_reg_D_0.y = rescale_o_factor_0 * t_reg_D_0.y + t_reg_O_0.y; + t_reg_D_1.x = rescale_o_factor_1 * t_reg_D_1.x + t_reg_O_1.x; + t_reg_D_1.y = rescale_o_factor_1 * t_reg_D_1.y + t_reg_O_1.y; + HALF2(R_D[i][j][0]) = __float22half2_rn(t_reg_D_0); + HALF2(R_D[i][j][1]) = __float22half2_rn(t_reg_D_1); + } // end for kWarpTileHeadDimV. + + // Now, we can update m, l after O has been scaled. + // 1. First, update block row sum Exp for each lane which + // need both m_new and m_old. + float block_row_sum_old_0 = lane_block_row_sum_old[i][0]; + float block_row_sum_old_1 = lane_block_row_sum_old[i][1]; + // Update l = exp(m_old - m_new) * l_old + row_sum(P). + lane_block_row_sum_old[i][0] = ( + rescale_o_factor_0 * block_row_sum_old_0 + block_row_sum_new_0); + lane_block_row_sum_old[i][1] = ( + rescale_o_factor_1 * block_row_sum_old_1 + block_row_sum_new_1); + // 2. Then, update block row max for each lane. + lane_block_row_max_old[i][0] = block_row_max_new_0; + lane_block_row_max_old[i][1] = block_row_max_new_1; + } + + // NOTE: After compute P @ V, we have to wait next K tile ready in smem. + // do not need to wait any things if kStage == 1. + if constexpr (kStage == 2) { + CP_ASYNC_WAIT_GROUP(0); + __syncthreads(); + } + + } // end loop over N + __syncthreads(); + + // Finaly, we still have to rescale O once more. + // O_output(D) = ( 1/l_final ) * O_final (FA2 paper) + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenP; ++i) { + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { + float2 t_reg_D_0 = __half22float2(HALF2(R_D[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_D_1 = __half22float2(HALF2(R_D[i][j][1])); // 8~15 {c2, c3} + float rescale_factor_0 = __frcp_rn(lane_block_row_sum_old[i][0]); + float rescale_factor_1 = __frcp_rn(lane_block_row_sum_old[i][1]); + t_reg_D_0.x = rescale_factor_0 * t_reg_D_0.x; + t_reg_D_0.y = rescale_factor_0 * t_reg_D_0.y; + t_reg_D_1.x = rescale_factor_1 * t_reg_D_1.x; + t_reg_D_1.y = rescale_factor_1 * t_reg_D_1.y; + HALF2(R_D[i][j][0]) = __float22half2_rn(t_reg_D_0); + HALF2(R_D[i][j][1]) = __float22half2_rn(t_reg_D_1); + } + } + + // Store O(D): Write O[Br,d] from regs -> gmem, collective store + // with reg reuse & warp shuffle. need R[2][4], may reuse + // R_Q[kWarpTileSeqLenQ][4]=[2][4]. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenP; ++i) { + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { + R_Q[0][0] = R_D[i][j][0]; R_Q[1][0] = R_D[i][j][1]; // warp_size 4 + R_Q[0][1] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 1, 4); + R_Q[0][2] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 2, 4); + R_Q[0][3] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 3, 4); + R_Q[1][1] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 1, 4); + R_Q[1][2] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 2, 4); + R_Q[1][3] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 3, 4); + + // st.global.v4 128 bits. [Br,d] + if (lane_id % 4 == 0) { + // (0/1)*32 + (0/1)*16=(0,16,32,48), + 0~7 -> 0~56 + int store_warp_regs_O_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenP ) + i * kMmaAtomM; + int store_lane_gmem_O_Br = O_tile_id * Br + store_warp_regs_O_Br + lane_id / 4; // 0~7 + // (0~3)*16 + (0/1)*8=(0,8,16,24,...,48,56) + int store_warp_regs_O_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN; + int store_lane_gmem_O_d = store_warp_regs_O_d; // (0~3)*16+(0/8) + int store_gmem_O_addr_0 = (O_gmem_offset + (store_lane_gmem_O_Br + 0) * kHeadDim + + store_lane_gmem_O_d); + int store_gmem_O_addr_1 = (O_gmem_offset + (store_lane_gmem_O_Br + 8) * kHeadDim + + store_lane_gmem_O_d); + LDST128BITS(O[store_gmem_O_addr_0]) = LDST128BITS(R_Q[0][0]); + LDST128BITS(O[store_gmem_O_addr_1]) = LDST128BITS(R_Q[1][0]); + } + } // end for kWarpTileHeadDimV + } // end for kWarpTileSeqLenQ +} + +// --------------------- PyTorch bindings for custom kernel ----------------------- +#include +#include +#define STRINGFY(str) #str +#define TORCH_BINDING_COMMON_EXTENSION(func) \ + m.def(STRINGFY(func), &func, STRINGFY(func)); + +#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ +if(((T).options().dtype() != (th_type))) { \ + std::cout << "Tensor Info:" << (T).options() << std::endl; \ + throw std::runtime_error("values must be "#th_type); \ +} + +#define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \ +if (((T2).size(0) != (T1).size(0)) || \ + ((T2).size(1) != (T1).size(1)) || \ + ((T2).size(2) != (T1).size(2)) || \ + ((T2).size(3) != (T1).size(3))) { \ + throw std::runtime_error("Tensor size mismatch!"); \ +} + +template +void launch_flash_attn_mma_stages( + torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O) { + constexpr int kMmaAtomM = 16; + constexpr int kMmaAtomN = 8; + constexpr int kMmaAtomK = 16; + constexpr int kMmaTileSeqLenQ = 2; + constexpr int kMmaTileSeqLenP = 2; + constexpr int kMmaTileSeqLenK = 4; + constexpr int kMmaTileHeadDimV = 4; + constexpr int kWarpTileSeqLenQ = 2; + constexpr int kWarpTileSeqLenP = 2; + constexpr int kWarpTileSeqLenK = 2; + constexpr int kWarpTileHeadDimV = (kHeadDim / (kMmaAtomN*kMmaTileHeadDimV)); + constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*2*2=64 + constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*4*2=64 + constexpr int kPad = 8; + + // Calculate SRAM size needed per block, Q,K,V,S smem size + const int smem_max_size = ((Br * (kHeadDim + kPad)) + + (kStage * kHeadDim * (Bc + kPad)) + + (Bc * (kHeadDim + kPad)) + + (Br * (Bc + kPad))) * sizeof(half); + + const int QKV_batch = Q.size(0); + const int QKV_head = Q.size(1); + const int QKV_seqlen = Q.size(2); // QKV_seqlen + assert(QKV_seqlen % Bc == 0); // multiple of Bc=64 + + dim3 grid(QKV_batch, QKV_head, div_ceil(QKV_seqlen, Br)); // batch_size x num_heads x Tr(=N/Br) + dim3 block(WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK); // 8 warps per block + + cudaFuncSetAttribute( + flash_attn_mma_stages_kernel< + kHeadDim, + kMmaAtomM, + kMmaAtomN, + kMmaAtomK, + kMmaTileSeqLenQ, + kMmaTileSeqLenK, + kMmaTileSeqLenP, + kMmaTileHeadDimV, + kWarpTileSeqLenQ, + kWarpTileSeqLenK, + kWarpTileSeqLenP, + kWarpTileHeadDimV, + kStage, + kPad + >, + cudaFuncAttributeMaxDynamicSharedMemorySize, + 98304 + ); + + flash_attn_mma_stages_kernel< + kHeadDim, + kMmaAtomM, + kMmaAtomN, + kMmaAtomK, + kMmaTileSeqLenQ, + kMmaTileSeqLenK, + kMmaTileSeqLenP, + kMmaTileHeadDimV, + kWarpTileSeqLenQ, + kWarpTileSeqLenK, + kWarpTileSeqLenP, + kWarpTileHeadDimV, + kStage, + kPad + ><<>>( + reinterpret_cast(Q.data_ptr()), + reinterpret_cast(K.data_ptr()), + reinterpret_cast(V.data_ptr()), + reinterpret_cast(O.data_ptr()), + QKV_seqlen + ); +} + +void flash_attn_mma_stages(torch::Tensor Q, torch::Tensor K, + torch::Tensor V, torch::Tensor O, + int stages) { + CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K^T [B,H,D,N], transposed. + CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D] + const int d = Q.size(3); // B, H, N, d + + if (stages == 2) { + switch (d) + { + case 64: + launch_flash_attn_mma_stages<64, 2>(Q, K, V, O); + break; + case 96: + launch_flash_attn_mma_stages<96, 2>(Q, K, V, O); + break; + case 128: + launch_flash_attn_mma_stages<128, 2>(Q, K, V, O); + break; + default: + throw std::runtime_error("headdim not support!"); + break; + } + } else { + switch (d) + { + case 64: + launch_flash_attn_mma_stages<64, 1>(Q, K, V, O); + break; + case 96: + launch_flash_attn_mma_stages<96, 1>(Q, K, V, O); + break; + case 128: + launch_flash_attn_mma_stages<128, 1>(Q, K, V, O); + break; + default: + throw std::runtime_error("headdim not support!"); + break; + } + } +} diff --git a/kernels/flash-attn/mma/flash_attn_mma_tiling.cu b/kernels/flash-attn/mma/flash_attn_mma_tiling.cu new file mode 100644 index 00000000..72c24a68 --- /dev/null +++ b/kernels/flash-attn/mma/flash_attn_mma_tiling.cu @@ -0,0 +1,901 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +using namespace nvcuda; +#include "utils.h" + +// Write FlashAttention-2 from scratch using Tensor Cores with MMA PTX instruction. +// The input is Q,K,V, 4D tensor with shape [batch_size, num_heads, seq_len, head_dim]. +// The output is O, a 4D tensor with shape [batch_size, num_heads, seq_len, head_dim]. + +// The FlashAttention-2 algorithm is described in the following paper: +// https://arxiv.org/abs/2110.08210 + +// Q,K,V,O: [batch_size, num_heads, seq_len, head_dim], [B,H,N,d] +// each block processes Q_tile with shape [Br,d] and full K,V with shape [N,d] +// Currently, we only support Br = Bc = 64. +template< + const int kHeadDim, // Headdim, 32,64,128 + const int kMmaAtomM, // MMA Atom M, 16 + const int kMmaAtomN, // MMA Atom N, 8 + const int kMmaAtomK, // MMA Atom K, 16 + const int kMmaTileSeqLenQ, // 2, more MMA(warp), M=16*2=32, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] + const int kMmaTileSeqLenK, // 4, more MMA(warp), N=8*4= 32, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] + const int kMmaTileSeqLenP, // 2, more MMA(warp), M=16*2=32, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] + const int kMmaTileHeadDimV, // 4, more MMA(warp), N=8*4= 32, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] + const int kWarpTileSeqLenQ, // 2, more values, M, Br=32*2=64, matmul M + const int kWarpTileSeqLenK, // 2, more values, N, Bc=32*2=64, matmul N + const int kWarpTileSeqLenP, // 2, more values, M, Br=32*2=64, matmul M + const int kWarpTileHeadDimV, // 2, more values, N, d=32*(1|2|3|4|...)=32|64|96|128|... + const int kStage, + const int kPad + > +__global__ void __launch_bounds__( + WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK) +flash_attn_mma_stages_kernel(half* Q, + half* K, + half* V, + half* O, + int QKV_seqlen) { + // Matmul Layout: Q[Br,d]@K^T[d,Bc] NN, P[Br,Bc]@V[Bc,d] NN, all row major. + static_assert(kMmaAtomM == 16 && kMmaAtomN == 8 && kMmaAtomK == 16); // m16n8k16 + static_assert(kMmaTileSeqLenQ == 2 && kMmaTileSeqLenK == 4); // Q@K^T + static_assert(kMmaTileSeqLenP == 2 && kMmaTileHeadDimV == 4); // P@V + static_assert(kWarpTileSeqLenQ == 2 && kWarpTileSeqLenK == 2); // Q@K^T + // e.g, kWarpTileHeadDimV: 1->d 32, 2->d 64, 3->d 96, 4-> d 128, ..., etc. + static_assert(kWarpTileSeqLenP == 2 && kWarpTileHeadDimV == ( + kHeadDim / (kMmaAtomN * kMmaTileHeadDimV))); // P@V + static_assert(kStage > 0 && kStage < 3); // 1,2 + static_assert(kPad >= 0 && kPad % 8 == 0); // 0,8,16 + constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*2*2=64 + constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*4*2=64 + constexpr int KNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*2*4=256, num threads + // Now, N must be mutliples of Bc(32/64) for KV tiling across seqlen. + const int Tc = div_ceil(QKV_seqlen, Bc); // Tc K^T_tile[d,Bc] + const float scale = 1.0f / sqrt((float) kHeadDim); + + // Launch: grid(batch, head_num, N/Br=Tr), block(256=8*mma or 128=4*mma) + const int QKV_batch_id = blockIdx.x; // Batch size, bx + const int QKV_head_id = blockIdx.y; // Head num, by + const int Q_tile_id = blockIdx.z; // Q tile_id, range [0, Tr), bz. + const int O_tile_id = Q_tile_id; // O tile_id, same as Q. + const int tid = threadIdx.x; // within block + const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block + const int lane_id = tid % WARP_SIZE; // 0~31 + const int warp_QP = warp_id % 2; // 0,1 + const int warp_KV = warp_id / 2; // 0,1,2,3 + // The layout of 8 MMA(2x4) [before] kWarpTileSeqLenQxkWarpTileSeqLenK(2x2) -> 16x2,8x4=32x32: + // | [32,32] | warp_KV 0 | warp_KV 1 | warp_KV 2 | warp_KV 3 | + // | warp_QP 0 |-- MMA 0 --|-- MMA 2 --|-- MMA 4 --|-- MMA 6 --| + // | warp_QP 1 |-- MMA 1 --|-- MMA 3 --|-- MMA 5 --|-- MMA 7 --| + // The layout of 8 MMA(2x4) [after] kWarpTileSeqLenQxkWarpTileSeqLenK(2x2) -> 32x2,32x2=64x64: + // | [64,64] | warp_KV 0 | warp_KV 1 | warp_KV 2 | warp_KV 3 | + // | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --| + // | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --| + // | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --| + // | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 2 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --| + // gridDim.y = head_num, gridDim.z = N/Br = Tr. + const int Q_gmem_offset = ((QKV_batch_id * gridDim.y * QKV_seqlen * kHeadDim) + + (QKV_head_id * QKV_seqlen * kHeadDim)); // Q [seqlen,d] + const int K_gmem_offset = ((QKV_batch_id * gridDim.y * kHeadDim * QKV_seqlen) + + (QKV_head_id * kHeadDim * QKV_seqlen)); // transpose K, [d,seqlen] + const int V_gmem_offset = Q_gmem_offset; // V [seqlen,d] + const int O_gmem_offset = Q_gmem_offset; // O [seqlen,d] + + // Mapping Q gmem -> tid -> smem, Q[Br,d]=[64,64 or 128], 256 threads. + int load_smem_Q_Br = (tid / (KNumThreads / Br)); // Br 64, tid / 4, row 0~64 + int load_smem_Q_d = (tid % (KNumThreads / Br)) * (kHeadDim / (KNumThreads / Br)); // (tid % 4) * 16, 0,16,32,48 + // Mapping K gmem -> tid -> smem, K^T[d,Bc]=[64 or 128,64], 256 threads. + int load_smem_K_d = (tid / (KNumThreads / kHeadDim)); // d 64, tid / 4, row 0~64 + int load_smem_K_Bc = (tid % (KNumThreads / kHeadDim)) * (Bc / (KNumThreads / kHeadDim)); // (tid % 4) * 16, 0,16,32,48 + // Mapping V gmem -> tid -> smem, V[Bc,d]=[64,64 or 128], 256 threads. + int load_smem_V_Bc = (tid / (KNumThreads / Bc)); // Bc 64, tid / 4, row 0~64 + int load_smem_V_d = (tid % (KNumThreads / Bc)) * (kHeadDim / (KNumThreads / Bc)); // (tid % 4) * 16, 0,16,32,48 + // global Q row of current head for tile [Br,d] per block. + int load_gmem_Q_Br = Q_tile_id * Br + load_smem_Q_Br; + if (load_gmem_Q_Br >= QKV_seqlen) return; + // KV tile gmem load index starts from 0 and increments with + // each iteration as we loop over seqlen. + int load_gmem_K_Bc_offset = 0; + int load_gmem_V_Bc_offset = 0; + + // Shared memory for Q,K,V,O, d=64->24M, d=128=48M + extern __shared__ half smem[]; + constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M + constexpr int K_tile_size = kHeadDim * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M, KV may shared 8M + constexpr int V_tile_size = Bc * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M, KV may shared 8M + constexpr int S_tile_size = Br * (Bc + kPad); // 64*64=4096, ~8192 bytes=8M, KV may shared 8M + // K multi-stages: currently, only apply multi stages for K across seq_len. + half* Q_tile_smem = smem; // 8M/16M + half* K_tile_smem = Q_tile_smem + Q_tile_size; // 8M/16M + half* V_tile_smem = K_tile_smem + kStage * K_tile_size; + half* S_tile_smem = V_tile_smem + V_tile_size; // for temp S=Q@K^T + // TODO: KV may shared same smem to reduce smem usage for headdim>=256 + // half* V_tile_smem = K_tile_smem; // KV may shared same smem 8M/16M + // stage 2, no shared KV smem, Br=Bc=64, d=64: 8M+(8M)*2+8M =32M, shared KV smem: 24M + // stage 2, no shared KV smem, Br=Bc=64, d=128: 16M+(16M)*2+16M=64M, shared KV smem: 48M + // stage 2, no shared KV smem, Br=Bc=64, d=256: 32M+(32M)*2+32M=128M, shared KV smem: 96M + // stage 1, no shared KV smem, Br=Bc=64, d=256: 32M+(32M)*1+32M=96M, shared KV smem: 64M + + uint32_t smem_Q_base_ptr = __cvta_generic_to_shared(Q_tile_smem); + uint32_t smem_K_base_ptr = __cvta_generic_to_shared(K_tile_smem); + uint32_t smem_V_base_ptr = __cvta_generic_to_shared(V_tile_smem); + uint32_t smem_S_base_ptr = __cvta_generic_to_shared(S_tile_smem); + + // --------------------- Registers/SMEM for thread block ------------------------- + // block m_old, l_old, store in lane, use float to keep precision. + float lane_block_row_max_old[kWarpTileSeqLenQ][2]; + float lane_block_row_sum_old[kWarpTileSeqLenQ][2]; + fill_2D_regs(lane_block_row_max_old, -INFINITY); + fill_2D_regs(lane_block_row_sum_old, 0.0f); + // m[Br], l[Br], for the output of P[Br,Bc]=Q[Br,d]@K^T[d,Bc], + // 64x(4)x4=1024 bytes, 1M+1M=2M. TODO: 64x4=256, may use each + // thread to store a max/sum value instead of using shared memory + // and mapping based on thread ID and row number in Br. + __shared__ float block_row_max_new_smem[Br][kMmaTileSeqLenK]; + __shared__ float block_row_sum_new_smem[Br][kMmaTileSeqLenK]; + + // ---------------------- Registers for S=Q@K^T/O=P@V ---------------------------- + // registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc] and O=P[Br,Bc]@V[Bc,d]=[Br,d]. + uint32_t R_Q[kWarpTileSeqLenQ][ 4]; + uint32_t R_K[kWarpTileSeqLenK][ 2]; + uint32_t R_V[kWarpTileHeadDimV][2]; + // registers for current tile_K_seqlen within, [64,64] = S_tile[Br,Bc] + // = Q_tile[Br,d] * K[Bc,d], each thread hold 2x32 bits regs. + uint32_t R_S[kWarpTileSeqLenQ][kWarpTileSeqLenK][ 2]; // [2][2][2] + // registers for tile_K_seqlen O=PV[Br,d]=P@V, [2][2/4][2], 8 or 16 regs. + // TODO: may reuse R_D as R_O? kWarpTileSeqLenP=kWarpTileSeqLenQ. + uint32_t R_O[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [2][2/4][2] + // registers final Output [D]=final rescale(R_O), [2][2/4][2], 8 or 16 regs. + uint32_t R_D[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [2][2/4][2] + fill_3D_regs(R_S, 0); + fill_3D_regs(R_D, 0); + fill_3D_regs(R_O, 0); + + // load Q from gmem -> smem, only load once. + { + int load_gmem_Q_d = load_smem_Q_d; + int load_gmem_Q_addr = (Q_gmem_offset + load_gmem_Q_Br * kHeadDim + load_gmem_Q_d); + uint32_t load_smem_Q_ptr = (smem_Q_base_ptr + ( + load_smem_Q_Br * (kHeadDim + kPad) + load_smem_Q_d) * sizeof(half)); + #pragma unroll + for (int i = 0; i < (kHeadDim / (KNumThreads / Br)); i += 8) { + CP_ASYNC_CG(load_smem_Q_ptr + i * 2, &Q[load_gmem_Q_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } + + // load K from gmem -> smem, (kStage - 1) K^T tiles, [d,Bc] + if constexpr (kStage > 1) { + #pragma unroll + for (int stage = 0; stage < (kStage - 1); ++stage) { + // update the offset of n according to stages + load_gmem_K_Bc_offset = stage * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_K_d = load_smem_K_d; // K^T [d,Bc] from [d,seqlen] + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen + int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + uint32_t load_smem_K_ptr = ( + smem_K_base_ptr + (stage * K_tile_size + + load_smem_K_d * (Bc + kPad) + + load_smem_K_Bc) * sizeof(half)); + #pragma unroll + for (int i = 0; i < (Bc / (KNumThreads / kHeadDim)); i += 8) { + CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } + } + + // wait Q and at least (kStage - 1) for K ready. + if constexpr (kStage > 1) { + CP_ASYNC_WAIT_GROUP(kStage - 2); // s2->0, s3->1, s4->2 + __syncthreads(); + } + + // : for K^T[d,seqlen] with K^T_tile[d,Bc] + // tile_K_seqlen: compute S_tile[Br,Bc] = Q@K^T = Q_tile[Br,d] * K^T[d,Bc] + #pragma unroll 1 + for (int tile_K_seqlen = 0; tile_K_seqlen < Tc; ++tile_K_seqlen) { + // TODO: process last tile_K_seqlen ? pad to multiple of 8. + // s2 tn 0->0, 1->1, 2->0; s3 tn 0->0, 1->1, 2->2, 3->0; + int smem_sel = (tile_K_seqlen) % kStage; + // s2 tn 0->1, 1->0, 2->1; s3 tn 0->2, 1->0, 2->1, 3->2; + int smem_sel_next = (tile_K_seqlen + (kStage - 1)) % kStage; + + // multi stages pipeling gmem -> smem + // NOTE: kStage must be > 1 for pipeling. For s1, smem_sel + // and smem_sel_next will always equal 0, thus, we can not + // prefetch KV from gmem to smem before tile_K_seqlen MMA done. + + if constexpr (kStage > 1) { + // First, prefetch curr V tile_K_seqlen [Bc,d] (no stages) + { + load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc; + int load_gmem_V_d = load_smem_V_d; + int load_gmem_V_addr = ( + V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d); + uint32_t load_smem_V_ptr = ( + smem_V_base_ptr + (load_smem_V_Bc * (kHeadDim + kPad) + + load_smem_V_d) * sizeof(half) + ); + #pragma unroll + for (int i = 0; i < (kHeadDim / (KNumThreads / Bc)); i += 8) { + CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } + + // Then, prefetch next stage K (tile_K_seqlen + 1) [d,Bc] + if ((tile_K_seqlen + 1) < Tc) { + load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; + int load_gmem_K_addr = ( + K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + uint32_t load_smem_K_ptr = ( + smem_K_base_ptr + (smem_sel_next * K_tile_size + + load_smem_K_d * (Bc + kPad) + + load_smem_K_Bc) * sizeof(half) + ); + #pragma unroll + for (int i = 0; i < (Bc / (KNumThreads / kHeadDim)); i += 8) { + CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } else { + // wait all memory issues ready for last tile. (may not need) + CP_ASYNC_WAIT_GROUP(0); + __syncthreads(); + } + } else { + // If no stages, kStage = 1, we have to load current K tile + // from gmem to smem and have to wait it ready for Q@K^T MMA. + + // First, prefetch curr K tile_K_seqlen [d,Bc] (no stages) + { + load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_K_d = load_smem_K_d; // load K^T [d,Bc] from [d,seqlen] + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; // < seqlen + int load_gmem_K_addr = (K_gmem_offset + load_gmem_K_d * QKV_seqlen + load_gmem_K_Bc); + uint32_t load_smem_K_ptr = ( + smem_K_base_ptr + (smem_sel * K_tile_size + + load_smem_K_d * (Bc + kPad) + + load_smem_K_Bc) * sizeof(half)); + #pragma unroll + for (int i = 0; i < (Bc / (KNumThreads / kHeadDim)); i += 8) { + CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } + + // Then, prefetch curr K tile_K_seqlen [d,Bc] (no stages) + { + load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc; + int load_gmem_V_d = load_smem_V_d; + int load_gmem_V_addr = ( + V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d); + uint32_t load_smem_V_ptr = ( + smem_V_base_ptr + (load_smem_V_Bc * (kHeadDim + kPad) + + load_smem_V_d) * sizeof(half) + ); + #pragma unroll + for (int i = 0; i < (kHeadDim / (KNumThreads / Bc)); i += 8) { + CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } + + // Wait K tile ready and let V tile copy async. + CP_ASYNC_WAIT_GROUP(1); + __syncthreads(); + } + + // : tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc] + // Matmul with NN layout, Q row major, K row major. + // S_tile[Br,Bc]=Q_tile[Br,d]@K[d,Bc] + fill_3D_regs(R_S, 0); + #pragma unroll + for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) { + // smem -> reg, load m16k16 smem Q, offset d according tile_K_d. + // ldmatrix.x4 for Q_tile_smem. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K] + int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM; + int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16; // 0~15 + int lane_smem_Q_d = tile_K_d * kMmaAtomK + (lane_id / 16) * 8; // 0,8 + uint32_t lane_smem_Q_ptr = ( + smem_Q_base_ptr + (lane_smem_Q_Br * (kHeadDim + kPad) + + lane_smem_Q_d) * sizeof(half) + ); + LDMATRIX_X4(R_Q[i][0], R_Q[i][1], R_Q[i][2], R_Q[i][3], + lane_smem_Q_ptr); // now, R_Q + } + + // smem -> reg, load k16n8 from smem K, offset d according tile_K_d. + // ldmatrix.x2.trans for K_tile_smem, [kMmaAtomK,Bc] from [d,Bc]=[K,N] + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN; // (N) + int lane_smem_K_d = tile_K_d * kMmaAtomK + lane_id % 16; // 0~15 (K); + int lane_smem_K_Bc = warp_smem_K_Bc; // 0(N) + uint32_t lane_smem_K_ptr = ( + smem_K_base_ptr + (smem_sel * K_tile_size + + lane_smem_K_d * (Bc + kPad) + + lane_smem_K_Bc) * sizeof(half) + ); + LDMATRIX_X2_T(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K + } // end for kWarpTileSeqLenK + + // MMA compute + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + HMMA16816(R_S[i][j][0], R_S[i][j][1], + R_Q[i][0], R_Q[i][1], R_Q[i][2], R_Q[i][3], + R_K[j][0], R_K[j][1], + R_S[i][j][0], R_S[i][j][1]); + } + } + } // end loop over d, S=Q@K^T + __syncthreads(); + + // TODO: May reuse K smem for V, for example, stages 2, stage + // 0 K smem can be reuse as V smem 0 because we do not need + // K values on stage 0 K smem anymore. + + // Now, we got a computed tile of S[Br,N], tile with shape [Br,Bc]. + // Assume [Br, Bc] = [64, 64] = 64x64 = 4096 values. Each thread holds + // a portion of this [Br, Bc] block, specifically, R_S = R_S[2][2][2]. + // This means that each Warp (MMA) repeats 2 times in the N direction + // for both Q and K, resulting in 2x2 = 4 sets of MMA results. Each set + // of results is stored in 2 32-bit registers, with each register holding + // 2 half-precision values. In other words, each thread stores (4x2)x2 = 16 + // half-precision values. With a total of 256 threads, the total number of + // half-precision values is 256x16 = 4096, which exactly matches the total + // [Br, Bc] = [64, 64] values. + + // The layout of 8 MMA m16n8k16 (2x4) [after] kWarpTileQPxkWarpTileKV(2x2) -> 32x2,32x2=64x64: + // | [64,64] | warp_KV 0 | warp_KV 1 | warp_KV 2 | warp_KV 3 | + // | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --| row max + // | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --| row max + // | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 3 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --| row max + // | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 3 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --| row max + + // WIP: online safe softmax, warp/block reduce max/sum, row wise + // warp 0/2/4/6, [0][2] row 0~15, col 0/8/16/32, max, [1][2] row 16~31, col 0/8/16/32, max + // warp 1/3/5/7, [0][2] row 32~47, col 0/8/16/32, max, [1][2] row 48~61, col 0/8/16/32, max + float lane_row_max_new[kWarpTileSeqLenQ][2]; + float lane_row_sum_new[kWarpTileSeqLenQ][2]; + fill_2D_regs(lane_row_max_new, -INFINITY); + fill_2D_regs(lane_row_sum_new, 0.0f); + + // Row max for [Br,Bc] tile, Thread -> Warp -> Block. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { + // Thread level reduce max across kWarpTileSeqLenK dim, namely Bc. + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + // reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type + // The layout of the fragments held by different threads for C. (m16n8k16) + // Row\Col 0 1 2 3 4 5 6 7 + // 0 T0: {c0, c1} T1: {c0, c1} T2: {c0, c1} T3: {c0, c1} + // 1 T4: {c0, c1} T5: {c0, c1} T6: {c0, c1} T7: {c0, c1} + // 2 ... + // ... + // 7 T28: {c0, c1} T29: {c0, c1} T30: {c0, c1} T31: {c0, c1} + // 8 T0: {c2, c3} T1: {c2, c3} T2: {c2, c3} T3: {c2, c3} + // 9 T4: {c2, c3} T5: {c2, c3} T6: {c2, c3} T7: {c2, c3} + // 10 ... + // ... + // 15 T28: {c2, c3} T29: {c2, c3} T30: {c2, c3} T31: {c2, c3} + float2 t_reg_S_0 = __half22float2(HALF2(R_S[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_S_1 = __half22float2(HALF2(R_S[i][j][1])); // 8~15 {c2, c3} + // This should be the row max after S = (Q @ K^T) / sqrt(d) + float tmp_max_0 = max(t_reg_S_0.x, t_reg_S_0.y) * scale; + float tmp_max_1 = max(t_reg_S_1.x, t_reg_S_1.y) * scale; + lane_row_max_new[i][0] = max(lane_row_max_new[i][0], tmp_max_0); + lane_row_max_new[i][1] = max(lane_row_max_new[i][1], tmp_max_1); + } // end for kWarpTileSeqLenK + + // Warp level reduce max, warp_size = 4 + // Each thread contains the maximum of 2 rows of Br, + // and only the values of T0, T4, ..., T28 are used. + // Br, row_id = warp_QP<0|1> * 32 + i<0|1> * 16 + 0 * 8 + (lane / 4) <0~7> + lane_row_max_new[i][0] = warp_reduce_max(lane_row_max_new[i][0]); + // Br, row_id = warp_QP<0|1> * 32 + i<0|1> * 16 + 1 * 8 + (lane / 4) <8~15> + lane_row_max_new[i][1] = warp_reduce_max(lane_row_max_new[i][1]); + + if (lane_id % 4 == 0) { // only need T0,T4,...,T28 + block_row_max_new_smem[ // Br, row_id, 0~7, 16~23, 32~39, 48~55 + warp_QP * 32 + i * 16 + 0 * 8 + (lane_id / 4)][warp_KV] = lane_row_max_new[i][0]; + block_row_max_new_smem[ // Br, row_id, 8~15, 24~31, 40~47, 56~63 + warp_QP * 32 + i * 16 + 1 * 8 + (lane_id / 4)][warp_KV] = lane_row_max_new[i][1]; + } + } // end for kWarpTileSeqLenQ + __syncthreads(); + + // Block level reduce max, row wise, 64x4=256 + float wrp_row_max_new = ( + block_row_max_new_smem[tid / kMmaTileSeqLenK][tid % kMmaTileSeqLenK]); // [0~63][0~4] + float blk_row_max_new = warp_reduce_max(wrp_row_max_new); + block_row_max_new_smem[tid / kMmaTileSeqLenK][tid % kMmaTileSeqLenK] = ( + blk_row_max_new); + __syncthreads(); + + // Exp sum and mul scale_factor for [Br,Bc] tile, Thread -> Warp -> Block. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { + // Use latest global row max without update. + // Br 0, row_id, 0~7, 16~23, 32~39, 48~55; + float block_row_max_new_0 = block_row_max_new_smem[ + warp_QP * 32 + i * 16 + 0 * 8 + (lane_id / 4)][0]; + // Br 1, row_id, 8~15, 24~31, 40~47, 56~63; + float block_row_max_new_1 = block_row_max_new_smem[ + warp_QP * 32 + i * 16 + 1 * 8 + (lane_id / 4)][0]; + // m_new = max(m_old, m_new) + block_row_max_new_0 = (tile_K_seqlen > 0 ? + max(lane_block_row_max_old[i][0], block_row_max_new_0) : + block_row_max_new_0); + block_row_max_new_1 = (tile_K_seqlen > 0 ? + max(lane_block_row_max_old[i][1], block_row_max_new_1) : + block_row_max_new_1); + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + float2 t_reg_S_0 = __half22float2(HALF2(R_S[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_S_1 = __half22float2(HALF2(R_S[i][j][1])); // 8~15 {c2, c3} + // P = Exp(S - m_new) + t_reg_S_0.x = __expf(t_reg_S_0.x * scale - block_row_max_new_0); + t_reg_S_0.y = __expf(t_reg_S_0.y * scale - block_row_max_new_0); + t_reg_S_1.x = __expf(t_reg_S_1.x * scale - block_row_max_new_1); + t_reg_S_1.y = __expf(t_reg_S_1.y * scale - block_row_max_new_1); + lane_row_sum_new[i][0] += (t_reg_S_0.x + t_reg_S_0.y); + lane_row_sum_new[i][1] += (t_reg_S_1.x + t_reg_S_1.y); + // Update R_S for P[Br,Bc] = Exp(S-m), point wise. + HALF2(R_S[i][j][0]) = __float22half2_rn(t_reg_S_0); + HALF2(R_S[i][j][1]) = __float22half2_rn(t_reg_S_1); + } // end for kWarpTileSeqLenK + + // Warp level reduce sum, warp_size = 4 + lane_row_sum_new[i][0] = warp_reduce_sum(lane_row_sum_new[i][0]); + lane_row_sum_new[i][1] = warp_reduce_sum(lane_row_sum_new[i][1]); + + if (lane_id % 4 == 0) { // only need T0,T4,...,T28 + block_row_sum_new_smem[ // Br, row_id, 0~7, 16~23, 32~39, 48~55 + warp_QP * 32 + i * 16 + 0 * 8 + (lane_id / 4)][warp_KV] = lane_row_sum_new[i][0]; + block_row_sum_new_smem[ // Br, row_id, 8~15, 24~31, 40~47, 56~63 + warp_QP * 32 + i * 16 + 1 * 8 + (lane_id / 4)][warp_KV] = lane_row_sum_new[i][1]; + } + } // end for kWarpTileSeqLenQ + __syncthreads(); + + // Block level reduce sum, row wise, 64x4=256 + float wrp_row_sum_new = ( + block_row_sum_new_smem[tid / kMmaTileSeqLenK][tid % kMmaTileSeqLenK]); // [0~63][0~4] + float blk_row_sum_new = warp_reduce_sum(wrp_row_sum_new); + block_row_sum_new_smem[tid / kMmaTileSeqLenK][tid % kMmaTileSeqLenK] = ( + blk_row_sum_new); + __syncthreads(); + + // Retile warp for [Br,d], kWarpTileHeadDimV: 1=32/(4*8); 2=64/(4*8); 4=128/(4*8). + // Compute P[Br,Bc] @ V[Bc,d] = [Br,d] = [64, 64/128], partion Attention. + + // If headdim=<32>, then, kWarpTileHeadDimV = 1, the layout of 8 MMA m16n8k16 (2x4) after + // kWarpTileSeqLenPxkWarpTileHeadDimV(2x1) tiling to (32x2,32x1)=(64x32), will look like: + // | [64,32] | warp_KV 0 | warp_KV 1 | warp_KV 2 | warp_KV 3 | + // | warp_QP 0 |-- MMA 0 --|-- MMA 2 --|-- MMA 4 --|-- MMA 6 --| + // | warp_QP 0 |-- MMA 0 --|-- MMA 2 --|-- MMA 4 --|-- MMA 6 --| + // | warp_QP 1 |-- MMA 1 --|-- MMA 3 --|-- MMA 5 --|-- MMA 7 --| + // | warp_QP 1 |-- MMA 1 --|-- MMA 3 --|-- MMA 5 --|-- MMA 7 --| + + // If headdim=<64>, then, kWarpTileHeadDimV = 2, the layout of 8 MMA m16n8k16 (2x4) after + // kWarpTileSeqLenPxkWarpTileHeadDimV(2x2) tiling to (32x2,32x2)=(64x64), will look like: + // | [64,64] | warp_KV 0 | warp_KV 1 | warp_KV 2 | warp_KV 3 | + // | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --| + // | warp_QP 0 |-- MMA 0,MMA 0 --|-- MMA 2,MMA 2 --|-- MMA 4,MMA 4 --|-- MMA 6,MMA 6 --| + // | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 3 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --| + // | warp_QP 1 |-- MMA 1,MMA 1 --|-- MMA 3,MMA 3 --|-- MMA 5,MMA 5 --|-- MMA 7,MMA 7 --| + + // If headdim=<128>, then, kWarpTileHeadDimV = 4, the layout of 8 MMA m16n8k16 (2x4) after + // kWarpTileSeqLenPxkWarpTileHeadDimV(2x2x2) tiling to (32x2,32x2x2)=(64x64x2), will look like: + // | [64,64x2] | warp_KV 0 | warp_KV 1 | warp_KV 2 | warp_KV 3 | + // | warp_QP 0 |-- MMA 0,MMA 0,MMA 0,MMA 0 --|-- MMA 2,MMA 2,MMA 2,MMA 2 --|-- MMA 4,MMA 4,MMA 4,MMA 4 --|-- MMA 6,MMA 6,MMA 6,MMA 6 --| + // | warp_QP 0 |-- MMA 0,MMA 0,MMA 0,MMA 0 --|-- MMA 2,MMA 2,MMA 2,MMA 2 --|-- MMA 4,MMA 4,MMA 4,MMA 4 --|-- MMA 6,MMA 6,MMA 6,MMA 6 --| + // | warp_QP 1 |-- MMA 1,MMA 1,MMA 1,MMA 1 --|-- MMA 3,MMA 3,MMA 3,MMA 3 --|-- MMA 5,MMA 5,MMA 5,MMA 5 --|-- MMA 7,MMA 7,MMA 7,MMA 7 --| + // | warp_QP 1 |-- MMA 1,MMA 1,MMA 1,MMA 1 --|-- MMA 3,MMA 3,MMA 3,MMA 3 --|-- MMA 5,MMA 5,MMA 5,MMA 5 --|-- MMA 7,MMA 7,MMA 7,MMA 7 --| + + // Write R_P(R_S) to P_smem [Br,Bc] + // store S(for DEBUG now) [Br,Bc] of [seqlen,seqlen] [64,64] + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + R_Q[0][0] = R_S[i][j][0]; R_Q[1][0] = R_S[i][j][1]; // warp_size 4 + R_Q[0][1] = __shfl_sync((0xffffffff), R_S[i][j][0], lane_id + 1, 4); + R_Q[0][2] = __shfl_sync((0xffffffff), R_S[i][j][0], lane_id + 2, 4); + R_Q[0][3] = __shfl_sync((0xffffffff), R_S[i][j][0], lane_id + 3, 4); + R_Q[1][1] = __shfl_sync((0xffffffff), R_S[i][j][1], lane_id + 1, 4); + R_Q[1][2] = __shfl_sync((0xffffffff), R_S[i][j][1], lane_id + 2, 4); + R_Q[1][3] = __shfl_sync((0xffffffff), R_S[i][j][1], lane_id + 3, 4); + + // st.global.v4 128 bits. + if (lane_id % 4 == 0) { + // (0/1)*32 + (0/1)*16=(0,16,32,48), + 0~7 -> 0~56 + int store_warp_regs_S_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM; + int store_lane_smem_S_Br = store_warp_regs_S_Br + lane_id / 4; // 0~7 + // (0~3)*16 + (0/1)*8=(0,8,16,24,...,48,56) + int store_warp_regs_S_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN; + int store_lane_smem_S_Bc = store_warp_regs_S_Bc; // (0~3)*16+(0/8) + int store_smem_S_addr_0 = ((store_lane_smem_S_Br + 0) * (Bc + kPad) + store_lane_smem_S_Bc); + int store_smem_S_addr_1 = ((store_lane_smem_S_Br + 8) * (Bc + kPad) + store_lane_smem_S_Bc); + LDST128BITS(S_tile_smem[store_smem_S_addr_0]) = LDST128BITS(R_Q[0][0]); + LDST128BITS(S_tile_smem[store_smem_S_addr_1]) = LDST128BITS(R_Q[1][0]); + } + } // end for kWarpTileHeadDimV + } // end for kWarpTileSeqLenQ + __syncthreads(); + + // Compute P[Br,Bc] @ V[Bc,d] = [Br,d] = [64, 64/128], partion Attention. + // Here, we have to wait V ready before compute O = P @ V + if constexpr (kStage == 2) { + // NOTE: For kStage > 1, we have send V mem issues before K + CP_ASYNC_WAIT_GROUP(1); // s1->-1, s2->0, s3->1, s4->2 + } else { + CP_ASYNC_WAIT_GROUP(0); + } + __syncthreads(); + + // : P[Br,Bc]@V[Bc,d]=[Br,d]=[64,64/128], partion Attention. + // Matmul with NN layout: P[Br,Bc] row major, V[Bc,d] row major. + // Make sure to clear the states in R_O before MMA for P@V for each step. + fill_3D_regs(R_O, 0); + #pragma unroll + for (int tile_V_Bc = 0; tile_V_Bc < (Bc / kMmaAtomK); ++tile_V_Bc) { + // smem -> reg, load m16k16 smem Q, offset d according tile_K_d. + // ldmatrix.x4 for S_tile_smem. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenP; ++i) { // S[Br,Bc]=[M,K] + int warp_smem_S_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenP) + i * kMmaAtomM; + int lane_smem_S_Br = warp_smem_S_Br + lane_id % 16; // 0~15 + int lane_smem_S_Bc = tile_V_Bc * kMmaAtomK + (lane_id / 16) * 8; // 0,8 + uint32_t lane_smem_S_ptr = ( + smem_S_base_ptr + (lane_smem_S_Br * (Bc + kPad) + + lane_smem_S_Bc) * sizeof(half) + ); + LDMATRIX_X4(R_Q[i][0], R_Q[i][1], R_Q[i][2], R_Q[i][3], + lane_smem_S_ptr); // now, R_P + } + + // Load k16n8 V from smem -> regs, R_KV, ldmatrix.x2.trans. + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { + int warp_smem_V_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN; // d, matmaul N + int lane_smem_V_Bc = tile_V_Bc * kMmaAtomK + lane_id % 16; // 0~15; Bc, matmul K + int lane_smem_V_d = warp_smem_V_d; // 0 + uint32_t lane_smem_V_ptr = ( + smem_V_base_ptr + (lane_smem_V_Bc * (kHeadDim + kPad) + + lane_smem_V_d) * sizeof(half) + ); + LDMATRIX_X2_T(R_V[j][0], R_V[j][1], lane_smem_V_ptr); // R_V + } + + // NOTE: Values for P[Br,Bc] already in R_S registers, can we use these + // registers for P(A) matrix directly ? How to do that ? + // according to the A matrix layout for MMA m16n8k16 instruction. + // reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type + // The layout of the fragments held by different threads for A matrix with .f16. + // R\C 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + // 0 T0: {a0, a1} T1: {a0, a1} T2: {a0, a1} T3: {a0, a1} T0: {a4, a5} T1: {a4, a5} T2: {a4, a5} T3: {a4, a5} + // 1 T4: {a0, a1} T5: {a0, a1} T6: {a0, a1} T7: {a0, a1} T4: {a4, a5} T5: {a4, a5} T6: {a4, a5} T7: {a4, a5} + // 2 (dashed arrow pointing right) + // ... + // 7 T28: {a0, a1} T29: {a0, a1} T30: {a0, a1} T31: {a0, a1} T28: {a4, a5} T29: {a4, a5} T30: {a4, a5} T31: {a4, a5} + // 8 T0: {a2, a3} T1: {a2, a3} T2: {a2, a3} T3: {a2, a3} T0: {a6, a7} T1: {a6, a7} T2: {a6, a7} T3: {a6, a7} + // 9 T4: {a2, a3} T5: {a2, a3} T6: {a2, a3} T7: {a2, a3} T4: {a6, a7} T5: {a6, a7} T6: {a6, a7} T7: {a6, a7} + // 10 (dashed arrow pointing right) + // ... + // 15 T28: {a2, a3} T29: {a2, a3} T30: {a2, a3} T31: {a2, a3} T28: {a6, a7} T29: {a6, a7} T30: {a6, a7} T31: {a6, a7} + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenP; ++i) { // kWarpTileSeqLenQ=2 + FA_MMA_CHECK_PRINT_REG(R_S[i][0][0], R_Q[i][0], "Check failed, R_S[%d][0][0], R_Q[%d][0], tile_V_Bc: %d, tid: %d, lane: %d", i, i, tile_V_Bc, tid, lane_id); + FA_MMA_CHECK_PRINT_REG(R_S[i][0][1], R_Q[i][1], "Check failed, R_S[%d][0][1], R_Q[%d][1], tile_V_Bc: %d, tid: %d, lane: %d", i, i, tile_V_Bc, tid, lane_id); + FA_MMA_CHECK_PRINT_REG(R_S[i][1][0], R_Q[i][2], "Check failed, R_S[%d][1][0], R_Q[%d][2], tile_V_Bc: %d, tid: %d, lane: %d", i, i, tile_V_Bc, tid, lane_id); + FA_MMA_CHECK_PRINT_REG(R_S[i][1][1], R_Q[i][3], "Check failed, R_S[%d][1][1], R_Q[%d][3], tile_V_Bc: %d, tid: %d, lane: %d", i, i, tile_V_Bc, tid, lane_id); + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { // kWarpTileHeadDimV=1,2,3,4,... + FA_MMA_PRINT_REG(R_V[j][0], "[Before] MMA P@V, R_V[%d][0], tile_V_Bc: %d, tid: %d, lane: %d", j, tile_V_Bc, tid, lane_id); + FA_MMA_PRINT_REG(R_V[j][1], "[Before] MMA P@V, R_V[%d][1], tile_V_Bc: %d, tid: %d, lane: %d", j, tile_V_Bc, tid, lane_id); + HMMA16816(R_O[i][j][0], R_O[i][j][1], + // FIXME(DefTruth): Still have some error while using R_S + // as registers for P(A) matrix directly. I will remove this + // comments when I have fixed it. + // R_S[i][0][0], R_S[i][0][1], R_S[i][1][0], R_S[i][1][1], + R_Q[i][0], R_Q[i][1], R_Q[i][2], R_Q[i][3], + R_V[j][0], R_V[j][1], + R_O[i][j][0], R_O[i][j][1]); + } + } + } // end for V Bc. + __syncthreads(); + + // Rescale O -> Update row sum Exp -> then, Update row max. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenP; ++i) { // kWarpTileSeqLenQ=kWarpTileSeqLenP + // m = max(m_old, m_new), l = exp(m_old - m) * l_old + l_new (FA2 paper) + // Br 0, row_id, 0~7, 16~23, 32~39, 48~55; Br 1, row_id, 8~15, 24~31, 40~47, 56~63 + float block_row_max_new_0 = block_row_max_new_smem[ + warp_QP * 32 + i * 16 + 0 * 8 + (lane_id / 4)][0]; + float block_row_max_new_1 = block_row_max_new_smem[ + warp_QP * 32 + i * 16 + 1 * 8 + (lane_id / 4)][0]; + float block_row_sum_new_0 = block_row_sum_new_smem[ + warp_QP * 32 + i * 16 + 0 * 8 + (lane_id / 4)][0]; + float block_row_sum_new_1 = block_row_sum_new_smem[ + warp_QP * 32 + i * 16 + 1 * 8 + (lane_id / 4)][0]; + float block_row_max_old_0 = lane_block_row_max_old[i][0]; + float block_row_max_old_1 = lane_block_row_max_old[i][1]; + block_row_max_new_0 = (tile_K_seqlen > 0 ? max(block_row_max_old_0, + block_row_max_new_0) : + block_row_max_new_0); + block_row_max_new_1 = (tile_K_seqlen > 0 ? max(block_row_max_old_1, + block_row_max_new_1) : + block_row_max_new_1); + block_row_max_old_0 = (tile_K_seqlen > 0 ? block_row_max_old_0 : + block_row_max_new_0); + block_row_max_old_1 = (tile_K_seqlen > 0 ? block_row_max_old_1 : + block_row_max_new_1); + + // rescale factor for O and l, exp(m_old - m) + float rescale_o_factor_0 = __expf(block_row_max_old_0 - block_row_max_new_0); + float rescale_o_factor_1 = __expf(block_row_max_old_1 - block_row_max_new_1); + // 0. Rescale O: Online rescaling O each tile_K_seqlen step, need m_new, m_old. + // m = max(m_old, m_new), O_new[Br,d] = exp(m_old - m) * O_old + P@V + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { + float2 t_reg_O_0 = __half22float2(HALF2(R_O[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_O_1 = __half22float2(HALF2(R_O[i][j][1])); // 8~15 {c2, c3} + float2 t_reg_D_0 = __half22float2(HALF2(R_D[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_D_1 = __half22float2(HALF2(R_D[i][j][1])); // 8~15 {c2, c3} + // Note that the formula in the FA2 paper is incorrect; here, + // the inverse of the exp function should not be taken, as it + // would result in an error during rescaling, namely, you have + // use exp(m_old - m_new), not 1/(m_old - m_new). + // O_new[Br,d] = exp(m_old - m_new) * O_old + P@V + t_reg_D_0.x = rescale_o_factor_0 * t_reg_D_0.x + t_reg_O_0.x; + t_reg_D_0.y = rescale_o_factor_0 * t_reg_D_0.y + t_reg_O_0.y; + t_reg_D_1.x = rescale_o_factor_1 * t_reg_D_1.x + t_reg_O_1.x; + t_reg_D_1.y = rescale_o_factor_1 * t_reg_D_1.y + t_reg_O_1.y; + HALF2(R_D[i][j][0]) = __float22half2_rn(t_reg_D_0); + HALF2(R_D[i][j][1]) = __float22half2_rn(t_reg_D_1); + } // end for kWarpTileHeadDimV. + + // Now, we can update m, l after O has been scaled. + // 1. First, update block row sum Exp for each lane which + // need both m_new and m_old. + float block_row_sum_old_0 = lane_block_row_sum_old[i][0]; + float block_row_sum_old_1 = lane_block_row_sum_old[i][1]; + // Update l = exp(m_old - m_new) * l_old + row_sum(P). + lane_block_row_sum_old[i][0] = ( + rescale_o_factor_0 * block_row_sum_old_0 + block_row_sum_new_0); + lane_block_row_sum_old[i][1] = ( + rescale_o_factor_1 * block_row_sum_old_1 + block_row_sum_new_1); + // 2. Then, update block row max for each lane. + lane_block_row_max_old[i][0] = block_row_max_new_0; + lane_block_row_max_old[i][1] = block_row_max_new_1; + } + + // NOTE: After compute P @ V, we have to wait next K tile ready in smem. + // do not need to wait any things if kStage == 1. + if constexpr (kStage == 2) { + CP_ASYNC_WAIT_GROUP(0); + __syncthreads(); + } + + } // end loop over N + __syncthreads(); + + // Finaly, we still have to rescale O once more. + // O_output(D) = ( 1/l_final ) * O_final (FA2 paper) + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenP; ++i) { + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { + float2 t_reg_D_0 = __half22float2(HALF2(R_D[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_D_1 = __half22float2(HALF2(R_D[i][j][1])); // 8~15 {c2, c3} + float rescale_factor_0 = __frcp_rn(lane_block_row_sum_old[i][0]); + float rescale_factor_1 = __frcp_rn(lane_block_row_sum_old[i][1]); + t_reg_D_0.x = rescale_factor_0 * t_reg_D_0.x; + t_reg_D_0.y = rescale_factor_0 * t_reg_D_0.y; + t_reg_D_1.x = rescale_factor_1 * t_reg_D_1.x; + t_reg_D_1.y = rescale_factor_1 * t_reg_D_1.y; + HALF2(R_D[i][j][0]) = __float22half2_rn(t_reg_D_0); + HALF2(R_D[i][j][1]) = __float22half2_rn(t_reg_D_1); + } + } + + // Store O(D): Write O[Br,d] from regs -> gmem, collective store + // with reg reuse & warp shuffle. need R[2][4], may reuse + // R_Q[kWarpTileSeqLenQ][4]=[2][4]. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenP; ++i) { + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { + R_Q[0][0] = R_D[i][j][0]; R_Q[1][0] = R_D[i][j][1]; // warp_size 4 + R_Q[0][1] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 1, 4); + R_Q[0][2] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 2, 4); + R_Q[0][3] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 3, 4); + R_Q[1][1] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 1, 4); + R_Q[1][2] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 2, 4); + R_Q[1][3] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 3, 4); + + // st.global.v4 128 bits. [Br,d] + if (lane_id % 4 == 0) { + // (0/1)*32 + (0/1)*16=(0,16,32,48), + 0~7 -> 0~56 + int store_warp_regs_O_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenP ) + i * kMmaAtomM; + int store_lane_gmem_O_Br = O_tile_id * Br + store_warp_regs_O_Br + lane_id / 4; // 0~7 + // (0~3)*16 + (0/1)*8=(0,8,16,24,...,48,56) + int store_warp_regs_O_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN; + int store_lane_gmem_O_d = store_warp_regs_O_d; // (0~3)*16+(0/8) + int store_gmem_O_addr_0 = (O_gmem_offset + (store_lane_gmem_O_Br + 0) * kHeadDim + + store_lane_gmem_O_d); + int store_gmem_O_addr_1 = (O_gmem_offset + (store_lane_gmem_O_Br + 8) * kHeadDim + + store_lane_gmem_O_d); + LDST128BITS(O[store_gmem_O_addr_0]) = LDST128BITS(R_Q[0][0]); + LDST128BITS(O[store_gmem_O_addr_1]) = LDST128BITS(R_Q[1][0]); + } + } // end for kWarpTileHeadDimV + } // end for kWarpTileSeqLenQ +} + +// --------------------- PyTorch bindings for custom kernel ----------------------- +#include +#include +#define STRINGFY(str) #str +#define TORCH_BINDING_COMMON_EXTENSION(func) \ + m.def(STRINGFY(func), &func, STRINGFY(func)); + +#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ +if(((T).options().dtype() != (th_type))) { \ + std::cout << "Tensor Info:" << (T).options() << std::endl; \ + throw std::runtime_error("values must be "#th_type); \ +} + +#define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \ +if (((T2).size(0) != (T1).size(0)) || \ + ((T2).size(1) != (T1).size(1)) || \ + ((T2).size(2) != (T1).size(2)) || \ + ((T2).size(3) != (T1).size(3))) { \ + throw std::runtime_error("Tensor size mismatch!"); \ +} + +template +void launch_flash_attn_mma_stages( + torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O) { + constexpr int kMmaAtomM = 16; + constexpr int kMmaAtomN = 8; + constexpr int kMmaAtomK = 16; + constexpr int kMmaTileSeqLenQ = 2; + constexpr int kMmaTileSeqLenP = 2; + constexpr int kMmaTileSeqLenK = 4; + constexpr int kMmaTileHeadDimV = 4; + constexpr int kWarpTileSeqLenQ = 2; + constexpr int kWarpTileSeqLenP = 2; + constexpr int kWarpTileSeqLenK = 2; + constexpr int kWarpTileHeadDimV = (kHeadDim / (kMmaAtomN*kMmaTileHeadDimV)); + constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*2*2=64 + constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*4*2=64 + constexpr int kPad = 0; + + // Calculate SRAM size needed per block, Q,K,V,S smem size + const int smem_max_size = ((Br * (kHeadDim + kPad)) + + (kStage * kHeadDim * (Bc + kPad)) + + (Bc * (kHeadDim + kPad)) + + (Br * (Bc + kPad))) * sizeof(half); + + const int QKV_batch = Q.size(0); + const int QKV_head = Q.size(1); + const int QKV_seqlen = Q.size(2); // QKV_seqlen + assert(QKV_seqlen % Bc == 0); // multiple of Bc=64 + + dim3 grid(QKV_batch, QKV_head, div_ceil(QKV_seqlen, Br)); // batch_size x num_heads x Tr(=N/Br) + dim3 block(WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK); // 8 warps per block + + cudaFuncSetAttribute( + flash_attn_mma_stages_kernel< + kHeadDim, + kMmaAtomM, + kMmaAtomN, + kMmaAtomK, + kMmaTileSeqLenQ, + kMmaTileSeqLenK, + kMmaTileSeqLenP, + kMmaTileHeadDimV, + kWarpTileSeqLenQ, + kWarpTileSeqLenK, + kWarpTileSeqLenP, + kWarpTileHeadDimV, + kStage, + kPad + >, + cudaFuncAttributeMaxDynamicSharedMemorySize, + 98304 + ); + + flash_attn_mma_stages_kernel< + kHeadDim, + kMmaAtomM, + kMmaAtomN, + kMmaAtomK, + kMmaTileSeqLenQ, + kMmaTileSeqLenK, + kMmaTileSeqLenP, + kMmaTileHeadDimV, + kWarpTileSeqLenQ, + kWarpTileSeqLenK, + kWarpTileSeqLenP, + kWarpTileHeadDimV, + kStage, + kPad + ><<>>( + reinterpret_cast(Q.data_ptr()), + reinterpret_cast(K.data_ptr()), + reinterpret_cast(V.data_ptr()), + reinterpret_cast(O.data_ptr()), + QKV_seqlen + ); +} + +void flash_attn_mma_stages(torch::Tensor Q, torch::Tensor K, + torch::Tensor V, torch::Tensor O, + int stages) { + CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K^T [B,H,D,N], transposed. + CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D] + const int d = Q.size(3); // B, H, N, d + + if (stages == 2) { + switch (d) + { + case 64: + launch_flash_attn_mma_stages<64, 2>(Q, K, V, O); + break; + case 96: + launch_flash_attn_mma_stages<96, 2>(Q, K, V, O); + break; + case 128: + launch_flash_attn_mma_stages<128, 2>(Q, K, V, O); + break; + default: + throw std::runtime_error("headdim not support!"); + break; + } + } else { + switch (d) + { + case 64: + launch_flash_attn_mma_stages<64, 1>(Q, K, V, O); + break; + case 96: + launch_flash_attn_mma_stages<96, 1>(Q, K, V, O); + break; + case 128: + launch_flash_attn_mma_stages<128, 1>(Q, K, V, O); + break; + default: + throw std::runtime_error("headdim not support!"); + break; + } + } +} diff --git a/kernels/flash-attn/naive/flash_attn_cuda.cu b/kernels/flash-attn/naive/flash_attn_cuda.cu index c42a9c9c..e1dd6058 100644 --- a/kernels/flash-attn/naive/flash_attn_cuda.cu +++ b/kernels/flash-attn/naive/flash_attn_cuda.cu @@ -16,7 +16,7 @@ #define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) -__global__ void flash_attn_1_fwd_f32_kernel( +__global__ void flash_attn_cuda_kernel( const float* Q, const float* K, const float* V, @@ -135,7 +135,7 @@ if (((T2).size(0) != (T1).size(0)) || \ throw std::runtime_error("Tensor size mismatch!"); \ } -void flash_attn_1_fwd_f32( +void flash_attn_cuda( torch::Tensor Q, torch::Tensor K, torch::Tensor V, @@ -174,7 +174,7 @@ void flash_attn_1_fwd_f32( dim3 grid(B, nh); // batch_size x num_heads dim3 block(Bc); // Bc threads per block - flash_attn_1_fwd_f32_kernel<<>>( + flash_attn_cuda_kernel<<>>( reinterpret_cast(Q.data_ptr()), reinterpret_cast(K.data_ptr()), reinterpret_cast(V.data_ptr()), diff --git a/kernels/flash-attn/pybind/flash_attn.cc b/kernels/flash-attn/pybind/flash_attn.cc index fe96caba..13cb4aed 100644 --- a/kernels/flash-attn/pybind/flash_attn.cc +++ b/kernels/flash-attn/pybind/flash_attn.cc @@ -5,10 +5,12 @@ #define TORCH_BINDING_COMMON_EXTENSION(func) \ m.def(STRINGFY(func), &func, STRINGFY(func)); -void flash_attn_1_fwd_f32(torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O); -void flash_attn_2_fwd_f16_mma_m16n8k16(torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O); +void flash_attn_cuda(torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O); +void flash_attn_mma_naive(torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O); +void flash_attn_mma_stages(torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O, int stages); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - TORCH_BINDING_COMMON_EXTENSION(flash_attn_1_fwd_f32) - TORCH_BINDING_COMMON_EXTENSION(flash_attn_2_fwd_f16_mma_m16n8k16) + TORCH_BINDING_COMMON_EXTENSION(flash_attn_cuda) + TORCH_BINDING_COMMON_EXTENSION(flash_attn_mma_naive) + TORCH_BINDING_COMMON_EXTENSION(flash_attn_mma_stages) } diff --git a/kernels/flash-attn/utils/utils.h b/kernels/flash-attn/utils/utils.h index e69de29b..f71b3380 100644 --- a/kernels/flash-attn/utils/utils.h +++ b/kernels/flash-attn/utils/utils.h @@ -0,0 +1,208 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +using namespace nvcuda; + +#define WARP_SIZE 32 +#define DEVICE_INLINE __device__ inline +#define HOST_DEVICE_INLINE __device__ __host__ inline + +#define WARP_SIZE 32 +#define DEVICE_INLINE __device__ inline +#define HOST_DEVICE_INLINE __device__ __host__ inline +#define INT4(value) (reinterpret_cast(&(value))[0]) +#define FLOAT4(value) (reinterpret_cast(&(value))[0]) +#define HALF2(value) (reinterpret_cast(&(value))[0]) +#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) +#define LDST32BITS(value) (reinterpret_cast(&(value))[0]) +#define LDST64BITS(value) (reinterpret_cast(&(value))[0]) +#define LDST128BITS(value) (reinterpret_cast(&(value))[0]) +// gmem -> smem +#define CP_ASYNC_COMMIT_GROUP() asm volatile("cp.async.commit_group;\n" ::) +#define CP_ASYNC_WAIT_ALL() asm volatile("cp.async.wait_all;\n" ::) +#define CP_ASYNC_WAIT_GROUP(n) asm volatile("cp.async.wait_group %0;\n" ::"n"(n)) +// ca(cache all, L1 + L2): support 4, 8, 16 bytes, cg(cache global, L2): only support 16 bytes. +#define CP_ASYNC_CA(dst, src, bytes) asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) +#define CP_ASYNC_CG(dst, src, bytes) asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) +// smem -> gmem: requires sm_90 or higher. +#define CP_ASYNC_BULK_COMMIT_GROUP() asm volatile("cp.async.bulk.commit_group;\n" ::) +#define CP_ASYNC_BULK_WAIT_ALL() asm volatile("cp.async.bulk.wait_all;\n" ::) +#define CP_ASYNC_BULK_WAIT_GROUP(n) asm volatile("cp.async.bulk.wait_group %0;\n" ::"n"(n)) +#define CP_ASYNC_BULK(dst, src, bytes) asm volatile("cp.async.bulk.global.shared::cta.bulk_group.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) +// ldmatrix +#define LDMATRIX_X1(R, addr) asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) +#define LDMATRIX_X2(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) +#define LDMATRIX_X4(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) +#define LDMATRIX_X1_T(R, addr) asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) +#define LDMATRIX_X2_T(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) +#define LDMATRIX_X4_T(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) +// stmatrix: requires sm_90 or higher. +#define STMATRIX_X1(addr, R) asm volatile("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" :: "r"(addr), "r"(R)) +#define STMATRIX_X2(addr, R0, R1) asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" :: "r"(addr), "r"(R0), "r"(R1)) +#define STMATRIX_X4(addr, R0, R1, R2, R3) asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" :: "r"(addr), "r"(R0), "r"(R1), "r"(R2), "r"(R3)) +#define STMATRIX_X1_T(addr, R) asm volatile("stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n" :: "r"(addr), "r"(R)) +#define STMATRIX_X2_T(addr, R0, R1) asm volatile("stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n" :: "r"(addr), "r"(R0), "r"(R1)) +#define STMATRIX_X4_T(addr, R0, R1, R2, R3) asm volatile("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" :: "r"(addr), "r"(R0), "r"(R1), "r"(R2), "r"(R3)) +// mma m16n8k16 +#define HMMA16816(RD0, RD1, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1) asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" : "=r"(RD0), "=r"(RD1) : "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1)) +#define HMMA16816F32(RD0, RD1, RD2, RD3, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1, RC2, RC3) asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" : "=r"(RD0), "=r"(RD1), "=r"(RD2), "=r"(RD3): "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1), "r"(RC2), "r"(RC3)) + + +HOST_DEVICE_INLINE +int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 1) : (a / b); } + +template +DEVICE_INLINE T warp_reduce_sum(T val) { + #pragma unroll + for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) { + val += __shfl_xor_sync(0xffffffff, val, mask, kWarpSize); + } + return val; +} + +template +DEVICE_INLINE T warp_reduce_max(T val) { + #pragma unroll + for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) { + T val_compare = __shfl_xor_sync(0xffffffff, val, mask, kWarpSize); + val = val > val_compare ? val : val_compare; + } + return val; +} + +template +DEVICE_INLINE void fill_3D_regs(T (&R)[M][N][K], T val) { + #pragma unroll + for (int i = 0; i < M; ++i) { + #pragma unroll + for (int j = 0; j < N; ++j) { + #pragma unroll + for (int k = 0; k < K; ++k) { + R[i][j][k] = val; + } + } + } +} + +template +DEVICE_INLINE void fill_2D_regs(T (&R)[M][N], T val) { + #pragma unroll + for (int i = 0; i < M; ++i) { + #pragma unroll + for (int j = 0; j < N; ++j) { + R[i][j] = val; + } + } +} + +#define INFHALF __float2half(65536.0f) +#define ZEROHALF __float2half(0.0f) + +#ifdef FLASH_ATTN_MMA_DEBUG +#define FA_MMA_PRINT_T0_REG(R, format, ...) \ +{ \ + if (tid == 0) { \ + float2 v_reg = __half22float2(HALF2(R)); \ + printf("[T0] " format ", V0=%f, V1=%f\n", \ + ##__VA_ARGS__, v_reg.x, v_reg.y); \ + } \ +} + +#define FA_MMA_PRINT_T32_REG(R, format, ...) \ +{ \ + if (tid < 32) { \ + float2 v_reg = __half22float2(HALF2(R)); \ + printf("[T%d] " format ", V0=%f, V1=%f\n", \ + tid, ##__VA_ARGS__, v_reg.x, v_reg.y);\ + } \ +} + +#define FA_MMA_PRINT_REG(R, format, ...) \ +{ \ + { \ + float2 v_reg = __half22float2(HALF2(R)); \ + printf(format", V0=%f, V1=%f\n", \ + ##__VA_ARGS__, v_reg.x, v_reg.y); \ + } \ +} + +#define FA_MMA_CHECK_PRINT_REG(R0, R1, format, ...) \ +{ \ + { \ + float2 v_reg_0 = __half22float2(HALF2(R0)); \ + float2 v_reg_1 = __half22float2(HALF2(R1)); \ + if ((fabs(v_reg_0.x - v_reg_1.x) > 0.01f) || \ + (fabs(v_reg_0.y - v_reg_1.y) > 0.01f)) { \ + printf(format", R0, V0=%f, V1=%f, R1, V0=%f, V1=%f\n", \ + ##__VA_ARGS__, v_reg_0.x, v_reg_0.y, v_reg_1.x, v_reg_1.y); \ + } \ + } \ +} + +#define FA_MMA_PRINT_T0(format, ...) \ +{ \ + if (tid == 0) { \ + printf("[T0] " format, ##__VA_ARGS__); \ + } \ +} + +#define FA_MMA_PRINT_T32(format, ...) \ +{ \ + if (tid < 32) { \ + printf("[T%d] " format, tid, ##__VA_ARGS__);\ + } \ +} + +#define FA_MMA_PRINT_L0_REG(R, format, ...) \ +{ \ + if (lane_id == 0) { \ + float2 v_reg = __half22float2(HALF2(R)); \ + printf("[L0] " format", V0=%f, V1=%f\n", \ + ##__VA_ARGS__, v_reg.x, v_reg.y); \ + } \ +} + +#define FA_MMA_PRINT_L0(format, ...) \ +{ \ + if (lane_id == 0) { \ + printf("[L0] " format, ##__VA_ARGS__); \ + } \ +} + +#define FA_MMA_PRINT_T0_B0_MATRIX(B, format, ...) \ +{ \ + if (tid == 0 && blockIdx.z == 0) { \ + printf("----------------------------------------\n"); \ + printf(format, ##__VA_ARGS__); \ + for (int i = 0; i < Br; ++i) { \ + for (int j = 0; j < kMmaTileSeqLenK; ++j) { \ + printf("[%d][%d]=%f", i, j, (B)[i][j]); \ + } \ + printf("\n"); \ + } \ + printf("----------------------------------------\n"); \ + } \ + __syncthreads(); \ +} + +#else + +#define FA_MMA_PRINT_REG(R, format, ...) {} +#define FA_MMA_CHECK_PRINT_REG(R0, R1, format, ...) {} +#define FA_MMA_PRINT_T0_REG(R, format, ...) {} +#define FA_MMA_PRINT_T32_REG(R, format, ...) {} +#define FA_MMA_PRINT_L0_REG(R, format, ...) {} +#define FA_MMA_PRINT_T0(format, ...) {} +#define FA_MMA_PRINT_T32(format, ...) {} +#define FA_MMA_PRINT_L0(format, ...) {} +#define FA_MMA_PRINT_T0_B0_MATRIX(B, format, ...) {} + +#endif diff --git a/kernels/hgemm/.gitignore b/kernels/hgemm/.gitignore index fc99f1fd..09d0d4d0 100755 --- a/kernels/hgemm/.gitignore +++ b/kernels/hgemm/.gitignore @@ -23,4 +23,10 @@ bin output *.egg-info *.whl -dist \ No newline at end of file +dist +*.pdf +*.tex +*.log +*.md5 +*.aux* +*.dpth diff --git a/kernels/hgemm/README.md b/kernels/hgemm/README.md index 2ea8e64c..6d66037c 100755 --- a/kernels/hgemm/README.md +++ b/kernels/hgemm/README.md @@ -19,7 +19,7 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's d |CUDA Cores|Sliced K (Loop over K)|Tile Block (BMxBN)|Tile Thread (t 8x8)| |:---:|:---:|:---:|:---:| |✔️|✔️|✔️|✔️| -|WMMA (m16n16k16)|MMA (m16n8k16)|Pack LDS T(pack 128 bits)|SMEM Padding| +|WMMA (m16n16k16)|MMA (m16n8k16)|Pack LDST (pack 128 bits)|SMEM Padding| |✔️|✔️|✔️|✔️| |Copy Async (cp.async.cg/ca)|Tile MMA (More Threads)|Tile Warp (More Values)|Multi Stages(2/3/4/5)| |✔️|✔️|✔️|✔️| @@ -81,7 +81,7 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(torch::Tensor a, torch::T void hgemm_mma_stages_block_swizzle_tn_cute(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); ``` -## 📖 目录 +## 📖 Contents - [📖 Prerequisites](#prerequisites) - [📖 Installation](#install) @@ -371,3 +371,6 @@ TODO - [cuda_learning](https://github.com/ifromeast/cuda_learning) - [cuda_hgemm](https://github.com/Bruce-Lee-LY/cuda_hgemm) - [cuda-tensorcore-hgemm](https://github.com/nicolaswilde/cuda-tensorcore-hgemm) +- [How_to_optimize_in_GPU](https://github.com/Liu-xiandong/How_to_optimize_in_GPU/tree/master/sgemv) +- [cute_gemm](https://github.com/weishengying/cute_gemm) +- [cutlass](https://github.com/NVIDIA/cutlass) diff --git a/kernels/hgemm/cutlass/hgemm_mma_stage_tn_cute.cu b/kernels/hgemm/cutlass/hgemm_mma_stage_tn_cute.cu index d16942ef..a56b52ba 100644 --- a/kernels/hgemm/cutlass/hgemm_mma_stage_tn_cute.cu +++ b/kernels/hgemm/cutlass/hgemm_mma_stage_tn_cute.cu @@ -69,6 +69,12 @@ __global__ void hgemm_mma_stages_block_swizzle_tn_cute_kernel( auto g2s_thr_copy_a = g2s_tiled_copy_a.get_slice(idx); auto tAgA_copy = g2s_thr_copy_a.partition_S(gA); // (CPY, CPY_M, CPY_K, num_tile_k) auto tAsA_copy = g2s_thr_copy_a.partition_D(sA); // (CPY, CPY_M, CPY_K, kStage) +#ifdef CUTE_HGEMM_DEBUG + if (thread0()) { + print("\npartition_S(tAgA_copy): \n"); print(tAgA_copy); print("\n"); + print("\nThrCopy(g2s_thr_copy_a): \n"); print(g2s_thr_copy_a); print("\n"); + } +#endif G2SCopyB g2s_tiled_copy_b; auto g2s_thr_copy_b = g2s_tiled_copy_b.get_slice(idx); @@ -235,8 +241,10 @@ void launch_hgemm_mma_stages_block_swizzle_tn_cute(const T *a, make_shape(Int{}, Int{}, Int{})) ); // (m,n) -> smem_idx #ifdef CUTE_HGEMM_DEBUG - print(SmemLayoutA{}); print("\n"); - print(SmemLayoutB{}); print("\n"); + print("SmemLayoutA: "); print(SmemLayoutA{}); print("\n"); + print("SmemLayoutB: "); print(SmemLayoutB{}); print("\n"); + print("SmemLayoutB: "); print(SmemLayoutB{}); print("\n"); + print("SmemLayoutAtom A&B Latex: \n"); print_latex(SmemLayoutAtom{}); print("\n"); #endif // mma @@ -258,7 +266,8 @@ void launch_hgemm_mma_stages_block_swizzle_tn_cute(const T *a, using MMA_P_T = Tile, Int, Int>; using MMA = decltype(make_tiled_mma(mma_atom{}, MMA_EU_RepeatT{}, MMA_P_T{})); #ifdef CUTE_HGEMM_DEBUG - print(MMA{}); print("\n"); + print("MMA: "); print(MMA{}); print("\n"); + print("MMA Latex: \n"); print_latex(MMA{}); print("\n"); #endif // copy from global memory to shared memory @@ -281,6 +290,8 @@ void launch_hgemm_mma_stages_block_swizzle_tn_cute(const T *a, #ifdef CUTE_HGEMM_DEBUG print("G2SCopyA: "); print(G2SCopyA{}); print("\n"); print("G2SCopyB: "); print(G2SCopyB{}); print("\n"); + print("G2SCopyA Latex: \n"); print_latex(G2SCopyA{}); print("\n"); + print("G2SCopyB Latex: \n"); print_latex(G2SCopyB{}); print("\n"); #endif // copy from shared memory to register // use mma tiled ,so no tiled here diff --git a/kernels/hgemm/makefile b/kernels/hgemm/makefile index 52ef01fc..5518dfef 100644 --- a/kernels/hgemm/makefile +++ b/kernels/hgemm/makefile @@ -10,7 +10,7 @@ default: cute_89: nvcc cutlass/hgemm_mma_stage_tn_cute.cu -o hgemm_cute.89.bin $(DEFAULT_FLAGS_89) cute_89_debug: - nvcc cutlass/hgemm_mma_stage_tn_cute.cu -o hgemm_cute.89.debug.bin $(DEFAULT_FLAGS_89) -DCUTE_HGEMM_DEBUG + nvcc cutlass/hgemm_mma_stage_tn_cute.cu -o hgemm_cute.89.debug.bin $(DEFAULT_FLAGS_89) -DCUTE_HGEMM_DEBUG -Xcompiler "-Wno-format" mma_89: nvcc mma/hgemm_mma_stage.cu -o hgemm_mma_stage.89.bin $(DEFAULT_FLAGS_89) clean: diff --git a/kernels/hgemm/mma/hgemm_mma_stage.cu b/kernels/hgemm/mma/hgemm_mma_stage.cu index 5da0171c..52feac84 100644 --- a/kernels/hgemm/mma/hgemm_mma_stage.cu +++ b/kernels/hgemm/mma/hgemm_mma_stage.cu @@ -207,7 +207,7 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_kernel( } // make sure all memory issues ready. - if ((K_STAGE - 2) > 0) { + if constexpr ((K_STAGE - 2) > 0) { CP_ASYNC_WAIT_GROUP(0); __syncthreads(); } @@ -439,7 +439,7 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_kernel( } // make sure all memory issues ready. - if ((K_STAGE - 2) > 0) { + if constexpr ((K_STAGE - 2) > 0) { CP_ASYNC_WAIT_GROUP(0); __syncthreads(); } @@ -876,7 +876,7 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_kernel( } // make sure all memory issues ready. - if ((K_STAGE - 2) > 0) { + if constexpr ((K_STAGE - 2) > 0) { CP_ASYNC_WAIT_GROUP(0); __syncthreads(); } @@ -1304,7 +1304,7 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4_kernel( } // make sure all memory issues ready. - if ((K_STAGE - 2) > 0) { + if constexpr ((K_STAGE - 2) > 0) { CP_ASYNC_WAIT_GROUP(0); __syncthreads(); } @@ -1760,7 +1760,7 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr_kernel( } // make sure all memory issues ready. - if ((K_STAGE - 2) > 0) { + if constexpr ((K_STAGE - 2) > 0) { CP_ASYNC_WAIT_GROUP(0); __syncthreads(); } diff --git a/others/.gitignore b/others/.gitignore index 171541b1..65f67af8 100644 --- a/others/.gitignore +++ b/others/.gitignore @@ -19,3 +19,6 @@ __pycache__ *.bin outupt bin +*.log +*.txt +*.tex \ No newline at end of file diff --git a/slides/.gitignore b/slides/.gitignore index 171541b1..65f67af8 100644 --- a/slides/.gitignore +++ b/slides/.gitignore @@ -19,3 +19,6 @@ __pycache__ *.bin outupt bin +*.log +*.txt +*.tex \ No newline at end of file diff --git a/third-party/.gitignore b/third-party/.gitignore index eb907036..65f67af8 100644 --- a/third-party/.gitignore +++ b/third-party/.gitignore @@ -7,8 +7,18 @@ build *.whl tmp -bin -output __pycache__ - - +*.onnx +*.engine +*.pt +*.pth +*.nsys* +*.ncu* +*.sqlite* +*.engine +*.bin +outupt +bin +*.log +*.txt +*.tex \ No newline at end of file