Skip to content

Commit

Permalink
simplified SFNO example and by removing factorized versions
Browse files Browse the repository at this point in the history
  • Loading branch information
bonevbs committed Dec 13, 2024
1 parent 0ccb230 commit 8f3e7d3
Show file tree
Hide file tree
Showing 13 changed files with 222 additions and 729 deletions.
3 changes: 2 additions & 1 deletion Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

### v0.7.3

* Changing default grid in all SHT routines to `equiangular`
* Changing default grid in all SHT routines to `equiangular`, which makes it consistent with DISCO convolutions
* Cleaning up the SFNO example and adding new Local Spherical Neural Operator model
* Reworked DISCO filter basis datastructure
* Support for new filter basis types

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ If you use `torch-harmonics` in an academic paper, please cite [1]
<a id="1">[1]</a>
Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere;
arXiv 2306.0383, 2023.
International Conference on Machine Learning, 2023. [arxiv link](https://arxiv.org/abs/2306.03838)

<a id="1">[2]</a>
Schaeffer N.;
Expand Down
139 changes: 91 additions & 48 deletions examples/train_sfno.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10

model.eval()

# make output
if not os.path.isdir(path_root):
os.makedirs(path_root, exist_ok=True)

losses = np.zeros(nics)
fno_times = np.zeros(nics)
nwp_times = np.zeros(nics)
Expand All @@ -178,18 +182,24 @@ def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10
prd = prd.unsqueeze(0)
uspec = ic.clone()

# add IC to power spectrum series
prd_coeffs = [dataset.sht(prd[0, plot_channel]).detach().cpu().clone()]
ref_coeffs = [prd_coeffs[0].clone()]

# ML model
start_time = time.time()
for i in range(1, autoreg_steps + 1):
# evaluate the ML model
prd = model(prd)

prd_coeffs.append(dataset.sht(prd[0, plot_channel]).detach().cpu().clone())

if iic == nics - 1 and nskip > 0 and i % nskip == 0:

# do plotting
fig = plt.figure(figsize=(7.5, 6))
dataset.solver.plot_griddata(prd[0, plot_channel], fig, vmax=4, vmin=-4)
plt.savefig(path_root + "_pred_" + str(i // nskip) + ".png")
dataset.solver.plot_griddata(prd[0, plot_channel], fig, vmax=4, vmin=-4, projection="robinson")
plt.savefig(os.path.join(path_root,'pred_'+str(i//nskip)+'.png'))
plt.close()

fno_times[iic] = time.time() - start_time
Expand All @@ -201,44 +211,51 @@ def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10
# advance classical model
uspec = dataset.solver.timestep(uspec, nsteps)
ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var)
ref_coeffs.append(dataset.sht(ref[plot_channel]).detach().cpu().clone())

if iic == nics - 1 and i % nskip == 0 and nskip > 0:

fig = plt.figure(figsize=(7.5, 6))
dataset.solver.plot_griddata(ref[plot_channel], fig, vmax=4, vmin=-4)
plt.savefig(path_root + "_truth_" + str(i // nskip) + ".png")
dataset.solver.plot_griddata(ref[plot_channel], fig, vmax=4, vmin=-4, projection="robinson")
plt.savefig(os.path.join(path_root,'truth_'+str(i//nskip)+'.png'))
plt.close()

nwp_times[iic] = time.time() - start_time

# compute power spectrum and add it to the buffers
prd_coeffs = dataset.solver.sht(prd[0, plot_channel])
ref_coeffs = dataset.solver.sht(ref[plot_channel])
prd_mean_coeffs.append(prd_coeffs)
ref_mean_coeffs.append(ref_coeffs)
prd_mean_coeffs.append(torch.stack(prd_coeffs, 0))
ref_mean_coeffs.append(torch.stack(ref_coeffs, 0))

# ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var)
ref = dataset.solver.spec2grid(uspec)
prd = prd * torch.sqrt(inp_var) + inp_mean
losses[iic] = l2loss_sphere(dataset.solver, prd, ref, relative=True).item()

# compute the averaged powerspectra of prediction and reference
prd_mean_coeffs = torch.stack(prd_mean_coeffs).abs().pow(2).mean(dim=0)
ref_mean_coeffs = torch.stack(ref_mean_coeffs).abs().pow(2).mean(dim=0)
prd_mean_coeffs[..., 1:] *= 2.0
ref_mean_coeffs[..., 1:] *= 2.0
prd_mean_ps = prd_mean_coeffs.sum(dim=-1).detach().cpu()
ref_mean_ps = ref_mean_coeffs.sum(dim=-1).detach().cpu()
with torch.no_grad():
prd_mean_coeffs = torch.stack(prd_mean_coeffs, dim=0).abs().pow(2).mean(dim=0)
ref_mean_coeffs = torch.stack(ref_mean_coeffs, dim=0).abs().pow(2).mean(dim=0)

prd_mean_coeffs[..., 1:] *= 2.0
ref_mean_coeffs[..., 1:] *= 2.0
prd_mean_ps = prd_mean_coeffs.sum(dim=-1).contiguous()
ref_mean_ps = ref_mean_coeffs.sum(dim=-1).contiguous()

# split the stuff
prd_mean_ps = [x.squeeze() for x in list(torch.split(prd_mean_ps, 1, dim=0))]
ref_mean_ps = [x.squeeze() for x in list(torch.split(ref_mean_ps, 1, dim=0))]

# compute the averaged powerspectrum
fig = plt.figure(figsize=(7.5, 6))
plt.loglog(prd_mean_ps, label="prediction")
plt.loglog(ref_mean_ps, label="reference")
plt.xlabel("$l$")
plt.ylabel("powerspectrum")
plt.legend()
plt.savefig(path_root + "_powerspectrum.png")
plt.close()
for step, (pps, rps) in enumerate(zip(prd_mean_ps, ref_mean_ps)):
fig = plt.figure(figsize=(7.5, 6))
plt.semilogy(pps, label="prediction")
plt.semilogy(rps, label="reference")
plt.xlabel("$l$")
plt.ylabel("powerspectrum")
plt.legend()
plt.savefig(os.path.join(path_root,f'powerspectrum_{step}.png'))
fig.clf()
plt.close()

return losses, fno_times, nwp_times

Expand Down Expand Up @@ -364,6 +381,9 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
torch.manual_seed(333)
torch.cuda.manual_seed(333)

# set parameters
nfuture=0

# set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
Expand All @@ -373,7 +393,8 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
dt = 1 * 3600
dt_solver = 150
nsteps = dt // dt_solver
dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(257, 512), device=device, normalize=True)
dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(257, 512), device=device, grid="legendre-gauss", normalize=True)
dataset.sht = RealSHT(nlat=257, nlon=512, grid= "equiangular").to(device=device)
# There is still an issue with parallel dataloading. Do NOT use it at the moment
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, persistent_workers=False)
Expand All @@ -391,38 +412,54 @@ def count_parameters(model):
from torch_harmonics.examples.sfno import SphericalFourierNeuralOperatorNet as SFNO
from torch_harmonics.examples.sfno import LocalSphericalNeuralOperatorNet as LSNO

# models["sfno_sc2_layers6_e32"] = partial(
# models[f"sfno_sc2_layers4_e32_nomlp"] = partial(
# SFNO,
# spectral_transform="sht",
# img_size=(nlat, nlon),
# grid="equiangular",
# num_layers=6,
# scale_factor=1,
# # hard_thresholding_fraction=0.8,
# num_layers=4,
# scale_factor=2,
# embed_dim=32,
# operator_type="driscoll-healy",
# activation_function="gelu",
# big_skip=True,
# pos_embed=False,
# use_mlp=True,
# use_mlp=False,
# normalization_layer="none",
# )

models["lsno_sc2_layers6_e32"] = partial(
LSNO,
spectral_transform="sht",
models[f"sfno_sc2_layers4_e32_nomlp_leggauss"] = partial(
SFNO,
img_size=(nlat, nlon),
grid="equiangular",
num_layers=6,
scale_factor=1,
grid="legendre-gauss",
# hard_thresholding_fraction=0.8,
num_layers=4,
scale_factor=2,
embed_dim=32,
operator_type="driscoll-healy",
activation_function="gelu",
big_skip=True,
big_skip=False,
pos_embed=False,
use_mlp=True,
use_mlp=False,
normalization_layer="none",
)

# models[f"lsno_sc1_layers4_e32_nomlp"] = partial(
# LSNO,
# spectral_transform="sht",
# img_size=(nlat, nlon),
# grid="equiangular",
# num_layers=4,
# scale_factor=2,
# embed_dim=32,
# operator_type="driscoll-healy",
# activation_function="gelu",
# big_skip=True,
# pos_embed=False,
# use_mlp=False,
# normalization_layer="none",
# )

# iterate over models and train each model
root_path = os.path.dirname(__file__)
for model_name, model_handle in models.items():
Expand All @@ -437,8 +474,12 @@ def count_parameters(model):
print(f"number of trainable params: {num_params}")
metrics[model_name]["num_params"] = num_params

exp_dir = os.path.join(root_path, 'checkpoints', model_name)
if not os.path.isdir(exp_dir):
os.makedirs(exp_dir, exist_ok=True)

if load_checkpoint:
model.load_state_dict(torch.load(os.path.join(root_path, "checkpoints/" + model_name), weights_only=True))
model.load_state_dict(torch.load(os.path.join(exp_dir, "checkpoint.pt")))

# run the training
if train:
Expand All @@ -454,27 +495,27 @@ def count_parameters(model):
print(f"Training {model_name}, single step")
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=20, loss_fn="l2", enable_amp=enable_amp, log_grads=log_grads)

# # multistep training
# print(f'Training {model_name}, two step')
# optimizer = torch.optim.Adam(model.parameters(), lr=5E-5)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
# gscaler = torch.GradScaler(enabled=enable_amp)
# dataloader.dataset.nsteps = 2 * dt//dt_solver
# train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=5, nfuture=1, enable_amp=enable_amp)
# dataloader.dataset.nsteps = 1 * dt//dt_solver
if nfuture > 0:
print(f'Training {model_name}, {nfuture} step')
optimizer = torch.optim.Adam(model.parameters(), lr=5E-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
gscaler = amp.GradScaler(enabled=enable_amp)
dataloader.dataset.nsteps = 2 * dt//dt_solver
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=20, loss_fn="l2", nfuture=nfuture, enable_amp=enable_amp, log_grads=log_grads)
dataloader.dataset.nsteps = 1 * dt//dt_solver

training_time = time.time() - start_time

run.finish()

torch.save(model.state_dict(), os.path.join(root_path, "checkpoints/" + model_name))
torch.save(model.state_dict(), os.path.join(exp_dir, 'checkpoint.pt'))

# set seed
torch.manual_seed(333)
torch.cuda.manual_seed(333)

with torch.inference_mode():
losses, fno_times, nwp_times = autoregressive_inference(model, dataset, os.path.join(root_path, "figures/" + model_name), nsteps=nsteps, autoreg_steps=30)
losses, fno_times, nwp_times = autoregressive_inference(model, dataset, os.path.join(exp_dir,'figures'), nsteps=nsteps, autoreg_steps=30, nics=50)
metrics[model_name]["loss_mean"] = np.mean(losses)
metrics[model_name]["loss_std"] = np.std(losses)
metrics[model_name]["fno_time_mean"] = np.mean(fno_times)
Expand All @@ -485,7 +526,9 @@ def count_parameters(model):
metrics[model_name]["training_time"] = training_time

df = pd.DataFrame(metrics)
df.to_pickle(os.path.join(root_path, "output_data/metrics.pkl"))
if not os.path.isdir(os.path.join(exp_dir, 'output_data',)):
os.makedirs(os.path.join(exp_dir, 'output_data'), exist_ok=True)
df.to_pickle(os.path.join(exp_dir, 'output_data', 'metrics.pkl'))


if __name__ == "__main__":
Expand Down
10 changes: 5 additions & 5 deletions notebooks/plot_spherical_harmonics.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 8f3e7d3

Please sign in to comment.