Skip to content

Commit

Permalink
Merge pull request #30 from JuliaAI/measures
Browse files Browse the repository at this point in the history
Adapt to migration of measures MLJBase -> StatisticalMeasures
  • Loading branch information
ablaom authored Sep 25, 2023
2 parents 2f501db + c4aba20 commit e63eadc
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 70 deletions.
18 changes: 15 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,29 +1,41 @@
name = "MLJEnsembles"
uuid = "50ed68f4-41fd-4504-931a-ed422449fee0"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.3.3"
version = "0.4.0"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e"
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161"
StatisticalMeasuresBase = "c062fc1d-0d66-479b-b6ac-8b44719de4cc"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
CategoricalArrays = "0.8, 0.9, 0.10"
CategoricalDistributions = "0.1.2"
ComputationalResources = "0.3"
Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
MLJBase = "0.20, 0.21"
MLJModelInterface = "0.4.1, 1.1"
ProgressMeter = "1.1"
ScientificTypesBase = "2,3"
StatisticalMeasuresBase = "0.1"
StatsBase = "0.32, 0.33, 0.34"
julia = "1.6"

[extras]
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Distances", "MLJBase", "NearestNeighbors", "Serialization", "StableRNGs", "StatisticalMeasures", "Test"]
2 changes: 1 addition & 1 deletion src/MLJEnsembles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ module MLJEnsembles

using MLJModelInterface
import MLJModelInterface: predict, fit, save, restore
import MLJBase # still needed for aggregating measures in oob-estimates of error
using Random
using CategoricalArrays
using CategoricalDistributions
Expand All @@ -11,6 +10,7 @@ using Distributed
import Distributions
using ProgressMeter
import StatsBase
import StatisticalMeasuresBase

export EnsembleModel

