Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question on the use of the Update! method and is_same_except() #212

Open
pasq-cat opened this issue Sep 24, 2024 · 26 comments
Open

Question on the use of the Update! method and is_same_except() #212

pasq-cat opened this issue Sep 24, 2024 · 26 comments

Comments

@pasq-cat
Copy link

Hi, i was trying to implement the update method for laplaceredux but I am having a problem.

this is the model

MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilistic
    model::Flux.Chain = nothing
    flux_loss = Flux.Losses.mse
    optimiser = Adam()
    epochs::Integer = 1000::(_ > 0)
    batch_size::Integer = 32::(_ > 0)
    subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
    subnetwork_indices = nothing
    hessian_structure::Union{HessianStructure,Symbol,String} =
        :full::(_ in (:full, :diagonal))
    backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))
    σ::Float64 = 1.0
    μ₀::Float64 = 0.0
    P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
    fit_prior_nsteps::Int = 100::(_ > 0)
end

this is the fit function that i have written

function MMI.fit(m::LaplaceRegressor, verbosity, X, y)

    #X = MLJBase.matrix(X) |> permutedims
    #y = reshape(y, 1, :)

    if Tables.istable(X)
        X = Tables.matrix(X)|>permutedims
    end

    # Reshape y if necessary
    y = reshape(y, 1, :)

    data_loader = Flux.DataLoader((X, y); batchsize=m.batch_size)
    opt_state = Flux.setup(m.optimiser, m.model)
    loss_history=[]
    push!(loss_history, m.flux_loss(m.model(X), y ))

    for epoch in 1:(m.epochs)

        loss_per_epoch= 0.0


        for (X_batch, y_batch) in data_loader
            # Forward pass: compute predictions
            y_pred = m.model(X_batch)

            # Compute loss
            loss = m.flux_loss(y_pred, y_batch)

            # Compute gradients 
            grads = gradient(m.model) do model
                # Recompute predictions inside gradient context
                y_pred = model(X_batch)
                m.flux_loss(y_pred, y_batch)
            end
            
            # Update parameters using the optimizer and computed gradients
            Flux.Optimise.update!(opt_state ,m.model , grads[1])

            # Accumulate the loss for this batch
            loss_per_epoch += sum(loss)  # Summing the batch loss
            
        end

        push!(loss_history,loss_per_epoch )

        # Print loss every 100 epochs if verbosity is 1 or more
        if verbosity >= 1 && epoch % 100 == 0
            println("Epoch $epoch: Loss: $loss_per_epoch ")
        end
    end

    la = LaplaceRedux.Laplace(
        m.model;
        likelihood=:regression,
        subset_of_weights=m.subset_of_weights,
        subnetwork_indices=m.subnetwork_indices,
        hessian_structure=m.hessian_structure,
        backend=m.backend,
        σ=m.σ,
        μ₀=m.μ₀,
        P₀=m.P₀,
    )

    # fit the Laplace model:
    LaplaceRedux.fit!(la, data_loader)
    optimize_prior!(la; verbose=false, n_steps=m.fit_prior_nsteps)

    fitresult = la
    report = (loss_history = loss_history,)
    cache = (deepcopy(m),opt_state, loss_history)
    return fitresult, cache, report
end

and now follows the incomplete update function that i was trying. I have removed the loop part since it's not important.

function MMI.update(m::LaplaceRegressor, verbosity, old_fitresult, old_cache, X, y)

println(" running MMI:update")

old_model = old_cache[1]

if Tables.istable(X)
    X = Tables.matrix(X)|>permutedims
end

# Reshape y if necessary
y = reshape(y, 1, :)


println(MMI.is_same_except(m, old_model, :epochs))



cache=()
report=()
return old_fitresult, cache, report
end



the issue is that if i try to rerun the model by changing only the number of epochs is_same_except still gives me

false

even though :epochs is listed as exception

using MLJ
flux_model = Chain(
    Dense(4, 10, relu),
    Dense(10, 10, relu),
    Dense(10, 1)
)
model = LaplaceRegressor(model=flux_model)

X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)
mach = machine(model, X, y) 
MLJBase.fit!(mach)



model.epochs=2000

MLJBase.fit!(mach)

so what is the correct way to implement is_same_except? thank you

@ablaom
Copy link
Member

ablaom commented Sep 25, 2024

Not sure what the problem might be. Can you provide a MWE demonstrating that is_same_except is not working as you expect. I.e, some variation of this (which is working for me):

