Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zygote's compilation scales badly with the number of ~ statements #1754

Closed
torfjelde opened this issue Dec 21, 2021 · 41 comments
Closed

Zygote's compilation scales badly with the number of ~ statements #1754

torfjelde opened this issue Dec 21, 2021 · 41 comments

Comments

@torfjelde
Copy link
Member

I'm not certain whether or not this can be considered a "Turing.jl-issue" or not, but I figured I would at least raise it as an issue here so people are aware.

The compilation time of Zygote scales quite badly with the number of ~ statements.

TL;DR: it takes almost 5 minutes to compile a model with 14 ~ statements. I don't have the result here, but at some point I tried one with 20 ~ statements, and it took a full ~23 mins to compile.

Demo

using Turing, Zygote
Turing.setadbackend(:zygote);
adbackend = Turing.Core.ZygoteAD();
spl = DynamicPPL.SampleFromPrior();
results = []
num_tildes = 0

Running the following snippet a couple of times we get a sense of the compilation times:

num_tildes += 1

Zygote.refresh()

expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);

f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);

t = @elapsed Turing.Core.gradient_logp(
   adbackend,
   vi[spl],
   vi,
   model,
   spl
);
push!(results, t)
results
14-element Vector{Any}:
  15.917412882
   9.600213651
  14.911253238
  24.699050206
  71.811476745
  49.158314248
  46.059394601
  57.44494627
  75.514564551
  94.927956369
 134.165383535
 156.079943416
 202.745837585
 273.93479382

That is, it takes almost 5 minutes to compile a model with 14 ~ statements. I don't have the result here, but at some point I tried one with 20 ~ statements, and it took a full ~23 mins to compile.

Additional info

versioninfo()
Julia Version 1.6.2
Commit 1b93d53fc4 (2021-07-14 15:36 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Core(TM) i7-10710U CPU @ 1.10GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-11.0.1 (ORCJIT, skylake)
Pkg.status()
    Status `/tmp/jl_PHYdgF/Project.toml`
[fce5fe82] Turing v0.19.2
[e88e6eb3] Zygote v0.6.32
@wupeifan
Copy link
Contributor

Thanks a lot for posting this!

@ParadaCarleton
Copy link
Member

Do we have code that checks the compile times of models? How long has Zygote compilation been taking this long?

@wupeifan
Copy link
Contributor

How long has Zygote compilation been taking this long?

Since Julia 1.6 afaik.

@ParadaCarleton
Copy link
Member

How long has Zygote compilation been taking this long?

Since Julia 1.6 afaik.

As in, the 1.6 update caused it? Is the compilation faster on 1.5?

@jlperla
Copy link

jlperla commented Dec 26, 2021

@torfjelde probably can get the right answer, if it matters. There are a lot of packages and versions interacting so not sure it necessarily matches a particular Julia version. It might take just as long to find the exact combination of Julia, zygote, chainrules, Turing, dynamicppl, etc versions that caused it as it would to actually solve the problem.

@wupeifan
Copy link
Contributor

As in, the 1.6 update caused it? Is the compilation faster on 1.5?

I don't think so. It's only my experience instead of any benchmarking/profiling (which we need!) I don't have a good answer to this question.

@jlperla
Copy link

jlperla commented Jan 7, 2022

Sorry to bump so soon @torfjelde but we have any sense of a timeline on this? Are we thinking a week, a month, etc.? We would help out but I fear this is a little too tightly connected to the PPL macros for us to contribute to.

@torfjelde
Copy link
Member Author

I'm sorry, I can't put a timeline on this right now.

And I worry and suspect it's not particularly related to Turing, unfortunately 😕 There was one change we made in Turing that I worried could have caused it, but when I tried with an older version the performance was still horrible.

The best way to identify what's wrong is to just run the test above for different versions of Turing and Zygote.

@torfjelde
Copy link
Member Author

torfjelde commented Jan 10, 2022

As of right now I have the following results:

Julia Args Turing Zygote 1 5 10 15 20
1.6.1 0.19 0.6 16.387332656 89.504752487 101.654405869 320.501781565 N/A
1.6.1 0.18 0.6 15.740250148 72.75134579 92.373805461 310.690419646 N/A
1.6.1 0.16 0.6 14.098175146 77.075566728 86.093118485 310.681374129 N/A
1.6.1 0.15 0.6 13.674902932 75.142343745 83.177268954 309.479831351 N/A
1.6.1 0.19.3 0.6.32 16.519126809 76.212964076 102.056456525 331.835923261 N/A
1.6.1 0.19.3 0.6.30 16.416369038 73.552713013 98.009949908 325.177503251 N/A
1.6.1 0.19.3 0.6.28 15.907693813 73.859609986 99.65385215 320.406822231 N/A
1.6.1 0.19.3 0.6.25 16.155023105 78.113216025 91.386990818 323.182170592 N/A
1.6.1 0.19.3 0.6.20 16.18978755 76.922928732 92.863883057 329.472662078 N/A
1.6.1 0.19.3 0.6.17 16.674484174 45.536406584 91.975029367 324.107966298 N/A
1.6.1 0.19.3 0.6.15 16.647893857 77.246173583 92.547574565 325.843565657 N/A
1.6.5 0.19.3 0.6.32 16.507995835 75.232973774 102.738249637 320.804476864 N/A
1.6.5 0.19.3 0.6.30 16.442945162 75.552669923 100.079247493 315.767280417 N/A
1.6.5 0.19.3 0.6.28 15.696215339 41.383477294 93.911315279 329.02145826 N/A
1.6.5 -O1 0.19.3 0.6.27 9.442187946 13.78550593 27.309818329 61.662899872
1.6.5 -O1 0.19.3 0.6.26 9.525658865 15.737620623 37.788878332 154.77089758 N/A
1.6.5 0.19.3 0.6.25 16.103407208 78.827783495 92.628645679 323.749846052 N/A
1.6.5 0.19.3 0.6.20 16.081257806 45.193458428 91.298156568 324.887103193 N/A
1.6.5 0.19.3 0.6.17 16.661767166 117.089430786 91.768988429 322.823669619 N/A
1.6.5 0.19.3 0.6.15 16.698116787 99.878925099 91.444128363 322.810816215 N/A
1.6.5 -O1 0.19.3 0.6.33 9.949009976 14.697673851 29.464735085 66.522595927 132.225028216
1.6.5 -O1 0.19.3 0.6.33 9.952505584 14.642345133 29.487756023 66.405126925 132.858871417
1.6.5 -O1 0.19.3 0.6.32 10.213924716 14.90292511 29.43141585 65.774123029 132.584280038
1.6.5 -O1 0.19.3 0.6.30 9.952360792 14.604632874 29.263546539 65.256550668 131.766387963
1.6.5 -O1 0.19.3 0.6.28 9.61100625 14.16199726 28.106580684 61.752193716 128.178242579
1.6.5 0.19.3 #master 16.492658709 76.261172697 99.347007007 318.410333441
1.6.5 0.19.3 mcabbot:opt_level 14.510376473 15.581238372 30.034043196 67.305733329
1.8.2 0.21.13 0.6.49 (#66cc60) 16.799219872 48.980447826 182.013203064 1359.453378186 N/A
1.8.2 0.21.13 0.6.49 (#ee8945) 32.313442694 104.451519939 284.157696506 N/A N/A

Columns with numeric values represents the number of ~ statements.

This is on the following system:

julia> versioninfo()
Julia Version 1.6.1
Commit 6aaedecc44 (2021-04-23 05:59 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Core(TM) i7-6850K CPU @ 3.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-11.0.1 (ORCJIT, broadwell)

The sad news is that I can't really glean anything useful from the above 😕 Here it seems to be "fine" (taking 5mins to compile is not good, but it's much better than the reported numbers).

Comments:

  • When I tried running this benchmark on my laptop (which has 32G memory), Zygote earlier than 0.6.17 blew up (on Julia 1.6.3), i.e. it seems as if the memory-usage of Zygote changed significantly from 0.6.17 to 0.6.18?
  • The version columns which are missing patch-version, e.g. 0.19, are using the most recent version of Zygote which is compatible with the corresponding Turing-version. I can find these out if need be (the numbers are from an earlier version of the script I'm using).

[2022-01-10 Mon 00:36]: I'm now running the same benchmarks on Julia 1.6.5 just to make sure it has nothing to do with weird interactions between the particular Julia version and Zygote.

Script I'm running
using Pkg; Pkg.activate(mktempdir())
TURING_VERSION = ENV["TURING_VERSION"]
ZYGOTE_VERSION = ENV["ZYGOTE_VERSION"]

@info "Trying to install Turing@$(TURING_VERSION) and Zygote@$(ZYGOTE_VERSION)"
Pkg.add(name="Turing", version=TURING_VERSION)
Pkg.add(name="Zygote", version=ZYGOTE_VERSION)

using Turing, Zygote

pkgversion(mod) = Pkg.TOML.parsefile(joinpath(dirname(Pkg.project().path), "Manifest.toml"))[string(mod)][1]["version"]
@info "Installed Turing@$(pkgversion(Turing)) and Zygote@$(pkgversion(Zygote))"

Turing.setadbackend(:zygote);
adbackend = Turing.Core.ZygoteAD();
spl = DynamicPPL.SampleFromPrior();

results = []

num_tildes = 1
@info "Running" num_tildes
Zygote.refresh()

expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);

f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);

t = @elapsed Turing.Core.gradient_logp(
   adbackend,
   vi[spl],
   vi,
   model,
   spl
);
push!(results, t)
@info "Result" t

num_tildes = 5
@info "Running" num_tildes
Zygote.refresh()

expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);

f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);

