From 78faebf5d715a973edd2bc3fe0b4f007c8df3c99 Mon Sep 17 00:00:00 2001 From: Ian Butterworth Date: Thu, 14 Dec 2023 12:37:38 -0500 Subject: [PATCH 1/8] set up training benchmarking script --- benchmarking/Project.toml | 21 ++++++++ benchmarking/benchmark.jl | 71 ++++++++++++++++++++++++++ benchmarking/tooling.jl | 102 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 194 insertions(+) create mode 100644 benchmarking/Project.toml create mode 100644 benchmarking/benchmark.jl create mode 100644 benchmarking/tooling.jl diff --git a/benchmarking/Project.toml b/benchmarking/Project.toml new file mode 100644 index 00000000..4bbcd793 --- /dev/null +++ b/benchmarking/Project.toml @@ -0,0 +1,21 @@ +[deps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" +TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[compat] +CUDA = "5" +Flux = "0.14" +MLDatasets = "0.7" +Metalhead = "0.9" +Optimisers = "0.3" +ProgressMeter = "1.9" +TimerOutputs = "0.5" +UnicodePlots = "3.6" +cuDNN = "1.2" diff --git a/benchmarking/benchmark.jl b/benchmarking/benchmark.jl new file mode 100644 index 00000000..9c48248f --- /dev/null +++ b/benchmarking/benchmark.jl @@ -0,0 +1,71 @@ + +using CUDA, cuDNN +using Flux +using Flux: logitcrossentropy, onecold, onehotbatch +using Metalhead +using MLDatasets +using Optimisers +using ProgressMeter +using TimerOutputs +using UnicodePlots + +include("tooling.jl") + +epochs = 45 +batchsize = 1000 +device = gpu +allow_skips = true + +train_loader, test_loader, labels = load_cifar10(; batchsize) +nlabels = length(labels) +firstbatch = first(first(train_loader)) +imsize = size(firstbatch)[1:2] + +to = TimerOutput() + +# these should all be the smallest variant of each that is tested in `/test` +modelstrings = ( + "AlexNet()", + "VGG(11, batchnorm=true)", + "SqueezeNet()", + "ResNet(18)", + "WideResNet(50)", + "ResNeXt(50, cardinality=32, base_width=4)", + "SEResNet(18)", + "SEResNeXt(50, cardinality=32, base_width=4)", + "Res2Net(50, base_width=26, scale=4)", + "Res2NeXt(50)", + "GoogLeNet(batchnorm=true)", + "DenseNet(121)", + "Inceptionv3()", + "Inceptionv4()", + "InceptionResNetv2()", + "Xception()", + "MobileNetv1(0.5)", + "MobileNetv2(0.5)", + "MobileNetv3(:small, 0.5)", + "MNASNet(MNASNet, 0.5)", + "EfficientNet(:b0)", + "EfficientNetv2(:small)", + "ConvMixer(:small)", + "ConvNeXt(:small)", + # "MLPMixer()", # found no tests + # "ResMLP()", # found no tests + # "gMLP()", # found no tests + "ViT(:tiny)", + "UNet()" + ) + +for (i, modstring) in enumerate(modelstrings) + @timeit to "$modstring" begin + @info "Evaluating $i/$(length(modelstrings)) $modstring" + @timeit to "First Load" eval(Meta.parse(modstring)) + @timeit to "Second Load" model=eval(Meta.parse(modstring)) + @timeit to "Training" train(model, + train_loader, + test_loader; + to, + device)||(allow_skips || break) + end +end +print_timer(to; sortby = :firstexec) diff --git a/benchmarking/tooling.jl b/benchmarking/tooling.jl new file mode 100644 index 00000000..72f0e180 --- /dev/null +++ b/benchmarking/tooling.jl @@ -0,0 +1,102 @@ +function loss_and_accuracy(data_loader, model, device; limit = nothing) + acc = 0 + ls = 0.0f0 + num = 0 + i = 0 + for (x, y) in data_loader + x, y = x |> device, y |> device + ŷ = model(x) + ls += logitcrossentropy(ŷ, y, agg=sum) + acc += sum(onecold(ŷ) .== onecold(y)) + num += size(x)[end] + if limit !== nothing + i == limit && break + i += 1 + end + end + return ls / num, acc / num +end + +function load_cifar10(; batchsize=1000) + @info "loading CIFAR-10 dataset" + train_dataset, test_dataset = CIFAR10(split=:train), CIFAR10(split=:test) + train_x, train_y = train_dataset[:] + test_x, test_y = test_dataset[:] + @assert train_dataset.metadata["class_names"] == test_dataset.metadata["class_names"] + labels = train_dataset.metadata["class_names"] + + # CIFAR10 label indices seem to be zero-indexed + train_y .+= 1 + test_y .+= 1 + + train_y_ohb = Flux.onehotbatch(train_y, eachindex(labels)) + test_y_ohb = Flux.onehotbatch(test_y, eachindex(labels)) + + train_loader = Flux.DataLoader((data=train_x, labels=train_y_ohb); batchsize, shuffle=true) + test_loader = Flux.DataLoader((data=test_x, labels=test_y_ohb); batchsize) + + return train_loader, test_loader, labels +end + +function _train(model, train_loader, test_loader; epochs = 45, device = gpu, limit=nothing, gpu_gc=true, gpu_stats=false, show_plots=false, to=TimerOutput()) + + model = model |> device + + opt = Optimisers.Adam() + state = Optimisers.setup(opt, model) + + train_loss_hist, train_acc_hist = Float64[], Float64[] + test_loss_hist, test_acc_hist = Float64[], Float64[] + + @info "starting training" + for epoch in 1:epochs + i = 0 + @showprogress "training epoch $epoch/$epochs" for (x, y) in train_loader + x, y = x |> device, y |> device + @timeit to "batch step" begin + gs, _ = gradient(model, x) do m, _x + logitcrossentropy(m(_x), y) + end + state, model = Optimisers.update(state, model, gs) + end + + device === gpu && gpu_stats && CUDA.memory_status() + if limit !== nothing + i == limit && break + i += 1 + end + end + + @info "epoch $epoch complete. Testing..." + train_loss, train_acc = loss_and_accuracy(train_loader, model, device; limit) + @timeit to "testing" test_loss, test_acc = loss_and_accuracy(test_loader, model, device; limit) + @info map(x->round(x, digits=3), (; train_loss, train_acc, test_loss, test_acc)) + + if show_plots + push!(train_loss_hist, train_loss); push!(train_acc_hist, train_acc); + push!(test_loss_hist, test_loss); push!(test_acc_hist, test_acc); + plt = lineplot(1:epoch, train_loss_hist, name = "train_loss", xlabel="epoch", ylabel="loss") + lineplot!(plt, 1:epoch, test_loss_hist, name = "test_loss") + display(plt) + plt = lineplot(1:epoch, train_acc_hist, name = "train_acc", xlabel="epoch", ylabel="acc") + lineplot!(plt, 1:epoch, test_acc_hist, name = "test_acc") + display(plt) + end + if device === gpu && gpu_gc + GC.gc() # GPU will OOM without this + end + end +end + +# because Flux stacktraces are ludicrously big on <1.10 so don't show them +function train(args...;kwargs...) + try + _train(args...; kwargs...) + catch ex + rethrow() + println() + @error sprint(showerror, ex) + GC.gc() + return false + end +end \ No newline at end of file From 7dac4fcce79b3e38ff87b6cd344c4f6b158cf616 Mon Sep 17 00:00:00 2001 From: Ian Butterworth Date: Thu, 14 Dec 2023 12:43:33 -0500 Subject: [PATCH 2/8] comments etc. --- benchmarking/benchmark.jl | 7 ++++--- benchmarking/tooling.jl | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/benchmarking/benchmark.jl b/benchmarking/benchmark.jl index 9c48248f..0b664216 100644 --- a/benchmarking/benchmark.jl +++ b/benchmarking/benchmark.jl @@ -49,9 +49,9 @@ modelstrings = ( "EfficientNetv2(:small)", "ConvMixer(:small)", "ConvNeXt(:small)", - # "MLPMixer()", # found no tests - # "ResMLP()", # found no tests - # "gMLP()", # found no tests + # "MLPMixer()", # no tests found + # "ResMLP()", # no tests found + # "gMLP()", # no tests found "ViT(:tiny)", "UNet()" ) @@ -60,6 +60,7 @@ for (i, modstring) in enumerate(modelstrings) @timeit to "$modstring" begin @info "Evaluating $i/$(length(modelstrings)) $modstring" @timeit to "First Load" eval(Meta.parse(modstring)) + # second load simulates what might be possible with a proper set-up pkgimage workload @timeit to "Second Load" model=eval(Meta.parse(modstring)) @timeit to "Training" train(model, train_loader, diff --git a/benchmarking/tooling.jl b/benchmarking/tooling.jl index 72f0e180..94b45b7e 100644 --- a/benchmarking/tooling.jl +++ b/benchmarking/tooling.jl @@ -93,10 +93,10 @@ function train(args...;kwargs...) try _train(args...; kwargs...) catch ex - rethrow() + # rethrow() println() @error sprint(showerror, ex) GC.gc() return false end -end \ No newline at end of file +end From 49c769f5fc85e6b39ca14b9e4c99b4521b5206a2 Mon Sep 17 00:00:00 2001 From: Ian Butterworth Date: Thu, 14 Dec 2023 14:59:54 -0500 Subject: [PATCH 3/8] fixes --- benchmarking/benchmark.jl | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/benchmarking/benchmark.jl b/benchmarking/benchmark.jl index 0b664216..d68b56c5 100644 --- a/benchmarking/benchmark.jl +++ b/benchmarking/benchmark.jl @@ -21,6 +21,8 @@ nlabels = length(labels) firstbatch = first(first(train_loader)) imsize = size(firstbatch)[1:2] +@info "Benchmarking" epochs batchsize device imsize + to = TimerOutput() # these should all be the smallest variant of each that is tested in `/test` @@ -43,8 +45,8 @@ modelstrings = ( "Xception()", "MobileNetv1(0.5)", "MobileNetv2(0.5)", - "MobileNetv3(:small, 0.5)", - "MNASNet(MNASNet, 0.5)", + "MobileNetv3(:small, width_mult=0.5)", + "MNASNet(:A1, width_mult=0.5)", "EfficientNet(:b0)", "EfficientNetv2(:small)", "ConvMixer(:small)", @@ -58,13 +60,15 @@ modelstrings = ( for (i, modstring) in enumerate(modelstrings) @timeit to "$modstring" begin - @info "Evaluating $i/$(length(modelstrings)) $modstring" - @timeit to "First Load" eval(Meta.parse(modstring)) + @info "Evaluating $i/$(length(modelstrings)): $modstring" + # Initial precompile is variable based on what came before, so don't time first load + eval(Meta.parse(modstring)) # second load simulates what might be possible with a proper set-up pkgimage workload - @timeit to "Second Load" model=eval(Meta.parse(modstring)) + @timeit to "Load" model=eval(Meta.parse(modstring)) @timeit to "Training" train(model, train_loader, test_loader; + limit = 1, to, device)||(allow_skips || break) end From 62511b0d91f6122e694086b9cb3e6a46c065bf5a Mon Sep 17 00:00:00 2001 From: Ian Butterworth Date: Thu, 21 Dec 2023 12:39:10 -0500 Subject: [PATCH 4/8] wip --- benchmarking/Project.toml | 2 ++ benchmarking/benchmark.jl | 73 ++++++++++++++++++++++----------------- benchmarking/tooling.jl | 19 +++++----- 3 files changed, 52 insertions(+), 42 deletions(-) diff --git a/benchmarking/Project.toml b/benchmarking/Project.toml index 4bbcd793..1036fa64 100644 --- a/benchmarking/Project.toml +++ b/benchmarking/Project.toml @@ -1,5 +1,6 @@ [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" @@ -11,6 +12,7 @@ cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [compat] CUDA = "5" +DataFrames = "1" Flux = "0.14" MLDatasets = "0.7" Metalhead = "0.9" diff --git a/benchmarking/benchmark.jl b/benchmarking/benchmark.jl index d68b56c5..e39763b2 100644 --- a/benchmarking/benchmark.jl +++ b/benchmarking/benchmark.jl @@ -1,5 +1,6 @@ using CUDA, cuDNN +using DataFrames using Flux using Flux: logitcrossentropy, onecold, onehotbatch using Metalhead @@ -14,6 +15,7 @@ include("tooling.jl") epochs = 45 batchsize = 1000 device = gpu +CUDA.allowscalar(false) allow_skips = true train_loader, test_loader, labels = load_cifar10(; batchsize) @@ -25,39 +27,41 @@ imsize = size(firstbatch)[1:2] to = TimerOutput() +common = "pretrain=false, inchannels=3, nclasses=$(length(labels))" + # these should all be the smallest variant of each that is tested in `/test` modelstrings = ( - "AlexNet()", - "VGG(11, batchnorm=true)", - "SqueezeNet()", - "ResNet(18)", - "WideResNet(50)", - "ResNeXt(50, cardinality=32, base_width=4)", - "SEResNet(18)", - "SEResNeXt(50, cardinality=32, base_width=4)", - "Res2Net(50, base_width=26, scale=4)", - "Res2NeXt(50)", - "GoogLeNet(batchnorm=true)", - "DenseNet(121)", - "Inceptionv3()", - "Inceptionv4()", - "InceptionResNetv2()", - "Xception()", - "MobileNetv1(0.5)", - "MobileNetv2(0.5)", - "MobileNetv3(:small, width_mult=0.5)", - "MNASNet(:A1, width_mult=0.5)", - "EfficientNet(:b0)", - "EfficientNetv2(:small)", - "ConvMixer(:small)", - "ConvNeXt(:small)", - # "MLPMixer()", # no tests found - # "ResMLP()", # no tests found - # "gMLP()", # no tests found - "ViT(:tiny)", - "UNet()" + "AlexNet(; $common)", + "VGG(11, batchnorm=true; $common)", + "SqueezeNet(; $common)", + "ResNet(18; $common)", + "WideResNet(50; $common)", + "ResNeXt(50, cardinality=32, base_width=4; $common)", + "SEResNet(18; $common)", + "SEResNeXt(50, cardinality=32, base_width=4; $common)", + "Res2Net(50, base_width=26, scale=4; $common)", + "Res2NeXt(50; $common)", + "GoogLeNet(batchnorm=true; $common)", + "DenseNet(121; $common)", + "Inceptionv3(; $common)", + "Inceptionv4(; $common)", + "InceptionResNetv2(; $common)", + "Xception(; $common)", + "MobileNetv1(0.5; $common)", + "MobileNetv2(0.5; $common)", + "MobileNetv3(:small, width_mult=0.5; $common)", + "MNASNet(:A1, width_mult=0.5; $common)", + "EfficientNet(:b0; $common)", + "EfficientNetv2(:small; $common)", + "ConvMixer(:small; $common)", + "ConvNeXt(:small; $common)", + # "MLPMixer(; $common)", # no tests found + # "ResMLP(; $common)", # no tests found + # "gMLP(; $common)", # no tests found + "ViT(:tiny; $common)", + "UNet(; $common)" ) - +df = DataFrame(; model=String[], train_loss=Float64[], train_acc=Float64[], test_loss=Float64[], test_acc=Float64[]) for (i, modstring) in enumerate(modelstrings) @timeit to "$modstring" begin @info "Evaluating $i/$(length(modelstrings)): $modstring" @@ -65,12 +69,17 @@ for (i, modstring) in enumerate(modelstrings) eval(Meta.parse(modstring)) # second load simulates what might be possible with a proper set-up pkgimage workload @timeit to "Load" model=eval(Meta.parse(modstring)) - @timeit to "Training" train(model, + @timeit to "Training" ret = train(model, train_loader, test_loader; limit = 1, to, - device)||(allow_skips || break) + device) + isnothing(ret) && !allow_skips ? break : continue + train_loss_hist, train_acc_hist, test_loss_hist, test_acc_hist = ret + push!(df, (modstring, train_loss_hist[end], train_acc_hist[end], test_loss_hist[end], test_acc_hist[end])) + end end +display(df) print_timer(to; sortby = :firstexec) diff --git a/benchmarking/tooling.jl b/benchmarking/tooling.jl index 94b45b7e..2ee7de35 100644 --- a/benchmarking/tooling.jl +++ b/benchmarking/tooling.jl @@ -48,10 +48,9 @@ function _train(model, train_loader, test_loader; epochs = 45, device = gpu, lim train_loss_hist, train_acc_hist = Float64[], Float64[] test_loss_hist, test_acc_hist = Float64[], Float64[] - @info "starting training" - for epoch in 1:epochs + @showprogress "training" for epoch in 1:epochs i = 0 - @showprogress "training epoch $epoch/$epochs" for (x, y) in train_loader + for (x, y) in train_loader x, y = x |> device, y |> device @timeit to "batch step" begin gs, _ = gradient(model, x) do m, _x @@ -67,14 +66,11 @@ function _train(model, train_loader, test_loader; epochs = 45, device = gpu, lim end end - @info "epoch $epoch complete. Testing..." train_loss, train_acc = loss_and_accuracy(train_loader, model, device; limit) @timeit to "testing" test_loss, test_acc = loss_and_accuracy(test_loader, model, device; limit) - @info map(x->round(x, digits=3), (; train_loss, train_acc, test_loss, test_acc)) - + push!(train_loss_hist, train_loss); push!(train_acc_hist, train_acc); + push!(test_loss_hist, test_loss); push!(test_acc_hist, test_acc); if show_plots - push!(train_loss_hist, train_loss); push!(train_acc_hist, train_acc); - push!(test_loss_hist, test_loss); push!(test_acc_hist, test_acc); plt = lineplot(1:epoch, train_loss_hist, name = "train_loss", xlabel="epoch", ylabel="loss") lineplot!(plt, 1:epoch, test_loss_hist, name = "test_loss") display(plt) @@ -86,6 +82,9 @@ function _train(model, train_loader, test_loader; epochs = 45, device = gpu, lim GC.gc() # GPU will OOM without this end end + train_loss, train_acc, test_loss, test_acc = train_loss_hist[end], train_acc_hist[end], test_loss_hist[end], test_acc_hist[end] + @info "results after $epochs epochs $(repr(map(x->round(x, digits=3), (; train_loss, train_acc, test_loss, test_acc))))" + return train_loss_hist, train_acc_hist, test_loss_hist, test_acc_hist end # because Flux stacktraces are ludicrously big on <1.10 so don't show them @@ -95,8 +94,8 @@ function train(args...;kwargs...) catch ex # rethrow() println() - @error sprint(showerror, ex) + println(sprint(showerror, ex)) GC.gc() - return false + return nothing end end From 3646ad8008a99c2cf4bc7dde289db5aa46365e19 Mon Sep 17 00:00:00 2001 From: Ian Butterworth Date: Thu, 21 Dec 2023 13:08:16 -0500 Subject: [PATCH 5/8] disable UNet --- benchmarking/benchmark.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarking/benchmark.jl b/benchmarking/benchmark.jl index e39763b2..c7b31dcc 100644 --- a/benchmarking/benchmark.jl +++ b/benchmarking/benchmark.jl @@ -59,7 +59,7 @@ modelstrings = ( # "ResMLP(; $common)", # no tests found # "gMLP(; $common)", # no tests found "ViT(:tiny; $common)", - "UNet(; $common)" + # "UNet(; $common)" # doesn't support kwargs "inchannels", "nclasses" ) df = DataFrame(; model=String[], train_loss=Float64[], train_acc=Float64[], test_loss=Float64[], test_acc=Float64[]) for (i, modstring) in enumerate(modelstrings) From 12faf419f61d3c16b761b105820ad0a9f91ead52 Mon Sep 17 00:00:00 2001 From: Ian Butterworth Date: Thu, 21 Dec 2023 13:13:41 -0500 Subject: [PATCH 6/8] table tweaks --- benchmarking/benchmark.jl | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/benchmarking/benchmark.jl b/benchmarking/benchmark.jl index c7b31dcc..c8464742 100644 --- a/benchmarking/benchmark.jl +++ b/benchmarking/benchmark.jl @@ -75,10 +75,14 @@ for (i, modstring) in enumerate(modelstrings) limit = 1, to, device) - isnothing(ret) && !allow_skips ? break : continue - train_loss_hist, train_acc_hist, test_loss_hist, test_acc_hist = ret + train_loss, train_acc, test_loss, test_acc = if isnothing(ret) + allow_skips || break + missing, missing, missing, missing + else + train_loss_hist, train_acc_hist, test_loss_hist, test_acc_hist = ret + train_loss_hist[end], train_acc_hist[end], test_loss_hist[end], test_acc_hist[end] + end push!(df, (modstring, train_loss_hist[end], train_acc_hist[end], test_loss_hist[end], test_acc_hist[end])) - end end display(df) From 5b6df6ac3ad67ea119b690753d90da639c7ec5b7 Mon Sep 17 00:00:00 2001 From: Ian Butterworth Date: Thu, 21 Dec 2023 14:23:09 -0500 Subject: [PATCH 7/8] tweaks --- benchmarking/benchmark.jl | 23 +++++++++++++++++------ benchmarking/tooling.jl | 23 +++++++++++++---------- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/benchmarking/benchmark.jl b/benchmarking/benchmark.jl index c8464742..f6ffae98 100644 --- a/benchmarking/benchmark.jl +++ b/benchmarking/benchmark.jl @@ -12,6 +12,8 @@ using UnicodePlots include("tooling.jl") +function run() + epochs = 45 batchsize = 1000 device = gpu @@ -62,6 +64,8 @@ modelstrings = ( # "UNet(; $common)" # doesn't support kwargs "inchannels", "nclasses" ) df = DataFrame(; model=String[], train_loss=Float64[], train_acc=Float64[], test_loss=Float64[], test_acc=Float64[]) +plt = lineplot([0], [0], title="test accuracy vs. time (s)", ylim = (0,1), xlim=(0,60)) +max_x = 0 for (i, modstring) in enumerate(modelstrings) @timeit to "$modstring" begin @info "Evaluating $i/$(length(modelstrings)): $modstring" @@ -74,16 +78,23 @@ for (i, modstring) in enumerate(modelstrings) test_loader; limit = 1, to, - device) - train_loss, train_acc, test_loss, test_acc = if isnothing(ret) + device + ) + elapsed, train_loss, train_acc, test_loss, test_acc = if isnothing(ret) allow_skips || break - missing, missing, missing, missing + missing, missing, missing, missing, missing else - train_loss_hist, train_acc_hist, test_loss_hist, test_acc_hist = ret - train_loss_hist[end], train_acc_hist[end], test_loss_hist[end], test_acc_hist[end] + elapsed, train_loss_hist, train_acc_hist, test_loss_hist, test_acc_hist = ret + max_x = max(maximum(elapsed), max_x) + lineplot!(plt, elapsed, test_acc_hist, name=modstring) + elapsed, train_loss_hist[end], train_acc_hist[end], test_loss_hist[end], test_acc_hist[end] end - push!(df, (modstring, train_loss_hist[end], train_acc_hist[end], test_loss_hist[end], test_acc_hist[end])) + push!(df, (modstring, train_loss, train_acc, test_loss, test_acc), promote=true) end + GC.gc(true) end +display(plt) display(df) print_timer(to; sortby = :firstexec) +end +run() diff --git a/benchmarking/tooling.jl b/benchmarking/tooling.jl index 2ee7de35..5246171e 100644 --- a/benchmarking/tooling.jl +++ b/benchmarking/tooling.jl @@ -38,7 +38,8 @@ function load_cifar10(; batchsize=1000) return train_loader, test_loader, labels end -function _train(model, train_loader, test_loader; epochs = 45, device = gpu, limit=nothing, gpu_gc=true, gpu_stats=false, show_plots=false, to=TimerOutput()) +function _train(model, train_loader, test_loader; epochs = 45, + device = gpu, limit=nothing, gpu_gc=true, gpu_stats=false, show_plots=false, to=TimerOutput()) model = model |> device @@ -47,7 +48,8 @@ function _train(model, train_loader, test_loader; epochs = 45, device = gpu, lim train_loss_hist, train_acc_hist = Float64[], Float64[] test_loss_hist, test_acc_hist = Float64[], Float64[] - + elapsed = Float64[] + start = time() @showprogress "training" for epoch in 1:epochs i = 0 for (x, y) in train_loader @@ -56,7 +58,7 @@ function _train(model, train_loader, test_loader; epochs = 45, device = gpu, lim gs, _ = gradient(model, x) do m, _x logitcrossentropy(m(_x), y) end - state, model = Optimisers.update(state, model, gs) + state, model = Optimisers.update!(state, model, gs) end device === gpu && gpu_stats && CUDA.memory_status() @@ -70,13 +72,14 @@ function _train(model, train_loader, test_loader; epochs = 45, device = gpu, lim @timeit to "testing" test_loss, test_acc = loss_and_accuracy(test_loader, model, device; limit) push!(train_loss_hist, train_loss); push!(train_acc_hist, train_acc); push!(test_loss_hist, test_loss); push!(test_acc_hist, test_acc); + push!(elapsed, time() - start) if show_plots - plt = lineplot(1:epoch, train_loss_hist, name = "train_loss", xlabel="epoch", ylabel="loss") - lineplot!(plt, 1:epoch, test_loss_hist, name = "test_loss") - display(plt) - plt = lineplot(1:epoch, train_acc_hist, name = "train_acc", xlabel="epoch", ylabel="acc") - lineplot!(plt, 1:epoch, test_acc_hist, name = "test_acc") - display(plt) + plt2 = lineplot(1:epoch, train_loss_hist, name = "train_loss", xlabel="epoch", ylabel="loss") + lineplot!(plt2, 1:epoch, test_loss_hist, name = "test_loss") + display(plt2) + plt2 = lineplot(1:epoch, train_acc_hist, name = "train_acc", xlabel="epoch", ylabel="acc") + lineplot!(plt2, 1:epoch, test_acc_hist, name = "test_acc") + display(plt2) end if device === gpu && gpu_gc GC.gc() # GPU will OOM without this @@ -84,7 +87,7 @@ function _train(model, train_loader, test_loader; epochs = 45, device = gpu, lim end train_loss, train_acc, test_loss, test_acc = train_loss_hist[end], train_acc_hist[end], test_loss_hist[end], test_acc_hist[end] @info "results after $epochs epochs $(repr(map(x->round(x, digits=3), (; train_loss, train_acc, test_loss, test_acc))))" - return train_loss_hist, train_acc_hist, test_loss_hist, test_acc_hist + return elapsed, train_loss_hist, train_acc_hist, test_loss_hist, test_acc_hist end # because Flux stacktraces are ludicrously big on <1.10 so don't show them From 3a884c1284d406ea1a8edc51097016c2d5593f2e Mon Sep 17 00:00:00 2001 From: Ian Butterworth Date: Thu, 21 Dec 2023 15:54:56 -0500 Subject: [PATCH 8/8] swtich to GLMakie --- benchmarking/Project.toml | 6 ++++-- benchmarking/benchmark.jl | 13 +++++++++---- benchmarking/tooling.jl | 4 ++-- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/benchmarking/Project.toml b/benchmarking/Project.toml index 1036fa64..9607bdd9 100644 --- a/benchmarking/Project.toml +++ b/benchmarking/Project.toml @@ -1,23 +1,25 @@ [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [compat] +Colors = "0.12" CUDA = "5" DataFrames = "1" Flux = "0.14" +GLMakie = "0.20" MLDatasets = "0.7" Metalhead = "0.9" Optimisers = "0.3" ProgressMeter = "1.9" TimerOutputs = "0.5" -UnicodePlots = "3.6" cuDNN = "1.2" diff --git a/benchmarking/benchmark.jl b/benchmarking/benchmark.jl index f6ffae98..a090a6a5 100644 --- a/benchmarking/benchmark.jl +++ b/benchmarking/benchmark.jl @@ -1,14 +1,15 @@ +using Colors using CUDA, cuDNN using DataFrames using Flux using Flux: logitcrossentropy, onecold, onehotbatch +using GLMakie using Metalhead using MLDatasets using Optimisers using ProgressMeter using TimerOutputs -using UnicodePlots include("tooling.jl") @@ -64,7 +65,10 @@ modelstrings = ( # "UNet(; $common)" # doesn't support kwargs "inchannels", "nclasses" ) df = DataFrame(; model=String[], train_loss=Float64[], train_acc=Float64[], test_loss=Float64[], test_acc=Float64[]) -plt = lineplot([0], [0], title="test accuracy vs. time (s)", ylim = (0,1), xlim=(0,60)) +cols = distinguishable_colors(length(modelstrings), [RGB(1,1,1), RGB(0,0,0)], dropseed=true) +f = Figure() +ax = Axis(f[1, 1], title="CIFAR-10 Training on a Nvidia 3090, batch 1000\nTest accuracy vs. time over 45 epochs", xlabel="Time (s)", ylabel="Testset Accuracy") +display(f) max_x = 0 for (i, modstring) in enumerate(modelstrings) @timeit to "$modstring" begin @@ -86,14 +90,15 @@ for (i, modstring) in enumerate(modelstrings) else elapsed, train_loss_hist, train_acc_hist, test_loss_hist, test_acc_hist = ret max_x = max(maximum(elapsed), max_x) - lineplot!(plt, elapsed, test_acc_hist, name=modstring) + lines!(ax, elapsed, test_acc_hist, label=modstring, color=cols[i]) elapsed, train_loss_hist[end], train_acc_hist[end], test_loss_hist[end], test_acc_hist[end] end push!(df, (modstring, train_loss, train_acc, test_loss, test_acc), promote=true) end GC.gc(true) end -display(plt) +f[1, 2] = Legend(f, ax, "Models", framevisible = false) +display(f) display(df) print_timer(to; sortby = :firstexec) end diff --git a/benchmarking/tooling.jl b/benchmarking/tooling.jl index 5246171e..06beab2a 100644 --- a/benchmarking/tooling.jl +++ b/benchmarking/tooling.jl @@ -49,7 +49,6 @@ function _train(model, train_loader, test_loader; epochs = 45, train_loss_hist, train_acc_hist = Float64[], Float64[] test_loss_hist, test_acc_hist = Float64[], Float64[] elapsed = Float64[] - start = time() @showprogress "training" for epoch in 1:epochs i = 0 for (x, y) in train_loader @@ -72,7 +71,7 @@ function _train(model, train_loader, test_loader; epochs = 45, @timeit to "testing" test_loss, test_acc = loss_and_accuracy(test_loader, model, device; limit) push!(train_loss_hist, train_loss); push!(train_acc_hist, train_acc); push!(test_loss_hist, test_loss); push!(test_acc_hist, test_acc); - push!(elapsed, time() - start) + push!(elapsed, time()) if show_plots plt2 = lineplot(1:epoch, train_loss_hist, name = "train_loss", xlabel="epoch", ylabel="loss") lineplot!(plt2, 1:epoch, test_loss_hist, name = "test_loss") @@ -85,6 +84,7 @@ function _train(model, train_loader, test_loader; epochs = 45, GC.gc() # GPU will OOM without this end end + elapsed = elapsed .- elapsed[1] train_loss, train_acc, test_loss, test_acc = train_loss_hist[end], train_acc_hist[end], test_loss_hist[end], test_acc_hist[end] @info "results after $epochs epochs $(repr(map(x->round(x, digits=3), (; train_loss, train_acc, test_loss, test_acc))))" return elapsed, train_loss_hist, train_acc_hist, test_loss_hist, test_acc_hist