From 22221bf1f78453f06419d1809ed679cf64d86bc0 Mon Sep 17 00:00:00 2001 From: Andrew Wason Date: Wed, 2 Oct 2024 14:23:31 -0400 Subject: [PATCH] cleanup --- llm_transformers.py | 8 ++++---- tests/test_transformers.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/llm_transformers.py b/llm_transformers.py index b94d6c6..e1fc3a9 100644 --- a/llm_transformers.py +++ b/llm_transformers.py @@ -22,6 +22,8 @@ from transformers.pipelines import Pipeline, check_task, get_supported_tasks from transformers.utils import get_available_devices +log = logging.getLogger(__name__) + TASK_BLACKLIST = ( "feature-extraction", "image-feature-extraction", @@ -50,7 +52,6 @@ def save_audio(audio: numpy.ndarray, sample_rate: int, output: pathlib.Path | No def save(f: ta.BinaryIO) -> None: # musicgen is shape (batch_size, num_channels, sequence_length) # https://huggingface.co/docs/transformers/v4.45.1/en/model_doc/musicgen#unconditional-generation - # XXX check shape of other audio pipelines sf.write(f, audio[0].T, sample_rate) if output is None: @@ -380,9 +381,8 @@ def handle_result( }: response.response_json = {task: result} yield "\n".join(f"{label} ({score})" for label, score in zip(labels, scores, strict=True)) - case _, _: - breakpoint() # XXX log an error and try json - print("DEFAULT CASE") # XXX + case str(task), _: + log.error("Unhandled pipeline task '%s'. Attempting to show results as JSON.", task) yield json.dumps(result, indent=4) def execute( diff --git a/tests/test_transformers.py b/tests/test_transformers.py index ee34b61..ed3b4df 100644 --- a/tests/test_transformers.py +++ b/tests/test_transformers.py @@ -30,21 +30,21 @@ def validate(out: str): path = out.strip() actual_sample_rate = sf.read(path)[1] pathlib.Path(path).unlink(missing_ok=True) - assert actual_sample_rate == sample_rate + assert sample_rate == actual_sample_rate return validate def equals_validator(value): def validate(out): - assert value == out + assert out == value return validate def json_validator(value: dict): def validate(out): - assert value == json.loads(out) + assert json.loads(out) == value return validate