From bcbbfabf9fb33aa14534e38a33d6ebab77911963 Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Sun, 24 Mar 2024 18:34:18 +1100 Subject: [PATCH 1/5] only project nothings in gradient --- src/compiler/interface.jl | 12 +++++++----- test/gradcheck.jl | 6 ++++++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 80fd9b477..5e5239743 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -114,8 +114,10 @@ sensitivity(y::AbstractArray) = error("Output is an array, so the gradient is no sensitivity(y) = error("Output should be scalar; gradients are not defined for output $(repr(y))") # Preserves output as tuple when gradients are collapsed -_project_all(::NTuple{N}, ::Nothing) where {N} = ntuple(_ -> nothing, N) -_project_all(x::Tuple, dx::Tuple) = map(_project, x, dx) +_project_nothings(::NTuple{N}, ::Nothing) where {N} = ntuple(_ -> nothing, N) +_project_nothings(x::Tuple, dx::Tuple) = map(x, dx) do _x, _dx + return _dx === nothing ? _project(_x, _dx) : _dx +end """ gradient(f, args...) @@ -146,7 +148,7 @@ julia> gradient([7, 11], 0, 1) do x, y, d function gradient(f, args...) y, back = pullback(f, args...) grad = back(sensitivity(y)) - return _project_all(args, grad) + return _project_nothings(args, grad) end # Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy! @@ -212,7 +214,7 @@ function withgradient(f, args...) else back(sensitivity(y)) end - results = _project_all(args, grad) + results = _project_nothings(args, grad) (val=y, grad=results) end @@ -473,7 +475,7 @@ function pullback(f, ps::Params) end # No conversion required here -_project_all(_, dx::Grads) = dx +_project_nothings(_, dx::Grads) = dx # Code Reflection diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 8cb7e6e1a..5be70b1c6 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -2125,3 +2125,9 @@ end @test gradient(x -> @.(x * x * x), 2.0) == gradient(x -> x * (x * x), 2.0) @test gradient(x -> @.(3.0*x*2.0*x), 2.0) == gradient(x -> 6(x^2), 2.0) end + +@testset "Sparse input" begin + g1 = Zygote.gradient(sum, zeros(1,1))[1] + g2 = Zygote.gradient(sum, spzeros(1,1))[1] + @test g1 == g2 +end From b82621d5ed2c47509008815510a78e44a459139d Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Sun, 24 Mar 2024 18:36:29 +1100 Subject: [PATCH 2/5] bump version to 0.7 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8ca9b5c1d..ff8195039 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Zygote" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.69" +version = "0.7.0" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" From b3e5b5207c4c071a1ba6e6b4cc235ffbb7271eed Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Sun, 24 Mar 2024 19:03:16 +1100 Subject: [PATCH 3/5] fix implementation --- src/compiler/interface.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 5e5239743..dac1221ce 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -114,10 +114,11 @@ sensitivity(y::AbstractArray) = error("Output is an array, so the gradient is no sensitivity(y) = error("Output should be scalar; gradients are not defined for output $(repr(y))") # Preserves output as tuple when gradients are collapsed -_project_nothings(::NTuple{N}, ::Nothing) where {N} = ntuple(_ -> nothing, N) -_project_nothings(x::Tuple, dx::Tuple) = map(x, dx) do _x, _dx - return _dx === nothing ? _project(_x, _dx) : _dx -end +_project_sentinel(::NTuple{N}, ::Nothing) where {N} = ntuple(_ -> nothing, N) +_project_sentinel(x::Tuple, dx::Tuple) = map(_project_sentinel, x, dx) +_project_sentinel(::Any, ::NoTangent) = nothing +_project_sentinel(::Any, ::ZeroTangent) = nothing +_project_sentinel(::Any, ::Nothing) = nothing """ gradient(f, args...) @@ -148,7 +149,7 @@ julia> gradient([7, 11], 0, 1) do x, y, d function gradient(f, args...) y, back = pullback(f, args...) grad = back(sensitivity(y)) - return _project_nothings(args, grad) + return _project_sentinel(args, grad) end # Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy! @@ -214,7 +215,7 @@ function withgradient(f, args...) else back(sensitivity(y)) end - results = _project_nothings(args, grad) + results = _project_sentinel(args, grad) (val=y, grad=results) end @@ -475,7 +476,7 @@ function pullback(f, ps::Params) end # No conversion required here -_project_nothings(_, dx::Grads) = dx +_project_sentinel(_, dx::Grads) = dx # Code Reflection From 6c5b17b6a0449aa00e0687fd0d144750794935eb Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Sun, 24 Mar 2024 19:05:07 +1100 Subject: [PATCH 4/5] add fallback --- src/compiler/interface.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index dac1221ce..1d825b8a9 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -119,6 +119,7 @@ _project_sentinel(x::Tuple, dx::Tuple) = map(_project_sentinel, x, dx) _project_sentinel(::Any, ::NoTangent) = nothing _project_sentinel(::Any, ::ZeroTangent) = nothing _project_sentinel(::Any, ::Nothing) = nothing +_project_sentinel(::Any, dx::Any) = dx """ gradient(f, args...) From ccae706687f45afce50af15de2b2f7a50b479330 Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Sun, 24 Mar 2024 21:13:07 +1100 Subject: [PATCH 5/5] rename to _project_grad and fix tests --- src/compiler/interface.jl | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 1d825b8a9..4708c0cc9 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -114,12 +114,14 @@ sensitivity(y::AbstractArray) = error("Output is an array, so the gradient is no sensitivity(y) = error("Output should be scalar; gradients are not defined for output $(repr(y))") # Preserves output as tuple when gradients are collapsed -_project_sentinel(::NTuple{N}, ::Nothing) where {N} = ntuple(_ -> nothing, N) -_project_sentinel(x::Tuple, dx::Tuple) = map(_project_sentinel, x, dx) -_project_sentinel(::Any, ::NoTangent) = nothing -_project_sentinel(::Any, ::ZeroTangent) = nothing -_project_sentinel(::Any, ::Nothing) = nothing -_project_sentinel(::Any, dx::Any) = dx +_project_grad(::NTuple{N}, ::Nothing) where {N} = ntuple(_ -> nothing, N) +_project_grad(x::Tuple, dx::Tuple) = map(_project_grad, x, dx) +_project_grad(::Any, ::NoTangent) = nothing +_project_grad(::Any, ::ZeroTangent) = nothing +_project_grad(::Any, ::Nothing) = nothing +_project_grad(::Any, dx::Any) = dx +_project_grad(x::AbstractArray, dx::Tuple) = _project(x, dx) +_project_grad(x::Any, dx::Base.RefValue) = _project(x, dx) """ gradient(f, args...) @@ -150,7 +152,7 @@ julia> gradient([7, 11], 0, 1) do x, y, d function gradient(f, args...) y, back = pullback(f, args...) grad = back(sensitivity(y)) - return _project_sentinel(args, grad) + return _project_grad(args, grad) end # Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy! @@ -216,7 +218,7 @@ function withgradient(f, args...) else back(sensitivity(y)) end - results = _project_sentinel(args, grad) + results = _project_grad(args, grad) (val=y, grad=results) end @@ -477,7 +479,7 @@ function pullback(f, ps::Params) end # No conversion required here -_project_sentinel(_, dx::Grads) = dx +_project_grad(_, dx::Grads) = dx # Code Reflection