diff --git a/Project.toml b/Project.toml index a24fe91c..bf730f53 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Metalhead" uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" -version = "0.8.4" +version = "0.9.0" [deps] Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" diff --git a/src/layers/drop.jl b/src/layers/drop.jl index 8593f930..596563f0 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -90,12 +90,11 @@ ChainRulesCore.@non_differentiable _dropblock_checks(x, drop_block_prob, gamma_s function (m::DropBlock)(x) _dropblock_checks(x, m.drop_block_prob, m.gamma_scale) - return Flux._isactive(m) ? - dropblock(m.rng, x, m.drop_block_prob, m.block_size, m.gamma_scale) : x + return dropblock(m.rng, x, m.drop_block_prob * Flux._isactive(m, x), m.block_size, m.gamma_scale) end function Flux.testmode!(m::DropBlock, mode = true) - return (m.active = (isnothing(mode) || mode === :auto) ? nothing : !mode; m) + return (m.active = isnothing(Flux._tidy_active(mode)) ? nothing : !mode; m) end function DropBlock(drop_block_prob = 0.1, block_size::Integer = 7, gamma_scale = 1.0,