Skip to content

Commit

Permalink
Merge pull request #39 from JuliaAI/acceleration-testing
Browse files Browse the repository at this point in the history
Add tests for CPUThreads and CPUProceses
  • Loading branch information
ablaom authored May 10, 2024
2 parents 77c2ad7 + 9e8684d commit 0483ff2
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 24 deletions.
30 changes: 8 additions & 22 deletions test/ensembles.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,3 @@
module TestEnsembles

using Test
using Random
using StableRNGs
using MLJEnsembles
using MLJBase
using ..Models
using CategoricalArrays
import Distributions
using StatisticalMeasures

## HELPER FUNCTIONS

@test MLJEnsembles._reducer([1, 2], [3, ]) == [1, 2, 3]
Expand Down Expand Up @@ -187,10 +175,10 @@ predict(ensemble_model, fitresult, MLJEnsembles.selectrows(X, test))

@testset "further test of sample weights" begin
## Note: This testset also indirectly tests for compatibility with the data-front end
# implemented by `KNNClassifier` as calls to `fit`/`predict` on an `Ensemble` model
# implemented by `KNNClassifier` as calls to `fit`/`predict` on an `Ensemble` model
# with `atom=KNNClassifier` would error if the ensemble implementation doesn't handle
# data front-end conversions properly.

rng = StableRNG(123)
N = 20
X = (x = rand(rng, 3N), );
Expand Down Expand Up @@ -224,18 +212,18 @@ predict(ensemble_model, fitresult, MLJEnsembles.selectrows(X, test))
end


## MACHINE TEST
## MACHINE TEST
## (INCLUDES TEST OF UPDATE.
## ALSO INCLUDES COMPATIBILITY TESTS FOR ENSEMBLES WITH ATOM MODELS HAVING A
## ALSO INCLUDES COMPATIBILITY TESTS FOR ENSEMBLES WITH ATOM MODELS HAVING A
## DIFFERENT DATA FRONT-END SEE #16)

@testset "machine tests" begin
@testset_accelerated "machine tests" acceleration begin
N =100
X = (x1=rand(N), x2=rand(N), x3=rand(N))
y = 2X.x1 - X.x2 + 0.05*rand(N)

atom = KNNRegressor(K=7)
ensemble_model = EnsembleModel(model=atom)
ensemble_model = EnsembleModel(; model=atom, acceleration)
ensemble = machine(ensemble_model, X, y)
train, test = partition(eachindex(y), 0.7)
fit!(ensemble, rows=train, verbosity=0)
Expand Down Expand Up @@ -264,15 +252,13 @@ end
atom;
bagging_fraction=0.6,
rng=123,
out_of_bag_measure = [log_loss, brier_score]
out_of_bag_measure = [log_loss, brier_score],
acceleration,
)
ensemble = machine(ensemble_model, X_, y_)
fit!(ensemble)
@test length(ensemble.fitresult.ensemble) == ensemble_model.n

end


end

true
26 changes: 24 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,26 @@
include("_models.jl")
using Distributed
# Thanks to https://stackoverflow.com/a/70895939/5056635 for the exeflags tip.
addprocs(; exeflags="--project=$(Base.active_project())")

@info "nprocs() = $(nprocs())"
import .Threads
@info "nthreads() = $(Threads.nthreads())"

include("test_utilities.jl")
include_everywhere("_models.jl")

@everywhere begin
using Test
using Random
using StableRNGs
using MLJEnsembles
using MLJBase
using ..Models
using CategoricalArrays
import Distributions
using StatisticalMeasures
import Distributed
end

include("ensembles.jl")
include("serialization.jl")

50 changes: 50 additions & 0 deletions test/test_utilities.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
using Test

using ComputationalResources

macro testset_accelerated(name::String, var, ex)
testset_accelerated(name, var, ex)
end
macro testset_accelerated(name::String, var, opts::Expr, ex)
testset_accelerated(name, var, ex; eval(opts)...)
end
function testset_accelerated(name::String, var, ex; exclude=[])
final_ex = quote
local $var = CPU1()
@testset $name $ex
end

resources = AbstractResource[CPUProcesses(), CPUThreads()]

for res in resources
if any(x->typeof(res)<:x, exclude)
push!(final_ex.args, quote
local $var = $res
@testset $(name*" ($(typeof(res).name))") begin
@test_broken false
end
end)
else
push!(final_ex.args, quote
local $var = $res
@testset $(name*" ($(typeof(res).name))") $ex
end)
end
end
# preserve outer location if possible
if ex isa Expr && ex.head === :block && !isempty(ex.args) &&
ex.args[1] isa LineNumberNode
final_ex = Expr(:block, ex.args[1], final_ex)
end
return esc(final_ex)
end

function include_everywhere(filepath)
include(filepath) # Load on Node 1 first, triggering any precompile
if nprocs() > 1
fullpath = joinpath(@__DIR__, filepath)
@sync for p in workers()
@async remotecall_wait(include, p, fullpath)
end
end
end

0 comments on commit 0483ff2

Please sign in to comment.