Skip to content

Commit

Permalink
Merge pull request #18 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.2.4 release
  • Loading branch information
ablaom authored Jan 5, 2024
2 parents 3bc4425 + 045363d commit 65facf9
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 17 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJTestInterface"
uuid = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.2.3"
version = "0.2.4"

[deps]
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Expand All @@ -10,4 +10,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
MLJBase = "0.20, 0.21, 1"
Test = "<0.0.1, 1"
julia = "1.6"
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ failures, summary = MLJTestInterface.test(
The following commands generate small datasets of the form `(X, y)` suitable for interface
tests:

- `MLJTestInterface.make_binary`
- `MLJTestInterface.make_binary(; row_table=false)`

- `MLJTestInterface.make_multiclass`
- `MLJTestInterface.make_multiclass(; row_table=false)` `

- `MLJTestInterface.make_regression`
- `MLJTestInterface.make_regression(; row_table=false)` `

- `MLJTestInterface.make_count`
- `MLJTestInterface.make_count(; row_table=false)` `

2 changes: 1 addition & 1 deletion src/attemptors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ function operations(fitted_machine, data...; throw=false, verbosity=1)
methods = MLJBase.implemented_methods(fitted_machine.model)
if model isa Static && !(:transform in methods)
push!(methods, :transform)
end
end
_, test = MLJBase.partition(1:MLJBase.nrows(first(data)), 0.5)
if :predict in methods
predict(fitted_machine, first(data))
Expand Down
34 changes: 23 additions & 11 deletions src/datasets.jl
Original file line number Diff line number Diff line change
@@ -1,47 +1,59 @@
"""
make_binary()
make_binary(; row_table=false)
Return data `(X, y)` for the crabs dataset, restricted to the two features `:FL`,
`:RW`. Target is `Multiclass{2}`.
The table `X` is a named tuple of vectors. For a vector of named tuples, set
`row_table=true`.
"""
function make_binary()
function make_binary(; row_table=false)
data = MLJBase.load_crabs()
y_, X = unpack(data, ==(:sp), col->col in [:FL, :RW])
y = coerce(y_, MLJBase.OrderedFactor)
return X, y
row_table ? (MLJBase.Tables.rowtable(X), y) : (X, y)
end

"""
make_multiclass()
make_multiclass(; row_table=false)
Return data `(X, y)` for the unshuffled iris dataset. Target is `Multiclass{3}`.
"""
make_multiclass() = MLJBase.@load_iris
function make_multiclass(; row_table=false)
X, y = MLJBase.@load_iris
row_table ? (MLJBase.Tables.rowtable(X), y) : (X, y)
end

"""
make_regression()
make_regression(; row_table=false)
Return data `(X, y)` for the Boston dataset, restricted to the two features `:LStat`,
`:Rm`. Target is `Continuous`.
The table `X` is a named tuple of vectors. For a vector of named tuples, set
`row_table=true`.
"""
function make_regression()
function make_regression(; row_table=false)
data = MLJBase.load_boston()
y, X = unpack(data, ==(:MedV), col->col in [:LStat, :Rm])
return X, y
row_table ? (MLJBase.Tables.rowtable(X), y) : (X, y)
end

"""
make_count()
make_count(; row_table=false)
Return data `(X, y)` for the Boston dataset, restricted to the two features `:LStat`,
`:Rm`, with the `Continuous` target converted to `Count` (integer).
The table `X` is a named tuple of vectors. For a vector of named tuples, set
`row_table=true`.
"""
function make_count()
function make_count(; row_table=false)
X, y_ = make_regression()
y = map-> round(Int, η), y_)
return X, y
row_table ? (MLJBase.Tables.rowtable(X), y) : (X, y)
end
39 changes: 39 additions & 0 deletions test/datasets.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
@testset "loading of datasets" begin
X, y = MTI.make_binary()
@test X isa NamedTuple
@test first(X) isa AbstractVector{Float64}
@test MLJBase.scitype(y) == AbstractVector{MLJBase.OrderedFactor{2}}
Xr, yr = MTI.make_binary(row_table=true)
@test Xr isa AbstractVector
@test MLJBase.Tables.rowtable(X) == Xr
@test yr == y

X, y = MTI.make_multiclass()
@test X isa NamedTuple
@test first(X) isa AbstractVector{Float64}
@test MLJBase.scitype(y) == AbstractVector{MLJBase.Multiclass{3}}
Xr, yr = MTI.make_multiclass(row_table=true)
@test Xr isa AbstractVector
@test MLJBase.Tables.rowtable(X) == Xr
@test yr == y

X, y = MTI.make_regression()
@test X isa NamedTuple
@test first(X) isa AbstractVector{Float64}
@test MLJBase.scitype(y) == AbstractVector{MLJBase.Continuous}
Xr, yr = MTI.make_regression(row_table=true)
@test Xr isa AbstractVector
@test MLJBase.Tables.rowtable(X) == Xr
@test yr == y

X, y = MTI.make_count()
@test X isa NamedTuple
@test first(X) isa AbstractVector{Float64}
@test MLJBase.scitype(y) == AbstractVector{MLJBase.Count}
Xr, yr = MTI.make_count(row_table=true)
@test Xr isa AbstractVector
@test MLJBase.Tables.rowtable(X) == Xr
@test yr == y
end

true
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ macro conditional_testset(name, expr)
end)
end

@conditional_testset "datasets" begin
include("datasets.jl")
end

@conditional_testset "attemptors" begin
include("attemptors.jl")
end
Expand Down

0 comments on commit 65facf9

Please sign in to comment.