t = @elapsed Turing.Core.gradient_logp(
   adbackend,
   vi[spl],
   vi,
   model,
   spl
);
push!(results, t)
@info "Result" t

num_tildes = 10
@info "Running" num_tildes
Zygote.refresh()

expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);

f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);

t = @elapsed Turing.Core.gradient_logp(
   adbackend,
   vi[spl],
   vi,
   model,
   spl
);
push!(results, t)
@info "Result" t

num_tildes = 15
@info "Running" num_tildes
Zygote.refresh()

expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);

f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);

t = @elapsed Turing.Core.gradient_logp(
   adbackend,
   vi[spl],
   vi,
   model,
   spl
);
push!(results, t)
@info "Result" t

println(join(results, " | "))

if haskey(ENV, "OUTPUT_FILE")
    open(ENV["OUTPUT_FILE"], "a") do io
        write(io, "| ", join([VERSION; pkgversion(Turing); pkgversion(Zygote); results;], " | "), " |")
        write(io, "\n")
    end
end
Script for Turing >= 0.21
using Pkg

if any(Base.Fix1(haskey, ENV), ["TURING_VERSION", "ZYGOTE_VERSION"])
    # In this case, we create a new env and install the corresponding package versions.
    Pkg.activate(mktempdir())

    if haskey(ENV, "TURING_VERSION")
        TURING_VERSION = ENV["TURING_VERSION"]
        @info "Trying to install Turing@$(TURING_VERSION)"
        Pkg.add(name="Turing", version=TURING_VERSION)
    end

    if haskey(ENV, "ZYGOTE_VERSION")
        ZYGOTE_VERSION = ENV["ZYGOTE_VERSION"]
        @info "Trying to install Zygote@$(ZYGOTE_VERSION)"
        Pkg.add(name="Zygote", version=ZYGOTE_VERSION)
    end
end

using Turing, Zygote
using Turing: LogDensityProblems


if VERSION < v"1.6.2"
    pkginfo(mod) = Pkg.TOML.parsefile(joinpath(dirname(Pkg.project().path), "Manifest.toml"))[string(mod)][1]
else
    pkginfo(mod) = Pkg.TOML.parsefile(joinpath(dirname(Pkg.project().path), "Manifest.toml"))["deps"][string(mod)][1]
end

pkgversion(mod) = pkginfo(mod)["version"]
pkghash(mod) = pkginfo(mod)["git-tree-sha1"]

@info "Installed Turing@$(pkgversion(Turing)) [#$(pkghash(Turing))] and Zygote@$(pkgversion(Zygote)) [#$(pkghash(Zygote))]"

