Skip to content

Commit

Permalink
allow each token to decide how much of input to reinject
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 23, 2025
1 parent b15815e commit 1f7ea12
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.44.5',
version = '1.44.6',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
8 changes: 8 additions & 0 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1736,6 +1736,7 @@ def __init__(
unet_skips = False,
num_residual_streams = 1,
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
learned_reinject_input_gate = False,
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 - further corroboration by https://arxiv.org/abs/2412.15113 (faster emergence of ICL) - looks like this setting may becoming a necessity for every transformer soon
learned_value_residual_mix = True, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned
rel_pos_kwargs: dict = dict(),
Expand Down Expand Up @@ -1993,6 +1994,7 @@ def __init__(

self.reinject_input = reinject_input
self.reinject_input_proj = nn.Linear(dim, dim, bias = False) if reinject_input else None
self.learned_reinject_input_gate = nn.Linear(dim, 1, bias = False) if learned_reinject_input_gate else None

# add the value from the first self attention block to all latter projected self attention values as a residual

Expand Down Expand Up @@ -2224,7 +2226,9 @@ def forward(
layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)

# derived input for reinjection if needed

inp_inject = None

if self.reinject_input:
assert not exists(in_attn_cond)
inp_inject = self.reinject_input_proj(x)
Expand All @@ -2233,6 +2237,10 @@ def forward(
# handle in-attention conditioning, which serves the same purpose of having the network learn the residual
inp_inject = in_attn_cond if in_attn_cond.ndim == 3 else rearrange(in_attn_cond, 'b d -> b 1 d')

if exists(inp_inject) and exists(self.learned_reinject_input_gate):
inp_inject_gate = self.learned_reinject_input_gate(x).sigmoid()
inp_inject = inp_inject * inp_inject_gate

# store all hiddens for skips

skip_hiddens = []
Expand Down

0 comments on commit 1f7ea12

Please sign in to comment.