From 84805a411d64cf782359a231cf6f86f4eede31c1 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Fri, 4 Oct 2024 16:23:56 -0400 Subject: [PATCH 1/7] check shapes, not length, to account for arrays with 0-length dims --- src/lib/broadcast.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 504ef614d..48516db3b 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -55,7 +55,7 @@ end function unbroadcast(x::AbstractArray, x̄) N = ndims(x̄) - if length(x) == length(x̄) + if size(x) == size(x̄) _project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors else dims = ntuple(d -> size(x, d) == 1 ? d : ndims(x̄)+1, ndims(x̄)) From d013ac74dc9d51ecb64c1596393ed17da08cbfc2 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Fri, 4 Oct 2024 16:43:26 -0400 Subject: [PATCH 2/7] Change z2d instead --- src/compiler/chainrules.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 7b070f730..5d14648aa 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -282,7 +282,8 @@ z2d(::Tuple{Vararg{Nothing}}, ::Tuple) = NoTangent() # collapse all-zero case z2d(dx, ::Any) = dx z2d(dx::AbstractArray{<:Number}, primal::AbstractArray) = dx z2d(dx::AbstractArray{<:AbstractArray{<:Number}}, primal::AbstractArray) = dx -z2d(dx::AbstractArray, primal::AbstractArray) = map(z2d, dx, primal) +z2d(dx::AbstractArray, primal::AbstractArray) = isempty(dx) ? NoTangent() : map(Zygote.z2d, dx, primal) + #= # As an optimisation, we can convert by `reinterpret` for bitstypes, e.g. arrays of tuples of numbers function z2d(dx::AbstractArray{S}, primal::AbstractArray{P}) where {S,P} From fc37e47097b7f097ac1968201c3ffcb5cb5b34ec Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Fri, 4 Oct 2024 16:45:21 -0400 Subject: [PATCH 3/7] Revert "check shapes, not length, to account for arrays with 0-length dims" This reverts commit 84805a411d64cf782359a231cf6f86f4eede31c1. --- src/lib/broadcast.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 48516db3b..504ef614d 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -55,7 +55,7 @@ end function unbroadcast(x::AbstractArray, x̄) N = ndims(x̄) - if size(x) == size(x̄) + if length(x) == length(x̄) _project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors else dims = ntuple(d -> size(x, d) == 1 ? d : ndims(x̄)+1, ndims(x̄)) From 9f59e801ecb4dd9e2ab230176e86d26fa26ab434 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Fri, 4 Oct 2024 16:49:17 -0400 Subject: [PATCH 4/7] Add test --- test/chainrules.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/chainrules.jl b/test/chainrules.jl index 3d5fcb035..1835e7238 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -419,6 +419,9 @@ end @test z2d_compiled.d === z2d_fallback.d @test z2d_compiled.c.a === z2d_fallback.c.a @test z2d_compiled.c.b === z2d_fallback.c.b + + # empty arrays => NoTangent() + @test z2d(ones(1, 0), ones(16, 0)) === NoTangent() end @testset "ChainRules translation" begin From 96da4310cbeb9dedf5a8f96c6d5196a5538435c2 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Fri, 4 Oct 2024 16:49:47 -0400 Subject: [PATCH 5/7] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d9a6ce453..ab4e1c905 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Zygote" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.71" +version = "0.6.72" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" From 5c6db01e6fe929173c0491435566af491565fbd2 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Fri, 4 Oct 2024 18:51:23 -0400 Subject: [PATCH 6/7] Fix test --- test/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/chainrules.jl b/test/chainrules.jl index 1835e7238..9d165d4bf 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -421,7 +421,7 @@ end @test z2d_compiled.c.b === z2d_fallback.c.b # empty arrays => NoTangent() - @test z2d(ones(1, 0), ones(16, 0)) === NoTangent() + @test @inferred(Zygote.z2d(ones(1, 0), ones(16, 0))) === NoTangent() end @testset "ChainRules translation" begin From 8d24a29bd4a6d47d9d1d5544d837ac2783e301a3 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Mon, 7 Oct 2024 10:37:17 -0400 Subject: [PATCH 7/7] Correct behavior --- src/compiler/chainrules.jl | 2 +- test/chainrules.jl | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 5d14648aa..c89c6ae66 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -282,7 +282,7 @@ z2d(::Tuple{Vararg{Nothing}}, ::Tuple) = NoTangent() # collapse all-zero case z2d(dx, ::Any) = dx z2d(dx::AbstractArray{<:Number}, primal::AbstractArray) = dx z2d(dx::AbstractArray{<:AbstractArray{<:Number}}, primal::AbstractArray) = dx -z2d(dx::AbstractArray, primal::AbstractArray) = isempty(dx) ? NoTangent() : map(Zygote.z2d, dx, primal) +z2d(dx::AbstractArray, primal::AbstractArray) = isempty(dx) ? dx : map(Zygote.z2d, dx, primal) #= # As an optimisation, we can convert by `reinterpret` for bitstypes, e.g. arrays of tuples of numbers diff --git a/test/chainrules.jl b/test/chainrules.jl index 9d165d4bf..c3809b992 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -420,8 +420,9 @@ end @test z2d_compiled.c.a === z2d_fallback.c.a @test z2d_compiled.c.b === z2d_fallback.c.b - # empty arrays => NoTangent() - @test @inferred(Zygote.z2d(ones(1, 0), ones(16, 0))) === NoTangent() + # empty dx => returns the dx + @test @inferred(Zygote.z2d(ones(1, 0), ones(16, 0))) === ones(1, 0) + @test @inferred(Zygote.z2d(Union{Nothing, Float64}[], ones(16, 0))) === Union{Nothing, Float64}[] end @testset "ChainRules translation" begin