Skip to content

Commit

Permalink
key support. output file support. new tests. GH actions CI.
Browse files Browse the repository at this point in the history
  • Loading branch information
rectalogic committed Oct 1, 2024
1 parent a196b79 commit 8ac54bf
Show file tree
Hide file tree
Showing 6 changed files with 689 additions and 29 deletions.
30 changes: 23 additions & 7 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,34 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v2
with:
enable-cache: true
cache-dependency-glob: "uv.lock"
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: pyproject.toml
- name: Install dependencies
run: |
pip install -e '.[test]'
run: uv sync --all-extras --dev --python ${{ matrix.python-version }} --python-preference only-system
- name: Cache models
id: cache-models
uses: actions/cache@v4
with:
path: ~/.cache/huggingface/hub/
save-always: true
# Update cache every time since models may be added
# https://github.com/actions/cache/blob/main/tips-and-workarounds.md#update-a-cache
key: models-${{ runner.os }}-${{ github.run_id }}
restore-keys: |
models-${{ runner.os }}
- name: Run tests
run: |
python -m pytest
run: uv run pytest tests


#XXX cache models
#XXX lint
19 changes: 18 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,20 @@ llm install llm-transformers
## Usage

XXX document `-o verbose True`
XXX HF_TOKEN/key usage

## Transformer tasks
Most models are freely accessible, some of them require accepting a license agreement and using a Hugging Face [API token](https://huggingface.co/settings/tokens) that has access to the model.
You can use `llm keys set huggingface`, or set the `HF_TOKEN` env var, or use the `--key` option to `llm`.

```sh-session
$ llm -m transformers -o model meta-llama/Llama-3.2-1B "A dog has"
Error: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/meta-llama/Llama-3.2-1B.
$ llm --key hf_******************** -m transformers -o model meta-llama/Llama-3.2-1B "A dog has"
A dog has been named as the killer of a woman who was found dead in her home.
```

## Transformer Pipeline Tasks

### [audio-classification](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.AudioClassificationPipeline)

Expand Down Expand Up @@ -92,6 +104,9 @@ Not supported.
$ llm -m transformers -o task image-segmentation https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png
/var/folders/b1/1j9kkk053txc5krqbh0lj5t00000gn/T/tmp0z8zvd8i.png (bird: 0.999439)
/var/folders/b1/1j9kkk053txc5krqbh0lj5t00000gn/T/tmpik_7r5qn.png (bird: 0.998787)
$ llm -m transformers -o task image-segmentation -o output /tmp/segment.png https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png
/tmp/segment-00.png (bird: 0.999439)
/tmp/segment-01.png (bird: 0.998787)
```

### [image-to-image](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.ImageToImagePipeline)
Expand Down Expand Up @@ -213,6 +228,8 @@ Your question was: "What is the capital of France?"
```sh-session
$ llm -m transformers -o kwargs '{"generate_kwargs": {"max_new_tokens": 100}}' -o model facebook/musicgen-small "techno music"
/var/folders/b1/1j9kkk053txc5krqbh0lj5t00000gn/T/tmpoueh05y6.wav
$ llm -m transformers -o task text-to-audio "Hello world"
/var/folders/b1/1j9kkk053txc5krqbh0lj5t00000gn/T/tmpmpwhkd8p.wav
```

### [token-classification](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.TokenClassificationPipeline)
Expand Down
76 changes: 55 additions & 21 deletions llm_transformers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Copyright (C) 2024 Andrew Wason
# SPDX-License-Identifier: Apache-2.0
import csv
import itertools
import json
import logging
import pathlib
import re
import tempfile
import typing as ta
Expand All @@ -10,10 +13,11 @@

import click
import llm
import numpy
import soundfile as sf
import torch
from PIL import Image
from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic import Field, field_validator, model_validator
from transformers import pipeline
from transformers.pipelines import Pipeline, check_task, get_supported_tasks
from transformers.utils import get_available_devices
Expand All @@ -31,10 +35,31 @@ def supported_tasks() -> ta.Iterator[str]:
yield task


def save_image(image: Image.Image) -> str:
with tempfile.NamedTemporaryFile(suffix=".png", delete=False, delete_on_close=False) as f:
image.save(f, format="png")
return f.name
def save_image(image: Image.Image, output: pathlib.Path | None) -> str:
if output is None:
with tempfile.NamedTemporaryFile(suffix=".png", delete=False, delete_on_close=False) as f:
image.save(f, format="png")
return f.name
else:
image.save(str(output))
return str(output)


def save_audio(audio: numpy.ndarray, sample_rate: int, output: pathlib.Path | None) -> str:
def save(f: ta.BinaryIO) -> None:
# musicgen is shape (batch_size, num_channels, sequence_length)
# https://huggingface.co/docs/transformers/v4.45.1/en/model_doc/musicgen#unconditional-generation
# XXX check shape of other audio pipelines
sf.write(f, audio[0].T, sample_rate)

if output is None:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False, delete_on_close=False) as f:
save(f)
return f.name
else:
with open(output, "wb") as f:
save(f)
return str(output)


def handle_required_kwarg(kwargs: dict, options: llm.Options, name: str, format: str, task: str) -> None:
Expand Down Expand Up @@ -92,6 +117,8 @@ def silence(verbose: bool | None = None):

class Transformers(llm.Model):
model_id = "transformers"
needs_key = "huggingface" # only some models need a key
key_env_var = "HF_TOKEN"

pipe: Pipeline | None = None

Expand All @@ -108,8 +135,12 @@ class Options(llm.Options):
description="Additional context for transformer, often a file path or URL, required by some transformers.",
default=None,
)
output: pathlib.Path | None = Field(
description="Output file path. Some models generate binary image/audio outputs which will be saved in this file, or a temporary file if not specified.",
default=None,
)
device: str | None = Field(
description="Device name. `llm transformers list-devices`.", default=None
description="Torch device name. `llm transformers list-devices`.", default=None
)
verbose: bool | None = Field(
description="Logging is disabled by default, enable this to see transformers warnings.",
Expand Down Expand Up @@ -216,24 +247,20 @@ def handle_inputs(
return args, kwargs

def handle_result(
self, task: str, result: ta.Any, response: llm.Response
self, task: str, result: ta.Any, prompt: llm.Prompt, response: llm.Response
) -> ta.Generator[str, None, None]:
match task, result:
case "image-to-image", Image.Image() as image:
path = save_image(image)
path = save_image(image, prompt.options.output)
response.response_json = {task: {"output": path}}
yield path
case "automatic-speech-recognition", {"text": str(text)}:
response.response_json = {task: result}
yield text
case "text-to-audio", {"audio": audio, "sampling_rate": int(sampling_rate)}:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False, delete_on_close=False) as f:
# musicgen is shape (batch_size, num_channels, sequence_length)
# https://huggingface.co/docs/transformers/v4.45.1/en/model_doc/musicgen#unconditional-generation
# XXX check shape of other audio pipelines
sf.write(f, audio[0].T, sampling_rate)
response.response_json = {task: {"output": f.name}}
yield f.name
case "text-to-audio", {"audio": numpy.ndarray() as audio, "sampling_rate": int(sample_rate)}:
path = save_audio(audio, sample_rate, prompt.options.output)
response.response_json = {task: {"output": path}}
yield path
case "object-detection", [
{
"score": float(),
Expand All @@ -245,8 +272,14 @@ def handle_result(
yield json.dumps(result, indent=4)
case "image-segmentation", [{"score": float(), "label": str(), "mask": Image.Image()}, *_]:
responses = []
for item in result:
path = save_image(item["mask"])
if prompt.options.output:
out = prompt.options.output
output_template = str(out.with_name(f"{out.stem}-{{:02}}{out.suffix}"))
else:
output_template = None
for i, item in enumerate(result):
output = output_template.format(i) if output_template else None
path = save_image(item["mask"], output)
responses.append({"score": item["score"], "label": item["label"], "output": path})
response.response_json = {task: responses}
yield "\n".join(
Expand All @@ -272,8 +305,8 @@ def handle_result(
]:
response.response_json = {task: result}
yield "\n".join(f"{item['sequence']} (score={item['score']})" for item in result)
case "depth-estimation", {"predicted_depth": torch.Tensor(), "depth": Image.Image(depth)}:
path = save_image(depth)
case "depth-estimation", {"predicted_depth": torch.Tensor(), "depth": Image.Image() as depth}:
path = save_image(depth, prompt.options.output)
response.response_json = {task: {"output": path}}
yield path
case "document-question-answering", [
Expand Down Expand Up @@ -361,6 +394,7 @@ def execute(
if prompt.options.device is not None
else None,
framework="pt",
token=self.key,
)
elif (prompt.options.task and self.pipe.task != prompt.options.task) or (
prompt.options.model and self.pipe.model.name_or_path != prompt.options.model
Expand All @@ -375,4 +409,4 @@ def execute(

result = self.pipe(*args, **kwargs)

yield from self.handle_result(normalized_task, result, response)
yield from self.handle_result(normalized_task, result, prompt, response)
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright (C) 2024 Andrew Wason
# SPDX-License-Identifier: Apache-2.0
[project]
name = "llm-transformers"
version = "0.1"
Expand All @@ -20,6 +22,7 @@ dependencies = [
"protobuf",
"pandas",
"av",
"numpy>=2.1.1",
]

[project.urls]
Expand Down
Loading

0 comments on commit 8ac54bf

Please sign in to comment.