Skip to content

Commit

Permalink
address #15
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 20, 2023
1 parent 1a09e86 commit 6427670
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 35 deletions.
58 changes: 41 additions & 17 deletions bs_roformer/bs_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
def exists(val):
return val is not None

def default(v, d):
return v if exists(v) else d

def pack_one(t, pattern):
return pack([t], pattern)

Expand Down Expand Up @@ -163,14 +166,29 @@ def forward(self, x):

return torch.stack(outs, dim = -2)

class LinearGLUWithTanH(Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def MLP(
dim_in,
dim_out,
dim_hidden = None,
depth = 1,
activation = nn.SiLU
):
dim_hidden = default(dim_hidden, dim_in)

def forward(self, x):
x, gate = self.proj(x).chunk(2, dim = -1)
return x.tanh() * gate.sigmoid()
net = []
dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)

for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
is_last = ind == (len(dims) - 2)

net.append(nn.Linear(layer_dim_in, layer_dim_out))

if not is_last:
continue

net.append(activation())

return nn.Sequential(*net)

class MaskEstimator(Module):
@beartype
Expand All @@ -189,24 +207,30 @@ def __init__(
for dim_in in dim_inputs:
net = []

for ind in range(depth):
is_first = ind == 0
is_last = ind == (depth - 1)

dim_layer_in = dim if is_first else dim_hidden
dim_layer_out = dim_hidden if not is_last else dim_in
tanh_mlp = nn.Sequential(
MLP(dim, dim_in // 2, dim_hidden = dim_hidden, depth = depth),
nn.Tanh()
)

net.append(LinearGLUWithTanH(dim_layer_in, dim_layer_out))
glu_mlp = nn.Sequential(
MLP(dim, dim_in, dim_hidden = dim_hidden, depth = depth),
nn.GLU(dim = -1)
)

self.to_freqs.append(nn.Sequential(*net))
self.to_freqs.append(ModuleList([tanh_mlp, glu_mlp]))

def forward(self, x):
x = x.unbind(dim = -2)

outs = []

for band_features, to_freq in zip(x, self.to_freqs):
freq_out = to_freq(band_features)
for band_features, (tanh_mlp, glu_mlp) in zip(x, self.to_freqs):
tanh_out = tanh_mlp(band_features)
glu_out = glu_mlp(band_features)

freq_out = torch.stack((tanh_out, glu_out), dim = -1)
freq_out = rearrange(freq_out, '... f c -> ... (f c)')

outs.append(freq_out)

return torch.cat(outs, dim = -1)
Expand Down
58 changes: 41 additions & 17 deletions bs_roformer/mel_band_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
def exists(val):
return val is not None

def default(v, d):
return v if exists(v) else d

def pack_one(t, pattern):
return pack([t], pattern)

Expand Down Expand Up @@ -170,14 +173,29 @@ def forward(self, x):

return torch.stack(outs, dim = -2)

class LinearGLUWithTanH(Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def MLP(
dim_in,
dim_out,
dim_hidden = None,
depth = 1,
activation = nn.SiLU
):
dim_hidden = default(dim_hidden, dim_in)

def forward(self, x):
x, gate = self.proj(x).chunk(2, dim = -1)
return x.tanh() * gate.sigmoid()
net = []
dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)

for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
is_last = ind == (len(dims) - 2)

net.append(nn.Linear(layer_dim_in, layer_dim_out))

if not is_last:
continue

net.append(activation())

return nn.Sequential(*net)

class MaskEstimator(Module):
@beartype
Expand All @@ -196,24 +214,30 @@ def __init__(
for dim_in in dim_inputs:
net = []

for ind in range(depth):
is_first = ind == 0
is_last = ind == (depth - 1)

dim_layer_in = dim if is_first else dim_hidden
dim_layer_out = dim_hidden if not is_last else dim_in
tanh_mlp = nn.Sequential(
MLP(dim, dim_in // 2, dim_hidden = dim_hidden, depth = depth),
nn.Tanh()
)

net.append(LinearGLUWithTanH(dim_layer_in, dim_layer_out))
glu_mlp = nn.Sequential(
MLP(dim, dim_in, dim_hidden = dim_hidden, depth = depth),
nn.GLU(dim = -1)
)

self.to_freqs.append(nn.Sequential(*net))
self.to_freqs.append(ModuleList([tanh_mlp, glu_mlp]))

def forward(self, x):
x = x.unbind(dim = -2)

outs = []

for band_features, to_freq in zip(x, self.to_freqs):
freq_out = to_freq(band_features)
for band_features, (tanh_mlp, glu_mlp) in zip(x, self.to_freqs):
tanh_out = tanh_mlp(band_features)
glu_out = glu_mlp(band_features)

freq_out = torch.stack((tanh_out, glu_out), dim = -1)
freq_out = rearrange(freq_out, '... f c -> ... (f c)')

outs.append(freq_out)

return torch.cat(outs, dim = -1)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'BS-RoFormer',
packages = find_packages(exclude=[]),
version = '0.2.5',
version = '0.2.6',
license='MIT',
description = 'BS-RoFormer - Band-Split Rotary Transformer for SOTA Music Source Separation',
author = 'Phil Wang',
Expand Down

0 comments on commit 6427670

Please sign in to comment.