Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ipiszy committed Sep 3, 2024
1 parent 223b148 commit 7c3b111
Show file tree
Hide file tree
Showing 16 changed files with 131 additions and 97 deletions.
6 changes: 6 additions & 0 deletions .eggs/README.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
This directory contains eggs that were downloaded by setuptools to build, test, and run plug-ins.

This directory caches those eggs to prevent repeated downloads.

However, it is safe to delete this directory.

7 changes: 2 additions & 5 deletions flash_attn/bert_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,13 @@ def unpad_input(hidden_states, attention_mask, unused_mask=None):
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
res = (
return (
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
used_seqlens_in_batch,
)
if unused_mask is not None:
return res + (used_seqlens_in_batch, )
else:
return res


def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
Expand Down
2 changes: 1 addition & 1 deletion flash_attn/flash_blocksparse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def forward(
key_padding_mask_bool = key_padding_mask.bool_matrix
nheads = qkv.shape[-2]
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
x_unpad, indices, cu_seqlens, max_s, _ = unpad_input(x, key_padding_mask_bool)
x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
output_unpad = flash_blocksparse_attn_func(
x_unpad,
Expand Down
2 changes: 1 addition & 1 deletion flash_attn/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
hidden_states = hidden_states[subset_mask]
else:
batch, seqlen = hidden_states.shape[:2]
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
hidden_states, indices, cu_seqlens, max_seqlen_in_batch, _ = unpad_input(
hidden_states, key_padding_mask
)
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
Expand Down
45 changes: 27 additions & 18 deletions hopper/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,18 @@ void set_params_fprop(Flash_fwd_params &params,

// Causal is the special case where window_size_right == 0 and window_size_left < 0.
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
params.is_causal = window_size_left < 0 && window_size_right == 0;
if ((window_size_left >= 0 || window_size_right >= 0) && !params.is_causal) {
params.is_local = true;
}
window_size_left = std::min(int(seqlen_k), window_size_left);
window_size_right = std::min(int(seqlen_k), window_size_right);
if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; }
if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; }
if (window_size_left < 0) { window_size_left = seqlen_k; }
if (window_size_right < 0) { window_size_right = seqlen_k; }
params.window_size_left = window_size_left;
params.window_size_right = window_size_right;

params.is_causal = window_size_left == seqlen_k && window_size_right == 0;
if ((window_size_left < seqlen_k || window_size_right < seqlen_k) && !params.is_causal) {
params.is_local = true;
}

#ifdef FLASHATTENTION_DISABLE_LOCAL
TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
"This flash attention build does not support local attention.");
Expand Down Expand Up @@ -356,6 +357,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);

if (is_causal) { window_size_right = 0; }

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
Expand All @@ -381,8 +384,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
softmax_lse.data_ptr(),
/*p_dropout=*/0.f,
softmax_scale,
/*window_size_left=*/is_causal ? -1 : window_size_left,
/*window_size_right=*/is_causal ? 0 : window_size_right);
/*window_size_left=*/window_size_left,
/*window_size_right=*/window_size_right);

auto tile_count_semaphore = is_causal ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
Expand Down Expand Up @@ -536,6 +539,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);

if (is_causal) { window_size_right = 0; }

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
Expand All @@ -559,8 +564,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
softmax_lse.data_ptr(),
/*p_dropout=*/0.f,
softmax_scale,
is_causal ? -1 : window_size_left,
is_causal ? 0 : window_size_right,
window_size_left,
window_size_right,
/*seqlenq_ngroups_swapped=*/false,
/*unpadded_lse=*/true);
params.total_q = total_q;
Expand Down Expand Up @@ -621,8 +626,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
const float softmax_scale,
const bool is_causal,
const int window_size_left,
const int window_size_right,
int window_size_left,
int window_size_right,
const bool deterministic) {

#ifdef FLASHATTENTION_DISABLE_BACKWARD
Expand Down Expand Up @@ -739,6 +744,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
dv_expanded = dv;
}

if (is_causal) { window_size_right = 0; }

Flash_bwd_params params;

set_params_dgrad(params,
Expand All @@ -762,8 +769,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
softmax_d.data_ptr(),
/*p_dropout=*/0.f,
softmax_scale,
/*window_size_left=*/is_causal ? -1 : window_size_left,
/*window_size_right=*/is_causal ? 0 : window_size_right,
/*window_size_left=*/window_size_left,
/*window_size_right=*/window_size_right,
deterministic);
params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();

Expand Down Expand Up @@ -814,8 +821,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x
const int max_seqlen_k, // max sequence length to choose the kernel
const float softmax_scale,
const bool is_causal,
const int window_size_left,
const int window_size_right,
int window_size_left,
int window_size_right,
const bool deterministic) {

#ifdef FLASHATTENTION_DISABLE_BACKWARD
Expand Down Expand Up @@ -932,6 +939,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x
dout_padded = dout;
}

