-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix #41, add tests around evaluate/tuning in MLJ #42
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fix has been tested. It works like a charm!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me.
Some comments about how this slipped through the net: Theoretically (but I haven't checked) this bug should have been caught by integration tests at MLJ; however, currently these are disabled for this model because of an earlier issue. That issue has been resolved, and re-instating the tests has been flagged but this was low on priority and so didn't happen.
I will check all this today and report back.
Of course feel free to just add the extra testing if you see no harm. |
Okay, I believe I am mistaken and the integration tests would not have caught this bug after all. I suggest we add the tests as proposed. I'll open an issue shortly to investigate why integration tests did not catch this one. |
Is this what you were thinking? I am forcing y_first to be a CategoricalValue in the fitresult (this was not the previous behavior in the fitresult). |
Yes, this appears to have addressed my concerns! import CatBoost
using MLJBase
Xtrain = MLJBase.table(rand(3, 2))
Xtest = MLJBase.table(rand(2, 2))
y = categorical(["yes", "yes", "yes", "no", "no"])
ytrain = y[1:3]
model = CatBoost.MLJCatBoostInterface.CatBoostClassifier()
mach = machine(model, Xtrain, ytrain) |> MLJBase.fit!
yhat = MLJBase.predict(mach, Xtest)
classes(yhat)
# 2-element CategoricalArrays.CategoricalArray{String,1,UInt32}:
# "no"
# "yes"
Thanks! |
can also confirm MLJIntegrationTests (level=3) pass as well. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to go.
It looks like the fix was just that we called
fit
instead ofMMI.fit
in theMMI.update
function.Ref: https://discourse.julialang.org/t/unable-to-find-fit-in-catboost-mljcatboostinterface/123818, #41