using MLJModelInterface
import MLJModelInterface as MMI

mutable struct Classifier <: Probabilistic
    x::Int
    y::Int
end

model = Classifier(1, 2)
model2 = deepcopy(model)
model2.y = 7

@assert MMI.is_same_except(model, model2, :y)

Or, if you suspect some other problem, a more self-contained MWE would be helpful.

@pasq-cat
Copy link
Author

for example, using the this

flux_model = Chain(
    Dense(4, 10, relu),
    Dense(10, 10, relu),
    Dense(10, 1)
)
model = LaplaceRegressor(model=flux_model)
copy_model= deepcopy(model)
copy_model.epochs= 2000
MLJBase.is_same_except(model , copy_model, :epochs)

gives false but i have only changed epochs

for a simpler example that does not need LaplaceRedux consider this:

using MLJModelInterface
import MLJModelInterface as MMI

mutable struct LaplaceRegr <: MLJBase.Probabilistic
    model::Flux.Chain
    epochs::Integer 
    batch_size::Integer
    
end

model = LaplaceRegr(flux_model,1000,2)

model2= deepcopy(model)

model2.epochs = 2000

MMI.is_same_except(model, model2, :epochs)

it's due to the fact that one of the field has a flux chain in it. If i remove it i get true.

@ablaom
Copy link
Member

ablaom commented Sep 25, 2024

Thanks, this helps me see the problem:

julia> c = Flux.Chain(Dense(2,3))
julia> c == deepcopy(c)
false

Unfortunately, MLJ was not designed with this kind of behaviour in mind, for hyperparameter values. This has occurred once before and a hack was introduced, the trait MMI.deep_properties. However, this does not fix your issue, because we need equality all the way down, and the hack only goes down one level (read the doc-string if you are interested). We could try to fix the hack, but it's such a corner case and technically breaking, so I'm not a big fan. (In fact, for orthogonal reasons, the hack is no longer used anyway.)

Another possible resolution is for you to explicitly add an overloading is_same_except(model2::LaplaceRegr, model2::LaplaceRegr, ...) = ... to get the actual behaviour you want (this will automatically carry over to == because the latter is defined in terms of the former). Feel free to lift most of the code you need from here. Or, you could just manually check in update whether the only hyperparameter that has "truly changed" is :epochs. We used to do that before is_the_same_except was added.

In any case, make sure neither fit or update actually mutates any hyper-parameter value as this is definitely disallowed (except for RNG's). So if you use :model to build the learned parameters (aka fitresult), be sure to first create a deep copy first. (This is an instance of a wider limitation of Flux, due to it's conflation of hyperparameters and learned parameters, which is fixed in Lux.jl.)

@pasq-cat
Copy link
Author

couldn't something like this replace the default es_same_except function?

# Define the function is_same_except
function is_same_except(m1::M1, m2::M2, exceptions::Symbol...) where {M1<:MLJType, M2<:MLJType}
    typeof(m1) === typeof(m2) || return false
    names = propertynames(m1)
    propertynames(m2) === names || return false

    for name in names
        if !(name in exceptions)
            if !_isdefined(m1, name)
               !_isdefined(m2, name) || return false
            elseif _isdefined(m2, name)
                if name in deep_properties(M1)
                    _equal_to_depth_one(
                        getproperty(m1,name),
                        getproperty(m2, name)
                    ) || return false
                else
                    (
                        is_same_except(
                            getproperty(m1, name),
                            getproperty(m2, name)
                        ) ||
                        getproperty(m1, name) isa AbstractRNG ||
                        getproperty(m2, name) isa AbstractRNG ||
                        (getproperty(m1, name) isa Flux.Chain && getproperty(m2, name) isa Flux.Chain && _equal_flux_chain(getproperty(m1, name), getproperty(m2, name)))
                    ) || return false
                end
            else
                return false
            end
        end
    end
    return true
end

with an helper function

function _equal_flux_chain(chain1::Flux.Chain, chain2::Flux.Chain)
    if length(chain1.layers) != length(chain2.layers)
        return false
    end
    for (layer1, layer2) in zip(chain1.layers, chain2.layers)
        if typeof(layer1) != typeof(layer2)
            return false
        end
        params1 = Flux.params(layer1)
        params2 = Flux.params(layer2)
        if length(params1) != length(params2)
            return false
        end
        for (p1, p2) in zip(params1, params2)
            if !isequal(p1, p2)
                return false
            end
        end
    end
    return true
end

