Skip to content

Commit

Permalink
Final changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ayush1999 committed Nov 10, 2018
1 parent 51d2953 commit 4f9d4ea
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion REQUIRE
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
julia 0.6
julia 1.0
ProtoBuf
BSON
Flux
Expand Down
14 changes: 7 additions & 7 deletions src/graph/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,19 +197,19 @@ ops[:LSTM] = function(params, ip...)
arg2 = vcall(reshape, ip[3], (4*len,3))
ip_ = vcall(reshape, ip[1], vcall(slice ,vcall(:size, ip[1]), 1, 2))

a = vcall(Flux.LSTMCell, arg1, arg2, zeros(len*4), zeros(len), zeros(len))
b = vcall(:LSTM ,a)
return vcall(b, ip_)
a = vcall(LSTM, arg1, arg2, zeros(len*4), zeros(len), zeros(len))

return vcall(a, ip_)
elseif length(ip) == 4
len = params[:hidden_size]
arg1 = vcall(reshape, ip[2], (4*len,3))
arg2 = vcall(reshape, ip[3], (4*len,4))
arg3 = ip[4][1:4*len]
b1 = vcall(reinterpret, Float32, vcall(zeros, 2))
a = vcall(Flux.LSTMCell, arg1, arg2, arg3, b1, b1)
b = vcall(:LSTM ,a)
b1 = vcall(:broadcast, Float32, vcall(reinterpret, Float32, vcall(zeros, 2)))
a = vcall(LSTM, arg1, arg2, arg3, b1, b1)

ip_ = vcall(reshape, ip[1], vcall(slice ,vcall(:size, ip[1]), 1, 2))
return vcall(b, ip_)
return vcall(a, ip_)
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ include("pooling.jl")
include("conv.jl")
include("reshape.jl")
include("arithmetic_ops.jl")
#include("lstm.jl")
include("lstm.jl")

end

0 comments on commit 4f9d4ea

Please sign in to comment.