Turing.setadbackend(:zygote);
adbackend = Turing.Core.ZygoteAD();
spl = DynamicPPL.SampleFromPrior();

results = []

num_tildes = 1
@info "Running" num_tildes
Zygote.refresh()

expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);

f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);

ℓ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext());
ℓ_with_grad = LogDensityProblems.ADgradient(:Zygote, ℓ);
t = @elapsed LogDensityProblems.logdensity_and_gradient(ℓ_with_grad, vi[spl]);
push!(results, t)
@info "Result" t

num_tildes = 5
@info "Running" num_tildes
Zygote.refresh()

expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);

f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);

ℓ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext());
ℓ_with_grad = LogDensityProblems.ADgradient(:Zygote, ℓ);
t = @elapsed LogDensityProblems.logdensity_and_gradient(ℓ_with_grad, vi[spl]);
push!(results, t)
@info "Result" t

num_tildes = 10
@info "Running" num_tildes
Zygote.refresh()

expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);

f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);

ℓ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext());
ℓ_with_grad = LogDensityProblems.ADgradient(:Zygote, ℓ);
t = @elapsed LogDensityProblems.logdensity_and_gradient(ℓ_with_grad, vi[spl]);
push!(results, t)
@info "Result" t

num_tildes = 15
@info "Running" num_tildes
Zygote.refresh()

expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);

f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);

ℓ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext());
ℓ_with_grad = LogDensityProblems.ADgradient(:Zygote, ℓ);
t = @elapsed LogDensityProblems.logdensity_and_gradient(ℓ_with_grad, vi[spl]);
push!(results, t)
@info "Result" t

println(join(results, " | "))

if haskey(ENV, "OUTPUT_FILE")
    open(ENV["OUTPUT_FILE"], "a") do io
        write(io, "| ", join([VERSION; pkgversion(Turing); pkgversion(Zygote); results;], " | "), " |")
        write(io, "\n")
    end
end

Edits:

@jlperla
Copy link

jlperla commented Jan 10, 2022

Thanks @torfjelde

Is there any chance this stuff gets better on 1.7 with the latest zygote? I know Turing doesn't support it yet, but does the dynamicppl let you test on 1.7?

@torfjelde
Copy link
Member Author

Is there any chance this stuff gets better on 1.7 with the latest zygote? I know Turing doesn't support it yet, but does the dynamicppl let you test on 1.7?

I can but it will take a bit of work (currently using Turing functionality to compute the gradient).

Do you know when you started experiencing these issues btw?

Also, maybe some of the Zygote people have any idea what's going on here @mcabbott ? TL;DR: Compilation time of Zygote.gradient blows up wrt. number of ~ in a Turing model.

@torfjelde
Copy link
Member Author

torfjelde commented Jan 10, 2022

Possibly related: FluxML/Zygote.jl#1119 and FluxML/Zygote.jl#1126

EDIT: Seems like it. -O1 helps a lot.

@torfjelde
Copy link
Member Author

torfjelde commented Jan 10, 2022

Wow, the -O1 really helps. See the results in the comment above.

EDIT: Even n=20 only results in ~2min of compilation.

@jlperla
Copy link

jlperla commented Jan 10, 2022

@torfjelde Am I reading that correctly that Julia 1.6.5 + Turing 0.19.3 + Zygote 0.6.33 brings it back to sanity?

@torfjelde
Copy link
Member Author

Am I reading that correctly that Julia 1.6.5 + Turing 0.19.3 + Zygote 0.6.33 brings it back to sanity?

No, specifically you need the -O1 optimization flag (the default is -O3). It seems as if Zygote + Julia 1.6 leads to some insane compilation times when the default optimizations are used.

@mcabbott
Copy link

mcabbott commented Jan 10, 2022

See if FluxML/Zygote.jl#1147 works as well as -O1 for this purpose. If you have any other benchmarks of runtime performance (i.e. @btime) it would also be interesting to see if those get worse.

@wupeifan
Copy link
Contributor

wupeifan commented Jan 10, 2022