it should work for every MLJ model that wrap a flux model.
if not, i will use it only for LaplaceRegressor

@ablaom
Copy link
Member

ablaom commented Sep 26, 2024

Great progress.

I think your test for equality of Chains is not correct, for it will not behave as expected for nested chains, like Chain(Chain(...), ...) I. Rather, just apply Flux.params directly to the whole of chain1 and chain2, i.e., there's no need to deconstruct. You may want to add a test, as this kind of logic is a bug-magnet.

I suggest you just overload locally and we not add complexity to MLJModelInterface for this one corner case. There is probably a more generic way to handle this, maybe by fixing deep_properties, but I don't have time to look at it just now.

@pasq-cat
Copy link
Author

You may want to add a test, as this kind of logic is a bug-magnet.

indeed, i just found out that the models don't pass the test if optimiser= Adam() is included in the struct. How should I handle this case? should i always add it to the exceptions?

@ablaom
Copy link
Member

ablaom commented Sep 29, 2024

Can you please provide some more detail. I don't see any problem at my end:

julia> using Optimisers, Flux

julia> import MLJModelInterface as MMI

julia> model = NeuralNetworkClassifier();

julia> model2 = deepcopy(model);

julia> MMI.is_same_except(model, model2)
true

julia> model2.optimiser = Adam(42)
Adam(42.0, (0.9, 0.999), 1.0e-8)

julia> MMI.is_same_except(model, model2)
false

julia> model.optimiser = Adam(42)
Adam(42.0, (0.9, 0.999), 1.0e-8)

julia> MMI.is_same_except(model, model2)
true

Are you perhaps using Flux.jl optimisers instead of Optimisers.jl optimisers?

@pasq-cat
Copy link
Author

pasq-cat commented Sep 30, 2024

Are you perhaps using Flux.jl optimisers instead of Optimisers.jl optimisers?

yes i think this is the issue because

using MLJModelInterface
import MLJModelInterface as MMI

mutable struct LaplaceRegr <: MLJBase.Probabilistic
    model::Flux.Chain
    epochs::Integer 
    batch_size::Integer
    optimiser
    
end

model = LaplaceRegr(flux_model,1000,2,Flux.Adam())



model2= deepcopy(model)

model2.epochs = 2000

MMI.is_same_except(model, model2)

gives me false.

Looks like the tutorial i have read to write the training loop is outdated and now Flux prefer optimisers from the Optimisers.jl package but the documentation available online is a confusing mix of old and new rules...

@ablaom
Copy link
Member

ablaom commented Sep 30, 2024

Looks like the tutorial i have read to write the training loop is outdated and now Flux prefer optimisers from the Optimisers.jl package but the documentation available online is a confusing mix of old and new rules...

Well MLJFlux now definitely requires only Optimiser.jl optimisers. If any of the MLJ/MLJFlux docs are out-of-date in this respect, please point them out.

@pasq-cat
Copy link
Author

pasq-cat commented Oct 1, 2024

Looks like the tutorial i have read to write the training loop is outdated and now Flux prefer optimisers from the Optimisers.jl package but the documentation available online is a confusing mix of old and new rules...

Well MLJFlux now definitely requires only Optimiser.jl optimisers. If any of the MLJ/MLJFlux docs are out-of-date in this respect, please point them out.

ah but it was not the official documentation, it was i think a medium page or something like that. anyway i think i have fixed the update loop. if you don't mind i would like to keep this issue open for a bit longer, just in case i encounter another problem. in the opposite case i will close it myself. ok? thank you.

@ablaom
Copy link
Member

ablaom commented Oct 1, 2024

Happy to support your work on an MLJ interface, and thanks for your persistence.

@pasq-cat
Copy link
Author

pasq-cat commented Oct 4, 2024

@ablaom hi Anthony, i think it's done now https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl/blob/direct_mlj_interface/src/direct_mlj.jl

in addition to the mandatory fit and predict methods i have also implemented the training_loss , the fitted_params and the reformat functions. However, regarding this last one there is still a minor inefficiency that i was unable to solve.

the reformat functions that I have implemented are

# for fit:
MMI.reformat(::Laplace_Models, X, y) = (MLJBase.matrix(X) |> permutedims, reshape(y, 1, :))
#for predict:
MMI.reformat(::Laplace_Models, X) = (MLJBase.matrix(X) |> permutedims,)

