Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
AndyZhou952 committed Feb 6, 2025
1 parent 849fa47 commit abc5c8e
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tools/captioners/PLLaVA/models/llama/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _update_causal_mask(self, attention_mask: Tensor, input_tensor: Tensor, cach
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
fill_value = -ms.numpy.inf if self.attn_implementation == "flash_attention" else 1.0
fill_value = -ms.numpy.inf if self.attn_implementation == "eager" else 1.0
causal_mask = ops.full((sequence_length, target_length), fill_value=fill_value, dtype=dtype)
exclude_mask = ops.arange(target_length) > cache_position.reshape(-1, 1)
causal_mask = ops.masked_fill(causal_mask, ~exclude_mask, Tensor(0, dtype=dtype))
Expand Down

0 comments on commit abc5c8e

Please sign in to comment.