-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathdarcy_flow_main.py
81 lines (66 loc) · 2.59 KB
/
darcy_flow_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from data_load_darcy import *
import random
import matplotlib.pyplot as plt
from darcy_flow_uno2d import UNO_9,UNO_11
import operator
from functools import reduce
from functools import partial
from train_darcy import train_model
from timeit import default_timer
from utilities3 import *
from Adam import Adam
from torchsummary import summary
import gc
import math
plt.rcParams['figure.figsize'] = [6, 30]
plt.rcParams['image.interpolation'] = 'nearest'
torch.manual_seed(10001)
random.seed(10001)
import sys
import logging
#hyper parameters
#weight decay = 0.001
#Learning rate = 0.001
#scheduler_gamma = 0.5
# Loading and Processing the data. file1 and file2 are generated data of Darcy Flow. Data can be generated by provided data generator.
train_a_1, train_u_1, test_a_1, test_u_1 = load_data_darcy(2,800,200,"Path to data file1")
train_a_2, train_u_2, test_a_2, test_u_2 = load_data_darcy(2,800,200,"Path to data file2")
sub = 2 # subsampling rate
S = 211 # Grid size/ resolution after sub-sampling
# single input and output
T_in = 1
T_f = 1
# number of train, test and validation samples
ntrain = 1500
nval = 250
ntest = 250
batch_size = 16
width = 32 #
inwidth = 3
epochs = 700
a = torch.cat([train_a_1,train_a_2,test_a_1,test_a_2], dim = 0)
u = torch.cat([train_u_1,train_u_2,test_u_1,test_u_2],dim = 0)
indexs = [i for i in range(a.shape[0])]
random.shuffle(indexs)
train_a,val_a,test_a = a[indexs[:ntrain]], a[indexs[ntrain:ntrain+nval]],a[indexs[ntrain+nval:]]
train_u,val_u,test_u = u[indexs[:ntrain]],u[indexs[ntrain:ntrain+nval]],u[indexs[ntrain+nval:]]
print(train_a.shape,val_a.shape,test_a.shape)
train_a = train_a.reshape(ntrain,S,S,T_in)
val_a = val_a.reshape(nval,S,S,T_in)
test_a = test_a.reshape(ntest,S,S,T_in)
gc.collect()
train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(val_a, val_u), batch_size=batch_size, shuffle=True)
model = UNO_9(inwidth,width,pad = 12).cuda()
#model = UNO_11(inwidth,width,pad = 8).cuda()
summary(model, (S, S,1))
gc.collect()
train_model(model,train_loader,val_loader,test_loader, ntrain,nval,ntest,S,'Darcy-D9-211.pt',\
T_f=T_f,batch_size=batch_size,epochs=epochs,learning_rate= 0.001,\
scheduler_step= 100,scheduler_gamma= 0.5,weight_decay = 0.001)