diff --git a/dataHelper.py b/dataHelper.py index b11383d..9d9e0a8 100644 --- a/dataHelper.py +++ b/dataHelper.py @@ -11,7 +11,7 @@ import pickle from utils import log_time_delta from tqdm import tqdm -from dataloader import Dataset +from dataloader import Dataset, Glove import torch from torch.autograd import Variable from codecs import open @@ -129,7 +129,7 @@ def load_text_vec(alphabet,filename="",embedding_size=-1): def getEmbeddingFile(name): #"glove" "w2v" - return os.path.join( ".vector_cache","glove.6B.300d.txt") + return os.path.join( ".vector_cache", "6b", "glove.6B.300d.txt") def getDataSet(opt): @@ -177,6 +177,7 @@ def loadData(opt): # from functools import reduce # word_set=set(reduce(lambda x,y :x+y,df["text"])) + Glove(corpus="6b", dim=300).process() glove_file = getEmbeddingFile(opt.__dict__.get("embedding","glove_6b_300")) loaded_vectors,embedding_size = load_text_vec(word_set,glove_file) word_set = word_set & set(loaded_vectors.keys()) diff --git a/main.py b/main.py index 50b589b..cd9daa8 100644 --- a/main.py +++ b/main.py @@ -38,7 +38,7 @@ if "CUDA_VISIBLE_DEVICES" not in os.environ.keys(): os.environ["CUDA_VISIBLE_DEVICES"] =opt.gpu #opt.model ='lstm' -opt.model ='capsule' +#opt.model ='capsule' if from_torchtext: train_iter, test_iter = utils.loadData(opt)