-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Update README.md adding autocast remark to readme * Update README.md small typo fix
- Loading branch information
Showing
1 changed file
with
27 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -194,6 +194,33 @@ Detailed usage of torch-harmonics, alongside helpful analysis provided in a seri | |
7. [Solving the shallow water equations](./notebooks/shallow_water_equations.ipynb) | ||
8. [Training Spherical Fourier Neural Operators](./notebooks/train_sfno.ipynb) | ||
|
||
## Remarks on automatic mixed precision (AMP) support | ||
|
||
Note that torch-harmonics uses Fourier transforms from `torch.fft` which in turn uses kernels from the optimized `cuFFT` library. This library supports fourier transforms of `float32` and `float64` (i.e. `single` and `double` precision) tensors for all input sizes. For `float16` (i.e. `half` precision) and `bfloat16` inputs however, the dimensions which are transformed are restricted to powers of two. Since data is converted to one of these reduced precision floating point formats when `torch.cuda.amp.autocast` is used, torch-harmonics will issue an error when the input shapes are not powers of two. For these cases, we recommend disabling autocast for the harmonics transform specifically: | ||
|
||
```python | ||
import torch | ||
import torch_harmonics as th | ||
|
||
sht = th.RealSHT(512, 1024, grid="equiangular").cuda() | ||
|
||
with torch.cuda.amp.autocast(enabled = True): | ||
# do some AMP converted math here | ||
x = some_math(x) | ||
# convert tensor to float32 | ||
x = x.to(torch.float32) | ||
# now disable autocast specifically for the transform, | ||
# making sure that the tensors are not converted | ||
# back to reduced precision internally | ||
with torch.cuda.amp.autocast(enabled = False): | ||
xt = sht(x) | ||
|
||
# continue operating on the transformed tensor | ||
xt = some_more_math(xt) | ||
``` | ||
|
||
Depending on the problem, it might be beneficial to upcast data to `float64` instead of `float32` precision for numerical stability. | ||
|
||
## Contributors | ||
|
||
[Boris Bonev](https://bonevbs.github.io) ([email protected]), [Thorsten Kurth](https://github.com/azrael417) ([email protected]), [Christian Hundt](https://github.com/gravitino) ([email protected]), [Nikola Kovachki](https://kovachki.github.io) ([email protected]), [Jean Kossaifi](http://jeankossaifi.com) ([email protected]) | ||
|