-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
129 lines (95 loc) · 4.11 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import sys
import os
import json
import numpy as np
# import model_ha as model
import torch
import torch.nn as nn
from pre_process import load_config
from pre_process import build_dict
from pre_process import load_word2vec_embedding
from pre_process import generate_examples
from utils import generate_batch_data
from utils import extract_data
from utils import cal_acc
from utils import evaluate_result
from HopAttention import HAQA
# model save path
torch_model_p = "model/haqa.pkl"
# check CPU or GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("using " + str(device))
# config path
config_path = "config.json"
# use GloVe pre-trained embedding
word_embedding_path = "GloVe/word2vec_glove.txt"
# vocab file for tokens in a specific dataset
vocab_path = "data/wikihop/vocab.txt"
# vocab file for chars in a specific dataset
vocab_char_path = "data/wikihop/vocab.txt.chars"
# train and dev set
train_path = "data/wikihop/training.json"
valid_path = "data/wikihop/validation.json"
def main():
# load config file
config = load_config(config_path)
# build dict for token (vocab_dict) and char (vocab_c_dict)
vocab_dict, vocab_c_dict = build_dict(vocab_path, vocab_char_path)
# load pre-trained embedding
# W_init: token index * token embeding
# embed_dim: embedding dimension
W_init, embed_dim = load_word2vec_embedding(word_embedding_path, vocab_dict)
K = 3
# generate train/valid examples
train_data, sen_cut_train, max_sen_len_train = generate_examples(train_path, vocab_dict, vocab_c_dict, config, "train")
dev_data, sen_cut_dev, max_sen_len_dev = generate_examples(valid_path, vocab_dict, vocab_c_dict, config, "dev")
max_sen_len = max(max_sen_len_train, max_sen_len_dev)
print("max sentence len: " + str(max_sen_len))
#------------------------------------------------------------------------
# training process begins
hidden_size = config['nhidden']
batch_size = config['batch_size']
coref_model = HAQA(hidden_size, batch_size, K, W_init, config, max_sen_len).to(device)
if len(sys.argv) > 4 and str(sys.argv[4]) == "load":
try:
coref_model.load_state_dict(torch.load(torch_model_p))
print("saved model loaded")
except:
print("no saved model")
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(coref_model.parameters(), lr=config['learning_rate']) # TODO: use hyper-params in paper
iter_index = 0
batch_acc_list = []
batch_loss_list = []
dev_acc_list = []
max_iter = int(config['num_epochs'] * len(train_data) / batch_size)
print("max iteration number: " + str(max_iter))
while True:
# building batch data
# batch_xxx_data is a list of batch data (len 15)
# [dw, m_dw, qw, m_qw, dc, m_dc, qc, m_qc, cd, m_cd, a, dei, deo, dri, dro]
batch_train_data, sen_cut_batch = generate_batch_data(train_data, config, "train", -1, sen_cut_train) # -1 means random sampling
# dw, m_dw, qw, m_qw, dc, m_dc, qc, m_qc, cd, m_cd, a, dei, deo, dri, dro = batch_train_data
# zero the parameter gradients
optimizer.zero_grad()
# forward pass
dw, dc, qw, qc, cd, cd_m = extract_data(batch_train_data)
cand_probs = coref_model(dw, dc, qw, qc, cd, cd_m, sen_cut_batch) # B x Cmax
answer = torch.tensor(batch_train_data[10]).type(torch.LongTensor) # B x 1
loss = criterion(cand_probs, answer)
# evaluation process
acc_batch = cal_acc(cand_probs, answer, batch_size)
batch_acc_list.append(acc_batch)
batch_loss_list.append(loss)
dev_acc_list = evaluate_result(iter_index, config, dev_data, batch_acc_list, batch_loss_list, dev_acc_list, coref_model, sen_cut_dev)
# save model
if iter_index % config['model_save_frequency'] == 0 and len(sys.argv) > 4:
torch.save(coref_model.state_dict(), torch_model_p)
# back-prop
loss.backward()
optimizer.step()
# check stopping criteria
iter_index += 1
if iter_index > max_iter: break
if __name__ == "__main__":
main()