-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPredictionDual_dataset
49 lines (37 loc) · 1.46 KB
/
PredictionDual_dataset
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import matplotlib.pyplot as plt
from Model import *
import numpy as np
from Dataholder import FROGdata,transforms
from torch.utils.data import DataLoader
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dataset = FROGdata(csv_file='Eim_20k_shg.csv',
img_dir_shg='SHG_crop/',
img_dir_thg='THG_crop/',
transforms=transforms, train=False)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
model_path = 'logs_im_11/Resnet_epoch_121.pt'
Cnn = resnet9_with_attention(6, 64, 128)
Cnn.load_state_dict(torch.load(model_path))
Cnn = Cnn.to(device)
Cnn = Cnn.eval()
# Loop through the dataloader and make predictions
for i, (dual_channel_image, real_points, filename) in enumerate(dataloader):
if dual_channel_image is None:
continue
dual_channel_image = dual_channel_image.to(device)
with torch.no_grad():
output = Cnn(dual_channel_image).squeeze(0).cpu().detach().numpy()
wl = np.arange(-64, 64, 1)
plt.plot(wl, output.reshape(128, 1), '--r', label='pred')
plt.plot(wl, real_points.reshape(128, 1), 'b', label='real')
plt.xlabel('Time (fs)')
plt.ylabel('E(t) (arb.units)')
plt.title(f'Dual Channel Prediction for {filename}" vs Ground Truth')
plt.grid()
plt.legend()
plt.show()
# output = pd.DataFrame(output)
# output.to_csv(f'logs_im_11/{filename}.csv', index=False, header=False)
i += 1
if i == 4:
break