Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

machines should check input shape #1023

Open
adienes opened this issue May 7, 2023 · 4 comments
Open

machines should check input shape #1023

adienes opened this issue May 7, 2023 · 4 comments

Comments

@adienes
Copy link

adienes commented May 7, 2023

when I train a machine on a (columns) subset of a dataframe, I can predict on the full dataframe but it will take the wrong columns. this should either 1. take the correct columns or 2. throw an error.

It also lets you predict on fewer columns than was trained on, which should certainly error.

julia> using DataFrames, MLJ

julia> LinearRegressor = @load LinearRegressor pkg=MLJLinearModels;

julia> lin = LinearRegressor();

julia> df, y = DataFrame([:a, :b, :c, :d] .=> eachrow(rand(4, 10))), rand(10);

julia> mach = machine(lin, df[!, Not(:a)], y);

julia> fit!(mach);

julia> predict(mach, df)
10-element Vector{Float64}:
 0.08774429862475279
 0.4458262368370531
 0.32256136877587177
 0.5354542680166791
 0.434305050281646
 0.5517689881226184
 0.04561962641849866
 0.5249170653913241
 0.3048487015504525
 0.214881835774883

julia> predict(mach, df[:, Not(:a)])
10-element Vector{Float64}:
 0.5571764736289844
 0.5025268094746667
 0.48111746876821104
 0.4831649084669946
 0.4668772596220155
 0.5380930276566298
 0.5973458808921681
 0.6215717735791058
 0.8709687793078741
 0.6683648141076388

julia> predict(mach, rand(10, 1))
10-element Vector{Float64}:
 1.0692661960466232
 1.028867758819888
 1.0258819339237566
 1.0436911424814095
 1.0479030848086417
 1.0010825211829961
 1.0164762898143949
 1.0363124651504347
 1.0775702045518292
 1.035356208601051

Versions
0.19.1

@ablaom
Copy link
Member

ablaom commented May 7, 2023

While I agree the suggested behaviour would be an improvement, I'm not sure I'd characterise this as a bug. Also, it's actually quite difficult to implement because there is a tension with the desire for API genericity. Not all models consume tables, and some inputs don't even have a concept of columns (eg, MLJFlux.jl's ImageClassifier). So this means the responsibility for making these checks probably has to be left with individual model implementations of MLJModelInteface. We could add tooling to make this easier, and I would welcome that. But rolling it out to all models would be a substantial task.

But perhaps I'm missing a simpler solution?

@adienes
Copy link
Author

adienes commented May 7, 2023

I see. The first example, where predict(mach, df) takes the first 5 cols, while not ideal, I can see an argument for being an acceptable side-effect of genericity. However in my honest opinion, I think the third example where predict(mach, rand(10, 1)) accepts a single column despite being trained on 5, is more dangerous and must error for any model (unless I am lacking imagination for some scenarios where this would be desired)

As I mentioned in slack, I would like to help improve the combination of MLJ and DataFrames, as I think together they have potential for some really elegant workflows (even more so than status quo), so I would love to understand better what's happening here. I imagine at some point the input table is being consumed / transformed into a Matrix ? where does that reinterpretation happen?

@ablaom
Copy link
Member

ablaom commented May 7, 2023

As I mentioned in slack, I would like to help improve the combination of MLJ and DataFrames, as I think together they have potential for some really elegant workflows (even more so than status quo), so I would love to understand better what's happening here. I imagine at some point the input table is being consumed / transformed into a Matrix ? where does that reinterpretation happen?

Most models that consume tabular input will convert it to a matrix inside their MLJModelInteface.fit method, and then convert back to table inside MLJModelInterface.predict (or transform). But there is an option (preferred but not required) to implement the "data front end" separately, which is one way to avoid repeating this coercion unnecessarily in, say, iterative models controlled externally by wrapping in IterativeModel. The data front-end is described here. I'd say this data front-end looks much the same for a large class of models (see the EvoTrees.jl example cited), so that would be the thing to generalize, and where your "column name" checks would live.

@OkonSamuel
Copy link
Member

I see. The first example, where predict(mach, df) takes the first 5 cols, while not ideal, I can see an argument for being an acceptable side-effect of genericity. However in my honest opinion, I think the third example where predict(mach, rand(10, 1)) accepts a single column despite being trained on 5, is more dangerous and must error for any model (unless I am lacking imagination for some scenarios where this would be desired)

As I mentioned in slack, I would like to help improve the combination of MLJ and DataFrames, as I think together they have potential for some really elegant workflows (even more so than status quo), so I would love to understand better what's happening here. I imagine at some point the input table is being consumed / transformed into a Matrix ? where does that reinterpretation happen?

While I agree with you that predict(mach, rand(10, 1)) running is kinda weird, I'm in support of @ablaom earlier point to leaving things for the model implementer and providing a utility function that helps the model implementer check this easily .
Maybe something like 'check_tabular_X_y' utility function while could live in MLJBase.jl or as a utils file in MLJModelInterface.jl.
I have been thinking of doing this for some time now. But I have found the time to get to it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: priority low / involved
Development

No branches or pull requests

3 participants