Skip to content

Commit

Permalink
[LayerNorm][FP16] Add pack support for f16x8 LD/ST (#46)
Browse files Browse the repository at this point in the history
* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update layer_norm.cu

* Update layer_norm.py

* Update README.md
  • Loading branch information
DefTruth authored Sep 25, 2024
1 parent e28cb4d commit 4667308
Show file tree
Hide file tree
Showing 4 changed files with 291 additions and 34 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
<img src=https://img.shields.io/badge/License-GPLv3.0-turquoise.svg >
</div>

🎉 **CUDA Learn Notes**: This repo aims to build a **Modern CUDA Learn Notes with PyTorch** for beginners, including **fp32, fp16/bf16, fp8/int8, Tensor/CUDA Cores**, flash_attn, sgemm, sgemv, hgemm, hgemv, warp/block reduce, dot prod, elementwise, sigmoid, relu, softmax, layernorm, rmsnorm, hist and some CUDA optimization techniques (pack LDST, warp gemv, sliced_k/split_k/pipeline gemm, bank conflicts free, MMA, etc).
🎉 **CUDA Learn Notes**: This repo aims to build a **Modern CUDA Learn Notes with PyTorch** for **[Beginners]**, including **fp32, fp16/bf16, fp8/int8, Tensor/CUDA Cores**, flash_attn, sgemm, sgemv, hgemm, hgemv, warp/block reduce, dot prod, elementwise, sigmoid, relu, softmax, layernorm, rmsnorm, hist and some CUDA optimization techniques (pack LDST, warp gemv, sliced_k/split_k/pipeline gemm, bank conflicts free, MMA, etc).

<img width="1438" alt="image" src="https://github.com/user-attachments/assets/0c5e5125-586f-43fa-8e8b-e2c61c1afbbe">

Expand All @@ -21,7 +21,7 @@

|📖 cuda kernel| 📖 elem dtype| 📖 acc dtype| 📖 docs | 📖 level |
|:---|:---|:---|:---|:---|
| ✔️ [nsys/ncu usage(timeline/ptx/sass)](./nvidia-nsight/)|/|/|[link](./nvidia-nsight/)|⭐️|
| ✔️ [nsys/ncu(timeline/ptx/sass)](./nvidia-nsight/)|/|/|[link](./nvidia-nsight/)|⭐️|
| ✔️ [elementwise_f32](./elementwise/elementwise.cu)|f32|/|[link](./elementwise/)|⭐️|
| ✔️ [elementwise_f32x4](./elementwise/elementwise.cu)|f32|/|[link](./elementwise/)|⭐️|
| ✔️ [elementwise_f16](./elementwise/elementwise.cu)|f16|/|[link](./elementwise/)|⭐️|
Expand Down Expand Up @@ -80,6 +80,7 @@
| ✔️ [layer_norm_f16_f16(per token)](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f16x2_f16(per token)](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f16x8_f16(per token)](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f16x8_pack_f16(per token)](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f16_f32(per token)](./layer-norm/layer_norm.cu)|f16|f32|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f32(per token)](./rms-norm/rms_norm.cu)|f32|f32|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f32x4(per token)](./rms-norm/rms_norm.cu)|f32|f32|[link](./rms-norm/)|⭐️⭐️|
Expand Down
75 changes: 64 additions & 11 deletions layer-norm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- [X] layer_norm_f16_f16_kernel
- [X] layer_norm_f16x2_f16_kernel
- [X] layer_norm_f16x8_f16_kernel
- [X] layer_norm_f16x8_pack_f16_kernel
- [X] layer_norm_f16_f32_kernel
- [X] PyTorch bindings

Expand All @@ -23,15 +24,67 @@ python3 layer_norm.py
输出:

