Skip to content

Commit

Permalink
fix mlp yet again
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 20, 2023
1 parent 3e67f27 commit 004ea27
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 33 deletions.
22 changes: 6 additions & 16 deletions bs_roformer/bs_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def MLP(
dim_out,
dim_hidden = None,
depth = 1,
activation = nn.SiLU
activation = nn.Tanh
):
dim_hidden = default(dim_hidden, dim_in)

Expand Down Expand Up @@ -207,30 +207,20 @@ def __init__(
for dim_in in dim_inputs:
net = []

tanh_mlp = nn.Sequential(
MLP(dim, dim_in // 2, dim_hidden = dim_hidden, depth = depth),
nn.Tanh()
)

glu_mlp = nn.Sequential(
MLP(dim, dim_in, dim_hidden = dim_hidden, depth = depth),
mlp = nn.Sequential(
MLP(dim, dim_in * 2, dim_hidden = dim_hidden, depth = depth),
nn.GLU(dim = -1)
)

self.to_freqs.append(ModuleList([tanh_mlp, glu_mlp]))
self.to_freqs.append(mlp)

def forward(self, x):
x = x.unbind(dim = -2)

outs = []

for band_features, (tanh_mlp, glu_mlp) in zip(x, self.to_freqs):
tanh_out = tanh_mlp(band_features)
glu_out = glu_mlp(band_features)

freq_out = torch.stack((tanh_out, glu_out), dim = -1)
freq_out = rearrange(freq_out, '... f c -> ... (f c)')

for band_features, mlp in zip(x, self.to_freqs):
freq_out = mlp(band_features)
outs.append(freq_out)

return torch.cat(outs, dim = -1)
Expand Down
22 changes: 6 additions & 16 deletions bs_roformer/mel_band_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def MLP(
dim_out,
dim_hidden = None,
depth = 1,
activation = nn.SiLU
activation = nn.Tanh
):
dim_hidden = default(dim_hidden, dim_in)

Expand Down Expand Up @@ -214,30 +214,20 @@ def __init__(
for dim_in in dim_inputs:
net = []

tanh_mlp = nn.Sequential(
MLP(dim, dim_in // 2, dim_hidden = dim_hidden, depth = depth),
nn.Tanh()
)

glu_mlp = nn.Sequential(
MLP(dim, dim_in, dim_hidden = dim_hidden, depth = depth),
mlp = nn.Sequential(
MLP(dim, dim_in * 2, dim_hidden = dim_hidden, depth = depth),
nn.GLU(dim = -1)
)

self.to_freqs.append(ModuleList([tanh_mlp, glu_mlp]))
self.to_freqs.append(mlp)

def forward(self, x):
x = x.unbind(dim = -2)

outs = []

for band_features, (tanh_mlp, glu_mlp) in zip(x, self.to_freqs):
tanh_out = tanh_mlp(band_features)
glu_out = glu_mlp(band_features)

freq_out = torch.stack((tanh_out, glu_out), dim = -1)
freq_out = rearrange(freq_out, '... f c -> ... (f c)')

for band_features, mlp in zip(x, self.to_freqs):
freq_out = mlp(band_features)
outs.append(freq_out)

return torch.cat(outs, dim = -1)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'BS-RoFormer',
packages = find_packages(exclude=[]),
version = '0.2.7',
version = '0.2.8',
license='MIT',
description = 'BS-RoFormer - Band-Split Rotary Transformer for SOTA Music Source Separation',
author = 'Phil Wang',
Expand Down

0 comments on commit 004ea27

Please sign in to comment.