Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: specialize some varargs #651

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

RFC: specialize some varargs #651

wants to merge 1 commit into from

Conversation

oxinabox
Copy link
Member

@MasonProtter

benchmark script

@btime (x->last(map(sin, fill(x,100))))'(0.1);
@btime (x->last(map(atan, fill(x,100), fill(2x, 100))))'(0.1);
@btime sin'(0.5);
@btime sin''(0.5);
@btime sin'''(0.5);

Julia 1.4

Without this PR

julia> @btime (x->last(map(sin, fill(x,100))))'(0.1);
  17.075 μs (769 allocations: 22.88 KiB)

julia> @btime (x->last(map(atan, fill(x,100), fill(2x, 100))))'(0.1);
  23.213 μs (988 allocations: 37.02 KiB)

julia> @btime sin'(0.5);
  1.346 ns (0 allocations: 0 bytes)

julia> @btime sin''(0.5);
  12.902 ns (0 allocations: 0 bytes)

julia> @btime sin'''(0.5);
  703.346 μs (2664 allocations: 68.30 KiB)

Only with interface.jl changes

julia> @btime (x->last(map(sin, fill(x,100))))'(0.1);
  15.689 μs (770 allocations: 22.89 KiB)

julia> @btime (x->last(map(atan, fill(x,100), fill(2x, 100))))'(0.1);
  21.278 μs (988 allocations: 37.02 KiB)

julia> @btime sin'(0.5);
  1.247 ns (0 allocations: 0 bytes)

julia> @btime sin''(0.5);
  10.701 ns (0 allocations: 0 bytes)

julia> @btime sin'''(0.5);
  679.070 μs (2626 allocations: 67.69 KiB)

With interface.jl and interface2.jl changes

julia> @btime (x->last(map(sin, fill(x,100))))'(0.1);
  17.219 μs (770 allocations: 22.89 KiB)

julia> @btime (x->last(map(atan, fill(x,100), fill(2x, 100))))'(0.1);
  22.907 μs (988 allocations: 37.02 KiB)

julia> @btime sin'(0.5);
  1.247 ns (0 allocations: 0 bytes)

julia> @btime sin''(0.5);
  11.414 ns (0 allocations: 0 bytes)

julia> @btime sin'''(0.5);
  674.351 μs (2617 allocations: 67.55 KiB)

Julia 1.5

Without this PR

julia> @btime (x->last(map(sin, fill(x,100))))'(0.1);
  14.080 μs (770 allocations: 23.55 KiB)

julia> @btime (x->last(map(atan, fill(x,100), fill(2x, 100))))'(0.1);
  19.151 μs (989 allocations: 37.94 KiB)

julia> @btime sin'(0.5);
  1.246 ns (0 allocations: 0 bytes)

julia> @btime sin''(0.5);
  11.812 ns (0 allocations: 0 bytes)

julia> @btime sin'''(0.5);
  531.809 μs (2410 allocations: 86.27 KiB)

With interface.jl and interface2.jl changes

julia> @btime (x->last(map(sin, fill(x,100))))'(0.1);
  14.947 μs (770 allocations: 23.55 KiB)

julia> @btime (x->last(map(atan, fill(x,100), fill(2x, 100))))'(0.1);
  19.142 μs (989 allocations: 37.94 KiB)

julia> @btime sin'(0.5);
  1.279 ns (0 allocations: 0 bytes)

julia> @btime sin''(0.5);
  10.987 ns (0 allocations: 0 bytes)

julia> @btime sin'''(0.5);
  509.693 μs (2310 allocations: 83.59 KiB)

@MasonProtter
Copy link
Contributor

Thanks for running these. Much less of an improvement that I had hoped, but it could be that there's just different bottlenecks to be tackled. Perhaps something in IRTools.jl itself.

What branch do you use for "without this PR"? Here's what I see on master:

julia> begin
       using Zygote
       @btime (x->last(map(sin, fill(x,100))))'(0.1)
       @btime (x->last(map(atan, fill(x,100), fill(2x, 100))))'(0.1)
       @btime sin'(0.5)
       @btime sin''(0.5)
       @btime sin'''(0.5)
       end;
  18.740 μs (769 allocations: 22.88 KiB)
  24.790 μs (988 allocations: 37.02 KiB)
  1.299 ns (0 allocations: 0 bytes)
  9.058 ns (0 allocations: 0 bytes)
  1.317 ms (4652 allocations: 114.19 KiB)

julia> versioninfo()
Julia Version 1.4.1
Commit 381693d3df* (2020-04-14 17:20 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: AMD Ryzen 5 2600 Six-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-8.0.1 (ORCJIT, znver1)
Environment:
  JULIA_NUM_THREADS = 6

and on your branch:

julia> begin
       using Zygote
       @btime (x->last(map(sin, fill(x,100))))'(0.1)
       @btime (x->last(map(atan, fill(x,100), fill(2x, 100))))'(0.1)
       @btime sin'(0.5)
       @btime sin''(0.5)
       @btime sin'''(0.5)
       end;
  17.320 μs (770 allocations: 22.89 KiB)
  23.600 μs (988 allocations: 37.02 KiB)
  1.299 ns (0 allocations: 0 bytes)
  9.018 ns (0 allocations: 0 bytes)
  740.627 μs (2617 allocations: 67.67 KiB)

so sin''' seems way worse for me on master than your "without this PR".

But pursuing SpecializeVarargs.jl just to make higher order derivatives faster seems like putting a bandaid on a headwound. For that, we either need ways to simplify IR after a derivative so that the IR doesn't combinatorially explode with higher order derivatives, or we need a way to generate less IR bloat in the first place using something like Taylor rules.

@oxinabox
Copy link
Member Author

Interesting, I thought i was using master...

The whole thing of nested Zygote can't be solved via Taylor mode.
Because the only reason you were ever nest Zygote is my mistake.
E.g. using Zygote on come large bit of code from a package that happens to use Zygote internally (and you didn't know about it).
The majority of times you want a second derivative you want to do forward over reverse, or forward over forward. If you do want reverse over reverse i hear Zygote over ReverseDiff works great already.
The sin'''(x) code is interesting for benchmarking reasons only.
It is some of the most pathological code one might want to do with AD.
Almost nothing in the real world that someone does on purpose is going to be that hard

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants