Skip to content

Commit

Permalink
feat: integrate new diarization strategy into rest of project
Browse files Browse the repository at this point in the history
  • Loading branch information
linozen committed Jan 13, 2025
1 parent 3eb5f9f commit 2689384
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 16 deletions.
2 changes: 2 additions & 0 deletions verbatim/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __call__(self, parser, namespace, values, option_string=None):
dest="stop_time",
)
parser.add_argument("-o", "--outdir", help="Path to the output directory", default=".")
parser.add_argument("--diarization-strategy", choices=["pyannote", "stereo"], default="pyannote", help="Diarization strategy to use")
parser.add_argument(
"-d",
"--diarization",
Expand Down Expand Up @@ -312,6 +313,7 @@ def __call__(self, parser, namespace, values, option_string=None):
isolate=args.isolate,
diarize=args.diarize,
diarization_file=args.diarization,
diarization_strategy=args.diarization_strategy,
)

from .audio.sources.factory import (
Expand Down
6 changes: 6 additions & 0 deletions verbatim/verbatim.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,12 @@ def advance_audio_window(self, offset: int):
self.window_ts += offset

def append_audio_to_window(self, audio_chunk: np.array):
# Convert stereo to mono if necessary
LOG.debug(f"Audio chunk shape before mono conversion: {audio_chunk.shape}")
if len(audio_chunk.shape) > 1 and audio_chunk.shape[1] > 1:
audio_chunk = np.mean(audio_chunk, axis=1)
LOG.debug(f"Audio chunk shape after mono conversion: {audio_chunk.shape}")

chunk_size = len(audio_chunk)
window_size = len(self.rolling_window.array)
if self.audio_ts + chunk_size <= self.window_ts + window_size:
Expand Down
74 changes: 58 additions & 16 deletions verbatim/voices/separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pyannote.core.annotation import Annotation

from ..audio.audio import wav_to_int16
from .diarize.factory import create_diarizer

# Configure logger
LOG = logging.getLogger(__name__)
Expand All @@ -17,6 +18,7 @@
class SpeakerSeparation:
def __init__(self, device: str, huggingface_token: str):
LOG.info("Initializing Separation Pipeline.")
self.device = device
self.huggingface_token = huggingface_token
self.pipeline = Pipeline.from_pretrained(
"pyannote/speech-separation-ami-1.0",
Expand Down Expand Up @@ -54,27 +56,67 @@ def separate_speakers(
out_rttm_file: str = None,
out_speaker_wav_prefix="",
nb_speakers: int = None,
diarization_strategy: str = "pyannote",
) -> Tuple[Annotation, Dict[str, str]]:
"""
Separate speakers in an audio file.
Args:
file_path: Path to input audio file
out_rttm_file: Path to output RTTM file
out_speaker_wav_prefix: Prefix for output WAV files
nb_speakers: Optional number of speakers
diarization_strategy: Diarization strategy to use ('pyannote' or 'stereo')
Returns:
Tuple of (diarization annotation, dictionary mapping speaker IDs to WAV files)
"""
if not out_rttm_file:
out_rttm_file = "out.rttm"

with ProgressHook() as hook:
diarization, sources = self.pipeline(file_path, hook=hook)
# For stereo strategy, we might want to handle separation differently
if diarization_strategy == "stereo":
# For stereo files, we can simply split the channels
sample_rate, audio_data = scipy.io.wavfile.read(file_path)
if audio_data.ndim != 2 or audio_data.shape[1] != 2:
raise ValueError("Stereo separation requires stereo audio input")

# dump the diarization output to disk using RTTM format
with open(out_rttm_file, "w", encoding="utf-8") as rttm:
diarization.write_rttm(rttm)
# Create diarization annotation
diarizer = create_diarizer(strategy="stereo", device=self.device, huggingface_token=self.huggingface_token)
diarization = diarizer.compute_diarization(file_path=file_path, out_rttm_file=out_rttm_file, nb_speakers=nb_speakers)

# dump sources to disk as SPEAKER_XX.wav files
speaker_wav_files = {}
for s, speaker in enumerate(diarization.labels()):
if s < sources.data.shape[1]:
speaker_data = sources.data[:, s]
if speaker_data.dtype != np.int16:
speaker_data = wav_to_int16(speaker_data)
# Split channels into separate files
speaker_wav_files = {}
for channel, speaker in enumerate(["SPEAKER_0", "SPEAKER_1"]):
channel_data = audio_data[:, channel]
if channel_data.dtype != np.int16:
channel_data = wav_to_int16(channel_data)
file_name = f"{out_speaker_wav_prefix}-{speaker}.wav" if out_speaker_wav_prefix else f"{speaker}.wav"
speaker_wav_files[speaker] = file_name
scipy.io.wavfile.write(file_name, 16000, speaker_data)
else:
LOG.debug(f"Skipping speaker {s} as it is out of bounds.")
return diarization, speaker_wav_files
scipy.io.wavfile.write(file_name, sample_rate, channel_data)

return diarization, speaker_wav_files

else:
# Use PyAnnote's neural separation for mono files
with ProgressHook() as hook:
diarization, sources = self.pipeline(file_path, hook=hook)

# Save diarization to RTTM file
with open(out_rttm_file, "w", encoding="utf-8") as rttm:
diarization.write_rttm(rttm)

# Save separated sources to WAV files
speaker_wav_files = {}
for s, speaker in enumerate(diarization.labels()):
if s < sources.data.shape[1]:
speaker_data = sources.data[:, s]
if speaker_data.dtype != np.int16:
speaker_data = wav_to_int16(speaker_data)
file_name = f"{out_speaker_wav_prefix}-{speaker}.wav" if out_speaker_wav_prefix else f"{speaker}.wav"
speaker_wav_files[speaker] = file_name
scipy.io.wavfile.write(file_name, 16000, speaker_data)
else:
LOG.debug(f"Skipping speaker {s} as it is out of bounds.")

return diarization, speaker_wav_files

0 comments on commit 2689384

Please sign in to comment.