-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
119 lines (71 loc) · 2.55 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from __future__ import print_function
import os, time, argparse
from tqdm import tqdm
import numpy
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable as Var
#import utils
import gc
import sys
import Constants
from model import *
from tree import Tree
from vocab import Vocab
from dataset import *
from trainer import Trainer
from config import parse_args
# MAIN BLOCK
def main():
args=parse_args()
print(args)
num_classes = 7
data_dir = args.data_dir #,'train_texts.blk')
train_file=os.path.join(data_dir,'train_data.pth')
#val_dir = args.val_data #'val_texts.blk')
val_file= os.path.join(data_dir,'val_data.pth')
vocab_file="../data/vocab.txt"
vocab = Vocab(filename=vocab_file)
if os.path.isfile(train_file):
train_dataset = torch.load(train_file)
else:
train_dataset = WebKbbDataset(vocab, num_classes,os.path.join(data_dir,'train_texts.blk'),os.path.join(data_dir,'train_labels.blk'))
torch.save(train_dataset, train_file)
if os.path.isfile(val_file):
val_dataset = torch.load(val_file)
else:
val_dataset = WebKbbDataset(vocab, num_classes,os.path.join(data_dir,'val_texts.blk'),os.path.join(data_dir,'val_labels.blk'))
torch.save(val_dataset, val_file)
vocab_size=vocab.size()
in_dim=200
mem_dim=200
hidden_dim=200
num_classes=7
sparsity=True
freeze=args.freeze_emb
epochs=args.epochs
lr=args.lr
pretrain=args.pretrain
cuda_flag=True
if not torch.cuda.is_available():
cuda_flag=False
model = DomTreeLSTM(vocab_size,in_dim, mem_dim, hidden_dim, num_classes, sparsity, freeze)
criterion = nn.CrossEntropyLoss()
if pretrain:
emb_file = os.path.join('../data', 'emb.pth')
if os.path.isfile(emb_file):
emb = torch.load(emb_file)
print(emb.size())
print("Embedding weights loaded")
else:
print("Embedding file not found")
model.emb.weight.data.copy_(emb)
optimizer = optimizer = optim.Adagrad(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
trainer = Trainer(model, criterion, optimizer,train_dataset,val_dataset,cuda_flag=cuda_flag)
for epoch in range(epochs):
trainer.train()
trainer.test()
#trainer.train(train_dataset)
if __name__ == "__main__":
main()