I can confirm that -O1 helps with our original problem by significantly decreasing the compilation time. Previously we have to wait for around 30-40min, and now it takes 2-3 Chopin preludes -- around 6min -- to compile. I would say this is similar to what I had in Julia 1.5.
Now the question is whether -O1 generates much less efficient code than -O3.

EDIT: preliminary experiments seem generate similar computing times.

@torfjelde
Copy link
Member Author

See if FluxML/Zygote.jl#1147 works as well as -O1 for this purpose. If you have any other benchmarks of runtime performance (i.e. @btime) it would also be interesting to see if those get worse.

Will give it a go 👍

Btw, not sure if this is more useful information , but when I try [email protected] even with -O1 the runtime blows up again (the nearest more recent version I tried which had reasonable compile time was 0.6.28).

@ToucheSir
Copy link

I've never heard of Chopin preludes being used as a unit of measurement, but people should do that more often :)

Now the question is whether -O1 generates much less efficient code than -O3.

Betting on no, as most performance-sensitive stuff is gated behind a rule. Maybe scalar- or control flow-heavy code, though the generated pullbacks for the latter are likely type unstable anyhow.

but when I try [email protected] even with -O1 the runtime blows up again

Fixed by FluxML/Zygote.jl#909 perhaps? That was in 0.6.27.

@torfjelde
Copy link
Member Author

See if FluxML/Zygote.jl#1147 works as well as -O1 for this purpose. If you have any other benchmarks of runtime performance (i.e. @btime) it would also be interesting to see if those get worse.

Gave it a try; seems to do the trick! Benchmarks in table above.

Don't know what effect it has on performance though, but seems like it would be worth it.

@torfjelde
Copy link
Member Author

torfjelde commented Jan 10, 2022

Fixed by FluxML/Zygote.jl#909 perhaps? That was in 0.6.27.

Ah, probably! I'll give it a go.

EDIT: Seems like indeed 0.6.27 improved things significantly 👍

@jlperla
Copy link

jlperla commented Mar 2, 2022

Any progress on this issue by chance? Did a check with Julia 1.7 and the latest Turing, DynamicPPL, and Zygote and am still getting > 30 minute TTFG for my model with 20ish parameters. -O1 makes things more reasonable, as it did before.

Just want to make sure that everyone knows those two Zygote issues linked did not fix things.

@ParadaCarleton
Copy link
Member

@ToucheSir @Keno Any progress on this?

@ToucheSir
Copy link

Unfortunately I have nothing concrete to report, but I have been looking into this over the past couple of months. Any help on grokking compilation latency + Zygote's internals would be much appreciated. I can't speak for Keno, but to my knowledge working on this is not on anyone else's plate.

@ToucheSir
Copy link

@torfjelde I'm not able to repro your latest timings on 1.8.2 locally with the following reduced MWE:

using Turing, Zygote
using Turing: LogDensityProblems
using SnoopCompileCore

# This helps a bit, ~4s
# @eval Turing.DynamicPPL begin
#   ChainRulesCore.@non_differentiable is_flagged(::VarInfo, ::VarName, ::String)
# end

Turing.setadbackend(:zygote);
adbackend = Turing.Core.ZygoteAD();
spl = DynamicPPL.SampleFromPrior();

num_tildes = 5
# num_tildes = 10

expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);

f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);

@info "starting eval"= Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext());
ℓ_with_grad = LogDensityProblems.ADgradient(:Zygote, ℓ);
tinf = @snoopi_deep LogDensityProblems.logdensity_and_gradient(ℓ_with_grad, vi[spl]);
@info "done eval"

using SnoopCompile, ProfileView
@show tinf

Results:

# v0.6.49
tinf = InferenceTimingNode: 38.138395/69.026222 on Core.Compiler.Timings.ROOT() with 388 direct children
# https://github.com/FluxML/Zygote.jl/pull/1195
tinf = InferenceTimingNode: 37.486037/65.298454 on Core.Compiler.Timings.ROOT() with 423 direct children

Versioninfo:

Julia Version 1.8.2
Commit 36034abf260 (2022-09-29 15:21 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × Intel(R) Core(TM) i7-4790K CPU @ 4.00GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, haswell)
  Threads: 1 on 8 virtual cores

