-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun.py
386 lines (312 loc) · 15.4 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
import os
from glob import glob
import torch
import colored_traceback
import numpy as np
from tqdm import tqdm
from tensorboardX import SummaryWriter
from src.corpus.iest_corpus import IESTCorpus
from src.corpus.embeddings import Embeddings
from src.utils.logger import Logger
from src.utils.ops import np_softmax
from src.train import Trainer
from src.optim.optim import OptimWithDecay, ScheduledOptim
from src.optim.schedulers import SlantedTriangularScheduler, TransformerScheduler
from src import config
from src.models.iest import (
IESTClassifier,
WordEncodingLayer,
WordCharEncodingLayer,
SentenceEncodingLayer,
)
from src.layers.pooling import PoolingLayer
from base_args import base_parser, CustomArgumentParser
colored_traceback.add_hook(always=True)
arg_parser = CustomArgumentParser(parents=[base_parser],
description='Implicit Emotion Classifier')
arg_parser.add_argument('--model', type=str, default="bilstm",
choices=SentenceEncodingLayer.SENTENCE_ENCODING_METHODS,
help='Model to use')
arg_parser.add_argument('--corpus', type=str, default="iest_emoji",
choices=list(config.corpora_dict.keys()),
help='Name of the corpus to use.')
arg_parser.add_argument('--embeddings', type=str, default="random",
choices=list(config.embedding_dict.keys()) + ["random"],
help='Name of the embeddings to use.')
arg_parser.add_argument('--lstm_hidden_size', type=int, default=2048,
help='Hidden dimension size for the word-level LSTM')
arg_parser.add_argument('--sent_enc_layers', type=int, default=1,
help='Number of layers for the word-level LSTM')
arg_parser.add_argument('--force_reload', action='store_true',
help='Whether to reload pickles or not (makes the '
'process slower, but ensures data coherence)')
arg_parser.add_argument('--char_emb_dim', '-ced', type=int, default=200,
help='Char embedding dimension')
arg_parser.add_argument('--pooling_method', type=str, default='max',
choices=PoolingLayer.POOLING_METHODS,
help='Pooling scheme to use as raw sentence '
'representation method.')
arg_parser.add_argument('--dropout', type=float, default=0.5,
help='Dropout applied to words representations and '
'final MLP. 0 means no dropout.')
arg_parser.add_argument('--lstm_layer_dropout', type=float, default=0.0,
help='Dropout between sentence encoding lstm layers. '
'0 means no dropout.')
arg_parser.add_argument('--sent_enc_dropout', type=float, default=0.1,
help='Dropout after the sentence encoding lstm. '
'0 means no dropout.')
arg_parser.add_argument('--model_hash', type=str, default=None,
help='Hash of the model to load, can be a partial hash')
arg_parser.add_argument('--word_encoding_method', '-wem', type=str, default="elmo",
choices=WordEncodingLayer.WORD_ENCODING_METHODS,
help='How to obtain word representations')
arg_parser.add_argument('--word_char_aggregation_method', '-wcam',
choices=WordCharEncodingLayer.AGGREGATION_METHODS,
default=None,
help='Way in which character-level and word-level word '
'representations are aggregated')
arg_parser.add_argument('--lowercase', '-lc', action='store_true',
help='Whether to lowercase data or not. WARNING: '
'REMEBER TO CLEAR THE CACHE BY PASSING '
'--force_reload or deleting .cache')
arg_parser.add_argument("--warmup_iters", "-wup", default=4000, type=int,
help="During how many iterations to increase the "
"learning rate. To be used with TransformerScheduler")
arg_parser.add_argument("--test", action="store_true",
help="Run this script in test mode")
arg_parser.add_argument("--save_sent_reprs", "-ssr", action="store_true",
default=False,
help='Save sentence representations in the experiment '
'directory. This is intended to be used with the '
'--test flag')
arg_parser.add_argument("--pos_emb_dim", "-pem", default=None, type=int,
help="The dimension to use for the POS embeddings. "
"If None, POS tags will not be used")
arg_parser.add_argument("--max_lr", "-mlr", default=1e-3, type=float,
help="Max learning rate to be use with Ruder's "
"slanted triangular learning rate schedule")
arg_parser.add_argument("--spreadsheet", "-ss", action='store_true',
help="Save results in google spreadsheet")
def validate_args(hp):
"""hp: argparser parsed arguments. type: Namespace"""
if hp.word_encoding_method == 'char_lstm' and not hp.word_char_aggregation_method:
raise ValueError(f'Need to pass a word_char_aggregation_method when '
f'using char_lstm word_encoding_method. '
f'Choose one from {WordCharEncodingLayer.AGGREGATION_METHODS}')
def main():
hp = arg_parser.parse_args()
validate_args(hp)
logger = Logger(hp, model_name='Baseline', write_mode=hp.write_mode)
if not hp.test:
print(f"Running experiment {logger.model_hash}.")
if hp.write_mode != 'NONE':
logger.write_hyperparams()
print(
f"Hyperparameters and checkpoints will be saved in "
f"{logger.run_savepath}"
)
torch.manual_seed(hp.seed)
torch.cuda.manual_seed_all(hp.seed) # silently ignored if there are no GPUs
CUDA = False
if torch.cuda.is_available() and not hp.no_cuda:
CUDA = True
USE_POS = False
if hp.pos_emb_dim is not None:
USE_POS = True
SAVE_IN_SPREADSHEET = False
if hp.spreadsheet:
SAVE_IN_SPREADSHEET = True
corpus = IESTCorpus(config.corpora_dict, hp.corpus,
force_reload=hp.force_reload,
train_data_proportion=hp.train_data_proportion,
dev_data_proportion=hp.dev_data_proportion,
batch_size=hp.batch_size,
lowercase=hp.lowercase,
use_pos=USE_POS)
if hp.embeddings != "random":
# Load pre-trained embeddings
embeddings = Embeddings(
config.embedding_dict[hp.embeddings],
k_most_frequent=None,
force_reload=hp.force_reload,
)
# Get subset of embeddings corresponding to our vocabulary
embedding_matrix = embeddings.generate_embedding_matrix(corpus.lang.token2id)
print(
f"{len(embeddings.unknown_tokens)} words from vocabulary not found "
f"in {hp.embeddings} embeddings. "
)
else:
word_vocab_size = len(corpus.lang.token2id)
embedding_matrix = np.random.uniform(
-0.05, 0.05, size=(word_vocab_size, 300)
)
# Repeat process for character embeddings with the difference that they are
# not pretrained
# Initialize character embedding matrix randomly
char_vocab_size = len(corpus.lang.char2id)
char_embedding_matrix = np.random.uniform(-0.05, 0.05,
size=(char_vocab_size,
hp.char_emb_dim))
pos_embedding_matrix = None
if USE_POS:
# Initialize pos embedding matrix randomly
pos_vocab_size = len(corpus.pos_lang.token2id)
pos_embedding_matrix = np.random.uniform(-0.05, 0.05,
size=(pos_vocab_size,
hp.pos_emb_dim))
if hp.model_hash:
# WARNING: This feature should be used only when testing. We
# haven't implemented a proper way to resume training yet.
experiment_path = os.path.join(config.RESULTS_PATH, hp.model_hash + '*')
ext_experiment_path = glob(experiment_path)
assert len(ext_experiment_path) == 1, 'Try provinding a longer model hash'
ext_experiment_path = ext_experiment_path[0]
model_path = os.path.join(ext_experiment_path, 'best_model.pth')
model = torch.load(model_path)
else:
num_classes = len(corpus.label2id)
batch_size = corpus.train_batches.batch_size
hidden_sizes = hp.lstm_hidden_size
model = IESTClassifier(
num_classes,
batch_size,
embedding_matrix=embedding_matrix,
char_embedding_matrix=char_embedding_matrix,
pos_embedding_matrix=pos_embedding_matrix,
word_encoding_method=hp.word_encoding_method,
word_char_aggregation_method=hp.word_char_aggregation_method,
sent_encoding_method=hp.model,
hidden_sizes=hidden_sizes,
use_cuda=CUDA,
pooling_method=hp.pooling_method,
batch_first=True,
dropout=hp.dropout,
lstm_layer_dropout=hp.lstm_layer_dropout,
sent_enc_dropout=hp.sent_enc_dropout,
sent_enc_layers=hp.sent_enc_layers
)
if CUDA:
model.cuda()
if hp.write_mode != 'NONE':
logger.write_architecture(str(model))
logger.write_current_run_details(str(model))
if hp.model == 'transformer':
core_optimizer = torch.optim.Adam(
model.parameters(),
lr=0,
betas=(0.9, 0.98),
eps=1e-9
)
transformer_scheduler = TransformerScheduler(
1024,
factor=1,
warmup_steps=hp.warmup_iters
)
optimizer = ScheduledOptim(core_optimizer, transformer_scheduler)
else:
# optimizer = OptimWithDecay(model.parameters(),
# method=hp.optim,
# initial_lr=hp.learning_rate,
# max_grad_norm=hp.grad_clipping,
# lr_decay=hp.learning_rate_decay,
# start_decay_at=hp.start_decay_at,
# decay_every=hp.decay_every)
core_optimizer = torch.optim.Adam(
[param for param in model.parameters() if param.requires_grad],
lr=0,
)
max_iter = corpus.train_batches.num_batches * hp.epochs
slanted_triangular_scheduler = SlantedTriangularScheduler(
max_iter,
max_lr=hp.max_lr,
cut_fraction=0.1,
ratio=32
)
optimizer = ScheduledOptim(core_optimizer, slanted_triangular_scheduler)
loss_function = torch.nn.CrossEntropyLoss()
trainer = Trainer(model, optimizer, loss_function, num_epochs=hp.epochs,
use_cuda=CUDA, log_interval=hp.log_interval)
# FIXME: This test block of code looks ugly here <2018-06-29 11:41:51, Jorge Balazs>
if hp.test:
if hp.model_hash is None:
raise RuntimeError(
'You should have provided a pre-trained model hash with the '
'--model_hash flag'
)
print(f'Testing model {model_path}')
eval_dict = trainer.evaluate(corpus.test_batches)
probs = np_softmax(eval_dict['output'])
probs_filepath = os.path.join(ext_experiment_path,
'test_probabilities.csv')
np.savetxt(probs_filepath, probs,
delimiter=',', fmt='%.8f')
print(f'Saved prediction probs in {probs_filepath}')
labels_filepath = os.path.join(ext_experiment_path,
'test_predictions.txt')
labels = [label + '\n' for label in eval_dict['labels']]
with open(labels_filepath, 'w', encoding='utf-8') as f:
f.writelines(labels)
print(f'Saved prediction file in {labels_filepath}')
representations_filepath = os.path.join(
ext_experiment_path,
'sentence_representations.txt'
)
if hp.save_sent_reprs:
with open(representations_filepath, 'w', encoding='utf-8') as f:
np.savetxt(representations_filepath, eval_dict['sent_reprs'],
delimiter=' ', fmt='%.8f')
exit()
# Main Training Loop
writer = None
if hp.write_mode != 'NONE':
writer = SummaryWriter(logger.run_savepath)
try:
best_accuracy = None
for epoch in tqdm(range(hp.epochs), desc='Epoch'):
total_loss = 0
trainer.train_epoch(corpus.train_batches, epoch, writer)
corpus.train_batches.shuffle_examples()
eval_dict = trainer.evaluate(corpus.dev_batches, epoch, writer)
if hp.update_learning_rate and hp.model != 'transformer':
# hp.update_learning_rate is not supposed to be used with
# scheduled learning rates
optim_updated, new_lr = trainer.optimizer.updt_lr_accuracy(epoch, eval_dict['accuracy'])
# TODO: lr_threshold shouldn't be hardcoded
lr_threshold = 1e-5
if new_lr < lr_threshold:
tqdm.write(f'Learning rate smaller than {lr_threshold}, '
f'stopping.')
break
if optim_updated:
tqdm.write(f'Learning rate decayed to {new_lr}')
accuracy = eval_dict['accuracy']
if not best_accuracy or accuracy > best_accuracy:
best_accuracy = accuracy
logger.update_results({'best_valid_acc': best_accuracy,
'best_epoch': epoch})
if hp.write_mode != 'NONE':
probs = np_softmax(eval_dict['output'])
probs_filepath = os.path.join(logger.run_savepath,
'best_dev_probabilities.csv')
np.savetxt(probs_filepath, probs,
delimiter=',', fmt='%.8f')
labels_filepath = os.path.join(logger.run_savepath,
'best_dev_predictions.txt')
labels = [label + '\n' for label in eval_dict['labels']]
with open(labels_filepath, 'w', encoding='utf-8') as f:
f.writelines(labels)
if hp.save_model:
logger.torch_save_file('best_model_state_dict.pth',
model.state_dict(),
progress_bar=tqdm)
logger.torch_save_file('best_model.pth',
model,
progress_bar=tqdm)
except KeyboardInterrupt:
pass
finally:
if SAVE_IN_SPREADSHEET:
logger.insert_in_googlesheets()
if __name__ == '__main__':
main()