Skip to content

Commit

Permalink
address #24 with better defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 18, 2023
1 parent 98d15d2 commit 46ff976
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
4 changes: 3 additions & 1 deletion bs_roformer/bs_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def __init__(
stft_hop_length = 512, # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
stft_win_length = 2048,
stft_normalized = False,
stft_window = None,
mask_estimator_depth = 2,
multi_stft_resolution_loss_weight = 1.,
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
Expand Down Expand Up @@ -305,7 +306,8 @@ def __init__(
n_fft = stft_n_fft,
hop_length = stft_hop_length,
win_length = stft_win_length,
normalized = stft_normalized
normalized = stft_normalized,
window = default(stft_window, torch.hann_window(stft_win_length))
)

freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex = True).shape[1]
Expand Down
6 changes: 4 additions & 2 deletions bs_roformer/mel_band_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ def __init__(
stft_hop_length = 512, # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
stft_win_length = 2048,
stft_normalized = False,
mask_estimator_depth = 1, # Number of hidden layers in each of the mask estimator MLPs
stft_window = None,
mask_estimator_depth = 1,
multi_stft_resolution_loss_weight = 1.,
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_hop_size = 147,
Expand Down Expand Up @@ -300,7 +301,8 @@ def __init__(
n_fft = stft_n_fft,
hop_length = stft_hop_length,
win_length = stft_win_length,
normalized = stft_normalized
normalized = stft_normalized,
window = default(stft_window, torch.hann_window(stft_win_length))
)

freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex = True).shape[1]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'BS-RoFormer',
packages = find_packages(exclude=[]),
version = '0.3.5',
version = '0.3.6',
license='MIT',
description = 'BS-RoFormer - Band-Split Rotary Transformer for SOTA Music Source Separation',
author = 'Phil Wang',
Expand Down

0 comments on commit 46ff976

Please sign in to comment.