if (is_causal) { window_size_right = 0; }

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
Expand Down Expand Up @@ -978,8 +987,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x
softmax_d.data_ptr(),
/*p_dropout=*/0.f,
softmax_scale,
/*window_size_left=*/is_causal ? -1 : window_size_left,
/*window_size_right=*/is_causal ? 0 : window_size_right,
/*window_size_left=*/window_size_left,
/*window_size_right=*/window_size_right,
deterministic);
params.total_q = total_q;
params.total_k = total_k;
Expand Down
4 changes: 2 additions & 2 deletions hopper/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def backward(ctx, dout, *args):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None
return dq, dk, dv, None, None, None, None


class FlashAttnVarlenFunc(torch.autograd.Function):
Expand Down Expand Up @@ -289,7 +289,7 @@ def backward(ctx, dout, *args):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None


def flash_attn_func(
Expand Down
26 changes: 15 additions & 11 deletions hopper/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class FlashAttnBwd {

// Type Aliases
static constexpr bool Is_causal = CollectiveMainloop_::Is_causal;
static constexpr bool Is_local = CollectiveMainloop_::Is_local;
static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen);
static constexpr bool Varlen = CollectiveMainloop_::Varlen;

Expand Down Expand Up @@ -155,6 +156,7 @@ class FlashAttnBwd {
static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
static constexpr int NumCopyThreads = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;
static constexpr int kBlockM = get<0>(TileShape_MNK{});
static constexpr int kBlockN = get<1>(TileShape_MNK{});

using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
Expand Down Expand Up @@ -218,12 +220,12 @@ class FlashAttnBwd {
auto block_coord = work_tile_info.get_block_coord(params.scheduler);
auto [n_block, bidh, bidb] = block_coord;
if constexpr (Varlen) {
if (n_block * kBlockM >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) {
if (n_block * kBlockN >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) {
scheduler.prefetch_next_work(params.scheduler, work_tile_info);
continue;
}
}
if constexpr (Is_causal) {
if constexpr (Is_causal || Is_local) {
int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb);
if (m_block_min >= m_block_max) {
Expand All @@ -247,13 +249,13 @@ class FlashAttnBwd {
auto block_coord = work_tile_info.get_block_coord(params.scheduler);
auto [n_block, bidh, bidb] = block_coord;
if constexpr (Varlen) {
if (n_block * kBlockM >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; }
}
if constexpr (Is_causal) {
int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb);
if (m_block_min >= m_block_max) { continue; }
if (n_block * kBlockN >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; }
}
// if constexpr (Is_causal || Is_local) {
// int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
// int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb);
// if (m_block_min >= m_block_max) { continue; }
// }
collective_mainloop.store_dq(params.mainloop, shared_storage, block_coord);
}
}
Expand All @@ -277,11 +279,14 @@ class FlashAttnBwd {
auto block_coord = work_tile_info.get_block_coord(params.scheduler);
auto [n_block, bidh, bidb] = block_coord;
if constexpr (Varlen) {
if (n_block * kBlockM >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; }
if (n_block * kBlockN >= collective_mainloop.get_seqlen_k(params.mainloop, bidb)) { continue; }
}
if constexpr (Is_causal) {
if constexpr (Is_causal || Is_local) {
int const m_block_min = collective_mainloop.get_m_block_min(params.mainloop, n_block, bidb);
int const m_block_max = collective_mainloop.get_m_block_max(params.mainloop, n_block, bidb);
auto seqlen_q = collective_mainloop.get_seqlen_q(params.mainloop, bidb);
auto seqlen_k = collective_mainloop.get_seqlen_k(params.mainloop, bidb);
auto original_m_block_max = cute::ceil_div(seqlen_q, kBlockM);
if (m_block_min >= m_block_max) { // We exit early and write 0 to dK and dV
collective_epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord);
continue;
Expand All @@ -300,7 +305,6 @@ class FlashAttnBwd {
}
collective_epilogue.store_tail();
}

}

};
Expand Down
3 changes: 2 additions & 1 deletion hopper/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
params.b,
params.dq_semaphore,
params.cu_seqlens_q, params.cu_seqlens_k,
params.seqused_q, params.seqused_k
params.seqused_q, params.seqused_k,
params.window_size_left, params.window_size_right
};
typename CollectiveEpilogue::Arguments epilogue_args {
static_cast<Element*>(params.dk_ptr),
Expand Down
6 changes: 3 additions & 3 deletions hopper/flash_fwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
using CollectiveMainloop = flash::CollectiveMainloopFwd<Kernel_traits, Is_causal, Is_local, Seqlen_traits>;
using CollectiveEpilogue = flash::CollectiveEpilogueFwd<Kernel_traits, Seqlen_traits>;
using Scheduler = std::conditional_t<
Seqlen_traits::kUseVarSeqLen,
Seqlen_traits::kUseVarSeqLen || Is_local,
flash::SingleTileScheduler,
std::conditional_t<!Is_causal && !Is_local,
std::conditional_t<!Is_causal,
flash::StaticPersistentTileScheduler,
flash::DynamicPersistentTileScheduler<Kernel_traits::kNThreads - cutlass::NumThreadsPerWarpGroup, Kernel_traits::NumProducerThreads>
>>;
Expand Down Expand Up @@ -137,7 +137,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
// Only use Cluster if number of tiles along seqlen_q is even and not Is_causal
BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Is_local && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
run_flash_fwd<
Flash_fwd_kernel_traits<Headdim, 128, Is_causal ? 128 : 176, 12, 2, false, UseCluster ? 2 : 1, T>,
Flash_fwd_kernel_traits<Headdim, 128, (Is_causal || Is_local) ? 128 : 176, 12, 2, false, UseCluster ? 2 : 1, T>,
Is_causal, Is_local, Seqlen_traits
>(params, stream);
});
Expand Down
27 changes: 15 additions & 12 deletions hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ struct CollectiveMainloopBwd {
int const seqlen_q = get_seqlen_q(params, bidb);
int const seqlen_k = get_seqlen_k(params, bidb);
int m_block_max = cute::ceil_div(seqlen_q, kBlockM);
if constexpr (Is_causal || Is_local) {
if constexpr (Is_local) {
return std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + params.window_size_left, kBlockM));
} else {
return m_block_max;
Expand Down Expand Up @@ -613,6 +613,15 @@ struct CollectiveMainloopBwd {
}
cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::dQEmpty) /*id*/); // sdQ empty, ready to be written to
}
if constexpr (Deterministic) {
constexpr int kBlockM = get<0>(TileShape_MNK{});
int const seqlen_q = get_seqlen_q(params, bidb);
int const m_block_global_max = cute::ceil_div(seqlen_q, kBlockM);
#pragma unroll 2
for (; m_block < m_block_global_max; ++m_block) {
Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head);
}
}
}

CUTLASS_DEVICE void
Expand Down Expand Up @@ -745,7 +754,7 @@ struct CollectiveMainloopBwd {

// We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64
// this helps quite a bit to not have to do causal masking for most of the iterations.
if constexpr (Is_causal || Is_local) {
if constexpr (Is_causal) {
static constexpr int n_masking_steps = cute::ceil_div(kBlockN, kBlockM) + 1;
CUTLASS_PRAGMA_NO_UNROLL
for (; m_block < std::min(m_block_max, m_block_min + n_masking_steps); ++m_block) {
Expand All @@ -761,17 +770,12 @@ struct CollectiveMainloopBwd {
warpgroup_wait<1>();
Tensor cS = cute::make_identity_tensor(select<1, 0>(TileShape_MNK{}));
Tensor taccScS = thread_mma_SdP.partition_C(cS);
int local_row_offset_right = 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM + params.window_size_right;
int local_row_offset_left = seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM - params.window_size_left;
int causal_row_offset = 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM;
#pragma unroll
for (int i = 0; i < size(tSrS); ++i) {
if (int(get<0>(taccScS(i))) >=
std::min(int(get<1>(taccScS(i))) + local_row_offset_right, seqlen_k - n_block * kBlockN)) {
std::min(int(get<1>(taccScS(i))) + causal_row_offset, seqlen_k - n_block * kBlockN)) {
tSrS(i) = -INFINITY;
} else if constexpr (Is_local) {
if (int(get<0>(taccScS(i))) < std::max(0, local_row_offset_left)) {
tSrS(i) = -INFINITY;
}
}
}
// Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
Expand Down Expand Up @@ -836,9 +840,8 @@ struct CollectiveMainloopBwd {
int local_row_offset_left = seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM - params.window_size_left;
#pragma unroll
for (int i = 0; i < size(tSrS); ++i) {
if ((int(get<0>(taccScS(i))) >=
std::min(int(get<1>(taccScS(i))) + local_row_offset_right, seqlen_k - n_block * kBlockN)
) || (int(get<0>(taccScS(i))) < std::max(0, local_row_offset_left))) {
if ((int(get<0>(taccScS(i))) >= std::min(int(get<1>(taccScS(i))) + local_row_offset_right, seqlen_k - n_block * kBlockN)) ||
(int(get<0>(taccScS(i))) < std::max(int(get<1>(taccScS(i))) + local_row_offset_left, 0))) {
tSrS(i) = -INFINITY;
}
}
Expand Down
Loading

0 comments on commit 7c3b111

Please sign in to comment.