Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha Rush committed Jan 6, 2024
1 parent c0ed723 commit 941567e
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 100 deletions.
35 changes: 35 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: debug-statements
- id: double-quote-string-fixer
- id: name-tests-test
- id: requirements-txt-fixer
- repo: https://github.com/asottile/setup-cfg-fmt
rev: v2.5.0
hooks:
- id: setup-cfg-fmt
- repo: https://github.com/asottile/reorder-python-imports
rev: v3.12.0
hooks:
- id: reorder-python-imports
exclude: ^(pre_commit/resources/|testing/resources/python3_hooks_repo/)
args: [--py39-plus, --add-import, 'from __future__ import annotations']
- repo: https://github.com/asottile/add-trailing-comma
rev: v3.1.0
hooks:
- id: add-trailing-comma
- repo: https://github.com/PyCQA/flake8
rev: 6.1.0
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
hooks:
- id: mypy
additional_dependencies: [types-all]
exclude: ^testing/resources/
33 changes: 33 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
SHELL := /bin/bash
.PHONY: help check autoformat notebook html clean
.DEFAULT: help

# Generates a useful overview/help message for various make features
help:
@echo "make check"
@echo " Run code style and linting (black, flake, isort) *without* changing files!"
@echo "make autoformat"
@echo " Run code styling (black, isort) and update in place - committing with pre-commit also does this."
@echo "make notebook"
@echo " Use jupytext-light to build a notebook (.ipynb) from the s4/s4.py script."
@echo "make html"
@echo " Use jupyter & jupytext to do the two-step conversion from the python script, to the HTML blog post."
@echo "make clean"
@echo " Delete the generated, top-level s4.ipynb notebook."


notebook: mamba.py
jupytext --to notebook mamba.py -o mamba.ipynb

html: mamba.py
jupytext --to notebook mamba.py -o mamba.ipynb
jupyter nbconvert --to html mamba.ipynb

md: mamba.py
jupytext --to markdown mamba.py

blog: md
pandoc docs/header-includes.yaml mamba.md --katex=/usr/local/lib/node_modules/katex/dist/ --output=docs/index.html --to=html5 --css=docs/github.min.css --css=docs/tufte.css --no-highlight --self-contained --metadata pagetitle="The Annotated Mamba"

clean: mamba.ipynb
rm -f mamba.ipynb
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Eventual version of the annotated mamba paper. Might take some time ...
Eventual version of the annotated mamba paper. Might take some time ...

Mamba: Linear-Time Sequence Modeling with Selective State Spaces
https://arxiv.org/abs/2312.00752
203 changes: 104 additions & 99 deletions mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,8 @@
# <p>Albert Gu and Tri Dao.</p>
# </center>
# <img src="mamba.png" width="100%"/>


# *Blog Post and [Library](https://github.com/srush/annotated-mamba/) by [Sasha Rush](http://rush-nlp.com/)

# *Blog Post and [Library](https://github.com/srush/annotated-mamba/) by [Sasha Rush](http://rush-nlp.com/)
# ## Table of Contents


# * [Part 1: Time-Varying State Space Models] (Modeling)
# - [Discrete-time SSM: The Recurrent Representation]
# - [Tangent: A Mechanics Example]
Expand All @@ -36,68 +31,76 @@
# - [Experiments: QuickDraw]
# - [Experiments: Spoken Digits]
# * [Conclusion]


# <nav id="TOC">
from __future__ import annotations

from dataclasses import dataclass
from typing import Tuple

import torch
from torch import Tensor
import torch.nn as nn


from jaxtyping import Float
from torch import arange
from torch import random
from torch import Tensor as T

# ## Part 1: State Space Models



# > The [state space model](https://en.wikipedia.org/wiki/State-space_representation) is defined by this simple equation.
# > It maps a 1-D input signal $u(t)$ to an $N$-D latent state $x(t)$
# > before projecting to a 1-D output signal $y(t)$.
# $$
# \begin{aligned}
# h'(t) &= \boldsymbol{A}h(t) + \boldsymbol{B}x(t) \\
# y(t) &= \boldsymbol{C}h(t)
# y(t) &= \boldsymbol{C}h(t)
# \end{aligned}
# $$

