Skip to content

Commit

Permalink
add value residual learning
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 20, 2024
1 parent aca155d commit 67de472
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 25 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,12 @@ out = model(x)
url = {https://api.semanticscholar.org/CorpusID:235458262}
}
```

```bibtex
@inproceedings{Zhou2024ValueRL,
title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273532030}
}
```
70 changes: 51 additions & 19 deletions bs_roformer/bs_roformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
from functools import partial

import torch
Expand All @@ -7,7 +8,7 @@

from bs_roformer.attend import Attend

from beartype.typing import Tuple, Optional, List, Callable
from beartype.typing import Callable
from beartype import beartype

from rotary_embedding_torch import RotaryEmbedding
Expand Down Expand Up @@ -70,7 +71,8 @@ def __init__(
dim_head = 64,
dropout = 0.,
rotary_embed = None,
flash = True
flash = True,
learned_value_residual_mix = False
):
super().__init__()
self.heads = heads
Expand All @@ -84,18 +86,29 @@ def __init__(
self.norm = RMSNorm(dim)
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)

self.to_value_residual_mix = nn.Linear(dim, heads) if learned_value_residual_mix else None

self.to_gates = nn.Linear(dim, heads)

self.to_out = nn.Sequential(
nn.Linear(dim_inner, dim, bias = False),
nn.Dropout(dropout)
)

def forward(self, x):
def forward(self, x, value_residual = None):
x = self.norm(x)

q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = self.heads)

value_residual = v

if exists(self.to_value_residual_mix):
mix = self.to_value_residual_mix(x)
mix = rearrange(mix, 'b n h -> b h n 1').sigmoid()

assert exists(value_residual)
v = v.lerp(value_residual, mix)

if exists(self.rotary_embed):
q = self.rotary_embed.rotate_queries_or_keys(q)
k = self.rotary_embed.rotate_queries_or_keys(k)
Expand All @@ -106,7 +119,8 @@ def forward(self, x):
out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()

out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

return self.to_out(out), value_residual

class Transformer(Module):
def __init__(
Expand All @@ -121,26 +135,33 @@ def __init__(
ff_mult = 4,
norm_output = True,
rotary_embed = None,
flash_attn = True
flash_attn = True,
add_value_residual = False
):
super().__init__()
self.layers = ModuleList([])

for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_embed = rotary_embed, flash = flash_attn),
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_embed = rotary_embed, flash = flash_attn, learned_value_residual_mix = add_value_residual),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
]))

self.norm = RMSNorm(dim) if norm_output else nn.Identity()

def forward(self, x):
def forward(self, x, value_residual = None):

first_values = None

for attn, ff in self.layers:
x = attn(x) + x
attn_out, next_values = attn(x, value_residual = value_residual)

first_values = default(first_values, next_values)

x = attn_out + x
x = ff(x) + x

return self.norm(x)
return self.norm(x), first_values

# bandsplit module

Expand All @@ -149,7 +170,7 @@ class BandSplit(Module):
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...]
dim_inputs: tuple[int, ...]
):
super().__init__()
self.dim_inputs = dim_inputs
Expand Down Expand Up @@ -202,7 +223,7 @@ class MaskEstimator(Module):
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...],
dim_inputs: tuple[int, ...],
depth,
mlp_expansion_factor = 4
):
Expand Down Expand Up @@ -257,7 +278,7 @@ def __init__(
num_stems = 1,
time_transformer_depth = 2,
freq_transformer_depth = 2,
freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS, # in the paper, they divide into ~60 bands, test with 1 for starters
freqs_per_bands: tuple[int, ...] = DEFAULT_FREQS_PER_BANDS, # in the paper, they divide into ~60 bands, test with 1 for starters
dim_head = 64,
heads = 8,
attn_dropout = 0.,
Expand All @@ -268,10 +289,10 @@ def __init__(
stft_hop_length = 512, # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
stft_win_length = 2048,
stft_normalized = False,
stft_window_fn: Optional[Callable] = None,
stft_window_fn: Callable | None = None,
mask_estimator_depth = 2,
multi_stft_resolution_loss_weight = 1.,
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_resolutions_window_sizes: tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_hop_size = 147,
multi_stft_normalized = False,
multi_stft_window_fn: Callable = torch.hann_window
Expand All @@ -297,10 +318,12 @@ def __init__(
time_rotary_embed = RotaryEmbedding(dim = dim_head)
freq_rotary_embed = RotaryEmbedding(dim = dim_head)

for _ in range(depth):
for layer_index in range(depth):
is_first = layer_index == 0

self.layers.append(nn.ModuleList([
Transformer(depth = time_transformer_depth, rotary_embed = time_rotary_embed, **transformer_kwargs),
Transformer(depth = freq_transformer_depth, rotary_embed = freq_rotary_embed, **transformer_kwargs)
Transformer(depth = time_transformer_depth, rotary_embed = time_rotary_embed, add_value_residual = not is_first, **transformer_kwargs),
Transformer(depth = freq_transformer_depth, rotary_embed = freq_rotary_embed, add_value_residual = not is_first, **transformer_kwargs)
]))

self.final_norm = RMSNorm(dim)
Expand Down Expand Up @@ -391,20 +414,29 @@ def forward(

x = self.band_split(x)

# value residuals

time_v_residual = None
freq_v_residual = None

# axial / hierarchical attention

for time_transformer, freq_transformer in self.layers:

x = rearrange(x, 'b t f d -> b f t d')
x, ps = pack([x], '* t d')

x = time_transformer(x)
x, next_time_v_residual = time_transformer(x, value_residual = time_v_residual)

time_v_residual = default(time_v_residual, next_time_v_residual)

x, = unpack(x, ps, '* t d')
x = rearrange(x, 'b f t d -> b t f d')
x, ps = pack([x], '* f d')

x = freq_transformer(x)
x, next_freq_v_residual = freq_transformer(x, value_residual = freq_v_residual)

freq_v_residual = default(freq_v_residual, next_freq_v_residual)

x, = unpack(x, ps, '* f d')

Expand Down
11 changes: 6 additions & 5 deletions bs_roformer/mel_band_roformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
from functools import partial

import torch
Expand All @@ -7,7 +8,7 @@

from bs_roformer.attend import Attend

from beartype.typing import Tuple, Optional, List, Callable
from beartype.typing import Callable
from beartype import beartype

from rotary_embedding_torch import RotaryEmbedding
Expand Down Expand Up @@ -219,7 +220,7 @@ class BandSplit(Module):
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...]
dim_inputs: tuple[int, ...]
):
super().__init__()
self.dim_inputs = dim_inputs
Expand Down Expand Up @@ -272,7 +273,7 @@ class MaskEstimator(Module):
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...],
dim_inputs: tuple[int, ...],
depth,
mlp_expansion_factor = 4
):
Expand Down Expand Up @@ -330,10 +331,10 @@ def __init__(
stft_hop_length = 512, # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
stft_win_length = 2048,
stft_normalized = False,
stft_window_fn: Optional[Callable] = None,
stft_window_fn: Callable | None = None,
mask_estimator_depth = 1,
multi_stft_resolution_loss_weight = 1.,
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_resolutions_window_sizes: tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_hop_size = 147,
multi_stft_normalized = False,
multi_stft_window_fn: Callable = torch.hann_window,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'BS-RoFormer',
packages = find_packages(exclude=[]),
version = '0.4.1',
version = '0.5.0',
license='MIT',
description = 'BS-RoFormer - Band-Split Rotary Transformer for SOTA Music Source Separation',
author = 'Phil Wang',
Expand Down

0 comments on commit 67de472

Please sign in to comment.