Skip to content

Commit

Permalink
Merge pull request #32 from JuliaAI/add-serialization-tests
Browse files Browse the repository at this point in the history
Add extra serialization test
  • Loading branch information
ablaom authored Mar 4, 2024
2 parents e63eadc + f9ab489 commit dda3b43
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions test/serialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,50 @@ end
@test predict(smach, X) == predict(mach, X)

rm(filename)
end

# define a supervised model with ephemeral `fitresult`, but which overcomes this by
# overloading `save`/`restore`:
thing = []
struct EphemeralRegressor <: Deterministic end
function MLJBase.fit(::EphemeralRegressor, verbosity, X, y)
# if I serialize/deserialized `thing` then `view` below changes:
view = objectid(thing)
fitresult = (thing, view, mean(y))
return fitresult, nothing, NamedTuple()
end
function MLJBase.predict(::EphemeralRegressor, fitresult, X)
thing, view, μ = fitresult
return view == objectid(thing) ? fill(μ, nrows(X)) :
throw(ErrorException("dead fitresult"))
end
MLJBase.target_scitype(::Type{<:EphemeralRegressor}) = AbstractVector{Continuous}
function MLJBase.save(::EphemeralRegressor, fitresult)
thing, _, μ = fitresult
return (thing, μ)
end
function MLJBase.restore(::EphemeralRegressor, serialized_fitresult)
thing, μ = serialized_fitresult
view = objectid(thing)
return (thing, view, μ)
end

@testset "serialization for atomic models with non-persistent fitresults" begin
# https://github.com/alan-turing-institute/MLJ.jl/issues/1099
X, y = (; x = rand(10)), fill(42.0, 3)
ensemble = EnsembleModel(
EphemeralRegressor(),
bagging_fraction=0.7,
n=2,
)
mach = machine(ensemble, X, y)
fit!(mach, verbosity=0)
io = IOBuffer()
MLJBase.save(io, mach)
seekstart(io)
mach2 = machine(io)
close(io)
@test MLJBase.predict(mach2, (; x = rand(2))) fill(42.0, 2)
end

end
Expand Down

0 comments on commit dda3b43

Please sign in to comment.