Skip to content

Commit

Permalink
Moving params from MLJFlow.jl and preparing evaluate! to work with lo…
Browse files Browse the repository at this point in the history
…g_evaluation
  • Loading branch information
pebeto committed Aug 12, 2023
1 parent a86babf commit 979a82f
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 30 deletions.
2 changes: 1 addition & 1 deletion src/MLJBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ export TransformedTargetModel

# resampling.jl:
export ResamplingStrategy, Holdout, CV, StratifiedCV, TimeSeriesCV,
evaluate!, Resampler, PerformanceEvaluation, log_evaluation
evaluate!, Resampler, PerformanceEvaluation

# `MLJType` and the abstract `Model` subtypes are exported from within
# src/composition/abstract_types.jl
Expand Down
2 changes: 1 addition & 1 deletion src/composition/models/stacking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,8 @@ function internal_stack_report(
# For each model we record the results mimicking the fields PerformanceEvaluation
results = NamedTuple{modelnames}(
[(
measure = stack.measures,
model = model,
measure = stack.measures,
measurement = Vector{Any}(undef, n_measures),
operation = _actual_operations(nothing, stack.measures, model, verbosity),
per_fold = [Vector{Any}(undef, nfolds) for _ in 1:n_measures],
Expand Down
65 changes: 37 additions & 28 deletions src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -474,11 +474,11 @@ be interpreted with caution. See, for example, Bates et al.
These fields are part of the public API of the `PerformanceEvaluation`
struct.
- `measure`: vector of measures (metrics) used to evaluate performance
- `model`: model used to create the performance evaluation. In the case a
tuning model, this is the best model found.
- `measure`: vector of measures (metrics) used to evaluate performance
- `measurement`: vector of measurements - one for each element of
`measure` - aggregating the performance measurements over all
train/test pairs (folds). The aggregation method applied for a given
Expand Down Expand Up @@ -512,15 +512,15 @@ struct.
training and evaluation respectively.
"""
struct PerformanceEvaluation{M,
Model,
Measure,
Measurement,
Operation,
PerFold,
PerObservation,
FittedParamsPerFold,
ReportPerFold} <: MLJType
measure::M
model::Model
model::M
measure::Measure
measurement::Measurement
operation::Operation
per_fold::PerFold
Expand Down Expand Up @@ -573,9 +573,9 @@ function Base.show(io::IO, ::MIME"text/plain", e::PerformanceEvaluation)

println(io, "PerformanceEvaluation object "*
"with these fields:")
println(io, " measure, operation, measurement, per_fold,\n"*
println(io, " model, measure, operation, measurement, per_fold,\n"*
" per_observation, fitted_params_per_fold,\n"*
" report_per_fold, train_test_rows", "model")
" report_per_fold, train_test_rows")
println(io, "Extract:")
show_color = MLJBase.SHOW_COLOR[]
color_off()
Expand Down Expand Up @@ -812,6 +812,24 @@ _process_accel_settings(accel) = throw(ArgumentError("unsupported" *

# --------------------------------------------------------------
# User interface points: `evaluate!` and `evaluate`
#
"""
log_evaluation(logger, performance_evaluation)
Log a performance evaluation to `logger`, an object specific to some logging
platform, such as mlflow. If `logger=nothing` then no logging is performed.
The method is called at the end of every call to `evaluate/evaluate!` using
the logger provided by the `logger` keyword argument.
# Implementations for new logging platforms
#
Julia interfaces to workflow logging platforms, such as mlflow (provided by
the MLFlowClient.jl interface) should overload
`log_evaluation(logger::LoggerType, performance_evaluation)`,
where `LoggerType` is a platform-specific type for logger objects. For an
example, see the implementation provided by the MLJFlow.jl package.
"""
log_evaluation(logger, performance_evaluation) = nothing

"""
evaluate!(mach,
Expand All @@ -825,7 +843,8 @@ _process_accel_settings(accel) = throw(ArgumentError("unsupported" *
acceleration=default_resource(),
force=false,
verbosity=1,
check_measure=true)
check_measure=true,
logger=nothing)
Estimate the performance of a machine `mach` wrapping a supervised
model in data, using the specified `resampling` strategy (defaulting
Expand Down Expand Up @@ -924,6 +943,8 @@ untouched.
- `check_measure` - default is `true`
- `logger` - a logger object (see [`MLJBase.log_evaluation`](@ref))
### Return value
A [`PerformanceEvaluation`](@ref) object. See
Expand All @@ -943,7 +964,8 @@ function evaluate!(mach::Machine{<:Measurable};
repeats=1,
force=false,
check_measure=true,
verbosity=1)
verbosity=1,
logger=nothing)

# this method just checks validity of options, preprocess the
# weights, measures, operations, and dispatches a
Expand Down Expand Up @@ -984,9 +1006,12 @@ function evaluate!(mach::Machine{<:Measurable};

_acceleration= _process_accel_settings(acceleration)

evaluate!(mach, resampling, weights, class_weights, rows, verbosity,
repeats, _measures, _operations, _acceleration, force)
evaluation = evaluate!(mach, resampling, weights, class_weights, rows,
verbosity, repeats, _measures, _operations,
_acceleration, force)
log_evaluation(logger, evaluation)

evaluation
end

"""
Expand Down Expand Up @@ -1161,22 +1186,6 @@ function measure_specific_weights(measure, weights, class_weights, test)
return nothing
end

# Workflow logging interfaces, such as MLJFlow (MLFlow connection via
# MLFlowClient.jl), overload the following method but replace the `logger`
# argument with `logger::LoggerType`, where `LoggerType` is specific to the
# logging platform.
"""
log_evaluation(logger, performance_evaluation)
Logs a performance evaluation (the returning object from `evaluate!`)
to a logging platform. The default implementation does nothing.
Can be overloaded by logging interfaces, such as MLJFlow (MLFlow
connection via MLFlowClient.jl), which replace the `logger` argument
with `logger::LoggerType`, where `LoggerType` is specific to the logging
platform.
"""
log_evaluation(logger, performance_evaluation) = nothing

# Evaluation when `resampling` is a TrainTestPairs (CORE EVALUATOR):
function evaluate!(mach::Machine, resampling, weights,
class_weights, rows, verbosity, repeats,
Expand Down Expand Up @@ -1285,8 +1294,8 @@ function evaluate!(mach::Machine, resampling, weights,
end

return PerformanceEvaluation(
measures,
mach.model,
measures,
per_measure,
operations,
per_fold,
Expand Down
63 changes: 63 additions & 0 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -469,3 +469,66 @@ end

generate_name!(model, existing_names; kwargs...) =
generate_name!(typeof(model), existing_names; kwargs...)

isamodel(::Any) = false
isamodel(::Model) = true

"""
deep_params(m::Model)
Recursively convert any object subtyping `Model` into a named tuple,
keyed on the property names of `m`. The named tuple is possibly nested
because `deep_params` is recursively applied to the property values, which
themselves might subtype `Model`.
For most `Model` objects, properties are synonymous with fields, but
this is not a hard requirement.
julia> deep_params(EnsembleModel(atom=ConstantClassifier()))
(atom = (target_type = Bool,),
weights = Float64[],
bagging_fraction = 0.8,
rng_seed = 0,
n = 100,
parallel = true,)
"""
deep_params(m) = deep_params(m, Val(isamodel(m)))
deep_params(m, ::Val{false}) = m
function deep_params(m, ::Val{true})
fields = propertynames(m)
NamedTuple{fields}(Tuple([deep_params(getproperty(m, field))
for field in fields]))
end

"""
flat_params(t::NamedTuple)
View a nested named tuple `t` as a tree and return, as a Dict, the key subtrees
and the values at the leaves, in the order they appear in the original tuple.
```julia-repl
julia> t = (X = (x = 1, y = 2), Y = 3)
julia> flat_params(t)
LittleDict{...} with 3 entries:
"X__x" => 1
"X__y" => 2
"Y" => 3
```
"""
function flat_params(parameters::NamedTuple)
result = LittleDict{String, Any}()
for key in keys(parameters)
value = params(getproperty(parameters, key))
if value isa NamedTuple
sub_dict = flat_params(value)
for (sub_key, sub_value) in pairs(sub_dict)
new_key = string(key, "__", sub_key)
result[new_key] = sub_value
end
else
result[string(key)] = value
end
end
return result
end
11 changes: 11 additions & 0 deletions test/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@ struct Baz <: Foo end
@test flat_values(t) == (1, 2, 3)
end

@testset "flattening parameters" begin
t = (a = (ax = (ax1 = 1, ax2 = 2), ay = 3), b = 4)
dict_t = Dict(
"a__ax__ax1" => 1,
"a__ax__ax2" => 2,
"a__ay" => 3,
"b" => 4,
)
@test flat_params(t) == dict_t
end

mutable struct M
a1
a2
Expand Down

0 comments on commit 979a82f

Please sign in to comment.