Skip to content

Commit

Permalink
[FA2] Release flash-attn-mma split-kv/q🎉 (#160)
Browse files Browse the repository at this point in the history
* Update and rename flash_attn_mma_tiling.cu to flexiable_flash_attn_mma.cu

* Update flexiable_flash_attn_mma.cu

* Update flash_attn.cc

* Update flash_attn_mma.py

* Update flexiable_flash_attn_mma.cu

* Update flash_attn_mma.py

* Update flash_attn_mma.py

* Rename flexiable_flash_attn_mma.cu to flexiable_flash_attn_mma_split_kv.cu

* Create flexiable_flash_attn_mma_split_q.cu

* Update flexiable_flash_attn_mma_split_kv.cu

* Update flexiable_flash_attn_mma_split_q.cu

* Update flash_attn.cc

* Update flash_attn_mma.py

* Update flexiable_flash_attn_mma_split_kv.cu

* Update flexiable_flash_attn_mma_split_q.cu

* Update flash_attn_mma_stage.cu

* Update flexiable_flash_attn_mma_split_kv.cu

* Update flexiable_flash_attn_mma_split_q.cu

* Update flash_attn_mma.py

* Update flash_attn_mma.py

* Update README.md

* Update README.md

* Update flexiable_flash_attn_mma_split_kv.cu

* Update flash_attn_mma_naive.cu

* Update utils.h

* Update flash_attn_mma.py

* Update flexiable_flash_attn_mma_split_q.cu

* support flash-attn-mma split-q

* support flash-attn-mma split-q

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update .gitmodules

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* support flash-attn-mma split-q

* support flash-attn-mma split-q

* Update README.md

* support flash-attn-mma split-q

* support flash-attn-mma split-q

* support flash-attn-mma split-q

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* support flash-attn-mma split-q

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md
  • Loading branch information
DefTruth authored Dec 15, 2024
1 parent 81404c1 commit 5afd8c1
Show file tree
Hide file tree
Showing 19 changed files with 1,513 additions and 1,558 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ outupt
bin
*.log
*.txt
*.tex
*.tex
__pycache__
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ outupt
bin
*.log
*.txt
*.tex
*.tex
__pycache__
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[submodule "third-party/cutlass"]
path = third-party/cutlass
url = https://github.com/NVIDIA/cutlass.git
tag = v3.5.1
tag = v3.5.1
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -671,4 +671,4 @@ into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.
<https://www.gnu.org/licenses/why-not-lgpl.html>.
159 changes: 114 additions & 45 deletions README.md

Large diffs are not rendered by default.

520 changes: 144 additions & 376 deletions kernels/flash-attn/README.md

Large diffs are not rendered by default.

Empty file.
129 changes: 88 additions & 41 deletions kernels/flash-attn/flash_attn_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def get_args():
parser.add_argument("--no-rand-k", '--no-rk', action="store_true")
parser.add_argument("--no-rand-v", '--no-rv', action="store_true")
parser.add_argument("--no-rand-qkv", '--no-rqkv', action="store_true")
parser.add_argument("--naive", action="store_true")
parser.add_argument("--sdpa", action="store_true")
parser.add_argument("--run-torch-unfused", '--torch', action="store_true")
parser.add_argument("--run-torch-sdpa", '--sdpa', action="store_true")
parser.add_argument("--check", action="store_true")
parser.add_argument("--show-all", '--show', action="store_true")
parser.add_argument("--B", type=int, default=None)
Expand All @@ -46,6 +46,7 @@ def get_args():
parser.add_argument("--D", type=int, default=None)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--verbose", '--v', action="store_true")
parser.add_argument("--warmup", type=int, default=2)
parser.add_argument("--iters", type=int, default=10)
parser.add_argument("--range-k", '--gk', action="store_true")
Expand All @@ -59,10 +60,10 @@ def get_args():
# Load the CUDA kernel as a python module
lib = load(name='flash_attn_lib',
sources=[
'./naive/flash_attn_cuda.cu',
'./mma/flash_attn_mma_naive.cu',
'./mma/flash_attn_mma_stage.cu',
'./pybind/flash_attn.cc'],
'./mma/flash_attn_mma_split_kv.cu',
'./mma/flash_attn_mma_split_q.cu',
'./pybind/flash_attn.cc'
],
extra_cuda_cflags=[
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__",
Expand All @@ -72,10 +73,43 @@ def get_args():
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
"-Xptxas -v",
"-diag-suppress 177",
f"-I {project_dir}/kernels/flash-attn/utils",
"-DFLASH_ATTN_MMA_DEBUG" if args.debug else ""
],
extra_cflags=['-std=c++17'])
extra_cflags=['-std=c++17'],
verbose=args.verbose)


