Skip to content

Commit

Permalink
fix call flash_attn
Browse files Browse the repository at this point in the history
  • Loading branch information
chenkenbio committed Jan 21, 2024
1 parent a461ef0 commit 32a2654
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/splicebert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def forward(
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)

if self.flash and attention_mask is None:
if self.flash:
# query_layer, key_layer shape: (batch_size, num_heads, seq_len, head_size)
# value_layer shape: (batch_size, num_heads, seq_len, head_size)
query_layer = query_layer.permute(0, 2, 1, 3).contiguous() # (batch_size, seq_len, num_heads, head_size)
Expand Down

0 comments on commit 32a2654

Please sign in to comment.