Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass correct num_items_in_batch value into the training_step function #35438

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

hiyouga
Copy link
Contributor

@hiyouga hiyouga commented Dec 27, 2024

What does this PR do?

This PR follows #34511 #34915 and partially reverts #35121 . We want to handle the loss calculation correctly for models that don't accept loss_kwargs when gradient accumulation was enabled.

In #35121 , the author always passes num_items_in_batch into the training_step function, which makes the logic of the following lines invalid, and consequently leads to incorrect loss calculation.

# Finally we need to normalize the loss for reporting
if num_items_in_batch is None:
loss = loss / self.args.gradient_accumulation_steps
self.accelerator.backward(loss, **kwargs)
return loss.detach()

Currently, there are still many models that don't accept loss_kwargs (such as Qwen2-VL). For these models, we need to set the num_items_in_batch parameter in training_step function to None, so as to scale the loss correctly in training.

However, we've found that since transformers 4.46.0, the loss of the models that don't accept loss_kwargs hasn't been calculated correctly!

Specifically, in transformers 4.46.0-4.46.1, the loss was wrongly multiplied by the gradient accumulation steps, resulting in a large loss value (see https://twitter.com/TheZachMueller/status/1851677628656423347 ).

https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605

loss *= self.args.gradient_accumulation_steps
self.accelerator.backward(loss, **kwargs)
return loss.detach() / self.args.gradient_accumulation_steps

In transformers 4.46.2-4.47.1, the loss was scaled after the backward pass, which also led to larger gradients (also mentioned in #35207 ).

https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/trainer.py#L3689

self.accelerator.backward(loss, **kwargs)
# Finally we need to normalize the loss for reporting
if num_items_in_batch is None:
return loss.detach() / self.args.gradient_accumulation_steps
return loss.detach()

In the latest version, due to the issues in this PR, the loss isn't scaled correctly either, and the final loss remains large. We hope the maintainers can attach importance to this problem and look forward to a timely fix.

We also conducted a group of experiments to verify the above conclusion. We trained the Qwen2-VL model (which does not accept loss_kwargs) using transformers 4.45.2, the main branch (5c75087), and the main branch after the PR fix. We can find that both version 4.45.2 and the fixed version have reasonable loss values and grad norms, while the main branch does not.

transformers version micro batchsize grad accumulation steps avg loss avg grad norm
4.45.2 4 2 1.17 0.32
4.45.2 1 8 0.99 0.43
main (4.48.0.dev0) 4 2 2.34 0.65
main (4.48.0.dev0) 1 8 7.93 3.52
main + this PR 4 2 1.17 0.32
main + this PR 1 8 0.99 0.43

image

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @muellerzr

@hiyouga hiyouga changed the title Pass correct num_items_in_batch value into the training_step method Pass correct num_items_in_batch value into the training_step function Dec 27, 2024
@hiyouga
Copy link
Contributor Author

hiyouga commented Dec 27, 2024

The test cases have almost passed, except for a few that encountered a timeout error.

pipelines_tf - Failed
tests_exotic_models - Success
collection_job - Success
tests_onnx - Success
tests_torch_and_tf - Success
tests_torch - Success
tests_hub - Success
tests_processors - Success
examples_torch - Success
tests_custom_tokenizers - Success
tests_generate - Success
tests_tf - Success
tests_non_model - Success
examples_tensorflow - Success
tests_torch_and_flax - Success
tests_tokenization - Success
pipelines_torch - Success

@yzhangcs
Copy link

yzhangcs commented Dec 29, 2024

yes, observe similar phenomenon #35207 (comment)

@yzhangcs
Copy link

@hiyouga Hi, could you try this out

import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, Trainer, TrainingArguments

num_batch = 32
gradient_accumulation_steps = 2
per_device_train_batch_size = 3
seq_len = 5

eff_batch_size = per_device_train_batch_size * gradient_accumulation_steps
dataset_len = num_batch * eff_batch_size

data = torch.arange(0, dataset_len * seq_len)
data = data.reshape(dataset_len, seq_len)
data = data.tolist()

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-1.5B").to("cuda")
dataset = Dataset.from_dict({"input_ids": data, "labels": data})

args = TrainingArguments(
    output_dir=f"out_bs_{per_device_train_batch_size}_grad_{gradient_accumulation_steps}_before",
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    logging_steps=2,
)

trainer = Trainer(model=model, args=args, train_dataset=dataset)

trainer.train()

Looks like the loss is still incorrect for gradient_accumulation_steps>1.

@LysandreJik
Copy link
Member

cc @ArthurZucker and @muellerzr

@hiyouga
Copy link
Contributor Author

hiyouga commented Dec 29, 2024

@yzhangcs Could you provide detailed experimental results and patch fix (if possible) thanks

@yzhangcs
Copy link

@hiyouga Hello, the results is similar to v4.47.0, where the loss for gradient_accumulation_steps=2 is errornously doubled.

gradient_accumulation_steps=1, per_device_train_batch_size=6

{'loss': 10.5949, 'grad_norm': 188.29049682617188, 'learning_rate': 4.5833333333333334e-05, 'epoch': 0.25}                                                                                                             
{'loss': 8.2312, 'grad_norm': 87.74673461914062, 'learning_rate': 4.166666666666667e-05, 'epoch': 0.5}                                                                                                                 
{'loss': 7.8988, 'grad_norm': 74.35655212402344, 'learning_rate': 3.7500000000000003e-05, 'epoch': 0.75}                                                                                                               
{'loss': 8.0197, 'grad_norm': 52.913124084472656, 'learning_rate': 3.3333333333333335e-05, 'epoch': 1.0}                                                                                                               
{'loss': 6.0704, 'grad_norm': 51.57878875732422, 'learning_rate': 2.916666666666667e-05, 'epoch': 1.25}                                                                                                                
{'loss': 5.8247, 'grad_norm': 56.05608367919922, 'learning_rate': 2.5e-05, 'epoch': 1.5}                                                                                                                               
{'loss': 5.9932, 'grad_norm': 50.10258483886719, 'learning_rate': 2.0833333333333336e-05, 'epoch': 1.75}                                                                                                               
{'loss': 5.3769, 'grad_norm': 65.85218811035156, 'learning_rate': 1.6666666666666667e-05, 'epoch': 2.0}                                                                                                                
{'loss': 4.1261, 'grad_norm': 62.63885498046875, 'learning_rate': 1.25e-05, 'epoch': 2.25}                                                                                                                             
{'loss': 3.5956, 'grad_norm': 67.01679229736328, 'learning_rate': 8.333333333333334e-06, 'epoch': 2.5}                                                                                                                 
{'loss': 3.5619, 'grad_norm': 72.61600494384766, 'learning_rate': 4.166666666666667e-06, 'epoch': 2.75}                                                                                                                
{'loss': 3.6611, 'grad_norm': 67.8575439453125, 'learning_rate': 0.0, 'epoch': 3.0}                                                                                                                                    
{'train_runtime': 56.4532, 'train_samples_per_second': 10.203, 'train_steps_per_second': 0.425, 'train_loss': 6.079538067181905, 'epoch': 3.0}                                                                         
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:55<00:00,  2.29s/it]

However, the results changed significantly for larger gradient accumulation steps:

gradient_accumulation_steps=2, per_device_train_batch_size=3

{'loss': 21.1898, 'grad_norm': 376.5796203613281, 'learning_rate': 4.5833333333333334e-05, 'epoch': 0.25}                                                                                                              
{'loss': 16.4623, 'grad_norm': 175.4936065673828, 'learning_rate': 4.166666666666667e-05, 'epoch': 0.5}                                                                                                                
{'loss': 15.7976, 'grad_norm': 148.7130889892578, 'learning_rate': 3.7500000000000003e-05, 'epoch': 0.75}                                                                                                              
{'loss': 16.0394, 'grad_norm': 105.82550811767578, 'learning_rate': 3.3333333333333335e-05, 'epoch': 1.0}                                                                                                              
{'loss': 12.1408, 'grad_norm': 103.15972137451172, 'learning_rate': 2.916666666666667e-05, 'epoch': 1.25}                                                                                                              
{'loss': 11.6494, 'grad_norm': 112.11367797851562, 'learning_rate': 2.5e-05, 'epoch': 1.5}                                                                                                                             
{'loss': 11.9865, 'grad_norm': 100.20618438720703, 'learning_rate': 2.0833333333333336e-05, 'epoch': 1.75}                                                                                                             
{'loss': 10.7537, 'grad_norm': 131.70639038085938, 'learning_rate': 1.6666666666666667e-05, 'epoch': 2.0}                                                                                                              
{'loss': 8.2522, 'grad_norm': 125.27814483642578, 'learning_rate': 1.25e-05, 'epoch': 2.25}                                                                                                                            
{'loss': 7.1912, 'grad_norm': 134.03387451171875, 'learning_rate': 8.333333333333334e-06, 'epoch': 2.5}                                                                                                                
{'loss': 7.1239, 'grad_norm': 145.23382568359375, 'learning_rate': 4.166666666666667e-06, 'epoch': 2.75}                                                                                                               
{'loss': 7.3222, 'grad_norm': 135.71575927734375, 'learning_rate': 0.0, 'epoch': 3.0}                                                                                                                                  
{'train_runtime': 64.6658, 'train_samples_per_second': 8.907, 'train_steps_per_second': 0.371, 'train_loss': 12.159087419509888, 'epoch': 3.0}                                                                         
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [01:03<00:00,  2.63s/it]

still not find a solution though :-(

@hiyouga
Copy link
Contributor Author

hiyouga commented Dec 30, 2024

@yzhangcs We cannot observe the same phenomenon with the latest main branch (5cabc75), the model we used is Qwen2.5-7B-Instruct.

transformers version micro batchsize grad accumulation steps avg loss avg grad norm
4.45.2 8 1 0.92 0.45
main (4.48.0.dev0) 8 1 0.92 0.45
main (4.48.0.dev0) 1 8 0.92 0.44

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants