-
-
Notifications
You must be signed in to change notification settings - Fork 437
/
Copy pathtrain_copy.py
71 lines (55 loc) · 1.7 KB
/
train_copy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import tqdm
import torch
import torch.optim as optim
from x_transformers import XTransformer
# constants
NUM_BATCHES = int(1e5)
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
GENERATE_EVERY = 100
NUM_TOKENS = 16 + 2
ENC_SEQ_LEN = 32
DEC_SEQ_LEN = 64 + 1
# helpers
def cycle():
while True:
prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
tgt = torch.cat((prefix, src, src), 1)
src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().cuda()
yield (src, tgt, src_mask)
# instantiate model
model = XTransformer(
dim = 512,
tie_token_emb = True,
return_tgt_loss = True,
enc_num_tokens=NUM_TOKENS,
enc_depth = 3,
enc_heads = 8,
enc_max_seq_len = ENC_SEQ_LEN,
dec_num_tokens = NUM_TOKENS,
dec_depth = 3,
dec_heads = 8,
dec_max_seq_len = DEC_SEQ_LEN
).cuda()
# optimizer
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# training
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train()
src, tgt, src_mask = next(cycle())
loss = model(src, tgt, mask=src_mask)
loss.backward()
print(f'{i}: {loss.item()}')
optim.step()
optim.zero_grad()
if i != 0 and i % GENERATE_EVERY == 0:
model.eval()
src, _, src_mask = next(cycle())
src, src_mask = src[:1], src_mask[:1]
start_tokens = (torch.ones((1, 1)) * 1).long().cuda()
sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask = src_mask)
incorrects = (src != sample).abs().sum()
print(f"input: ", src)
print(f"predicted output: ", sample)
print(f"incorrects: {incorrects}")