Skip to content

Commit

Permalink
Fix GPU config
Browse files Browse the repository at this point in the history
  • Loading branch information
koki0702 committed May 10, 2018
1 parent 13ba33e commit e4bebdb
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 15 deletions.
7 changes: 2 additions & 5 deletions ch04/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,8 @@
corpus, word_to_id, id_to_word = ptb.load_data('train')
vocab_size = len(word_to_id)

if not config.GPU:
contexts, target = create_contexts_target(corpus, window_size=window_size)
else:
corpus = to_cpu(corpus)
contexts, target = create_contexts_target(corpus, window_size=window_size)
contexts, target = create_contexts_target(corpus, window_size)
if config.GPU:
contexts, target = to_gpu(contexts), to_gpu(target)

# モデルなどの生成
Expand Down
8 changes: 7 additions & 1 deletion ch06/train_better_rnnlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# ==============================================
from common.optimizer import SGD
from common.trainer import RnnlmTrainer
from common.util import eval_perplexity
from common.util import eval_perplexity, to_gpu
from dataset import ptb
from better_rnnlm import BetterRnnlm

Expand All @@ -27,6 +27,12 @@
corpus, word_to_id, id_to_word = ptb.load_data('train')
corpus_val, _, _ = ptb.load_data('val')
corpus_test, _, _ = ptb.load_data('test')

if config.GPU:
corpus = to_gpu(corpus)
corpus_val = to_gpu(corpus_val)
corpus_test = to_gpu(corpus_test)

vocab_size = len(word_to_id)
xs = corpus[:-1]
ts = corpus[1:]
Expand Down
9 changes: 0 additions & 9 deletions dataset/ptb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
raise ImportError('Use Python3!')
import pickle
import numpy as np
from common.config import GPU


url_base = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/'
Expand Down Expand Up @@ -86,9 +85,6 @@ def load_data(data_type='train'):

if os.path.exists(save_path):
corpus = np.load(save_path)
if GPU:
from common.util import to_gpu
corpus = to_gpu(corpus)
return corpus, word_to_id, id_to_word

file_name = key_file[data_type]
Expand All @@ -99,11 +95,6 @@ def load_data(data_type='train'):
corpus = np.array([word_to_id[w] for w in words])

np.save(save_path, corpus)

if GPU:
from common.util import to_gpu
corpus = to_gpu(corpus)

return corpus, word_to_id, id_to_word


Expand Down

0 comments on commit e4bebdb

Please sign in to comment.