@ToucheSir
Copy link

ToucheSir commented Nov 12, 2022

I had a crack at figuring out why inference times also scale so badly. All outputs of that exploration may be found in this gist.

It turns out that we can capture a good chunk of the slowness and poor scaling with just pullback(model.f, ...). SnoopCompile didn't feel very helpful because it just reports high exclusive inference time for that function, but digging into the IR with SnoopCompile's Cthulhu integration did turn up something interesting.

To demonstrate, we should first examine @code_warntype output for the un-transformed function. DynamicPPL does generate a decent amount of code, but the compiler should be able to manage 600ish statements and 100ish slots1 without too much trouble.

Now let's see the output for the augmented primal function from pullback. Yes, you read that correctly. Zygote generates a function with over 4,300 statements and over 19,000 slots! I have no idea why both numbers are so high, but I suspect there is something up with the IRTools IR -> Julia IR translation that happens in Zygote2 (Edit: IRTools.slots! is the cause of the blow-up. See the output of each IRTools pass Zygote uses for the gory details). Given the sheer amount of code, I'm not surprised that LLVM times seem to be rather horrendous as well.

So what is to be done? The first thing that comes to mind is to figure out why this is happening how to fix the slot explosion on the IRTools side , and how much Zygote's own passes might be responsible. That however would require someone quite familiar with IRTools internals. The more ambitious plan would be to ditch IRTools in Zygote completely and use CodeInfo/IRCode like other libraries (including TuringLang's own LibTask) have done. Having given this a try some months back, I think the biggest missing piece is someone who understands Zygote's pre-IRTools AD transform (or the implementing AD transforms on SSA IR in general) well enough to do the port. It's not clear how a lot of the logic in Zygote now would look if passes went from working on block args to Phi nodes, for example.

Footnotes

  1. I count slots because someone mentioned before that they can consume quite a bit of memory during compilation. So having too many of them in a function may be the culprit behind the memory blowup documented here and in https://github.com/TuringLang/Turing.jl/issues/1754#issuecomment-1008460319.

  2. Zygote used to operate on native Julia IR, but switched to IRTools a couple of years back. This means every function goes through a native IR (really CodeInfo) -> IRTools IR -> AD transform -> native IR pipeline.

@yebai
Copy link
Member

yebai commented Nov 12, 2022

Diffractor will hopefully fix these issues, right?

@ToucheSir
Copy link

ToucheSir commented Nov 12, 2022

Last I asked there were no plans to support setfield! (on any type) or setindex! on Dicts/RefValues, both of which Turing seems to need. So I'm not terribly optimistic...

@ParadaCarleton
Copy link
Member

Briefly going to comment on this to say--the solution to this issue is to use ReverseDiff or ForwardDiff.jl or (a few years down the line when it's mature) maybe some other autodiff solution like Enzyme.jl. Development on Zygote/IRTools and source-to-source AD in Julia (rather than LLVM) is effectively dead now.

@devmotion
Copy link
Member

To be honest, ReverseDiff has other issues and ForwardDiff is not always an option. Actually, Zygote is developed much more actively (https://github.com/FluxML/Zygote.jl/commits/master) than ReverseDiff (https://github.com/JuliaDiff/ReverseDiff.jl/commits/master) or ForwardDiff (https://github.com/JuliaDiff/ForwardDiff.jl/commits/master) (there hasn't been any new release of master since the breaking change that downstream packages did - IMO correctly - reject in a non-breaking release was reapplied to the master branch; and nobody wants to deal with and possibly fix these still existing downstream issues, hence nobody is willing to tag any new release on the ForwardDiff master branch).

@ToucheSir
Copy link

Actually, Zygote is developed much more actively (https://github.com/FluxML/Zygote.jl/commits/master) than ReverseDiff (https://github.com/JuliaDiff/ReverseDiff.jl/commits/master) or ForwardDiff (https://github.com/JuliaDiff/ForwardDiff.jl/commits/master)

Although this is true, it's not quite an apples-to-apples comparison because of what those commits are doing. The majority of Zygote work these days is maintenance work, and much of that is mandatory because some change in Julia internals broke the source-to-source AD bit. The rest are primarily filling in holes/edge cases/robustness issues in the existing rule system, which has many, many more of them than either Forward or ReverseDiff.

To be honest, ReverseDiff has other issues...

IMO they should be listed out, because my impression is that many of them are more tractable than the fundamental ones facing Zygote. Some offline experimentation suggests that even the infamous lack of GPU support could be addressed without a complete package rewrite. That said, my personal deal-breaker with ReverseDiff is not any technical issue—it's that there appears to be no appetite for anything other than the most urgent maintenance. I can absolutely appreciate why this is the case, but it is a little disheartening to treat ReverseDiff as some technological dead-end when other languages/libraries have managed to take similar ideas further.

@yebai
Copy link
Member

yebai commented Oct 20, 2023

It would be great to have a ReverseDiff2. In my experience, it is very performant and made a good tradeoff between simplicity, generality and performance. There are some weakly-justified pushes to differentiate through everything. However, it is very hard to differentiate through everything, and I am not sure that one wants to do that. A performant, well-tested, and maintainable AD is what's needed.

cc @willtebbutt

@devmotion
Copy link
Member

IMO they should be listed out

ReverseDiff e.g. only supports differentiation of vectors and real numbers, has problems with wrapper types (e.g., JuliaDiff/ReverseDiff.jl#223), faces the general problem of arrays of tracked reals vs tracked arrays, and probably defines too many methods (JuliaDiff/ReverseDiff.jl#226). It has multiple correctness issues (JuliaDiff/ReverseDiff.jl#145, JuliaDiff/ReverseDiff.jl#168, JuliaDiff/ReverseDiff.jl#239, JuliaDiff/ReverseDiff.jl#233) and its ChainRules macro support has multiple bugs (eg JuliaDiff/ReverseDiff.jl#221). Some of these issues might require more changes, some of them might be more easily fixable - but IMO it's really not as good currently as some people seem to think and also would require time and effort that nobody seems to be willing to invest (as you can see from the recent commit history).

The main intention of my comment was just: All Julia AD packages have problems, and I think one thing that contributed to the current situation was that people suggested to abandon certain AD packages as soon as a new promising alternative appeared - and only later it was realized that they also have their own set of problems and limitations. I think it would be better

  • to instruct users about the strengths and weaknesses of the different alternatives and provide recommendations for which package to choose for which applications (also to avoid that they pick a backend without knowing about its limitations and being frustrated about it once they realize those later), and
  • to maintain and improve the different AD packages for their intended target use, keeping in mind their limitations and not focusing on a single best AD system.

@thorek1
Copy link

thorek1 commented Feb 27, 2024

any update on this?
I am stuck with 36 ~ statements and didnt see the end of the compilation yet (Julia 1.10.1, Zygote 0.6.69, Turing 0.30.5)
what are feasible workarounds? does filldist help?

@yebai
Copy link
Member

yebai commented Feb 28, 2024

I don't think there is a good solution yet; this is a general issue with Zygote: FluxML/Zygote.jl#1119

We are rewriting Zygote/ReverseDiff. Hopefully the issue will be resolved when that is complete.

@thorek1
Copy link

thorek1 commented Feb 28, 2024

ok thanks for the info.

what i found to work reasonably well is wrapping all distributions in an arraydist. starting julia with -O1 also helps but the improvement by using arraydist was large enough so that i dont need to modify the startup args.

for example (ignore the 4 element inputs and μσ; i wrote my own wrappers):

# Handling distributions with varying parameters using arraydist
dists = [
InverseGamma(0.1, 2.0, 0.01, 3.0, μσ = true), # z_ea
InverseGamma(0.1, 2.0, 0.025,5.0, μσ = true), # z_eb
InverseGamma(0.1, 2.0, 0.01, 3.0, μσ = true), # z_eg
InverseGamma(0.1, 2.0, 0.01, 3.0, μσ = true), # z_eqs
InverseGamma(0.1, 2.0, 0.01, 3.0, μσ = true), # z_em
InverseGamma(0.1, 2.0, 0.01, 3.0, μσ = true), # z_epinf
InverseGamma(0.1, 2.0, 0.01, 3.0, μσ = true), # z_ew
Beta(0.5, 0.20, μσ = true), # crhoa
Beta(0.5, 0.20, μσ = true), # crhob
Beta(0.5, 0.20, μσ = true), # crhog
Beta(0.5, 0.20, μσ = true), # crhoqs
Beta(0.5, 0.20, μσ = true), # crhoms
Beta(0.5, 0.20, μσ = true), # crhopinf
Beta(0.5, 0.20, μσ = true), # crhow
Beta(0.5, 0.2, μσ = true), # cmap
Beta(0.5, 0.2, μσ = true), # cmaw
Normal(4.0, 1.5,   2.0, 15.0), # csadjcost
Normal(1.50,0.375, 0.25, 3.0), # csigma
Beta(0.7, 0.1, μσ = true), # chabb
Beta(0.5, 0.1, μσ = true), # cprobw
Normal(2.0, 0.75, 0.25, 10.0), # csigl
Beta(0.5, 0.10, μσ = true), # cprobp
Beta(0.5, 0.15, μσ = true), # cindw
Beta(0.5, 0.15, μσ = true), # cindp
Beta(0.5, 0.15, μσ = true), # czcap
Normal(1.25, 0.125, 1.0, 3.0), # cfc
Normal(1.5, 0.25, 1.0, 3.0), # crpi
Beta(0.75, 0.10, μσ = true), # crr
Normal(0.125, 0.05, 0.001, 0.5), # cry
Normal(0.125, 0.05, 0.001, 0.5), # crdy
Gamma(0.625, 0.1, 0.1, 2.0, μσ = true), # constepinf
Gamma(0.25, 0.1, 0.01, 2.0, μσ = true), # constebeta
Normal(0.0, 2.0, -10.0, 10.0), # constelab
Normal(0.4, 0.10, 0.1, 0.8), # ctrend
Normal(0.5, 0.25, 0.01, 2.0), # cgy
Normal(0.3, 0.05, 0.01, 1.0), # calfa
]

Turing.@model function model_loglikelihood_function(data, m, observables,fixed_parameters)
    all_params ~ Turing.arraydist(dists)

    z_ea, z_eb, z_eg, z_eqs, z_em, z_epinf, z_ew, crhoa, crhob, crhog, crhoqs, crhoms, crhopinf, crhow, cmap, cmaw, csadjcost, csigma, chabb, cprobw, csigl, cprobp, cindw, cindp, czcap, cfc, crpi, crr, cry, crdy, constepinf, constebeta, constelab, ctrend, cgy, calfa = all_params

...
end

@marcobonici
Copy link

@thorek1 thanks for sharing this!
I will try this later but could you please share how much did this improve the compilation time?
Thx in advance!

@yebai
Copy link
Member

yebai commented Apr 30, 2024

The compilation time with Tapir seems much better now.

julia> using Turing, Tapir, ADTypes

julia> num_tildes = 50
50

julia> expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;

julia> mainbody = last(expr.args);

julia> append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);

julia> f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
demo50 (generic function with 2 methods)

julia> model = f();

julia> t = @elapsed chn = sample(model, NUTS(), 1000; adtype=AutoTapir())

┌ Info: Found initial step size
└   ϵ = 0.8500000000000001                                                                                                         |  ETA: N/A
Sampling 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:02:34
168.426989084

@yebai yebai closed this as completed Apr 30, 2024
@torfjelde
Copy link
Member Author

That's awesome, but should this issue be closed? It's specifically related to Zygote, no?

@yebai yebai closed this as not planned Won't fix, can't repro, duplicate, stale Apr 30, 2024
@yebai
Copy link
Member

yebai commented Apr 30, 2024

That's awesome, but should this issue be closed? It's specifically related to Zygote, no?

It should have been closed as "not planned" since fixing it requires some fairly fundamental changes to Zygote (e.g. working with optimised IR)

@wupeifan
Copy link
Contributor

Oh I must just saw another one trying to estimate Smets & Wouters paper with NUTS in Julia...
I think with that amount of ~ things are still tolerable... 36 should be fine though, at least when I raised the issue it was for a second-order Smets&Wouters...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants