Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradients dropped by adapt #131

Open
mcabbott opened this issue Aug 22, 2022 · 1 comment
Open

Gradients dropped by adapt #131

mcabbott opened this issue Aug 22, 2022 · 1 comment

Comments

@mcabbott
Copy link
Member

Moving y to "gpu" inside loss causes its gradient to be lost:

julia> using Tracker, JLArrays

julia> JLArrays.allowscalar(false)

julia> Tracker.withgradient((x,y) -> sum(x[1:2] + jl(y))^2, jl([1,2,3.0]), [4,5.0])
(val = 144.0, grad = ([24.0, 24.0, 0.0], [0.0, 0.0])) 

julia> ans.grad[1] isa JLArray
true

unlike Zygote:

julia> Zygote.withgradient((x,y) -> sum(x[1:2] + jl(y))^2, jl([1,2,3.0]), [4,5.0])
(val = 144.0, grad = ([24.0, 24.0, 0.0], [24.0, 24.0]))
@ToucheSir
Copy link
Member

Riffing on a rule:

@grad function Adapt.adapt_storage(adaptor, x::AT) where {AT <: Array}
  adapt_storage_pullback(Δ) = (nothing, Adapt.adapt_storage(AT, Δ))
  return Adapt.adapt_storage(adaptor, x), adapt_storage_pullback
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants