Skip to content

Commit

Permalink
[RMSNorm][FP16] Pack f16x8 rmsnorm (#47)
Browse files Browse the repository at this point in the history
* Update .gitmodules

* Update README.md

* Update rms_norm.cu

* Update rms_norm.py

* Update README.md

* Update rms_norm.cu

* Update rms_norm.py

* Update README.md

* Update README.md
  • Loading branch information
DefTruth authored Sep 25, 2024
1 parent 4667308 commit 54c761d
Show file tree
Hide file tree
Showing 5 changed files with 351 additions and 210 deletions.
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[submodule "third-party/cutlass"]
path = third-party/cutlass
url = git@github.com:NVIDIA/cutlass.git
url = https://github.com/NVIDIA/cutlass.git
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@
| ✔️ [rms_norm_f16x2_f16(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x8_f16(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x8_f32(per token)](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x16_f16(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x16_f32(per token)](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x8_pack_f16(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x8_pack_f32(per token)](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16_f32(per token)](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [sgemm_sliced_k_f32](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_t_8x8_sliced_k_f32x4](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
Expand Down
87 changes: 71 additions & 16 deletions rms-norm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
- [X] rms_norm_f16x2_f16_kernel
- [X] rms_norm_f16x8_f16_kernel
- [X] rms_norm_f16x8_f32_kernel
- [X] rms_norm_f16x16_f16_kernel
- [X] rms_norm_f16x16_f32_kernel
- [X] rms_norm_f16x8_pack_f16_kernel
- [X] rms_norm_f16x8_pack_f32_kernel
- [X] rms_norm_f16_f32_kernel
- [X] PyTorch bindings

Expand All @@ -26,18 +26,73 @@ python3 rms_norm.py
输出:

```bash
--------------------------------------------------------------------------------
out_f32: [0.92419142, -0.08846965, 1.06359947], time:0.03389192ms
out_f32x4: [0.92419147, -0.08846966, 1.06359959], time:0.00855207ms
out_f32_th: [0.92419606, -0.08847010, 1.06360483], time:0.04171062ms
--------------------------------------------------------------------------------
out_f16f16: [0.92431641, -0.08843994, 1.06347656], time:0.03518176ms
out_f16x2f16: [0.92431641, -0.08843994, 1.06347656], time:0.01200986ms
out_f16x8f16: [0.92431641, -0.08843994, 1.06347656], time:0.00625682ms
out_f16x8f32: [0.92431641, -0.08843994, 1.06347656], time:0.00625014ms
out_f16x16f16: [0.92431641, -0.08843994, 1.06347656], time:0.02620339ms
out_f16x16f32: [0.92431641, -0.08843994, 1.06347656], time:0.01505637ms
out_f16f32: [0.92431641, -0.08843994, 1.06347656], time:0.03300810ms
out_f16_th: [0.92431641, -0.08843994, 1.06347656], time:0.04187107ms
--------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
N=4096, K=512
out_f32: ['0.04078517 ', '0.74503314 ', '0.87149841 '], time:0.01198173ms
out_f32x4: ['0.04078517 ', '0.74503314 ', '0.87149841 '], time:0.00517488ms
out_f32_th: ['0.04078539 ', '0.74503714 ', '0.87150306 '], time:0.04351616ms
-------------------------------------------------------------------------------------
out_f16f16: ['0.040802 ', '0.74511719 ', '0.87158203 '], time:0.01200986ms
out_f16f32: ['0.040802 ', '0.74511719 ', '0.87109375 '], time:0.01180410ms
out_f16x2f16: ['0.040802 ', '0.74511719 ', '0.87158203 '], time:0.00670171ms
out_f16x8f16: ['0.040802 ', '0.74511719 ', '0.87158203 '], time:0.00411820ms
out_f16x8f32: ['0.040802 ', '0.74511719 ', '0.87158203 '], time:0.00411677ms
out_f16x8packf16: ['0.040802 ', '0.74511719 ', '0.87158203 '], time:0.00411630ms
out_f16x8packf32: ['0.040802 ', '0.74511719 ', '0.87109375 '], time:0.00399137ms
out_f16_th: ['0.040802 ', '0.74511719 ', '0.87158203 '], time:0.04383564ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
N=4096, K=1024
out_f32: ['-0.76329279 ', '-0.62111992 ', '-1.45531178 '], time:0.03398657ms
out_f32x4: ['-0.76329279 ', '-0.62111992 ', '-1.45531178 '], time:0.00862885ms
out_f32_th: ['-0.76329684 ', '-0.62112319 ', '-1.4553194 '], time:0.04355550ms
-------------------------------------------------------------------------------------
out_f16f16: ['-0.76318359 ', '-0.62109375 ', '-1.45507812 '], time:0.03526235ms
out_f16f32: ['-0.76318359 ', '-0.62109375 ', '-1.45605469 '], time:0.03302288ms
out_f16x2f16: ['-0.76318359 ', '-0.62109375 ', '-1.45507812 '], time:0.01215649ms
out_f16x8f16: ['-0.76318359 ', '-0.62109375 ', '-1.45507812 '], time:0.00632071ms
out_f16x8f32: ['-0.76318359 ', '-0.62109375 ', '-1.45507812 '], time:0.00631690ms
out_f16x8packf16: ['-0.76318359 ', '-0.62109375 ', '-1.45507812 '], time:0.00528240ms
out_f16x8packf32: ['-0.76318359 ', '-0.62109375 ', '-1.45605469 '], time:0.00519514ms
out_f16_th: ['-0.76318359 ', '-0.62109375 ', '-1.45507812 '], time:0.04399920ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
N=4096, K=2048
out_f32x4: ['-0.17984088 ', '-1.76387513 ', '-0.32782754 '], time:0.01650691ms
out_f32_th: ['-0.17984176 ', '-1.76388371 ', '-0.32782915 '], time:0.09451318ms
-------------------------------------------------------------------------------------
out_f16x2f16: ['-0.17980957 ', '-1.76367188 ', '-0.32788086 '], time:0.03497124ms
out_f16x8f16: ['-0.17980957 ', '-1.76367188 ', '-0.32788086 '], time:0.01254177ms
out_f16x8f32: ['-0.17980957 ', '-1.76367188 ', '-0.32788086 '], time:0.01253581ms
out_f16x8packf16: ['-0.17980957 ', '-1.76367188 ', '-0.32788086 '], time:0.00903535ms
out_f16x8packf32: ['-0.17980957 ', '-1.76367188 ', '-0.32788086 '], time:0.00894380ms
out_f16_th: ['-0.17980957 ', '-1.76367188 ', '-0.32788086 '], time:0.04889655ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
N=4096, K=4096
out_f32x4: ['-1.14100003 ', '-0.71529448 ', '2.26544118 '], time:0.18783689ms
out_f32_th: ['-1.14100587 ', '-0.71529812 ', '2.26545286 '], time:0.52556086ms
-------------------------------------------------------------------------------------
out_f16x8f16: ['-1.140625 ', '-0.71484375 ', '2.26367188 '], time:0.03605795ms
out_f16x8f32: ['-1.140625 ', '-0.71484375 ', '2.26367188 '], time:0.03605533ms
out_f16x8packf16: ['-1.140625 ', '-0.71484375 ', '2.26367188 '], time:0.01718473ms
out_f16x8packf32: ['-1.140625 ', '-0.71533203 ', '2.26367188 '], time:0.01735568ms
out_f16_th: ['-1.140625 ', '-0.71484375 ', '2.26367188 '], time:0.11150384ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
N=4096, K=8192
out_f16x8f16: ['-0.40844727 ', '-0.14294434 ', '-0.93359375 '], time:0.19292974ms
out_f16x8f32: ['-0.40844727 ', '-0.14294434 ', '-0.93359375 '], time:0.19298863ms
out_f16x8packf16: ['-0.40844727 ', '-0.14294434 ', '-0.93359375 '], time:0.18497562ms
out_f16x8packf32: ['-0.40844727 ', '-0.14294434 ', '-0.93310547 '], time:0.18479729ms
out_f16_th: ['-0.40844727 ', '-0.14294434 ', '-0.93359375 '], time:0.59557104ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
N=8192, K=8192
out_f16x8f16: ['-0.35253906 ', '-1.04101562 ', '0.17358398 '], time:0.38169765ms
out_f16x8f32: ['-0.35253906 ', '-1.04101562 ', '0.17358398 '], time:0.38264203ms
out_f16x8packf16: ['-0.35253906 ', '-1.04101562 ', '0.17358398 '], time:0.40794849ms
out_f16x8packf32: ['-0.35229492 ', '-1.04003906 ', '0.17346191 '], time:0.40747380ms
out_f16_th: ['-0.35229492 ', '-1.04003906 ', '0.17346191 '], time:1.35807014ms
-------------------------------------------------------------------------------------
```
Loading

0 comments on commit 54c761d

Please sign in to comment.