Skip to content

Commit

Permalink
Merge pull request #776 from yonishelach/transcribe-v2
Browse files Browse the repository at this point in the history
[Transcribe] Moving from whisper to torch
  • Loading branch information
aviaIguazio authored Jan 15, 2024
2 parents 1a57426 + 4ad02e3 commit ae24f1e
Show file tree
Hide file tree
Showing 5 changed files with 1,541 additions and 403 deletions.
323 changes: 243 additions & 80 deletions transcribe/function.yaml

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions transcribe/item.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ spec:
image: mlrun/mlrun
kind: job
requirements:
- openai-whisper
- transformers
- tqdm
- torchaudio
- torch
- accelerate
url: ''
version: 0.0.2
test_valid: True
version: 1.0.0
6 changes: 4 additions & 2 deletions transcribe/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
faster-whisper
transformers
torch
torchaudio
tqdm
librosa
accelerate
22 changes: 9 additions & 13 deletions transcribe/test_transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
#
import os
import pathlib
import sys
import tempfile
from difflib import SequenceMatcher

import mlrun
import pytest


expected_outputs = [
"This is a speech to text test.",
"In the heart of the stadium, "
Expand All @@ -29,24 +29,21 @@
"The crowd roars, a symphony of passion, "
"as the game writes its unpredictable story on the field of destiny.",
]
whisper_models = [
"tiny.en",
"tiny",
"base.en",
"base",
models = [

"openai/whisper-tiny",
]


@pytest.mark.skipif(
condition=sys.version_info[:2] < (3, 8),
reason="whisper requires python 3.8 and above"
)
@pytest.mark.parametrize("model_name", whisper_models)
@pytest.mark.skipif(os.system("which ffmpeg") != 0, reason="ffmpeg not installed")
@pytest.mark.parametrize("model_name", models)
@pytest.mark.parametrize("audio_path", ["./data", "./data/speech_01.mp3"])
def test_transcribe(model_name: str, audio_path: str):
# Setting variables and importing function:
artifact_path = tempfile.mkdtemp()
transcribe_function = mlrun.import_function("function.yaml")
project = mlrun.get_or_create_project("test")
transcribe_function = project.set_function("transcribe.py", "transcribe", kind="job", image="mlrun/mlrun")
# transcribe_function = mlrun.import_function("function.yaml")
temp_dir = tempfile.mkdtemp()

# Running transcribe function:
Expand All @@ -56,7 +53,6 @@ def test_transcribe(model_name: str, audio_path: str):
"data_path": audio_path,
"model_name": model_name,
"device": "cpu",
"compute_type": "int8",
"output_directory": temp_dir,
},
local=True,
Expand Down
Loading

0 comments on commit ae24f1e

Please sign in to comment.