A_ = Float[T, '#B #L #D N 1']
B_ = Float[T, '#B #L #D N 1']
C_ = Float[T, '#B #L #D 1 N']
SSM_ = tuple[A_, B_, C_]


# Shape check.

@dataclass
class SSM:
A: Tensor
B: Tensor
C: Tensor
delta: Tensor
def unpack(self):
return self.A, self.B, self.C
def SSM(A: A_, B: B_, C: C_) -> SSM_:
return (A, B, C)

# Shape


@dataclass
class DiscreteSSM:
A_bar: Tensor
B_bar: Tensor
C: Tensor
def unpack(self):
return self.A_bar, self.B_bar, self.C
def random_ssm(N: int, L: int = 1) -> SSM_:
A, B, C = random(1, 1, 1, N), random(1, 1, 1, N, 1), random(1, 1, 1, 1, N)
return SSM(A, B, C)


# Same type

# $$
# \begin{aligned}
# h_t &= \boldsymbol{\overline{A}}h_t + \boldsymbol{\overline{B}}x_t \\
# y_t &= \boldsymbol{C}h_t
# y_t &= \boldsymbol{C}h_t
# \end{aligned}
# $$


def ssm_rnn(ssm: DiscreteSSM, x: Tensor):
A_bar, B_bar, C = ssm.unpack()
h_t_1 = torch.zeros_like(C)
for x_t in x:
h_t = A_bar @ h_t_1 + B_bar @ x_t
y_t = C @ h
yield y_t
h_t_1 = h_t
X_ = Float[T, 'B L D']
Y_ = Float[T, 'B L D']
H_ = Float[T, 'B D N 1']


def ssm_rnn(A_bar: A_, B_bar: B_, C: C_, x: X_) -> Y_:
B, L, D = x.shape
N = A_bar.shape[-1]
ys = []
h_l_1: H_ = torch.zeros(B, D, N, 1)
for l, x_l in enumerate(x.unbind(-1)):
h_l: H_ = A_bar[:, l] * h_l_1 + B_bar[:, l] @ x_l[..., None, None]
y_l: Float[T, 'B D'] = C[:, l] @ h_l
ys.append(y_l[..., 0, 0])
h_l_1 = h_l
return torch.stack(ys)
#
# $$
# \begin{aligned}
Expand All @@ -109,17 +112,22 @@ def ssm_rnn(ssm: DiscreteSSM, x: Tensor):
# $$


Delta_ = Float[T, '#B #L #D 1 1']
# $$
# \begin{aligned}
# \boldsymbol{\overline{A}} = \exp(\Delta \boldsymbol{A}) \\
# \boldsymbol{\overline{B}} = (\Delta \boldsymbol{A})^{-1} (\boldsymbol{\overline{A}} - \boldsymbol{A}) \Delta \boldsymbol{B}
# \end{aligned}
# $$
def discretize_zoh(ssm: SSM) -> DiscreteSSM:
A, B, C, delta = ssm.unpack()
A_bar = torch.exp(delta * A)
B_bar = torch.inverse(delta * A) @ (A_bar - A) @ delta * B
return DiscreteSSM(A_bar, B_bar, C)


def discretize_zoh(A: A_, B: B_, C: C_, delta: Delta_) -> SSM:
dA: A_ = delta * A
A_bar: A_ = torch.exp(dA)
# A is diagonal
dA_inv = 1.0 / dA
B_bar: B_ = (dA_inv * (A_bar - A) * delta) * B
return SSM(A_bar, B_bar, C)


# $$
Expand All @@ -128,95 +136,92 @@ def discretize_zoh(ssm: SSM) -> DiscreteSSM:
# \end{aligned}
# $$


# Structured State Space.
# D x N S4D matrix.

class StructuredMatrix:
def __init__(self, rep: Tensor):
self._rep = rep

def s4d_real(shape)
N = shape[-1]
# fix up
return -(torch.arange(N) + 1).view(N).reshape(shape)

def __matmul__(self, other: Tensor):
# Diagonal mult
return None

# ## Selective SSM

# $$
# \begin{aligned}
# h'(t) &= \boldsymbol{A}h(t) + \boldsymbol{B}(x) x(t) \\
# y(t) &= \boldsymbol{C}(x) h(t)
# y(t) &= \boldsymbol{C}(x) h(t)
# \end{aligned}
# $$


class SelectiveSSM:
def __init__(self, D, N):
self.D = D
self.N = N
self.A = torch.Parameter()
self.s_B = torch.nn.Linear(D, N)
self.s_C = torch.nn.Linear(D, N)
self.s_Delta = torch.linear(D, 1)


class Scan:
def scan(ssm: SSM_, delta, Delta_, x: X_) -> Y_:
selective_ssm = discretize_zoh(*ssm, delta)
return ssm_rnn(*selective_ssm, x)


class SelectiveStructuredSSM(nn.Module):
def __init__(self, D, N, scanner: Scan):
init_A: Float[T, '1 1 D N 1'] = - \
(arange(N) + 1)[None, None, None, :, None].repeat(2, D)
self.A: A_ = torch.Parameter(init_A)
self.s_B, self.s_C = nn.Linear(D, N), nn.Linear(D, N)
self.s_Delta = nn.Linear(D, 1)
self.p_Delta = torch.Parameter(torch.Tensor(D))
self.scaner = scanner

def forward(self, x: Tensor):
BATCH, L, D = = x.shape
assert D == self.D
B = self.s_B(x) # B, L, N
C = self.s_C(x) # B, L, N
Delta = torch.nn.softplus(self.s_Delta(x) + self.p_Delta) # B, L, D
return SSM(self.A, B, C, Delta)
def forward(self, x: X_) -> Y_:
B: Float[T, 'B L 1 N 1'] = self.s_B(x)[..., None, :, None]
C: Float[T, 'B L 1 1 N'] = self.s_C(x)[..., None, None, :]
ssm: SSM_ = SSM(self.A, B, C)
Delta: Float[T, 'B L D 1 1'] = nn.softplus(self.s_Delta(x) + self.p_Delta)[..., None, None]
return self.scanner.scan(ssm, Delta, x)


# $$
# \begin{aligned}
# (\Delta(x), \boldsymbol{A}, \boldsymbol{B}(x), \boldsymbol{C}(x)) \mapsto (\boldsymbol{\overline{A}}(x), \boldsymbol{\overline{B}}(x), \boldsymbol{C}(x))
# \end{aligned}
# $$
# Full Selective
# Full Selective

class S6:
def __init__(self, N: int, D:int) -> None:
self.selective_ssm = SelectiveSSM(N, D)

def forward(self, x):
selective_ssm: SelectiveSSM = self.selective_ssm(x)
discrete_ssm = discretize_zoh(selective_ssm)
y = discrete_ssm(x)
return y

# ## Mamba Architecture

#
#

# ![](images/arch.png)

class Mamba():
class Mamba(nn.Module):
def __init__(self, N, D):
D_2 = D / 2
self.s6 = S6(N, D)
self.s6 = SelectiveStructuredSSM(N, D)
D_2 = D // 2
self.p_up1 = nn.Linear(D_2, D)
self.p_up2 = nn.Linear(D_2, D)
self.p_down = nn.Linear(D, D_2)
self.conv = nn.Conv1d()

def forward(self, x):
sigma = torch.relu
x1 = self.p_up1(x)
x1 = sigma(self.conv(x1))
x1 = torch.relu(self.conv(x1))
x1 = self.s6(x1)
x2 = self.p_up2(x)
return self.p_down(x1 * sigma(x2))
return self.p_down(x1 * torch.relu(x2))

# ## Efficient Implementation

# Mamba choices


def op(d1: tuple[A_, B_], d2: tuple[A_, B_]):
A1, b1 = tuple(d1)
A2, b2 = tuple(d2)
return (A1 * A2, A1 * B_)


def pscan(op, inp):
if inp.shape[0] == 1:
return inp
return pscan(
op, op(
[i[:, 0::2] for i in inp],
[i[:, 1::2] for i in inp],
),
)


class Scan2:
def scan(ssm: SSM_, delta, Delta_, x: X_) -> Y_:
pass
# selective_ssm = discretize_zoh(*ssm, delta)

# return ssm_rnn(*selective_ssm, x)

0 comments on commit 941567e

Please sign in to comment.