Using torch.bfloat16 to prevent overflow instead of default fp16 in AMP #345
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Using torch.bfloat16 to prevent overflow. Float16 has three less integer bits compared to bfloat16 which causes NaN loss and NaN grad norms during AMP training. This seems to be a common issue while training the Swin Transformer.
BFloat16 has same integer bits compared to FP32 but less precision bits. If we want higher precision but also want to save GPU memory, then TensorFloat32 or tfloat32 can be used instead.
TF32 has less precision bits when compared to FP32, but 3 more integer bits compared to FP16. But TF32 can only be used on latest NVIDIA ampere gpus or newer.