Skip to content

Commit

Permalink
use cautious lion for parity task
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 6, 2025
1 parent 80d5653 commit bdd8047
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 12 deletions.
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "x-transformers"
version = "2.0.1"
version = "2.0.2"
description = "X-Transformers"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down Expand Up @@ -34,7 +34,11 @@ Homepage = "https://pypi.org/project/x-transformers/"
Repository = "https://github.com/lucidrains/x-transformers"

[project.optional-dependencies]
examples = ["tqdm", "torchvision"]
examples = [
"lion-pytorch",
"tqdm",
"torchvision"
]

test = [
"pytest",
Expand Down
24 changes: 16 additions & 8 deletions train_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@

# constants

NUM_BATCHES = 100000
BATCH_SIZE = 256
LEARNING_RATE = 3e-4
EVAL_EVERY = 500
TRAIN_MAX_LENGTH = 64

EVAL_LENGTHS = (16, 32, 64, 128, 256, 512)
TRAIN_MAX_LENGTH = EVAL_LENGTHS[-2]

LOSS_THRES_INCREASE_LEN = 1e-3
MEET_CRITERIA_THRES_INCREASE_LEN = 10

HYBRIDIZE_WITH_RNN = True

# rnn for fully resolving state tracking by hybridization
Expand All @@ -28,6 +32,7 @@

decoder_kwargs = dict(
attn_hybrid_fold_axial_dim = 4, # even if recurrence is every 4 tokens, can generalize for parity
attn_hybrid_learned_mix = True,
attn_hybrid_module = GRU(dim, dim_head * heads, batch_first = True)
)

Expand All @@ -48,7 +53,9 @@

# optimizer

adam = optim.Adam(model.parameters(), lr = LEARNING_RATE)
from lion_pytorch.cautious_lion import Lion

optimizer = Lion(model.parameters(), lr = LEARNING_RATE, cautious_factor = 0.1)

# data generator

Expand All @@ -73,7 +80,8 @@ def cycle(length):
train_seq_len = 1
stop_length = EVAL_LENGTHS[-2]

with tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training') as pbar:
with tqdm.tqdm(mininterval = 10., desc = 'training') as pbar:

while train_seq_len < stop_length:
model.train()

Expand All @@ -90,21 +98,21 @@ def cycle(length):
last_loss = loss[:, -1].mean()
loss.mean().backward()

if last_loss.item() < 0.001:
if last_loss.item() < LOSS_THRES_INCREASE_LEN:
meet_criteria += 1
else:
meet_criteria = 0

if meet_criteria >= 10:
if meet_criteria >= MEET_CRITERIA_THRES_INCREASE_LEN:
meet_criteria = 0
train_seq_len += 1
print(f'criteria met, incrementing to {train_seq_len}')

print(f'({train_seq_len})| {i}: {last_loss.item()}')
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

adam.step()
adam.zero_grad()
optimizer.step()
optimizer.zero_grad()

last_step = train_seq_len == stop_length

Expand Down
10 changes: 8 additions & 2 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,6 +1204,7 @@ def __init__(
hybrid_module: Module | None = None,
hybrid_mask_kwarg: str | None = None,
hybrid_fold_axial_dim: int | None = None,
hybrid_learned_mix = False,
one_kv_head = False,
kv_heads = None,
value_dim_head = None,
Expand Down Expand Up @@ -1446,7 +1447,7 @@ def __init__(

if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
hybrid_mix = LinearNoBias(dim, heads)
hybrid_mix = LinearNoBias(dim, heads) if hybrid_learned_mix else None

hybrid_norms = ModuleList([
MultiheadRMSNorm(dim_head, heads = heads),
Expand Down Expand Up @@ -1779,7 +1780,12 @@ def forward(
out = out_norm(out)
hybrid_out = hybrid_out_norm(hybrid_out)

out = 0.5 * (out + hybrid_out)
if exists(self.hybrid_mix):
mix = self.hybrid_mix(x)
mix = rearrange(mix, 'b n h -> b h n 1')
out = out.lerp(hybrid_out, mix.sigmoid())
else:
out = 0.5 * (out + hybrid_out)

# merge heads

Expand Down

0 comments on commit bdd8047

Please sign in to comment.