From 32a2654dd85d2818654307966ebc926546169621 Mon Sep 17 00:00:00 2001 From: chenkenbio Date: Mon, 22 Jan 2024 00:58:46 +0800 Subject: [PATCH] fix call flash_attn --- src/splicebert_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/splicebert_model.py b/src/splicebert_model.py index d91702c..25d48ef 100644 --- a/src/splicebert_model.py +++ b/src/splicebert_model.py @@ -229,7 +229,7 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_layer, value_layer) - if self.flash and attention_mask is None: + if self.flash: # query_layer, key_layer shape: (batch_size, num_heads, seq_len, head_size) # value_layer shape: (batch_size, num_heads, seq_len, head_size) query_layer = query_layer.permute(0, 2, 1, 3).contiguous() # (batch_size, seq_len, num_heads, head_size)