-
Notifications
You must be signed in to change notification settings - Fork 80
/
Copy pathlstm.py
79 lines (68 loc) · 2.61 KB
/
lstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import shutil
import torch
import torch.nn as nn
from auto_LiRPA.utils import logger
class LSTMCore(nn.Module):
def __init__(self, args):
super(LSTMCore, self).__init__()
self.input_size = args.input_size // args.num_slices
self.hidden_size = args.hidden_size
self.num_classes = args.num_classes
self.device = args.device
self.cell_f = nn.LSTMCell(self.input_size, self.hidden_size)
self.linear = nn.Linear(self.hidden_size, self.num_classes)
def forward(self, X):
batch_size, length = X.shape[0], X.shape[1]
h_f = torch.zeros(batch_size, self.hidden_size).to(X.device)
c_f = h_f.clone()
h_f_sum = h_f.clone()
for i in range(length):
h_f, c_f = self.cell_f(X[:, i], (h_f, c_f))
h_f_sum = h_f_sum + h_f
states = h_f_sum / float(length)
logits = self.linear(states)
return logits
class LSTM(nn.Module):
def __init__(self, args):
super(LSTM, self).__init__()
self.args = args
self.device = args.device
self.lr = args.lr
self.num_slices = args.num_slices
self.dir = args.dir
if not os.path.exists(self.dir):
os.makedirs(self.dir)
self.checkpoint = 0
self.model = LSTMCore(args)
if args.load:
self.model.load_state_dict(args.load)
logger.info(f"Model loaded: {args.load}")
else:
logger.info("Model initialized")
self.model = self.model.to(self.device)
self.core = self.model
def save(self, epoch):
output_dir = os.path.join(self.dir, "ckpt-%d" % epoch)
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.mkdir(output_dir)
path = os.path.join(output_dir, "model")
torch.save(self.core.state_dict(), path)
with open(os.path.join(self.dir, "checkpoint"), "w") as file:
file.write(str(epoch))
logger.info("LSTM saved: %s" % output_dir)
def build_optimizer(self):
param_group = []
for p in self.core.named_parameters():
param_group.append(p)
param_group = [{"params": [p[1] for p in param_group], "weight_decay": 0.}]
return torch.optim.Adam(param_group, lr=self.lr)
def get_input(self, batch):
X = torch.cat([example[0].reshape(1, self.num_slices, -1) for example in batch])
y = torch.tensor([example[1] for example in batch], dtype=torch.long)
return X.to(self.device), y.to(self.device)
def train(self):
self.core.train()
def eval(self):
self.core.eval()