Skip to content

Commit

Permalink
attention should be able to attend to nothing
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 20, 2023
1 parent 75ebb06 commit 65f3178
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 1 deletion.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,14 @@ out = model(x)
year = {2022}
}
```

```bibtex
@article{Bondarenko2023QuantizableTR,
title = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing},
author = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort},
journal = {ArXiv},
year = {2023},
volume = {abs/2306.12929},
url = {https://api.semanticscholar.org/CorpusID:259224568}
}
```
5 changes: 5 additions & 0 deletions bs_roformer/bs_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def __init__(
self.norm = RMSNorm(dim)
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)

self.to_gates = nn.Linear(dim, heads)

self.to_out = nn.Sequential(
nn.Linear(dim_inner, dim, bias = False),
nn.Dropout(dropout)
Expand All @@ -98,6 +100,9 @@ def forward(self, x):

out = self.attend(q, k, v)

gates = self.to_gates(x)
out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()

out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

Expand Down
5 changes: 5 additions & 0 deletions bs_roformer/mel_band_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def __init__(
self.norm = RMSNorm(dim)
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)

self.to_gates = nn.Linear(dim, heads)

self.to_out = nn.Sequential(
nn.Linear(dim_inner, dim, bias = False),
nn.Dropout(dropout)
Expand All @@ -105,6 +107,9 @@ def forward(self, x):

out = self.attend(q, k, v)

gates = self.to_gates(x)
out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()

out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'BS-RoFormer',
packages = find_packages(exclude=[]),
version = '0.2.8',
version = '0.3.0',
license='MIT',
description = 'BS-RoFormer - Band-Split Rotary Transformer for SOTA Music Source Separation',
author = 'Phil Wang',
Expand Down

0 comments on commit 65f3178

Please sign in to comment.