-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Question about long sequence lengths with Hyena #10
Comments
The runtime numbers in the paper do not use the optimized fftconv kernel precisely because of the temporary 8k limitation.
Figure 4.3 (left) goes up to 100k, I think you're referencing the one on the right (which is only a zoomed-in portion of the left figure)?
Can you give more details on your benchmarking workload? Hyena should already be much faster at 32k tokens, I suspect there might be other factors at play. |
Yeah my bad. I saw 10^5 and thought 10000 for some reason 🤦
My benchmarking baseline was a model with:
I compared it to a model replacing FlashAttention with a I ran into memory issues using the Even with the much smaller embedding dim, the network with the Do you think the memory usage issues are just an artifact of the standalone implementation?
Are you saying that the code in The benchmark above was just a quick test to get some rough numbers so there were some other differences between the baseline and the Hyena test:
Do you have any suggestions for things to try? Thanks! |
Thanks for all the info! I pushed a small benchmarking script here for both forward and backward passes, what numbers do you see when you run it? On my end (on a single A100) I see Hyena as 5x/6x faster at batch size 1 and seqlen 32k. If you use the same script and benchmark at batch size 64, you should get the FlashAttention runtime numbers of the original paper. Regarding memory: yes at the moment without the custom kernel the memory scaling is slightly worse for Hyena w.r.t FlashAttention, though they are both linear. Doing a bit more recomputation on the backward pass helps, we're working on these optimizations. |
I was just about to run a benchmark I wrote when you posted your comment :) I modified your script to import from Full output:
That said, the speed difference at 64k is still "only" 7.3x vs the 100x from the paper. Any thoughts on what could be causing that? Thanks again! |
Awesome! It's all a game of batch sizes, try running at batch sizes 16, 32 and 64 and you should see the speedup get larger. |
Hmm that doesn't seem to work. At a batch size of 16 and sequence length of 32k, FlashAttention takes 3.48 times longer than Hyena (see details below). Thoughts:
As far as I can tell, there isn't really a good reason to use a large batch size here vs a batch size of 1 + gradient accumulation. Any ideas on what's going on? Details:GPU:
Hyena implementation permalink Batch size of 1Bechmark script with batch size of 1:
Batch size of 16Benchmark script with batch size of 16 (limited to seq len of 32k because Hyena runs out of memory at larger seq lengths):
|
Something else I noticed is that the paper says "Hyena speedups reach 100x at sequence length 64K" and references Figure 4.3, but if you look at the LaTeX for Figure 4.3, it's actually only an 11.4x difference. I know the paper is still a draft so is the figure (or text) outdated? Or are we interpreting the meaning of "speedup" differently? Thanks! Figure 4.3 from the paper: \addplot [line width=1pt, indianred]
table {%
1024 0.9
2048 1.16
4096 1.47
8192 1.5
16384 2.84
32768 5.41
65536 11.32
};
\addplot [line width=1pt, cornflowerblue]
table {%
1024 0.4
2048 1.25
4096 2.16
8192 6.17
16384 21.74
32768 90.71
};
\addplot [line width=1pt, lightseagreen, dashed]
table {%
1024 0.29
2048 0.3
4096 0.63
8192 2.1
16384 8.33
32768 32.85
65536 129.07
};
Section 4.4 says (emphasis mine):
|
Interesting finds, a few things here:
At saturation (width large enough), take the numbers in the figure and what you see running this benchmark as ground truth, and expect a few more If you plan to run models at
|
Thanks! That makes sense. I think it would be super useful to have a sweep over Thanks again! |
Hello!
In the Hyena paper, section 4.4 says "Hyena speedups reach 100x at sequence length 64K."
The figure referenced by that section (figure 4.3) stops at a sequence length short of 10k and the optimized implementation in this repo appears to be limited to an 8k sequence length.
There are a few other references to a 100x speedup over FlashAttention in the paper (and in blog posts). Are these measured speedups or extrapolated from smaller sequence lengths?
I've experimented with the implementation in standalone_hyena.py but it appears to be ~3x slower than FlashAttention at sequence lengths > 32k tokens.
Do you have an estimate for when the fftconv implementation in this repo will support longer sequence lengths (or a pointer to another Hyena codebase if the speedups in the paper were measured)?
Thanks for the great work!
The text was updated successfully, but these errors were encountered: