Skip to content

Commit

Permalink
Allow loading and unloading of models (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
gaspardpetit authored Jan 17, 2024
1 parent 4aa8387 commit 39173b7
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 14 deletions.
13 changes: 13 additions & 0 deletions verbatim/filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from abc import ABC, abstractmethod

# pylint: disable=unused-argument
class Filter(ABC):
@abstractmethod
def execute(self, **kwargs: dict):
...
def load(self, **kwargs: dict):
...

def unload(self, **kwargs: dict):
...

7 changes: 4 additions & 3 deletions verbatim/language_detection/detect_language.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from abc import ABC, abstractmethod
from abc import abstractmethod
from numpy import ndarray
from pyannote.core import Annotation
from pyannote.database.util import load_rttm
from ..wav_conversion import ConvertToWav

from ..transcription import Transcription, Utterance
from ..filter import Filter


class DetectLanguage(ABC):
class DetectLanguage(Filter):
"""
Abstract base class for language detection.
Expand Down Expand Up @@ -241,6 +241,7 @@ def fill_language_gaps(self,
changed = True
center.confidence = confidence

# pylint: disable=arguments-differ
def execute(self, diarization_file: str, voice_file_path:str,
language_file: str, languages=None, **kwargs: dict) -> Transcription:
"""
Expand Down
2 changes: 2 additions & 0 deletions verbatim/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,6 @@ def execute(self):
] + self.transcripte_writing

for f in filters:
f.load(**self.context.to_dict())
f.execute(**self.context.to_dict())
f.unload(**self.context.to_dict())
7 changes: 5 additions & 2 deletions verbatim/speaker_diarization/diarize_speakers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from abc import ABC, abstractmethod
from abc import abstractmethod
from pyannote.core import Annotation

from ..filter import Filter

class DiarizeSpeakers(ABC):

class DiarizeSpeakers(Filter):
"""
Abstract class for diarization of speakers in an audio file.
Expand All @@ -13,6 +15,7 @@ class DiarizeSpeakers(ABC):
"""

@abstractmethod
# pylint: disable=arguments-differ
def execute(self, voice_file_path: str, diarization_file: str, min_speakers: int = 1, max_speakers: int = None,
**kwargs: dict) -> Annotation:
"""
Expand Down
4 changes: 4 additions & 0 deletions verbatim/speaker_diarization/diarize_speakers_pyannote.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def diarize_on_silences(self, voice_file_path: str, model_pyannote_segmentation:
Returns:
Annotation: Pyannote Annotation object containing information about speaker diarization.
"""
LOG.info(f"Loading model {model_pyannote_segmentation}")
model = Model.from_pretrained(model_pyannote_segmentation, use_auth_token=huggingface_token)
if model is None:
LOG.error(f"Failed to retrieve model {model_pyannote_segmentation}")
Expand All @@ -54,6 +55,9 @@ def diarize_on_silences(self, voice_file_path: str, model_pyannote_segmentation:
pipeline.instantiate(hyper_parameters)
vad: Annotation = pipeline(voice_file_path)
vad.uri = "waveform"

LOG.info(f"Unloading model {model_pyannote_segmentation}")
del model
return vad

@staticmethod
Expand Down
6 changes: 4 additions & 2 deletions verbatim/speech_transcription/transcribe_speech.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import logging
from abc import ABC, abstractmethod
from abc import abstractmethod
import numpy as np
from numpy import ndarray
from pyannote.core import Annotation
Expand All @@ -9,11 +9,12 @@
from ..transcription import Transcription
from ..speaker_diarization import DiarizeSpeakersSpeechBrain
from ..wav_conversion import ConvertToWav
from ..filter import Filter

LOG = logging.getLogger(__name__)


class TranscribeSpeech(ABC):
class TranscribeSpeech(Filter):
"""
Abstract class for transcribing audio.
Expand Down Expand Up @@ -299,6 +300,7 @@ def execute_for_speaker_and_language(self, speech_segment_float32_16khz, sequenc
return whole_transcription

# pylint: disable=unused-argument
# pylint: disable=arguments-differ
def execute(self, voice_file_path:str, language_file:str,
transcription_path: str, diarization_file:str, languages: list, **kwargs: dict) -> Transcription:
"""
Expand Down
7 changes: 4 additions & 3 deletions verbatim/transcript_writing/write_transcript.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from abc import abstractmethod
from ..filter import Filter


class WriteTranscript(ABC):
class WriteTranscript(Filter):
"""
Abstract base class for writing transcriptions to a file.
Expand All @@ -15,6 +15,7 @@ class WriteTranscript(ABC):
"""

@abstractmethod
# pylint: disable=arguments-differ
def execute(self, transcription_path: str, output_file: str, **kwargs: dict) -> None:
"""
Execute the transcription writing process.
Expand Down
7 changes: 5 additions & 2 deletions verbatim/voice_isolation/isolate_voices.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from abc import ABC, abstractmethod
from abc import abstractmethod
from numpy import ndarray

from ..filter import Filter

class IsolateVoices(ABC):

class IsolateVoices(Filter):
"""
Abstract base class for voice isolation methods.
Expand All @@ -13,6 +15,7 @@ class IsolateVoices(ABC):
"""

@abstractmethod
# pylint: disable=arguments-differ
def execute(self, audio_file_path: str, voice_file_path: str, **kwargs) -> ndarray:
"""
Execute the voice isolation process.
Expand Down
7 changes: 5 additions & 2 deletions verbatim/wav_conversion/convert_to_wav.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from abc import ABC, abstractmethod
from abc import abstractmethod
import numpy as np
from numpy import ndarray
import torch
Expand All @@ -7,9 +7,12 @@
import torchaudio
from pydub import AudioSegment

from ..filter import Filter

class ConvertToWav(ABC):

class ConvertToWav(Filter):
@abstractmethod
# pylint: disable=arguments-differ
def execute(self, source_file_path: str, audio_file_path: str, **kwargs: dict) -> None:
"""
Abstract method for converting audio files to WAV format.
Expand Down

0 comments on commit 39173b7

Please sign in to comment.