-
Notifications
You must be signed in to change notification settings - Fork 156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Oversampling and undersampling #661
Comments
One drawback is that that |
any movement on this? This sounds like some data preparation utility that can be provided by a third party ML utility package? |
We have some functionality in ClassImbalance.jl, but there is not yet an MLJ interface for that package. Help is welcome. |
show stopper |
Yeah the package needs to be updated and modernized. |
I've implemented SMOTE in |
It's been a while since I posted the above POC. Here's an updated version, based on more recent versions of the packages, and some other mild changes. You'll need MLJBase >= 0.21.12, and MLJDecisionTreeInterface in your env. using MLJ, Tables
import MLJBase, StatsBase
## A QUICK AND DIRTY OVERSAMPLER FOR ILLUSTRATION
mutable struct NaiveOversampler <: Static
ratio::Float64
end
NaiveOversampler(; ratio=1.0) = NaiveOversampler(ratio)
function MLJBase.transform(oversampler::NaiveOversampler, verbosity, X, y)
d = StatsBase.countmap(y)
counts = sort(collect(d), by=pair->last(pair))
minority_class = first(counts) |> first
dominant_class = last(counts) |> first
nextras = max(
0,
round(Int, oversampler.ratio*d[dominant_class] - d[minority_class]),
)
all_indices = eachindex(y)
minority_indices = all_indices[y .== minority_class]
extra_indices = rand(minority_indices, nextras)
over_indices = vcat(all_indices, extra_indices)
Xover = Tables.subset(X, over_indices) |> Tables.materializer(X)
yover = y[over_indices]
return Xover, yover
end
# demonstration:
X = (x1=1:4, x2=5:8)
y = coerce([true, false, true, true], Multiclass)
StatsBase.countmap(y)
# Dict{CategoricalArrays.CategoricalValue{Bool, UInt32}, Int64} with 2 entries:
# false => 1
# true => 3
naive = NaiveOversampler()
mach = machine(naive) # static transformers have no training arguments
Xover, yover = transform(mach, X, y)
StatsBase.countmap(yover)
# Dict{CategoricalArrays.CategoricalValue{Bool, UInt32}, Int64} with 2 entries:
# false => 3
# true => 3
## COMPOSITE FOR WRAPPING A CLASSIFIER WITH OVERSAMPLING
# default component models for the wrapper:
naive = NaiveOversampler()
dummy = ConstantClassifier()
# we restrict to wrapping to `Probabilistic` models and so use
# `ProbablisticNetworkComposite` for the "exported" learning network type:
struct BalancedModel <:ProbabilisticNetworkComposite
model::Probabilistic
balancer # oversampler or undersampler
end
BalancedModel(; model=dummy, balancer=naive) =
BalancedModel(model, balancer)
BalancedModel(model; kwargs...) = BalancedModel(; model, kwargs...)
function MLJBase.prefit(over_sampled_model::BalancedModel, verbosity, _X, _y)
# the learning network:
X = source(_X)
y = source(_y)
mach1 = machine(:balancer) # `Static`, so no training arguments here
data = transform(mach1, X, y)
# `first` and `last` are overloaded for nodes, so we can do:
X_over = first(data)
y_over = last(data)
# we use the oversampled data for training:
mach2 = machine(:model, X_over, y_over)
# but consume new prodution data from the source:
yhat = predict(mach2, X)
# return the learning network interface:
return (; predict=yhat)
end
## DEMONSTRATION
# synthesize some synthetic data:
Xraw, yraw = make_moons(1000);
for_deletion = eachindex(yraw)[yraw .== 0][1:400]
to_keep = setdiff(eachindex(yraw), for_deletion)
X = Tables.rowtable(Xraw)[to_keep]
y = coerce(yraw[to_keep], OrderedFactor)
train, test = partition(eachindex(y), 0.6)
model = (@load DecisionTreeClassifier pkg=DecisionTree)()
balanced_model = BalancedModel(model)
# BalancedModel(
# model = DecisionTreeClassifier(
# max_depth = -1,
# min_samples_leaf = 1,
# min_samples_split = 2,
# min_purity_increase = 0.0,
# n_subfeatures = 0,
# post_prune = false,
# merge_purity_threshold = 1.0,
# display_depth = 5,
# feature_importance = :impurity,
# rng = Random._GLOBAL_RNG()),
# balancer = NaiveOversampler(
# ratio = 1.0))
mach = machine(balanced_model, X, y)
fit!(mach, rows=train)
predict(mach, rows=test[1:3])
# 3-element UnivariateFiniteVector{OrderedFactor{2}, String, UInt32, Float64}:
# UnivariateFinite{OrderedFactor{2}}(0=>1.0, 1=>0.0)
# UnivariateFinite{OrderedFactor{2}}(0=>0.0, 1=>1.0)
# UnivariateFinite{OrderedFactor{2}}(0=>0.0, 1=>1.0) |
A large number of oversampling/undersampling strategies, with MLJ interfaces, are now provided by Imbalance.jl, and a wrapper, Closing as complete. cc @EssamWissam |
https://imbalanced-learn.readthedocs.io/en/stable/over_sampling.html#over-sampling
edit (July 2023) An updated version of the POC below is later in this thread
This is just to kick off a discussion. I see oversampling/undersampling as transformers plus model wrappers. Here's a rough POC for this:
cc @DilumAluthge
The text was updated successfully, but these errors were encountered: