diff --git a/src/composition/models/stacking.jl b/src/composition/models/stacking.jl index 3a5cb6aa..4a760e24 100644 --- a/src/composition/models/stacking.jl +++ b/src/composition/models/stacking.jl @@ -388,6 +388,7 @@ function internal_stack_report( # For each model we record the results mimicking the fields PerformanceEvaluation results = NamedTuple{modelnames}( [( + model = model, measure = stack.measures, measurement = Vector{Any}(undef, n_measures), operation = _actual_operations(nothing, stack.measures, model, verbosity), @@ -395,7 +396,9 @@ function internal_stack_report( per_observation = Vector{Union{Missing, Vector{Any}}}(missing, n_measures), fitted_params_per_fold = [], report_per_fold = [], - train_test_pairs = tt_pairs + train_test_pairs = tt_pairs, + resampling = stack.resampling, + repeats = 1 ) for model in getfield(stack, :models) ] diff --git a/src/resampling.jl b/src/resampling.jl index 1b74e042..43483cc3 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -474,6 +474,9 @@ be interpreted with caution. See, for example, Bates et al. These fields are part of the public API of the `PerformanceEvaluation` struct. +- `model`: model used to create the performance evaluation. In the case a + tuning model, this is the best model found. + - `measure`: vector of measures (metrics) used to evaluate performance - `measurement`: vector of measurements - one for each element of @@ -507,15 +510,22 @@ struct. - `train_test_rows`: a vector of tuples, each of the form `(train, test)`, where `train` and `test` are vectors of row (observation) indices for training and evaluation respectively. + +- `resampling`: the resampling strategy used to generate the train/test pairs. + +- `repeats`: the number of times the resampling strategy was repeated. """ struct PerformanceEvaluation{M, + Measure, Measurement, Operation, PerFold, PerObservation, FittedParamsPerFold, - ReportPerFold} <: MLJType - measure::M + ReportPerFold, + R} <: MLJType + model::M + measure::Measure measurement::Measurement operation::Operation per_fold::PerFold @@ -523,6 +533,8 @@ struct PerformanceEvaluation{M, fitted_params_per_fold::FittedParamsPerFold report_per_fold::ReportPerFold train_test_rows::TrainTestPairs + resampling::R + repeats::Int end # pretty printing: @@ -568,9 +580,9 @@ function Base.show(io::IO, ::MIME"text/plain", e::PerformanceEvaluation) println(io, "PerformanceEvaluation object "* "with these fields:") - println(io, " measure, operation, measurement, per_fold,\n"* + println(io, " model, measure, operation, measurement, per_fold,\n"* " per_observation, fitted_params_per_fold,\n"* - " report_per_fold, train_test_rows") + " report_per_fold, train_test_rows, resampling, repeats") println(io, "Extract:") show_color = MLJBase.SHOW_COLOR[] color_off() @@ -808,6 +820,22 @@ _process_accel_settings(accel) = throw(ArgumentError("unsupported" * # -------------------------------------------------------------- # User interface points: `evaluate!` and `evaluate` +""" + log_evaluation(logger, performance_evaluation) +Log a performance evaluation to `logger`, an object specific to some logging +platform, such as mlflow. If `logger=nothing` then no logging is performed. +The method is called at the end of every call to `evaluate/evaluate!` using +the logger provided by the `logger` keyword argument. +# Implementations for new logging platforms +# +Julia interfaces to workflow logging platforms, such as mlflow (provided by +the MLFlowClient.jl interface) should overload +`log_evaluation(logger::LoggerType, performance_evaluation)`, +where `LoggerType` is a platform-specific type for logger objects. For an +example, see the implementation provided by the MLJFlow.jl package. +""" +log_evaluation(logger, performance_evaluation) = nothing + """ evaluate!(mach, resampling=CV(), @@ -820,7 +848,8 @@ _process_accel_settings(accel) = throw(ArgumentError("unsupported" * acceleration=default_resource(), force=false, verbosity=1, - check_measure=true) + check_measure=true, + logger=nothing) Estimate the performance of a machine `mach` wrapping a supervised model in data, using the specified `resampling` strategy (defaulting @@ -919,6 +948,8 @@ untouched. - `check_measure` - default is `true` +- `logger` - a logger object (see [`MLJBase.log_evaluation`](@ref)) + ### Return value @@ -939,7 +970,8 @@ function evaluate!(mach::Machine{<:Measurable}; repeats=1, force=false, check_measure=true, - verbosity=1) + verbosity=1, + logger=nothing) # this method just checks validity of options, preprocess the # weights, measures, operations, and dispatches a @@ -981,7 +1013,8 @@ function evaluate!(mach::Machine{<:Measurable}; _acceleration= _process_accel_settings(acceleration) evaluate!(mach, resampling, weights, class_weights, rows, verbosity, - repeats, _measures, _operations, _acceleration, force) + repeats, _measures, _operations, _acceleration, force, logger, + resampling) end @@ -1158,9 +1191,10 @@ function measure_specific_weights(measure, weights, class_weights, test) end # Evaluation when `resampling` is a TrainTestPairs (CORE EVALUATOR): -function evaluate!(mach::Machine, resampling, weights, - class_weights, rows, verbosity, repeats, - measures, operations, acceleration, force) +# `user_resampling` keyword argument is the user defined resampling strategy +function evaluate!(mach::Machine, resampling, weights, class_weights, rows, + verbosity, repeats, measures, operations, acceleration, + force, logger, user_resampling) # Note: `rows` and `repeats` are ignored here @@ -1264,7 +1298,8 @@ function evaluate!(mach::Machine, resampling, weights, MLJBase.aggregate(per_fold[k], m) end - return PerformanceEvaluation( + evaluation = PerformanceEvaluation( + mach.model, measures, per_measure, operations, @@ -1272,9 +1307,13 @@ function evaluate!(mach::Machine, resampling, weights, per_observation, fitted_params_per_fold |> collect, report_per_fold |> collect, - resampling + resampling, + user_resampling, + repeats ) + log_evaluation(logger, evaluation) + evaluation end # ---------------------------------------------------------------- @@ -1293,7 +1332,7 @@ function evaluate!(mach::Machine, resampling::ResamplingStrategy, [train_test_pairs(resampling, _rows, train_args...) for i in 1:repeats]... ) - return evaluate!( + evaluate!( mach, repeated_train_test_pairs, weights, @@ -1303,7 +1342,6 @@ function evaluate!(mach::Machine, resampling::ResamplingStrategy, repeats, args... ) - end # ==================================================================== @@ -1319,7 +1357,8 @@ end operation=predict, repeats = 1, acceleration=default_resource(), - check_measure=true + check_measure=true, + logger=nothing ) Resampling model wrapper, used internally by the `fit` method of @@ -1354,7 +1393,7 @@ are not to be confused with any weights bound to a `Resampler` instance in a machine, used for training the wrapped `model` when supported. """ -mutable struct Resampler{S} <: Model +mutable struct Resampler{S, L} <: Model model resampling::S # resampling strategy measure @@ -1365,6 +1404,7 @@ mutable struct Resampler{S} <: Model check_measure::Bool repeats::Int cache::Bool + logger::L end # Some traits are markded as `missing` because we cannot determine @@ -1403,7 +1443,8 @@ function Resampler(; acceleration=default_resource(), check_measure=true, repeats=1, - cache=true + cache=true, + logger=nothing ) resampler = Resampler( model, @@ -1415,7 +1456,8 @@ function Resampler(; acceleration, check_measure, repeats, - cache + cache, + logger ) message = MLJModelInterface.clean!(resampler) isempty(message) || @warn message @@ -1460,7 +1502,9 @@ function MLJModelInterface.fit(resampler::Resampler, verbosity::Int, args...) _measures, _operations, _acceleration, - false + false, + resampler.logger, + resampler.resampling ) fitresult = (machine = mach, evaluation = e) @@ -1523,7 +1567,9 @@ function MLJModelInterface.update( measures, operations, acceleration, - false + false, + resampler.logger, + resampler.resampling ) report = (evaluation = e, ) fitresult = (machine=mach2, evaluation=e)