-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtorch_colab.py
176 lines (142 loc) · 8.29 KB
/
torch_colab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import argparse
import torch
import torchaudio
import os
import gc
import glob
def frequency_blend_phases(phase1, phase2, freq_bins, low_cutoff=420, high_cutoff=4200, base_factor=0.25, scale_factor=1.85):
"""
Blend two phase arrays with different weights depending on frequency.
Parameters:
phase1: Tensor of shape (frequency_bins, time_frames) - First phase matrix.
phase2: Tensor of shape (frequency_bins, time_frames) - Second phase matrix.
freq_bins: Tensor of shape (frequency_bins,) - Frequencies corresponding to bins.
low_cutoff: int - Frequency below which blend_factor is base_factor.
high_cutoff: int - Frequency above which blend_factor is base_factor + scale_factor.
base_factor: float - The starting blend factor for low frequencies.
scale_factor: float - The difference in blend factor between low and high frequencies.
Returns:
blended_phase: Tensor of shape (frequency_bins, time_frames).
"""
# Validate input dimensions
if phase1.shape != phase2.shape:
raise ValueError("phase1 and phase2 must have the same shape.")
if len(freq_bins) != phase1.shape[0]:
raise ValueError("freq_bins must have the same length as the number of frequency bins in phase1 and phase2.")
if low_cutoff >= high_cutoff:
raise ValueError("low_cutoff must be less than high_cutoff.")
# Initialize blended phase
blended_phase = torch.zeros_like(phase1)
# Compute blend factors for all frequencies
blend_factors = torch.zeros_like(freq_bins)
# Below low_cutoff: blend factor is base_factor
blend_factors[freq_bins < low_cutoff] = base_factor
# Above high_cutoff: blend factor is base_factor + scale_factor
blend_factors[freq_bins > high_cutoff] = base_factor + scale_factor
# Between low_cutoff and high_cutoff: interpolate linearly
in_range_mask = (freq_bins >= low_cutoff) & (freq_bins <= high_cutoff)
blend_factors[in_range_mask] = base_factor + scale_factor * (
(freq_bins[in_range_mask] - low_cutoff) / (high_cutoff - low_cutoff)
)
# Apply blend factors to each frequency bin
for i in range(phase1.shape[0]):
blended_phase[i, :] = (1 - blend_factors[i]) * phase1[i, :] + blend_factors[i] * phase2[i, :]
# Wrap phase to the range [-π, π]
blended_phase = torch.remainder(blended_phase + torch.pi, 2 * torch.pi) - torch.pi
return blended_phase
def transfer_magnitude_phase(source_file, target_file, transfer_magnitude=True, transfer_phase=True, low_cutoff=500, high_cutoff=5000, output_32bit=False, output_folder=None):
# Determine output path with "(Corrected)" suffix
target_name, target_ext = os.path.splitext(os.path.basename(target_file))
# Remove "_other" from the final output file name
target_name = target_name.replace("_other", "").replace("_vocals", "").replace("_instrumental", "").replace("_Other", "").replace("_Vocals", "").replace("_Instrumental", "").strip()
# Add "(Corrected)" suffix to the file name
output_file = output_file = os.path.join(output_folder, f"{target_name} (Fixed Instrumental){target_ext}") if output_folder else os.path.join(os.path.dirname(target_file), f"{target_name} (Corrected){target_ext}")
# Load audio using torchaudio
print(f"Phase Fixing {target_name}...")
source_waveform, source_sr = torchaudio.load(source_file)
target_waveform, target_sr = torchaudio.load(target_file)
# Ensure sample rates match
if source_sr != target_sr:
raise ValueError("Sample rates of source and target audio files must match.")
# STFT settings
n_fft = 2048
hop_length = 512
window = torch.hann_window(n_fft)
# Compute STFTs for each channel
source_stfts = torch.stft(source_waveform, n_fft=n_fft, hop_length=hop_length, window=window, return_complex=True, pad_mode="reflect")
target_stfts = torch.stft(target_waveform, n_fft=n_fft, hop_length=hop_length, window=window, return_complex=True, pad_mode="reflect")
# Frequency bins
freqs = torch.linspace(0, source_sr // 2, steps=n_fft // 2 + 1)
# Process each channel independently
modified_stfts = []
for source_stft, target_stft in zip(source_stfts, target_stfts):
source_mag, source_phs = torch.abs(source_stft), torch.angle(source_stft)
target_mag, target_phs = torch.abs(target_stft), torch.angle(target_stft)
# Transfer magnitude
modified_stft = target_stft.clone()
if transfer_magnitude:
modified_stft = source_mag * torch.exp(1j * torch.angle(modified_stft))
# Transfer or blend phase
if transfer_phase:
blended_phase = frequency_blend_phases(target_phs, source_phs, freqs, low_cutoff, high_cutoff)
modified_stft = torch.abs(modified_stft) * torch.exp(1j * blended_phase)
modified_stfts.append(modified_stft)
# Convert modified STFTs back to time domain
modified_waveform = torch.istft(
torch.stack(modified_stfts),
n_fft=n_fft,
hop_length=hop_length,
window=window,
length=source_waveform.size(1)
)
# Save the modified audio to a file
dtype = torch.float32 if output_32bit else torch.int16
torchaudio.save(output_file, modified_waveform, target_sr, encoding="PCM_S", bits_per_sample=32 if output_32bit else 16)
print(f"Corrected file saved as {output_file}")
def process_files(base_folder, unwa_folder, output_folder, low_cutoff, high_cutoff, output_32bit):
# Find all files in the unwa folder (with any suffix)
unwa_files = glob.glob(os.path.join(unwa_folder, "*"))
# Sort the files alphabetically (optional)
unwa_files.sort()
# Iterate over the files in the unwa folder
for unwa_file in unwa_files:
# Get the base file name (without the extension)
base_name_with_suffix = os.path.splitext(os.path.basename(unwa_file))[0]
# Strip any trailing spaces or underscores
base_name = base_name_with_suffix.strip().replace("_other", "").replace("_vocals", "").replace("_instrumental", "").replace("_Other", "").replace("_Vocals", "").replace("_Instrumental", "")
# Handle the case where the suffix is part of the file name (like 'cedo 2_instrumental')
instrumental_file = os.path.join(base_folder, f"{base_name}_instrumental{os.path.splitext(unwa_file)[1]}")
# Check if the corresponding instrumental file exists in the base folder
if os.path.exists(instrumental_file):
# Process the pair
transfer_magnitude_phase(
source_file=instrumental_file,
target_file=unwa_file,
transfer_magnitude=False,
transfer_phase=True,
low_cutoff=low_cutoff,
high_cutoff=high_cutoff,
output_32bit=output_32bit,
output_folder=output_folder
)
else:
print(f"Warning: No matching instrumental file found for {unwa_file}, skipping.")
gc.collect()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Transfer magnitude and/or phase between audio files.")
parser.add_argument("--base_folder", required=True, help="Path to the base folder containing instrumental files (kim).")
parser.add_argument("--unwa_folder", required=True, help="Path to the folder containing corresponding unwa files.")
parser.add_argument("--output_folder", required=True, help="Path to the output folder for corrected files.")
parser.add_argument("--low_cutoff", type=int, default=420, help="Low cutoff frequency for phase blending.")
parser.add_argument("--high_cutoff", type=int, default=4200, help="High cutoff frequency for phase blending.")
parser.add_argument("--output_32bit", action="store_true", help="Save the output as a 32-bit file.")
args = parser.parse_args()
# Process all matching files in the base folder and the corresponding unwa folder
process_files(
base_folder=args.base_folder,
unwa_folder=args.unwa_folder,
output_folder=args.output_folder,
low_cutoff=args.low_cutoff,
high_cutoff=args.high_cutoff,
output_32bit=args.output_32bit
)