-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ModernBERT bug fixes #35404
Open
warner-benjamin
wants to merge
3
commits into
huggingface:main
Choose a base branch
from
AnswerDotAI:modernbert_bug_fixes
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+53
−19
Open
ModernBERT bug fixes #35404
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
# limitations under the License. | ||
|
||
import math | ||
from contextlib import nullcontext | ||
from typing import Dict, Literal, Optional, Tuple, Union | ||
|
||
import torch | ||
|
@@ -141,6 +142,9 @@ class ModernBertConfig(PretrainedConfig): | |
the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not | ||
shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may | ||
be faster in some scenarios. | ||
repad_logits_with_grad (`bool`, *optional*, defaults to `False`): | ||
When True, ModernBertForMaskedLM keep track of the logits' gradient when repadding for output. This only | ||
applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient. | ||
|
||
Examples: | ||
|
||
|
@@ -196,6 +200,7 @@ def __init__( | |
sparse_prediction=False, | ||
sparse_pred_ignore_index=-100, | ||
reference_compile=None, | ||
repad_logits_with_grad=False, | ||
**kwargs, | ||
): | ||
super().__init__( | ||
|
@@ -235,6 +240,7 @@ def __init__( | |
self.sparse_prediction = sparse_prediction | ||
self.sparse_pred_ignore_index = sparse_pred_ignore_index | ||
self.reference_compile = reference_compile | ||
self.repad_logits_with_grad = repad_logits_with_grad | ||
|
||
if self.classifier_pooling not in ["cls", "mean"]: | ||
raise ValueError( | ||
|
@@ -852,12 +858,14 @@ def _autoset_attn_implementation( | |
): | ||
# If the user didn't specify anything, try to use flash_attention_2 if available. | ||
# Otherwise we fall back to the default SDPA -> Eager from the super() method. | ||
# ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't | ||
# need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check. | ||
if config._attn_implementation_internal is None: | ||
config._attn_implementation_internal = "flash_attention_2" | ||
try: | ||
return cls._check_and_enable_flash_attn_2( | ||
config, | ||
torch_dtype=torch_dtype, | ||
torch_dtype=torch.float16, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is a cleaner solution to avoid the unnecessary FP32 warning than I figured was possible, nice. |
||
device_map=device_map, | ||
hard_check_only=False, | ||
check_device_map=check_device_map, | ||
|
@@ -867,7 +875,7 @@ def _autoset_attn_implementation( | |
return super()._autoset_attn_implementation( | ||
config, | ||
use_flash_attention_2=use_flash_attention_2, | ||
torch_dtype=torch_dtype, | ||
torch_dtype=torch.float16, | ||
device_map=device_map, | ||
check_device_map=check_device_map, | ||
) | ||
|
@@ -892,6 +900,14 @@ def _maybe_set_compile(self): | |
) | ||
self.config.reference_compile = False | ||
|
||
if self.device.type == "cpu": | ||
if self.config.reference_compile: | ||
logger.warning_once( | ||
"Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. " | ||
"Falling back to non-compiled mode." | ||
) | ||
self.config.reference_compile = False | ||
|
||
if self.config.reference_compile is None: | ||
self.config.reference_compile = is_triton_available() | ||
|
||
|
@@ -911,8 +927,8 @@ def resize_token_embeddings(self, *args, **kwargs): | |
MODERNBERT_INPUTS_DOCSTRING = r""" | ||
Args: | ||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | ||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide | ||
it. | ||
Indices of input sequence tokens in the vocabulary. With Flash Attention 2.0, padding will be ignored | ||
by default should you provide it. | ||
|
||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and | ||
[`PreTrainedTokenizer.__call__`] for details. | ||
|
@@ -941,7 +957,7 @@ def resize_token_embeddings(self, *args, **kwargs): | |
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): | ||
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers | ||
perform global attention, while the rest perform local attention. This mask is used to avoid attending to | ||
far-away tokens in the local attention layers. | ||
far-away tokens in the local attention layers when not using Flash Attention. | ||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | ||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, | ||
config.n_positions - 1]`. | ||
|
@@ -952,11 +968,11 @@ def resize_token_embeddings(self, *args, **kwargs): | |
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): | ||
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. | ||
max_seqlen (`int`, *optional*): | ||
Maximum sequence length in the batch. Used to pad the output tensors. | ||
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors. | ||
batch_size (`int`, *optional*): | ||
Batch size of the input sequences. Used to pad the output tensors. | ||
seq_len (`int`, *optional*): | ||
Sequence length of the input sequences. Used to pad the output tensors. | ||
Sequence length of the input sequences including padding tokens. Used to pad the output tensors. | ||
output_attentions (`bool`, *optional*): | ||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned | ||
tensors for more detail. | ||
|
@@ -1246,8 +1262,9 @@ def forward( | |
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size) | ||
|
||
if self.config._attn_implementation == "flash_attention_2": | ||
with torch.no_grad(): | ||
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad(): | ||
logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) | ||
|
||
if not return_dict: | ||
output = (logits,) | ||
return ((loss,) + output) if loss is not None else output | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call - I got carried away with the Python class naming