Expand Down
125 changes: 76 additions & 49 deletions src/ensembles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,10 @@ If a single measure or non-empty vector of measures is specified by
written to the training report (call `report` on the trained
machine wrapping the ensemble model).
*Important:* If sample weights `w` (not to be confused with atomic
weights) are specified when constructing a machine for the ensemble
model, as in `mach = machine(ensemble_model, X, y, w)`, then `w` is
used by any measures specified in `out_of_bag_measure` that support
sample weights.
*Important:* If per-observation or class weights `w` (not to be confused with atomic
weights) are specified when constructing a machine for the ensemble model, as in `mach =
machine(ensemble_model, X, y, w)`, then `w` is used by any measures specified in
`out_of_bag_measure` that support them.
"""
function EnsembleModel(
Expand Down Expand Up @@ -395,34 +394,56 @@ function _fit(res::CPUProcesses, func, verbosity, stuff)
if i != nworkers()
func(atom, 0, chunk_size, n_patterns, n_train, rng, progress_meter, args...)
else
func(atom, 0, chunk_size + left_over, n_patterns, n_train, rng, progress_meter, args...)
func(
atom,
0,
chunk_size + left_over,
n_patterns,
n_train,
rng,
progress_meter,
args...,
)
end
end
end

@static if VERSION >= v"1.3.0-DEV.573"
function _fit(res::CPUThreads, func, verbosity, stuff)
atom, n, n_patterns, n_train, rng, progress_meter, args = stuff
if verbosity > 0
println("Ensemble-building in parallel on $(Threads.nthreads()) threads.")
end
nthreads = Threads.nthreads()
chunk_size = div(n, nthreads)
left_over = mod(n, nthreads)
resvec = Vector(undef, nthreads) # FIXME: Make this type-stable?

Threads.@threads for i = 1:nthreads
resvec[i] = if i != nworkers()
func(atom, 0, chunk_size, n_patterns, n_train, rng, progress_meter, args...)
else
func(atom, 0, chunk_size + left_over, n_patterns, n_train, rng, progress_meter, args...)
end
end
function _fit(res::CPUThreads, func, verbosity, stuff)
atom, n, n_patterns, n_train, rng, progress_meter, args = stuff
if verbosity > 0
println("Ensemble-building in parallel on $(Threads.nthreads()) threads.")
end
nthreads = Threads.nthreads()
chunk_size = div(n, nthreads)
left_over = mod(n, nthreads)
resvec = Vector(undef, nthreads) # FIXME: Make this type-stable?

return reduce(_reducer, resvec)
Threads.@threads for i = 1:nthreads
resvec[i] = if i != nworkers()
func(atom, 0, chunk_size, n_patterns, n_train, rng, progress_meter, args...)
else
func(
atom,
0,
chunk_size + left_over,
n_patterns,
n_train,
rng,
progress_meter,
args...,
)
end
end

return reduce(_reducer, resvec)
end

# for subsampling weights, which could be `nothing`, per-observation weights, or
# class_weights:
_view(class_weights::AbstractDict, rows) = class_weights
_view(::Nothing, rows) = nothing
_view(weights, rows) = view(weights, rows)

function MMI.fit(
model::EitherEnsembleModel{Atom}, verbosity::Int, args...
) where Atom<:Supervised
Expand All @@ -446,10 +467,14 @@ function MMI.fit(
acceleration = CPU1()
end

# we wrap the measures in `robust_measure` so they can be called with weights, even
# when they don't support them, and just ignore them silently.
if model.out_of_bag_measure isa Vector
out_of_bag_measure = model.out_of_bag_measure
out_of_bag_measure =
StatisticalMeasuresBase.robust_measure.(model.out_of_bag_measure)
else
out_of_bag_measure = [model.out_of_bag_measure,]
out_of_bag_measure =
[StatisticalMeasuresBase.robust_measure(model.out_of_bag_measure),]
end

if model.rng isa Integer
Expand Down Expand Up @@ -484,7 +509,7 @@ function MMI.fit(

if !isempty(out_of_bag_measure)

metrics=zeros(length(ensemble),length(out_of_bag_measure))
measurements=zeros(length(ensemble),length(out_of_bag_measure))
for i= 1:length(ensemble)
#oob indices
ooB_indices= setdiff(1:n_patterns, ensemble_indices[i])
Expand All @@ -493,42 +518,44 @@ function MMI.fit(
"Data size too small or "*
"bagging_fraction too close to 1.0. ")
end
yhat = predict(atom, ensemble[i], selectrows(atom, ooB_indices, atom_specific_X)...)
yhat = predict(
atom,
ensemble[i],
selectrows(atom, ooB_indices, atom_specific_X)...,
)
Xtest = selectrows(X, ooB_indices)
ytest = selectrows(y, ooB_indices)

if w === nothing
wtest = nothing
else
wtest = selectrows(w, ooB_indices)
end
# this could be class weights OR per-observation weights, OR `nothing`:
wtest = _view(w, ooB_indices)

for k in eachindex(out_of_bag_measure)
m = out_of_bag_measure[k]
if MMI.reports_each_observation(m)
s = MLJBase.aggregate(
MLJBase.value(m, yhat, Xtest, ytest, wtest),
m
)
else
s = MLJBase.value(m, yhat, Xtest, ytest, wtest)
end
metrics[i,k] = s
s = m(yhat, ytest, wtest)
measurements[i,k] = s
end
end

# aggregate metrics across the ensembles:
aggregated_metrics = map(eachindex(out_of_bag_measure)) do k
MLJBase.aggregate(metrics[:,k], out_of_bag_measure[k])
# aggregate measurements across the ensembles:
aggregated_measurements = map(eachindex(out_of_bag_measure)) do k
StatisticalMeasuresBase.aggregate(
measurements[:,k],
mode=StatisticalMeasuresBase.external_aggregation_mode(
out_of_bag_measure[k],
)
)
end

names = Symbol.(string.(out_of_bag_measure))

else
aggregated_metrics = missing
aggregated_measurements = missing
end

report=(measures=out_of_bag_measure, oob_measurements=aggregated_metrics,)
report=(
measures=out_of_bag_measure,
oob_measurements=aggregated_measurements,
)
cache = deepcopy(model)

return fitresult, cache, report
Expand All @@ -542,7 +569,7 @@ function MMI.update(model::EitherEnsembleModel,

n = model.n

if MLJBase.is_same_except(model.model, old_model.model,
if MMI.is_same_except(model.model, old_model.model,
:n, :atomic_weights, :acceleration)
if n > old_model.n
verbosity < 1 ||
Expand Down
16 changes: 0 additions & 16 deletions test/Project.toml

This file was deleted.

2 changes: 1 addition & 1 deletion test/ensembles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using MLJBase
using ..Models
using CategoricalArrays
import Distributions

using StatisticalMeasures

## HELPER FUNCTIONS

Expand Down

0 comments on commit e63eadc

Please sign in to comment.