```bash
--------------------------------------------------------------------------------
out_f32: [0.54253572, -0.13322251, 1.65217566], time:0.01894307ms
out_f32x4: [0.54253572, -0.13322251, 1.65217566], time:0.00595951ms
out_f32_th: [0.54200566, -0.13309236, 1.65056157], time:0.07212615ms
--------------------------------------------------------------------------------
out_f16f16: [0.54248047, -0.13330078, 1.65332031], time:0.01863098ms
out_f16x2f16: [0.54248047, -0.13330078, 1.65332031], time:0.00949597ms
out_f16x8f16: [0.54248047, -0.13317871, 1.65234375], time:0.00466394ms
out_f16f32: [0.54248047, -0.13317871, 1.65234375], time:0.01892662ms
out_f16_th: [0.54199219, -0.13305664, 1.65039062], time:0.07164359ms
--------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
N=4096, K=512
-------------------------------------------------------------------------------------
out_f32: ['-1.76292217 ', '0.04765211 ', '0.50859255 '], time:0.01897240ms
out_f32x4: ['-1.76292217 ', '0.04765211 ', '0.50859255 '], time:0.00600266ms
out_f32_th: ['-1.76119995 ', '0.04760556 ', '0.50809568 '], time:0.07085347ms
-------------------------------------------------------------------------------------
out_f16f16: ['-1.76367188 ', '0.04763794 ', '0.50878906 '], time:0.01869035ms
out_f16f32: ['-1.76367188 ', '0.04766846 ', '0.50878906 '], time:0.01897883ms
out_f16x2f16: ['-1.76367188 ', '0.04766846 ', '0.50878906 '], time:0.00951219ms
out_f16x8f16: ['-1.76367188 ', '0.04766846 ', '0.50878906 '], time:0.00467825ms
out_f16x8packf16: ['-1.76367188 ', '0.04763794 ', '0.50878906 '], time:0.00430202ms
out_f16_th: ['-1.76171875 ', '0.04760742 ', '0.50830078 '], time:0.07009959ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
N=4096, K=1024
-------------------------------------------------------------------------------------
out_f32: ['-0.65619785 ', '1.33576787 ', '-0.29172164 '], time:0.05123448ms
out_f32x4: ['-0.65619785 ', '1.33576787 ', '-0.29172164 '], time:0.01073551ms
out_f32_th: ['-0.65587735 ', '1.33511555 ', '-0.29157916 '], time:0.07034254ms
-------------------------------------------------------------------------------------
out_f16f16: ['-0.65576172 ', '1.3359375 ', '-0.29174805 '], time:0.05320668ms
out_f16f32: ['-0.65576172 ', '1.3359375 ', '-0.29150391 '], time:0.05061388ms
out_f16x2f16: ['-0.65576172 ', '1.3359375 ', '-0.29174805 '], time:0.01861978ms
out_f16x8f16: ['-0.65576172 ', '1.3359375 ', '-0.29174805 '], time:0.00745845ms
out_f16x8packf16: ['-0.65576172 ', '1.3359375 ', '-0.29174805 '], time:0.00648832ms
out_f16_th: ['-0.65527344 ', '1.33398438 ', '-0.29150391 '], time:0.07068610ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
N=4096, K=2048
-------------------------------------------------------------------------------------
out_f32x4: ['0.92044634 ', '0.37421227 ', '-2.49094558 '], time:0.02202415ms
out_f32_th: ['0.92022169 ', '0.37412092 ', '-2.49033761 '], time:0.12026787ms
-------------------------------------------------------------------------------------
out_f16x2f16: ['0.92041016 ', '0.37426758 ', '-2.49023438 '], time:0.05346847ms
out_f16x8f16: ['0.92041016 ', '0.37426758 ', '-2.49023438 '], time:0.01381087ms
out_f16x8packf16: ['0.92041016 ', '0.37426758 ', '-2.49023438 '], time:0.01159072ms
out_f16_th: ['0.92041016 ', '0.37426758 ', '-2.49023438 '], time:0.08454061ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
N=4096, K=4096
-------------------------------------------------------------------------------------
out_f32x4: ['-2.05339074 ', '0.25924587 ', '0.42393678 '], time:0.18885875ms
out_f32_th: ['-2.05314016 ', '0.25921422 ', '0.42388505 '], time:0.77834105ms
-------------------------------------------------------------------------------------
out_f16x8f16: ['-2.05273438 ', '0.2590332 ', '0.42382812 '], time:0.03327322ms
out_f16x8packf16: ['-2.05273438 ', '0.2590332 ', '0.42382812 '], time:0.02402687ms
out_f16_th: ['-2.05273438 ', '0.2590332 ', '0.42382812 '], time:0.17436218ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
N=4096, K=8192
-------------------------------------------------------------------------------------
out_f16x8f16: ['-1.0234375 ', '-0.3371582 ', '-1.54882812 '], time:0.19311237ms
out_f16x8packf16: ['-1.0234375 ', '-0.33691406 ', '-1.54882812 '], time:0.18668032ms
out_f16_th: ['-1.0234375 ', '-0.33691406 ', '-1.54882812 '], time:0.84443021ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
N=8192, K=8192
-------------------------------------------------------------------------------------
out_f16x8f16: ['-1.03320312 ', '0.41455078 ', '-0.49707031 '], time:0.38361049ms
out_f16x8packf16: ['-1.03320312 ', '0.41455078 ', '-0.49707031 '], time:0.40809250ms
out_f16_th: ['-1.03320312 ', '0.41455078 ', '-0.49707031 '], time:1.99517584ms
-------------------------------------------------------------------------------------
```
139 changes: 130 additions & 9 deletions layer-norm/layer_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#define INT4(value) (reinterpret_cast<int4*>(&(value))[0])
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
#define LDST128BITS(value) (reinterpret_cast<float4*>(&(value))[0])

// -------------------------------------- FP32 --------------------------------------
// Warp Reduce Sum
Expand Down Expand Up @@ -325,6 +327,55 @@ __global__ void layer_norm_f16_f32_kernel(half* x, half* y, float g, float b, in
}
}

template<const int NUM_THREADS=256>
__global__ void layer_norm_f16x8_pack_f16_kernel(half* x, half* y, float g, float b, int N, int K) {
int tid = threadIdx.x; // 0..K-1
int bid = blockIdx.x; // 0..N-1
int idx = (bid * blockDim.x + threadIdx.x) * 8;
const half epsilon = __float2half(1e-5f);
const half g_ = __float2half(g);
const half b_ = __float2half(b);
const half K_ = __int2half_rn(K);
const half z_ = __float2half(0.0f);

__shared__ half s_mean; // shared within block
__shared__ half s_variance; // shared within block
// temporary register(memory), .local space in ptx, addressable
half pack_x[8], pack_y[8]; // 8x16 bits=128 bits.
// reinterpret as float4 and load 128 bits in 1 memory issue.
LDST128BITS(pack_x[0]) = LDST128BITS(x[idx]); // load 128 bits

half value = z_;
#pragma unroll
for (int i = 0; i < 8; ++i) {
value += ((idx + i) < N * K ? pack_x[i] : z_);
}
half sum = block_reduce_sum_f16_f16<NUM_THREADS>(value);
if (tid == 0) s_mean = sum / K_;
// wait for s_mean in shared memory to be ready for all threads
__syncthreads();

half variance = z_;
#pragma unroll
for (int i = 0; i < 8; ++i) {
half v_hat = pack_x[i] - s_mean;
variance += ((idx + i) < N * K ? v_hat * v_hat : z_);
}
variance = block_reduce_sum_f16_f16<NUM_THREADS>(variance);
if (tid == 0) s_variance = hrsqrt(variance / (K_ + epsilon));
// wait for s_variance in shared memory to be ready for all threads
__syncthreads();

#pragma unroll
for (int i = 0; i < 8; ++i) {
// TODO: use __hfma2, __hsub2, __hmul2 here
pack_y[i] = __hfma((pack_x[i] - s_mean) * s_variance, g_, b_);
}
// reinterpret as float4 and store 128 bits in 1 memory issue.
if ((idx + 7) < N * K) { LDST128BITS(y[idx]) = LDST128BITS(pack_y[0]); }
// TODO: support non 8-multiple K here
}