they simply transform the input X data in a matrix and permute the dims and reshape y.
y is the problem, because in the case of classification i have to add this part here in the fit and in the update functions

    if typeof(m) == LaplaceRegressor
        nothing
    else
        # Convert labels to integer format starting from 0 for one-hot encoding
        y_plain = MLJBase.int(y[1, :]) .- 1

        # One-hot encoding of labels
        unique_labels = unique(y_plain) # Ensure unique labels for one-hot encoding
        y = Flux.onehotbatch(y_plain, unique_labels) # One-hot encoding
    end

which is kind of ugly and inefficient. It seems that I cannot move this part in the reformat function (even if i specialize it for LaplaceClassifier), because if I do it, I lose access to the labels that the predict method needs.

Is there a better way or this is how it has to be done? thank you

@ablaom
Copy link
Member

ablaom commented Oct 8, 2024

An option is to put the labels into the output data of reformat. However, as selectrows will not work on such objects (needed for resampling operations, like CV), you will need to overload selectrows(::Laplace_Models, I, data) to say how to resample these objects. There's an example in the docs.

Does this address your query?

@pasq-cat
Copy link
Author

@ablaom do you mean return the labels as a third argument in format? if that's what you meant then i guess i will leave it as it is now because i would have to change also fit and predict together with the selectrows function.

@ablaom
Copy link
Member

ablaom commented Oct 14, 2024

Yes,reformat would return an object that includes the labels, and then you would need to overload selectrows for that object, and also have fit/update include the labels in the fitresult to make them available to predict.

I'm indifferent as to how you proceed. If you'd like a new review of the interface, please (re-)open the issue at MLJModels.jl and ping me, thanks.

@pasq-cat
Copy link
Author

@ablaom ok, thank you. i will ask first patrick to give a first check to the code i have written, then i will open the issue.

@pasq-cat
Copy link
Author

pasq-cat commented Oct 16, 2024

@ablaom hi anthony, sorry for asking again, we have almost completed the interface but we are facing one last issue.
is it possible to overload evaluate! ?
i tried to run evaluate!(mach, resampling=cv, measure=LogLoss, verbosity=0) for the LaplaceRegressor but I got the error

  Got exception outside of a @test
  MethodError: no method matching LogLoss(::Matrix{Normal{Float64}}, ::Vector{Float64})

  Closest candidates are:
    LogLoss(::Any)
     @ StatisticalMeasures C:\Users\Pasqu\.julia\packages\StatisticalMeasures\UTtxb\src\probabilistic.jl:220

seems to me that the LogLoss is the appropriate measure according to https://juliaai.github.io/StatisticalMeasures.jl/dev/examples_of_usage/#Probabilistic-regression but evaluate! gives errors

Should I overload evaluate! too?

@ablaom
Copy link
Member

ablaom commented Oct 17, 2024

I think StatisticaMeasure's LogLoss() (and some others, e.g, BrierLoss()) should work. But it looks like you are comparing a matrix of normal distributions with a vector of ground truth. Is your predict not outputting vectors of distibutions?

In any case, I probably need more context to be of any help.

One does not overload evaluate or evaluate!. You can defined new resampling strategies.

@pasq-cat
Copy link
Author

mmm so i guess the problem is in the format of the output.
i am updating the predict function so that it returns vector of distributions.
right now it's like this

function MMI.predict(m::LaplaceModels, fitresult, Xnew)
    la, decode = fitresult
    if typeof(m) == LaplaceRegressor
        yhat = LaplaceRedux.predict(la, Xnew; ret_distr=false)
        # Extract mean and variance matrices
        means, variances = yhat

        # Create Normal distributions from the means and variances
        return [Normal(μ, sqrt(σ)) for (μ, σ) in zip(means, variances)]

    else
        predictions =
            LaplaceRedux.predict(la, Xnew; link_approx=m.link_approx, ret_distr=false) |>
            permutedims

        return MLJBase.UnivariateFinite(decode, predictions; pool=missing)
    end
end

but i still got this error message

Regression: Error During Test at C:\Users\Pasqu\Documents\julia_projects\LaplaceRedux.jl\test\direct_mlj_interface.jl:12
Got exception outside of a @test
MethodError: no method matching pdf(::Vector{Normal{Float64}}, ::Float64)

Closest candidates are:
pdf(::DiscreteUniform, ::Real)
@ Distributions C:\Users\Pasqu.julia\packages\Distributions\uuqsE\src\univariate\discrete\discreteuniform.jl:73

@ablaom
Copy link
Member

ablaom commented Oct 17, 2024

And what is at LaplaceRedux.jl\test\direct_mlj_interface.jl:12?

