-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain_utils.py
86 lines (74 loc) · 2.4 KB
/
train_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
from torch import nn
from torch.optim import Adam
from torch.optim import Optimizer
from torch.utils.data import DataLoader
def training_model_once(model: nn.Module, train_data_loader: DataLoader, optimizer: Optimizer):
"""Returns training losses for one epoch while training the model"""
model.train()
training_losses = []
for batch in train_data_loader:
batch = batch.cuda()
outputs = model.forward(batch)
loss = model.loss(outputs, batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
training_losses.append(loss.item())
return training_losses
def get_validation_loss(model, data_loader):
model.eval()
total_loss = 0
with torch.no_grad():
for minibatch in data_loader:
minibatch = minibatch.cuda()
outputs = model.forward(minibatch)
loss = model.loss(outputs, minibatch)
total_loss += loss * minibatch.shape[0]
avg_loss = total_loss / len(data_loader.dataset)
return avg_loss.item()
def training_loop(model, train_data_loader, test_data_loader, epochs=10, lr=1e-3, optimizer=Adam, silent=False):
optimizer = optimizer(model.parameters(), lr=lr)
train_losses = []
test_losses = [get_validation_loss(model, test_data_loader)]
for epoch in range(epochs):
train_losses.extend(training_model_once(model, train_data_loader, optimizer))
test_loss = get_validation_loss(model, test_data_loader)
test_losses.append(test_loss)
if not silent:
print(f"Epoch {epoch}, Test loss {test_loss:.4f}")
return train_losses, test_losses
def debug_grad(module, grad_input, grad_output):
for grad in grad_input:
assert torch.isfinite(grad).all()
for grad in grad_output:
assert torch.isfinite(grad).all()
def get_toy_dataset():
return torch.tensor(
[
(0, 1),
(0, 1),
(0, 1),
(0, 1),
(0, 1),
(0, 1),
(0, 1),
(0, 1),
(0, 1),
(0, 1),
(0, 1),
(0, 1),
(0, 1),
(0, 1),
(0, 1),
(0, 1),
(1, 0),
(1, 0),
(1, 0),
(1, 0),
(1, 1),
]
* 10
)
def get_paramed_dataset(topright, topleft, bottomright, bottomleft):
return tensor()