From f9ab4896c9586122637f5ac68a5257d773c630c7 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 4 Mar 2024 15:11:12 +1300 Subject: [PATCH] add extra serialization test --- test/serialization.jl | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/test/serialization.jl b/test/serialization.jl index d77370e..ab07097 100644 --- a/test/serialization.jl +++ b/test/serialization.jl @@ -58,7 +58,50 @@ end @test predict(smach, X) == predict(mach, X) rm(filename) +end +# define a supervised model with ephemeral `fitresult`, but which overcomes this by +# overloading `save`/`restore`: +thing = [] +struct EphemeralRegressor <: Deterministic end +function MLJBase.fit(::EphemeralRegressor, verbosity, X, y) + # if I serialize/deserialized `thing` then `view` below changes: + view = objectid(thing) + fitresult = (thing, view, mean(y)) + return fitresult, nothing, NamedTuple() +end +function MLJBase.predict(::EphemeralRegressor, fitresult, X) + thing, view, μ = fitresult + return view == objectid(thing) ? fill(μ, nrows(X)) : + throw(ErrorException("dead fitresult")) +end +MLJBase.target_scitype(::Type{<:EphemeralRegressor}) = AbstractVector{Continuous} +function MLJBase.save(::EphemeralRegressor, fitresult) + thing, _, μ = fitresult + return (thing, μ) +end +function MLJBase.restore(::EphemeralRegressor, serialized_fitresult) + thing, μ = serialized_fitresult + view = objectid(thing) + return (thing, view, μ) +end + +@testset "serialization for atomic models with non-persistent fitresults" begin + # https://github.com/alan-turing-institute/MLJ.jl/issues/1099 + X, y = (; x = rand(10)), fill(42.0, 3) + ensemble = EnsembleModel( + EphemeralRegressor(), + bagging_fraction=0.7, + n=2, + ) + mach = machine(ensemble, X, y) + fit!(mach, verbosity=0) + io = IOBuffer() + MLJBase.save(io, mach) + seekstart(io) + mach2 = machine(io) + close(io) + @test MLJBase.predict(mach2, (; x = rand(2))) ≈ fill(42.0, 2) end end