@ablaom
Copy link
Member

ablaom commented Oct 17, 2024

I can't see anything wrong with LogLoss:

using MLJ

model = ConstantRegressor()
data = make_regression()
evaluate(model, data...; measure = LogLoss())

# Evaluating over 6 folds: 100%[=========================] Time: 0:00:02
# PerformanceEvaluation object with these fields:
#   model, measure, operation,
#   measurement, per_fold, per_observation,
#   fitted_params_per_fold, report_per_fold,
#   train_test_rows, resampling, repeats
# Extract:
# ┌──────────────────────┬───────────┬─────────────┐
# │ measure              │ operation │ measurement │
# ├──────────────────────┼───────────┼─────────────┤
# │ LogLoss(             │ predict   │ 2.02        │
# │   tol = 2.22045e-16) │           │             │
# └──────────────────────┴───────────┴─────────────┘
# ┌─────────────────────────────────────┬─────────┐
# │ per_fold                            │ 1.96*SE │
# ├─────────────────────────────────────┼─────────┤
# │ [2.27, 2.23, 1.82, 1.89, 1.81, 2.1] │ 0.182   │
# └─────────────────────────────────────┴─────────┘

X, y = data
mach = machine(model, X, y) |> fit!
yhat = predict(mach, X);
yhat[1:3]
# 3-element Vector{Distributions.Normal{Float64}}:
#  Distributions.Normal{Float64}(μ=-0.8146301697352831, σ=1.759433773949422)
#  Distributions.Normal{Float64}(μ=-0.8146301697352831, σ=1.759433773949422)
#  Distributions.Normal{Float64}(μ=-0.8146301697352831, σ=1.759433773949422)

LogLoss()(yhat, y)

#1.9839305711450421

@pasq-cat
Copy link
Author

pasq-cat commented Oct 18, 2024

@ablaom

it's just a bunch of tests to placate the authoritarian codebot that patrick has unleashed in every pull request . everything works except the last line.
update: there were some errors

@testset "Regression" begin
    flux_model = Chain(
        Dense(4, 10, relu),
        Dense(10, 10, relu),
        Dense(10, 1)
    )
    model = LaplaceRegressor(model=flux_model,epochs=50)
    
    X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)
    #train, test = partition(eachindex(y), 0.7); # 70:30 split
    mach = machine(model, X, y) #|> MLJBase.fit! #|> (fitresult,cache,report)
    MLJBase.fit!(mach, verbosity=1)
    #Xnew, ynew = make_regression(3, 4; rng=123)
    yhat = MLJBase.predict(mach, X) # probabilistic predictions
    MLJBase.predict_mode(mach, X)   # point predictions
    MLJBase.fitted_params(mach)   #fitted params function 
    MLJBase.training_losses(mach) #training loss history
    model.epochs= 100 #changing number of epochs
    MLJBase.fit!(mach) #testing update function
    model.epochs= 50 #changing number of epochs to a lower number
    MLJBase.fit!(mach) #testing update function
    model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
    MLJBase.fit!(mach) #testing update function (the laplace part)
    yhat = MLJBase.predict(mach, X) # probabilistic predictions
    println( typeof(yhat) )
    println( size(yhat) )
    println( typeof(y) )
    println( size(y) )

    #evaluate!(mach, resampling=cv, measure=log_loss, verbosity=0)
    LogLoss()(yhat, y)

if i run this i get

Matrix{Normal{Float64}}
(1, 100)
Vector{Float64}
(100,)
Regression: Error During Test at C:\Users\Pasqu\Documents\julia_projects\LaplaceRedux.jl\test\direct_mlj_interface.jl:12
Got exception outside of a @test
MethodError: no method matching pdf(::Vector{Normal{Float64}}, ::Float64)

@pasq-cat
Copy link
Author

ok solved. i removed a vec and i forgot about it. sorry for the mess i made.

@pasq-cat
Copy link
Author

@ablaom hi, i was reading the The document string standard
and I realized that there is no mention of reformat or selectrows among the functions that needs to be reported.
Should i add a description of these two functions in the "Operations" sections since i have modified them? I doubt the user would be interested in them.

@ablaom
Copy link
Member

ablaom commented Oct 21, 2024

no, not public api

@pasq-cat
Copy link
Author

pasq-cat commented Nov 7, 2024

hi, sorry for the delay, i had an exam. I think (/i hope) the interface is ready for a review, patrick also gave a look. i will reopen the issue on the official MLJ page. @ablaom

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants