From 13aa49271ff868086546ebe68e3632016b5c5bee Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Thu, 23 Nov 2023 10:15:06 +0100 Subject: [PATCH] Update README.md (#20) * Update README.md adding autocast remark to readme * Update README.md small typo fix --- README.md | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/README.md b/README.md index 77cb5a3..8073f3b 100644 --- a/README.md +++ b/README.md @@ -194,6 +194,33 @@ Detailed usage of torch-harmonics, alongside helpful analysis provided in a seri 7. [Solving the shallow water equations](./notebooks/shallow_water_equations.ipynb) 8. [Training Spherical Fourier Neural Operators](./notebooks/train_sfno.ipynb) +## Remarks on automatic mixed precision (AMP) support + +Note that torch-harmonics uses Fourier transforms from `torch.fft` which in turn uses kernels from the optimized `cuFFT` library. This library supports fourier transforms of `float32` and `float64` (i.e. `single` and `double` precision) tensors for all input sizes. For `float16` (i.e. `half` precision) and `bfloat16` inputs however, the dimensions which are transformed are restricted to powers of two. Since data is converted to one of these reduced precision floating point formats when `torch.cuda.amp.autocast` is used, torch-harmonics will issue an error when the input shapes are not powers of two. For these cases, we recommend disabling autocast for the harmonics transform specifically: + +```python +import torch +import torch_harmonics as th + +sht = th.RealSHT(512, 1024, grid="equiangular").cuda() + +with torch.cuda.amp.autocast(enabled = True): + # do some AMP converted math here + x = some_math(x) + # convert tensor to float32 + x = x.to(torch.float32) + # now disable autocast specifically for the transform, + # making sure that the tensors are not converted + # back to reduced precision internally + with torch.cuda.amp.autocast(enabled = False): + xt = sht(x) + + # continue operating on the transformed tensor + xt = some_more_math(xt) +``` + +Depending on the problem, it might be beneficial to upcast data to `float64` instead of `float32` precision for numerical stability. + ## Contributors [Boris Bonev](https://bonevbs.github.io) (bbonev@nvidia.com), [Thorsten Kurth](https://github.com/azrael417) (tkurth@nvidia.com), [Christian Hundt](https://github.com/gravitino) (chundt@nvidia.com), [Nikola Kovachki](https://kovachki.github.io) (nkovachki@nvidia.com), [Jean Kossaifi](http://jeankossaifi.com) (jkossaifi@nvidia.com)