def get_mha_tflops(B, H, N, D, T=1.0):
# Q @ K^T FLOPs
flops_qk = B * H * N * N * (2 * D - 1)

# Scaling FLOPs
flops_scaling = B * H * N * N

# Safe_Softmax FLOPs
flops_row_max = B * H * N * (N - 1) # row max
flops_subtract_max = B * H * N * N # sub max
flops_exp = B * H * N * N # pointwise exp
flops_row_sum = B * H * N * (N - 1) # row sum
flops_normalization = B * H * N * N # 归一化

flops_safe_softmax = flops_row_max + flops_subtract_max + flops_exp + flops_row_sum + flops_normalization

# P @ V FLOPs
flops_pv = B * H * N * D * (2 * N - 1)

# Total FLOPs
total_flops = flops_qk + flops_scaling + flops_safe_softmax + flops_pv

# Convert to TFLOPS
# 1 TFLOPS = 10^12 FLOPS
# ref: https://imgtec.eetrend.com/blog/2021/100062210.html.
tflops = total_flops * 1e-12 / (T)

return tflops


def run_benchmark(perf_func: callable,
Expand Down Expand Up @@ -123,8 +157,14 @@ def run_benchmark(perf_func: callable,
out = perf_func(q, k, v)
torch.cuda.synchronize()
end = time.time()
total_secs = (end - start)
total_time = (end - start) * 1000 # ms
mean_time = total_time / iters
mean_secs = total_secs / iters
B, H, N, D = q.size()
if "flash" in tag:
B, N, H, D = q.size()
TFLOPS = get_mha_tflops(B, H, N, D, mean_secs)
out_info = f"{tag}"
out_val_first = out.flatten()[:3].detach().cpu().numpy().tolist()
out_val_last = out.flatten()[-3:].detach().cpu().numpy().tolist()
Expand All @@ -133,10 +173,11 @@ def run_benchmark(perf_func: callable,
out_val = out_val_first[:2]
out_val.append(out_val_last[-1])
out_val = [f"{v:<12}" for v in out_val]
print(f"{out_info:>20}: {out_val}, time:{mean_time:.6f}ms")
print(f"{out_info:>25}: {out_val}, time:{mean_time:<.6f}ms, TFLOPS:{TFLOPS:<6.2f}")
if show_all:
print(out)
time.sleep(0.05)
torch.cuda.synchronize()
return out.clone(), mean_time


Expand All @@ -159,18 +200,38 @@ def get_qkvo(B, H, N, D):
v = torch.ones(B, H, N, D, device="cuda", dtype=torch.half).contiguous()

o = torch.zeros(B, H, N, D, device="cuda", dtype=torch.half).contiguous()
tk = k.transpose(-2, -1).contiguous()
fq = q.transpose(1, 2).contiguous()
fk = k.transpose(1, 2).contiguous()
fv = v.transpose(1, 2).contiguous()

return q, k, v, o
return q, k, v, o, tk, fq, fk, fv


# un-fused naive attn
def naive_attn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
def unfused_standard_attn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1))))
att = F.softmax(att, dim=-1)
y = att @ v
return y


def check_all_close(out_flash: torch.Tensor, out_mma: torch.Tensor,
tag: str = "out_mma", show_all: bool = False):
out_flash = out_flash.transpose(1, 2)
if show_all:
for i in range(int(N/8)):
if i < 4:
print("-" * 120)
print(f"out_flash[:, :, {(i*8)}:{(i+1)*8}, :]:\n")
print(out_flash[:, :, (i*8):(i+1)*8, :].float())
print(f"{tag}[:, :, {(i*8)}:{(i+1)*8}, :]:\n")
print(out_mma[:, :, (i*8):(i+1)*8, :].float())
print("-" * 120)
all_close = torch.allclose(out_flash.float(), out_mma.float(), atol=1e-2)
print(f"out_flash vs {tag}: {all_close}")


Bs = [1, 2, 4] if not args.B else [args.B]
Hs = [1, 4, 8] if not args.H else [args.H]
Ns = [1024, 2048] if not args.N else [args.N]
Expand All @@ -180,42 +241,28 @@ def naive_attn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):

seed = args.seed if args.seed else random.choice(range(10000))
set_rand_seed(seed)
print("-" * 100)
print(" "* 10 + f"B: batch_size, H: n_head, N: seq_len, D: head_dim, "
print("-" * 120)
print(" "* 20 + f"B: batch_size, H: n_head, N: seq_len, D: head_dim, "
f"seed: {seed}, Warmup: {args.warmup}, Iters: {args.iters}")

for (B, H, N, D) in BHNDs:
print("-" * 100)
print(" " * 25 + f"B={B}, H={H}, N={N}, D={D}, Warmup: {args.warmup}, Iters: {args.iters}")
q, k, v, o = get_qkvo(B, H, N, D)
tk = k.transpose(-2, -1).contiguous()
fq = q.transpose(1, 2).contiguous()
fk = k.transpose(1, 2).contiguous()
fv = v.transpose(1, 2).contiguous()
print("-" * 120)
print(" " * 30 + f"B={B}, H={H}, N={N}, D={D}, Warmup: {args.warmup}, Iters: {args.iters}")
q, k, v, o, tk, fq, fk, fv = get_qkvo(B, H, N, D)
torch.cuda.synchronize()

if args.naive:
out_naive, _ = run_benchmark(naive_attn, q, k, v, "naive(unfused)")

# using fp16 Tesor Core MMA instruction
out_mma_naive, _ = run_benchmark(lib.flash_attn_mma_naive, q, k, v, "mma(naive)", o)
out_mma_stage1, _ = run_benchmark(lib.flash_attn_mma_stages, q, tk, v, "mma(stage1)", o, stages=1)
out_mma_stage2, _ = run_benchmark(lib.flash_attn_mma_stages, q, tk, v, "mma(stage2)", o, stages=2)
out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)")

if args.sdpa:
out_sdpa, _ = run_benchmark(F.scaled_dot_product_attention, q, k, v, "(sdpa)")
print("-" * 100)
if args.run_torch_unfused:
out_unfused, _ = run_benchmark(unfused_standard_attn, q, k, v, "torch(unfused)")
out_mma_split_kv1, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, tk, v, "mma(split-kv+stage1)", o, stages=1)
out_mma_split_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_kv, q, tk, v, "mma(split-kv+stage2)", o, stages=2)
out_mma_split_q1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, tk, v, "mma(split-q+stage1)", o, stages=1)
out_mma_split_q2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q, q, tk, v, "mma(split-q+stage2)", o, stages=2)
out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)")
if args.run_torch_sdpa:
out_sdpa, _ = run_benchmark(F.scaled_dot_product_attention, q, k, v, "(sdpa)")
print("-" * 120)

torch.cuda.synchronize()
if args.check:
out_flash = out_flash.transpose(1, 2)
for i in range(int(N/8)):
if i < 4:
print("-" * 100)
print(f"out_flash[:, :, {(i*8)}:{(i+1)*8}, :]:\n")
print(out_flash[:, :, (i*8):(i+1)*8, :].float())
print(f"out_mma_stage1[:, :, {(i*8)}:{(i+1)*8}, :]:\n")
print(out_mma_stage1[:, :, (i*8):(i+1)*8, :].float())
print("-" * 100)
print(f"{torch.allclose(out_flash.float(), out_mma_naive.float(), atol=1e-2)}")
check_all_close(out_flash, out_mma_split_kv1, "out_mma_split_kv1", args.show_all)
check_all_close(out_flash, out_mma_split_q1, "out_mma_split_q1", args.show_all)
Loading

0 comments on commit 5afd8c1

Please sign in to comment.