Skip to content

Commit

Permalink
just use device off the raw audio tensor passed in
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 22, 2023
1 parent 7aebf44 commit 93a07dd
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 14 deletions.
10 changes: 4 additions & 6 deletions bs_roformer/bs_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,10 +349,6 @@ def __init__(
normalized = multi_stft_normalized
)

@property
def device(self):
return next(self.parameters()).device

def forward(
self,
raw_audio,
Expand All @@ -371,6 +367,8 @@ def forward(
d - feature dimension
"""

device = raw_audio.device

if raw_audio.ndim == 2:
raw_audio = rearrange(raw_audio, 'b t -> b 1 t')

Expand All @@ -381,7 +379,7 @@ def forward(

raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')

stft_window = self.stft_window_fn(device = self.device)
stft_window = self.stft_window_fn(device = device)

stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window = stft_window, return_complex = True)
stft_repr = torch.view_as_real(stft_repr)
Expand Down Expand Up @@ -462,7 +460,7 @@ def forward(
n_fft = max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
win_length = window_size,
return_complex = True,
window = self.multi_stft_window_fn(window_size, device = self.device),
window = self.multi_stft_window_fn(window_size, device = device),
**self.multi_stft_kwargs,
)

Expand Down
12 changes: 5 additions & 7 deletions bs_roformer/mel_band_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,6 @@ def __init__(

self.match_input_audio_length = match_input_audio_length

@property
def device(self):
return next(self.parameters()).device

def forward(
self,
raw_audio,
Expand All @@ -405,6 +401,8 @@ def forward(
d - feature dimension
"""

device = raw_audio.device

if raw_audio.ndim == 2:
raw_audio = rearrange(raw_audio, 'b t -> b 1 t')

Expand All @@ -418,7 +416,7 @@ def forward(

raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')

stft_window = self.stft_window_fn(device = self.device)
stft_window = self.stft_window_fn(device = device)

stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window = stft_window, return_complex = True)
stft_repr = torch.view_as_real(stft_repr)
Expand All @@ -428,7 +426,7 @@ def forward(

# index out all frequencies for all frequency ranges across bands ascending in one go

batch_arange = torch.arange(batch, device = self.device)[..., None]
batch_arange = torch.arange(batch, device = device)[..., None]

# account for stereo

Expand Down Expand Up @@ -522,7 +520,7 @@ def forward(
n_fft = max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
win_length = window_size,
return_complex = True,
window = self.multi_stft_window_fn(window_size, device = self.device),
window = self.multi_stft_window_fn(window_size, device = device),
**self.multi_stft_kwargs,
)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'BS-RoFormer',
packages = find_packages(exclude=[]),
version = '0.3.8',
version = '0.3.9',
license='MIT',
description = 'BS-RoFormer - Band-Split Rotary Transformer for SOTA Music Source Separation',
author = 'Phil Wang',
Expand Down

0 comments on commit 93a07dd

Please sign in to comment.