forked from jcjohnson/torch-rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
LanguageModel.lua
171 lines (139 loc) · 4.47 KB
/
LanguageModel.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
require 'torch'
require 'nn'
require 'VanillaRNN'
require 'LSTM'
local utils = require 'util.utils'
local LM, parent = torch.class('nn.LanguageModel', 'nn.Module')
function LM:__init(kwargs)
self.idx_to_token = utils.get_kwarg(kwargs, 'idx_to_token')
self.token_to_idx = {}
self.vocab_size = 0
for idx, token in pairs(self.idx_to_token) do
self.token_to_idx[token] = idx
self.vocab_size = self.vocab_size + 1
end
self.model_type = utils.get_kwarg(kwargs, 'model_type')
self.wordvec_dim = utils.get_kwarg(kwargs, 'wordvec_size')
self.rnn_size = utils.get_kwarg(kwargs, 'rnn_size')
self.num_layers = utils.get_kwarg(kwargs, 'num_layers')
self.dropout = utils.get_kwarg(kwargs, 'dropout')
local V, D, H = self.vocab_size, self.wordvec_dim, self.rnn_size
self.net = nn.Sequential()
self.rnns = {}
self.net:add(nn.LookupTable(V, D))
for i = 1, self.num_layers do
local prev_dim = H
if i == 1 then prev_dim = D end
local rnn
if self.model_type == 'rnn' then
rnn = nn.VanillaRNN(prev_dim, H)
elseif self.model_type == 'lstm' then
rnn = nn.LSTM(prev_dim, H)
end
rnn.remember_states = true
table.insert(self.rnns, rnn)
self.net:add(rnn)
if self.dropout > 0 then
self.net:add(nn.Dropout(self.dropout))
end
end
-- After all the RNNs run, we will have a tensor of shape (N, T, H);
-- we want to apply a 1D temporal convolution to predict scores for each
-- vocab element, giving a tensor of shape (N, T, V). Unfortunately
-- nn.TemporalConvolution is SUPER slow, so instead we will use a pair of
-- views (N, T, H) -> (NT, H) and (NT, V) -> (N, T, V) with a nn.Linear in
-- between. Unfortunately N and T can change on every minibatch, so we need
-- to set them in the forward pass.
self.view1 = nn.View(1, 1, -1):setNumInputDims(3)
self.view2 = nn.View(1, -1):setNumInputDims(2)
self.net:add(self.view1)
self.net:add(nn.Linear(H, V))
self.net:add(self.view2)
end
function LM:updateOutput(input)
local N, T = input:size(1), input:size(2)
self.view1:resetSize(N * T, -1)
self.view2:resetSize(N, T, -1)
return self.net:forward(input)
end
function LM:backward(input, gradOutput, scale)
return self.net:backward(input, gradOutput, scale)
end
function LM:parameters()
return self.net:parameters()
end
function LM:resetStates()
for i, rnn in ipairs(self.rnns) do
rnn:resetStates()
end
end
function LM:encode_string(s)
local encoded = torch.LongTensor(#s)
for i = 1, #s do
local token = s:sub(i, i)
local idx = self.token_to_idx[token]
assert(idx ~= nil, 'Got invalid idx')
encoded[i] = idx
end
return encoded
end
function LM:decode_string(encoded)
assert(torch.isTensor(encoded) and encoded:dim() == 1)
local s = ''
for i = 1, encoded:size(1) do
local idx = encoded[i]
local token = self.idx_to_token[idx]
s = s .. token
end
return s
end
--[[
Sample from the language model. Note that this will reset the states of the
underlying RNNs.
Inputs:
- init: String of length T0
- max_length: Number of characters to sample
Returns:
- sampled: (1, max_length) array of integers, where the first part is init.
--]]
function LM:sample(kwargs)
local T = utils.get_kwarg(kwargs, 'length', 100)
local start_text = utils.get_kwarg(kwargs, 'start_text', '')
local verbose = utils.get_kwarg(kwargs, 'verbose', 0)
local sample = utils.get_kwarg(kwargs, 'sample', 1)
local temperature = utils.get_kwarg(kwargs, 'temperature', 1)
local sampled = torch.LongTensor(1, T)
self:resetStates()
local scores, first_t
if #start_text > 0 then
if verbose > 0 then
print('Seeding with: "' .. start_text .. '"')
end
local x = self:encode_string(start_text):view(1, -1)
local T0 = x:size(2)
sampled[{{}, {1, T0}}]:copy(x)
scores = self:forward(x)[{{}, {T0, T0}}]
first_t = T0 + 1
else
if verbose > 0 then
print('Seeding with uniform probabilities')
end
local w = self.net:get(1).weight
scores = w.new(1, 1, self.vocab_size):fill(1)
first_t = 1
end
for t = first_t, T do
if sample == 0 then
local _, next_char = scores:max(3)
next_char = next_char[{{}, {}, 1}]
else
local probs = torch.div(scores, temperature):double():exp():squeeze()
probs:div(torch.sum(probs))
next_char = torch.multinomial(probs, 1):view(1, 1)
end
sampled[{{}, {t, t}}]:copy(next_char)
scores = self:forward(next_char)
end
self:resetStates()
return self:decode_string(sampled[1])
end