Skip to content

Commit

Permalink
allow for queries, keys, values to be derived from different combinat…
Browse files Browse the repository at this point in the history
…ions of residual streams for self attention
  • Loading branch information
lucidrains committed Jan 31, 2025
1 parent 1f7ea12 commit 3fe9821
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 12 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.44.6',
version = '1.44.8',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
51 changes: 40 additions & 11 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,7 @@ def __init__(
*,
layer_index,
num_residual_streams,
num_input_views = 1,
tanh = True,
**kwargs
):
Expand All @@ -900,13 +901,16 @@ def __init__(

self.static_beta = nn.Parameter(torch.ones(num_residual_streams))

init_alpha0 = torch.zeros((num_residual_streams, 1))
init_alpha0[layer_index % num_residual_streams, 0] = 1.
init_alpha0 = torch.zeros((num_residual_streams, num_input_views))
init_alpha0[layer_index % num_residual_streams, :] = 1.

self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))

self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + num_input_views))
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)

self.num_input_views = num_input_views

self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)

Expand All @@ -928,7 +932,13 @@ def prepare(self, residuals):

mix_h = einsum('... s t, ... s d -> ... t d', alpha, residuals)

branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
views = self.num_input_views

if views == 1:
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
else:
branch_input, residuals = mix_h[..., :views, :], mix_h[..., views:, :]
branch_input = rearrange(branch_input, '... v d -> v ... d')

return branch_input, residuals, dict(beta = beta)

Expand Down Expand Up @@ -1200,6 +1210,7 @@ def __init__(
learned_value_residual_mix = False,
laser = False, # https://arxiv.org/abs/2411.03493v1
laser_softclamp_value = 15.,
qkv_receive_diff_residuals = False,
onnxable = False,
attend_sdp_kwargs: dict = dict(
enable_flash = True,
Expand Down Expand Up @@ -1239,6 +1250,10 @@ def __init__(
assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
self.to_v = LinearNoBias(dim_kv, v_dim) if not shared_kv else None

# whether qkv receives different residual stream combinations from hyper connections

self.qkv_receive_diff_residuals = qkv_receive_diff_residuals

# enhancing gradients to attention through exponentiated values

self.laser = laser
Expand Down Expand Up @@ -1423,14 +1438,21 @@ def forward(
cache: Intermediates | None = None,
value_residual = None
):
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context)
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals

kv_input = default(context, x)
assert not (qkv_receive_diff_residuals and has_context), 'qkv receiving different sequences can only be used for self attention'

q_input = x
k_input = kv_input
v_input = kv_input
r_input = x
if qkv_receive_diff_residuals:
assert not exists(self.to_r)

q_input, k_input, v_input = x
else:
kv_input = default(context, x)

q_input = x
k_input = kv_input
v_input = kv_input
r_input = x

if exists(mem):
k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
Expand Down Expand Up @@ -1735,6 +1757,7 @@ def __init__(
layerscale_init_value = 0.,
unet_skips = False,
num_residual_streams = 1,
qkv_receive_diff_residuals = False,
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
learned_reinject_input_gate = False,
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 - further corroboration by https://arxiv.org/abs/2412.15113 (faster emergence of ICL) - looks like this setting may becoming a necessity for every transformer soon
Expand Down Expand Up @@ -1771,6 +1794,8 @@ def __init__(

assert not (num_residual_streams > 1 and gate_residual)

assert not (num_residual_streams == 1 and qkv_receive_diff_residuals)

# positions related

self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))
Expand Down Expand Up @@ -2020,7 +2045,7 @@ def __init__(

if layer_type == 'a':
self_attn_learned_value_residual = learned_value_residual_mix and not is_first_self_attn
layer = Attention(dim, heads = heads, causal = causal, learned_value_residual_mix = self_attn_learned_value_residual, **attn_kwargs)
layer = Attention(dim, heads = heads, causal = causal, qkv_receive_diff_residuals = qkv_receive_diff_residuals, learned_value_residual_mix = self_attn_learned_value_residual, **attn_kwargs)
is_first_self_attn = False
elif layer_type == 'c':
layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs})
Expand All @@ -2041,6 +2066,10 @@ def __init__(

if num_residual_streams > 1:
residual_fn = partial(HyperConnection, num_residual_streams = num_residual_streams)

if layer_type == 'a' and qkv_receive_diff_residuals:
residual_fn = partial(residual_fn, num_input_views = 3)

elif gate_residual:
residual_fn = GRUGating
else:
Expand Down

0 comments on commit 3fe9821

Please sign in to comment.