From 67de47226810b3d90357fb7933995290d9477c52 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 20 Dec 2024 10:13:49 -0800 Subject: [PATCH] add value residual learning --- README.md | 9 ++++ bs_roformer/bs_roformer.py | 70 +++++++++++++++++++++++--------- bs_roformer/mel_band_roformer.py | 11 ++--- setup.py | 2 +- 4 files changed, 67 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 73bcc0c..33f7e3a 100644 --- a/README.md +++ b/README.md @@ -159,3 +159,12 @@ out = model(x) url = {https://api.semanticscholar.org/CorpusID:235458262} } ``` + +```bibtex +@inproceedings{Zhou2024ValueRL, + title = {Value Residual Learning For Alleviating Attention Concentration In Transformers}, + author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan}, + year = {2024}, + url = {https://api.semanticscholar.org/CorpusID:273532030} +} +``` diff --git a/bs_roformer/bs_roformer.py b/bs_roformer/bs_roformer.py index edbb4a7..916d21c 100644 --- a/bs_roformer/bs_roformer.py +++ b/bs_roformer/bs_roformer.py @@ -1,3 +1,4 @@ +from __future__ import annotations from functools import partial import torch @@ -7,7 +8,7 @@ from bs_roformer.attend import Attend -from beartype.typing import Tuple, Optional, List, Callable +from beartype.typing import Callable from beartype import beartype from rotary_embedding_torch import RotaryEmbedding @@ -70,7 +71,8 @@ def __init__( dim_head = 64, dropout = 0., rotary_embed = None, - flash = True + flash = True, + learned_value_residual_mix = False ): super().__init__() self.heads = heads @@ -84,6 +86,8 @@ def __init__( self.norm = RMSNorm(dim) self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False) + self.to_value_residual_mix = nn.Linear(dim, heads) if learned_value_residual_mix else None + self.to_gates = nn.Linear(dim, heads) self.to_out = nn.Sequential( @@ -91,11 +95,20 @@ def __init__( nn.Dropout(dropout) ) - def forward(self, x): + def forward(self, x, value_residual = None): x = self.norm(x) q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = self.heads) + value_residual = v + + if exists(self.to_value_residual_mix): + mix = self.to_value_residual_mix(x) + mix = rearrange(mix, 'b n h -> b h n 1').sigmoid() + + assert exists(value_residual) + v = v.lerp(value_residual, mix) + if exists(self.rotary_embed): q = self.rotary_embed.rotate_queries_or_keys(q) k = self.rotary_embed.rotate_queries_or_keys(k) @@ -106,7 +119,8 @@ def forward(self, x): out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid() out = rearrange(out, 'b h n d -> b n (h d)') - return self.to_out(out) + + return self.to_out(out), value_residual class Transformer(Module): def __init__( @@ -121,26 +135,33 @@ def __init__( ff_mult = 4, norm_output = True, rotary_embed = None, - flash_attn = True + flash_attn = True, + add_value_residual = False ): super().__init__() self.layers = ModuleList([]) for _ in range(depth): self.layers.append(ModuleList([ - Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_embed = rotary_embed, flash = flash_attn), + Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_embed = rotary_embed, flash = flash_attn, learned_value_residual_mix = add_value_residual), FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) ])) self.norm = RMSNorm(dim) if norm_output else nn.Identity() - def forward(self, x): + def forward(self, x, value_residual = None): + + first_values = None for attn, ff in self.layers: - x = attn(x) + x + attn_out, next_values = attn(x, value_residual = value_residual) + + first_values = default(first_values, next_values) + + x = attn_out + x x = ff(x) + x - return self.norm(x) + return self.norm(x), first_values # bandsplit module @@ -149,7 +170,7 @@ class BandSplit(Module): def __init__( self, dim, - dim_inputs: Tuple[int, ...] + dim_inputs: tuple[int, ...] ): super().__init__() self.dim_inputs = dim_inputs @@ -202,7 +223,7 @@ class MaskEstimator(Module): def __init__( self, dim, - dim_inputs: Tuple[int, ...], + dim_inputs: tuple[int, ...], depth, mlp_expansion_factor = 4 ): @@ -257,7 +278,7 @@ def __init__( num_stems = 1, time_transformer_depth = 2, freq_transformer_depth = 2, - freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS, # in the paper, they divide into ~60 bands, test with 1 for starters + freqs_per_bands: tuple[int, ...] = DEFAULT_FREQS_PER_BANDS, # in the paper, they divide into ~60 bands, test with 1 for starters dim_head = 64, heads = 8, attn_dropout = 0., @@ -268,10 +289,10 @@ def __init__( stft_hop_length = 512, # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction stft_win_length = 2048, stft_normalized = False, - stft_window_fn: Optional[Callable] = None, + stft_window_fn: Callable | None = None, mask_estimator_depth = 2, multi_stft_resolution_loss_weight = 1., - multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256), + multi_stft_resolutions_window_sizes: tuple[int, ...] = (4096, 2048, 1024, 512, 256), multi_stft_hop_size = 147, multi_stft_normalized = False, multi_stft_window_fn: Callable = torch.hann_window @@ -297,10 +318,12 @@ def __init__( time_rotary_embed = RotaryEmbedding(dim = dim_head) freq_rotary_embed = RotaryEmbedding(dim = dim_head) - for _ in range(depth): + for layer_index in range(depth): + is_first = layer_index == 0 + self.layers.append(nn.ModuleList([ - Transformer(depth = time_transformer_depth, rotary_embed = time_rotary_embed, **transformer_kwargs), - Transformer(depth = freq_transformer_depth, rotary_embed = freq_rotary_embed, **transformer_kwargs) + Transformer(depth = time_transformer_depth, rotary_embed = time_rotary_embed, add_value_residual = not is_first, **transformer_kwargs), + Transformer(depth = freq_transformer_depth, rotary_embed = freq_rotary_embed, add_value_residual = not is_first, **transformer_kwargs) ])) self.final_norm = RMSNorm(dim) @@ -391,6 +414,11 @@ def forward( x = self.band_split(x) + # value residuals + + time_v_residual = None + freq_v_residual = None + # axial / hierarchical attention for time_transformer, freq_transformer in self.layers: @@ -398,13 +426,17 @@ def forward( x = rearrange(x, 'b t f d -> b f t d') x, ps = pack([x], '* t d') - x = time_transformer(x) + x, next_time_v_residual = time_transformer(x, value_residual = time_v_residual) + + time_v_residual = default(time_v_residual, next_time_v_residual) x, = unpack(x, ps, '* t d') x = rearrange(x, 'b f t d -> b t f d') x, ps = pack([x], '* f d') - x = freq_transformer(x) + x, next_freq_v_residual = freq_transformer(x, value_residual = freq_v_residual) + + freq_v_residual = default(freq_v_residual, next_freq_v_residual) x, = unpack(x, ps, '* f d') diff --git a/bs_roformer/mel_band_roformer.py b/bs_roformer/mel_band_roformer.py index 9062412..3d6c55f 100644 --- a/bs_roformer/mel_band_roformer.py +++ b/bs_roformer/mel_band_roformer.py @@ -1,3 +1,4 @@ +from __future__ import annotations from functools import partial import torch @@ -7,7 +8,7 @@ from bs_roformer.attend import Attend -from beartype.typing import Tuple, Optional, List, Callable +from beartype.typing import Callable from beartype import beartype from rotary_embedding_torch import RotaryEmbedding @@ -219,7 +220,7 @@ class BandSplit(Module): def __init__( self, dim, - dim_inputs: Tuple[int, ...] + dim_inputs: tuple[int, ...] ): super().__init__() self.dim_inputs = dim_inputs @@ -272,7 +273,7 @@ class MaskEstimator(Module): def __init__( self, dim, - dim_inputs: Tuple[int, ...], + dim_inputs: tuple[int, ...], depth, mlp_expansion_factor = 4 ): @@ -330,10 +331,10 @@ def __init__( stft_hop_length = 512, # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction stft_win_length = 2048, stft_normalized = False, - stft_window_fn: Optional[Callable] = None, + stft_window_fn: Callable | None = None, mask_estimator_depth = 1, multi_stft_resolution_loss_weight = 1., - multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256), + multi_stft_resolutions_window_sizes: tuple[int, ...] = (4096, 2048, 1024, 512, 256), multi_stft_hop_size = 147, multi_stft_normalized = False, multi_stft_window_fn: Callable = torch.hann_window, diff --git a/setup.py b/setup.py index 21b4ed7..3b393e3 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'BS-RoFormer', packages = find_packages(exclude=[]), - version = '0.4.1', + version = '0.5.0', license='MIT', description = 'BS-RoFormer - Band-Split Rotary Transformer for SOTA Music Source Separation', author = 'Phil Wang',