diff --git a/docs/make.jl b/docs/make.jl index 9c2ac9ef..710e479c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -13,7 +13,7 @@ makedocs(; "Composition" => "composition.md", "Datasets" => "datasets.md", "Distributions" => "distributions.md", - "Utilities" => "utilities.md" + "Utilities" => "utilities.md", ], repo="https://$REPO/blob/{commit}{path}#L{line}", sitename="MLJBase.jl" diff --git a/src/composition/models/stacking.jl b/src/composition/models/stacking.jl index 3a5cb6aa..407a3ca0 100644 --- a/src/composition/models/stacking.jl +++ b/src/composition/models/stacking.jl @@ -388,6 +388,7 @@ function internal_stack_report( # For each model we record the results mimicking the fields PerformanceEvaluation results = NamedTuple{modelnames}( [( + model = model, measure = stack.measures, measurement = Vector{Any}(undef, n_measures), operation = _actual_operations(nothing, stack.measures, model, verbosity), diff --git a/src/resampling.jl b/src/resampling.jl index 1b74e042..d52b9b1b 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -474,6 +474,9 @@ be interpreted with caution. See, for example, Bates et al. These fields are part of the public API of the `PerformanceEvaluation` struct. +- `model`: model used to create the performance evaluation. In the case a + tuning model, this is the best model found. + - `measure`: vector of measures (metrics) used to evaluate performance - `measurement`: vector of measurements - one for each element of @@ -509,13 +512,15 @@ struct. training and evaluation respectively. """ struct PerformanceEvaluation{M, + Measure, Measurement, Operation, PerFold, PerObservation, FittedParamsPerFold, ReportPerFold} <: MLJType - measure::M + model::M + measure::Measure measurement::Measurement operation::Operation per_fold::PerFold @@ -568,7 +573,7 @@ function Base.show(io::IO, ::MIME"text/plain", e::PerformanceEvaluation) println(io, "PerformanceEvaluation object "* "with these fields:") - println(io, " measure, operation, measurement, per_fold,\n"* + println(io, " model, measure, operation, measurement, per_fold,\n"* " per_observation, fitted_params_per_fold,\n"* " report_per_fold, train_test_rows") println(io, "Extract:") @@ -807,6 +812,24 @@ _process_accel_settings(accel) = throw(ArgumentError("unsupported" * # -------------------------------------------------------------- # User interface points: `evaluate!` and `evaluate` +# +""" + log_evaluation(logger, performance_evaluation) + +Log a performance evaluation to `logger`, an object specific to some logging +platform, such as mlflow. If `logger=nothing` then no logging is performed. +The method is called at the end of every call to `evaluate/evaluate!` using +the logger provided by the `logger` keyword argument. + +# Implementations for new logging platforms +# +Julia interfaces to workflow logging platforms, such as mlflow (provided by +the MLFlowClient.jl interface) should overload +`log_evaluation(logger::LoggerType, performance_evaluation)`, +where `LoggerType` is a platform-specific type for logger objects. For an +example, see the implementation provided by the MLJFlow.jl package. +""" +log_evaluation(logger, performance_evaluation) = nothing """ evaluate!(mach, @@ -820,7 +843,8 @@ _process_accel_settings(accel) = throw(ArgumentError("unsupported" * acceleration=default_resource(), force=false, verbosity=1, - check_measure=true) + check_measure=true, + logger=nothing) Estimate the performance of a machine `mach` wrapping a supervised model in data, using the specified `resampling` strategy (defaulting @@ -919,6 +943,7 @@ untouched. - `check_measure` - default is `true` +- `logger` - a logger object (see [`MLJBase.log_evaluation`](@ref)) ### Return value @@ -939,7 +964,8 @@ function evaluate!(mach::Machine{<:Measurable}; repeats=1, force=false, check_measure=true, - verbosity=1) + verbosity=1, + logger=nothing) # this method just checks validity of options, preprocess the # weights, measures, operations, and dispatches a @@ -980,9 +1006,9 @@ function evaluate!(mach::Machine{<:Measurable}; _acceleration= _process_accel_settings(acceleration) - evaluate!(mach, resampling, weights, class_weights, rows, verbosity, - repeats, _measures, _operations, _acceleration, force) - + evaluate!(mach, resampling, weights, class_weights, rows, + verbosity, repeats, _measures, _operations, + _acceleration, force, logger) end """ @@ -1160,7 +1186,7 @@ end # Evaluation when `resampling` is a TrainTestPairs (CORE EVALUATOR): function evaluate!(mach::Machine, resampling, weights, class_weights, rows, verbosity, repeats, - measures, operations, acceleration, force) + measures, operations, acceleration, force, logger) # Note: `rows` and `repeats` are ignored here @@ -1264,7 +1290,8 @@ function evaluate!(mach::Machine, resampling, weights, MLJBase.aggregate(per_fold[k], m) end - return PerformanceEvaluation( + evaluation = PerformanceEvaluation( + mach.model, measures, per_measure, operations, @@ -1275,6 +1302,9 @@ function evaluate!(mach::Machine, resampling, weights, resampling ) + log_evaluation(logger, evaluation) + + evaluation end # ---------------------------------------------------------------- @@ -1319,7 +1349,8 @@ end operation=predict, repeats = 1, acceleration=default_resource(), - check_measure=true + check_measure=true, + logger=nothing, ) Resampling model wrapper, used internally by the `fit` method of @@ -1354,7 +1385,7 @@ are not to be confused with any weights bound to a `Resampler` instance in a machine, used for training the wrapped `model` when supported. """ -mutable struct Resampler{S} <: Model +mutable struct Resampler{S, L} <: Model model resampling::S # resampling strategy measure @@ -1365,6 +1396,7 @@ mutable struct Resampler{S} <: Model check_measure::Bool repeats::Int cache::Bool + logger::L end # Some traits are markded as `missing` because we cannot determine @@ -1403,7 +1435,8 @@ function Resampler(; acceleration=default_resource(), check_measure=true, repeats=1, - cache=true + cache=true, + logger=nothing ) resampler = Resampler( model, @@ -1415,7 +1448,8 @@ function Resampler(; acceleration, check_measure, repeats, - cache + cache, + logger ) message = MLJModelInterface.clean!(resampler) isempty(message) || @warn message @@ -1460,7 +1494,8 @@ function MLJModelInterface.fit(resampler::Resampler, verbosity::Int, args...) _measures, _operations, _acceleration, - false + false, + resampler.logger ) fitresult = (machine = mach, evaluation = e) @@ -1523,7 +1558,8 @@ function MLJModelInterface.update( measures, operations, acceleration, - false + false, + resampler.logger ) report = (evaluation = e, ) fitresult = (machine=mach2, evaluation=e)