// --------------------- PyTorch bindings for custom kernel -----------------------
#define STRINGFY(str) #str
#define TORCH_BINDING_COMMON_EXTENSION(func) \
Expand All @@ -350,7 +401,7 @@ layer_norm_f32_kernel<(K)><<<grid, block>>>( \

#define DISPATCH_LAYER_NORM_F32_KERNEL(N, K) \
dim3 block((K)); \
dim3 grid((N)); \
dim3 grid((N)); \
switch ((K)) \
{ \
case 64: \
Expand Down Expand Up @@ -382,7 +433,7 @@ layer_norm_f32x4_kernel<(K)/4><<<grid, block>>>( \

#define DISPATCH_LAYER_NORM_F32x4_KERNEL(N, K) \
dim3 block((K)/4); \
dim3 grid((N)); \
dim3 grid((N)); \
switch ((K)) \
{ \
case 64: \
Expand All @@ -400,9 +451,15 @@ layer_norm_f32x4_kernel<(K)/4><<<grid, block>>>( \
case 1024: \
LANUCH_LAYER_NORM_F32x4_KERNEL(1024) \
break; \
case 2048: \
LANUCH_LAYER_NORM_F32x4_KERNEL(2048) \
break; \
case 4096: \
LANUCH_LAYER_NORM_F32x4_KERNEL(4096) \
break; \
default: \
throw std::runtime_error( \
"only support K: 64/128/256/512/1024"); \
"only support K: 64/128/.../1024*4"); \
break; \
}

Expand Down Expand Up @@ -433,7 +490,7 @@ layer_norm_f16_f16_kernel<(K)><<<grid, block>>>( \

#define DISPATCH_LAYER_NORM_F16F16_KERNEL(N, K) \
dim3 block((K)); \
dim3 grid((N)); \
dim3 grid((N)); \
switch ((K)) \
{ \
case 64: \
Expand Down Expand Up @@ -465,7 +522,7 @@ layer_norm_f16_f32_kernel<(K)><<<grid, block>>>( \

#define DISPATCH_LAYER_NORM_F16F32_KERNEL(N, K) \
dim3 block((K)); \
dim3 grid((N)); \
dim3 grid((N)); \
switch ((K)) \
{ \
case 64: \
Expand Down Expand Up @@ -497,7 +554,7 @@ layer_norm_f16x2_f16_kernel<(K)/2><<<grid, block>>>( \

#define DISPATCH_LAYER_NORM_F16x2F16_KERNEL(N, K) \
dim3 block((K)/2); \
dim3 grid((N)); \
dim3 grid((N)); \
switch ((K)) \
{ \
case 64: \
Expand All @@ -515,9 +572,12 @@ layer_norm_f16x2_f16_kernel<(K)/2><<<grid, block>>>( \
case 1024: \
LANUCH_LAYER_NORM_F16x2F16_KERNEL(1024) \
break; \
case 2048: \
LANUCH_LAYER_NORM_F16x2F16_KERNEL(2048) \
break; \
default: \
throw std::runtime_error( \
"only support K: 64/128/256/512/1024"); \
"only support K: 64/128/.../1024*2"); \
break; \
}

Expand All @@ -529,7 +589,7 @@ layer_norm_f16x8_f16_kernel<(K)/8><<<grid, block>>>( \

#define DISPATCH_LAYER_NORM_F16x8F16_KERNEL(N, K) \
dim3 block((K)/8); \
dim3 grid((N)); \
dim3 grid((N)); \
switch ((K)) \
{ \
case 64: \
Expand All @@ -547,12 +607,62 @@ layer_norm_f16x8_f16_kernel<(K)/8><<<grid, block>>>( \
case 1024: \
LANUCH_LAYER_NORM_F16x8F16_KERNEL(1024) \
break; \
case 2048: \
LANUCH_LAYER_NORM_F16x8F16_KERNEL(2048) \
break; \
case 4096: \
LANUCH_LAYER_NORM_F16x8F16_KERNEL(4096) \
break; \
case 8192: \
LANUCH_LAYER_NORM_F16x8F16_KERNEL(8192) \
break; \
default: \
throw std::runtime_error( \
"only support K: 64/128/256/512/1024"); \
"only support K: 64/128/.../1024*8"); \
break; \
}

#define LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL(K) \
layer_norm_f16x8_pack_f16_kernel<(K)/8><<<grid, block>>>( \
reinterpret_cast<half*>(x.data_ptr()), \
reinterpret_cast<half*>(y.data_ptr()), \
g, b, N, (K));

#define DISPATCH_LAYER_NORM_F16x8_PACK_F16_KERNEL(N, K) \
dim3 block((K)/8); \
dim3 grid((N)); \
switch ((K)) \
{ \
case 64: \
LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL(64) \
break; \
case 128: \
LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL(128) \
break; \
case 256: \
LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL(256) \
break; \
case 512: \
LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL(512) \
break; \
case 1024: \
LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL(1024) \
break; \
case 2048: \
LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL(2048) \
break; \
case 4096: \
LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL(4096) \
break; \
case 8192: \
LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL(8192) \
break; \
default: \
throw std::runtime_error( \
"only support K: 64/128/.../1024*8"); \
break; \
}

void layer_norm_f16_f16(torch::Tensor x, torch::Tensor y, float g, float b) {
CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(y, torch::kHalf)
Expand Down Expand Up @@ -580,6 +690,16 @@ void layer_norm_f16x8_f16(torch::Tensor x, torch::Tensor y, float g, float b) {
DISPATCH_LAYER_NORM_F16x8F16_KERNEL(N, K)
}

void layer_norm_f16x8_pack_f16(torch::Tensor x, torch::Tensor y, float g, float b) {
CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(y, torch::kHalf)
CHECK_TORCH_TENSOR_SHAPE(x, y)
const int N = x.size(0);
const int K = x.size(1);
DISPATCH_LAYER_NORM_F16x8_PACK_F16_KERNEL(N, K)
}


void layer_norm_f16_f32(torch::Tensor x, torch::Tensor y, float g, float b) {
CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(y, torch::kHalf)
Expand All @@ -595,6 +715,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
TORCH_BINDING_COMMON_EXTENSION(layer_norm_f16_f16)
TORCH_BINDING_COMMON_EXTENSION(layer_norm_f16x2_f16)
TORCH_BINDING_COMMON_EXTENSION(layer_norm_f16x8_f16)
TORCH_BINDING_COMMON_EXTENSION(layer_norm_f16x8_pack_f16)
TORCH_BINDING_COMMON_EXTENSION(layer_norm_f16_f32)
}

Loading

0 comments on commit 4667308

Please sign in to comment.