-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA OOM When Using Flash Attention #163
Comments
I just tried training the 7B model and it works fine with both flash attention and without. It would be great if I can train models>7B with flash attention. I think using flash attention will enable me to train the 13B model only. Beyond that I'll need to look into memory efficient tuning methods such as LoRA |
How many GPUs do you use? |
@zhisbug, thank you for your reply. I was using 4A100-80GB GPUs but then realized the repository uses 8. I have a questions regarding the batch size. The run code uses 8GPUs with I checked the Alpaca repo and they set the Your input will be very much appreciated! |
@Michaelvll might be the right person to clarify the training config |
@HaniItani is your problem solved? |
Yes, I figured it out. Thank you very much. |
@HaniItani what did you do to fix it? i'm also running into OOM training 13B models with flash attention, wondering if it's the same problem. thanks! |
Hi @alwayshalffull , Flash Attention was not the problem, I had the wrong parameters. The 13B model requires 8 80GB-A100 GPUs to finetune with a per device train batch size of 4 as reflected in |
Hello,
Thank you for sharing your awesome work!
I'm trying to train Vicuna on my own dataset. I walked through the installation process from source. I had to install
pytorch
withcuda
11.7.0 support instead of 11.6. My server only supportscuda
11.2.2/11.4.4/11.5.2/11.7.0/11.8.0 but not 11.6.When I try to train the 13B model with flash attention, I get CUDA OOM error even when the
per_device_train_batch_size
is set to 1. I think there might be a memory leak. I also tried building flash attention from source and still got the same error.I know this is probably a flash attention problem, but do you have any insights? Any guidance will be very much appreciated.
Best regards,
Hani
The text was updated successfully, but these errors were encountered: