diff --git a/Project.toml b/Project.toml index d9a6ce453..ab4e1c905 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Zygote" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.71" +version = "0.6.72" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 7b070f730..c89c6ae66 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -282,7 +282,8 @@ z2d(::Tuple{Vararg{Nothing}}, ::Tuple) = NoTangent() # collapse all-zero case z2d(dx, ::Any) = dx z2d(dx::AbstractArray{<:Number}, primal::AbstractArray) = dx z2d(dx::AbstractArray{<:AbstractArray{<:Number}}, primal::AbstractArray) = dx -z2d(dx::AbstractArray, primal::AbstractArray) = map(z2d, dx, primal) +z2d(dx::AbstractArray, primal::AbstractArray) = isempty(dx) ? dx : map(Zygote.z2d, dx, primal) + #= # As an optimisation, we can convert by `reinterpret` for bitstypes, e.g. arrays of tuples of numbers function z2d(dx::AbstractArray{S}, primal::AbstractArray{P}) where {S,P} diff --git a/test/chainrules.jl b/test/chainrules.jl index 3d5fcb035..c3809b992 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -419,6 +419,10 @@ end @test z2d_compiled.d === z2d_fallback.d @test z2d_compiled.c.a === z2d_fallback.c.a @test z2d_compiled.c.b === z2d_fallback.c.b + + # empty dx => returns the dx + @test @inferred(Zygote.z2d(ones(1, 0), ones(16, 0))) === ones(1, 0) + @test @inferred(Zygote.z2d(Union{Nothing, Float64}[], ones(16, 0))) === Union{Nothing, Float64}[] end @testset "ChainRules translation" begin