Use the command python main.py --device=cuda --dataset=tiny_stories.py
. Set --dataset=shakespeare_char
for the character-level Shakespeare dataset.
This is a project which replaces attention in a traditional GPT2-based transformer with my idea, the linear-complexity matrix recurrent unit (MRU). This repo is forked from my repo transformer-train-script. Based on testing on the shakespeare_char toy dataset, the MRU seems to work well as a replacement for attention.
The above loss plot is the first train attempt, using the independent-heads branch of this repo and my other repo https://github.com/mikayahlevi/transformer-train-script.
I have limited compute and experience with datascience, so I haven't been able to test the LM on much other than the toy dataset. Firstly, I would like to test this on larger and more informative datasets. If anyone wants to help me with this, reach out to me at [email protected] or any other means. Secondly, the MRU is still relatively slow compared to the theoretical amount of operations it should take, so I would like to investigate writing a CUDA kernel or just trying to optimize the PyTorch code.
The idea of a matrix recurrent unit is dictated by the update rule
- Matrix multiplication is associative but not commutative. The associativity means I can compute the cumulative matrix product using an (inclusive) parallel scan. The lack of commutativity means that the order of tokens is automatically incorporated into the MRU.
- When you try to do this scan on an traditional RNN, the number of operations scales cubically with the amount of elements in the output state, meaning that limited information is retained compared to the amount of computation. On the other hand, if the states are matrices, the number of operations as a function of elements in the output state is
$((d_o)^2)^\frac{3}{2}$ , where$(d_o)^2$ is the number of elements in the square$d_o \times d_o$ output matrix state. Some more info here: https://arxiv.org/abs/1709.04057. - When processing the tokens sequentially or in parallel with the Brent-Kung parallel scan, the network scales linearly with time in contrast to attention which scales quadratically with time.
For the rest of this document, let's call the sequence length
The number of operations for the MRU itself in is:
- Using recurrence
- Using the Brent-Kung scan
- Using the Hillis-Steel scan
The parallel scans take more computation, but they have the advantage of using parallel hardware more effeciently. While an RNN would take
The MRU should take in a sequence of vectors and return a sequence of vector, like any other traditional operation in a neural network. For now I'll be ignoring the batch and sequence dimensions and only focus on the last dimension.
Therefore,
After finishing this project, I've been informed that this project actually has quite a bit of overlap with DeltaNet (https://arxiv.org/abs/2102.11174) and RWKV7 (https://x.com/BlinkDL_AI/status/1833863117480280528). Note that I may misunderstand these other projects. The recurrence relation of RWKV7 and DeltaNet is almost a subset of the MRU with additional structure on
- RWKV7 and DeltaNet don't derive an effecient scan like I do in the next section. The paper Parallelizing Linear Transformers with the Delta Rule over Sequence Length (https://arxiv.org/pdf/2406.06484) does derive a less parallel (if I'm not mistaken) chunkwise form, though.
- The MRU deconstructs the states to extract one output feature per state matrix element by reshaping it. DeltaNet and RWKV, on the other hand, only extract the square root of the number of elements per state matrix by using the matrices as weight for a linear, leading to orders of magitude more computation for an equivalent number of features.
For the MRU, I've derived an effecient algorithm using a parallel scan to compute it. Sorry for my most likely incorrect mathematical notation. I am not well versed in the math fields that this computation involves. Note that the
The forward pass can be computed using a parallel scan.
The backwards pass for the MRU way more complicated.
The gradient of
The expanded gradient of
If we define
I'll call the second part of the gradient a new variable,
You can see
If we let
and
then we can express the equation with
Which can be computed with a reverse parallel scan because matrix multiplication is associative.
Combining all of this, we get the final gradient for the input matrices,