Skip to content

Commit

Permalink
Support images directly in UserMessage (#387)
Browse files Browse the repository at this point in the history
* Move OpenaiChatModel UserImageMessage support into openai_chat_model.py

* Implement message_to_anthropic_message for UserImageMessage

* Get test passing using _combine_messages

* WIP: Add ImageBytes, ImageUrl. Expand UserMessage types.

* Remove duplicate message_to_x_message for UserImageMessage

* Add make test-fix-snapshots

* Move image_bytes fixtures into top-level conftest

* Fix typing for ImageBytes.mime_type. Add tests

* Fix mypy errors on UserMessage conversion typing

* Validate bytes are valid image

* Use ImageBytes in UserImageMessage conversion functions

* Add typevar UserMessageContentT

* Fix mypy errors due to UserMessage now generic

* Attempt to coerce type in Placeholder.format

* Fix: list -> Iterable in UserMessage serialization

* Fix: return message in message_to_openai_message

* Make Placeholder a BaseModel

* Make ContentT covariant

* Add covariant PlaceholderT

* Fix type checking for UserMessage format

* Remove unused type ignores

* Require pydantic 2.10 to fix generic in BaseModel

* Revert "Remove unused type ignores"

This reverts commit 87b81ca.

* Remove pydantic url from error messages in tests

* Add TypeAlias UserMessageContentBlock

* Add trailing .0 to pydantic version in pyproject

* Use TypeAdapter for Placeholder coercion

* Add typing-extensions as dependency

* Remove todo for testing AssistantMessage with FunctionCall

* Ignore logfire not configured warnings

* Deprecate UserImageMessage

* Add make test-snapshots-create and improve naming

* Add tests for UserMessage with ImageBytes/Url

* Upgrade mypy version

* Add tests for ImageBytes with ChatModel

* Improve AssistantMessage typing, add tests

* Handle Literal string in AssistantMessage.format typing

* Add NotPlaceholder Protocol

* Fix type hints for UserMessage

* Add github issue link for failing type tests

* Improve handling of Literal in AssistantMessage typing

* Rename to PlaceholderTypeT

* Remove done todo re name not in kwargs error

* Add top-level imports for ImageBytes, ImageUrl

* Update docs for vision

* Add note about Placeholder coercion
  • Loading branch information
jackmpcollins authored Jan 6, 2025
1 parent 755b553 commit 0cb9e7b
Show file tree
Hide file tree
Showing 31 changed files with 1,348 additions and 468 deletions.
12 changes: 10 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,18 @@ testcov: test # Run tests and generate a coverage report
test-vcr-once: # Run the tests and record new VCR cassettes
uv run pytest -vv --record-mode=once

.PHONY: test-fix-vcr
test-fix-vcr: # Run the last failed tests and rewrite the VCR cassettes
.PHONY: test-vcr-fix
test-vcr-fix: # Run the last failed tests and rewrite the VCR cassettes
uv run pytest -vv --last-failed --last-failed-no-failures=none --record-mode=rewrite

.PHONY: test-snapshots-create
test-snapshots-create: # Run the tests and create new inline-snapshots
uv run pytest -vv --inline-snapshot=create

.PHONY: test-snapshots-fix
test-snapshots-fix: # Run the tests and fix inline-snapshots
uv run pytest -vv --inline-snapshot=fix

.PHONY: docs
docs: # Build the documentation
uv run mkdocs build
Expand Down
4 changes: 2 additions & 2 deletions docs/chat-prompting.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ escaped_string.format(example="test")

## Placeholder

The `Placeholder` class enables templating of `AssistantMessage` content within the `@chatprompt` decorator. This allows dynamic changing of the messages used to prompt the model based on the arguments provided when the function is called.
The `Placeholder` class enables templating of message content within the `@chatprompt` decorator. This allows dynamic changing of the messages used to prompt the model based on the arguments provided when the function is called.

```python
from magentic import chatprompt, AssistantMessage, Placeholder, UserMessage
Expand All @@ -75,7 +75,7 @@ get_similar_quote(
# Quote(quote='The Force will be with you, always.', character='Obi-Wan Kenobi')
```

`Placeholder` can also be utilized in the `format` method of custom `Message` subclasses to provide an explicit way of inserting values from the function arguments. For example, see `UserImageMessage` in (TODO: link to GPT-vision page).
`Placeholder` can also be used in `UserMessage` to allow inserting `ImageBytes`, `ImageUrl`, or other content blocks from function arguments. For more information see [Vision](vision.md).

## FunctionCall

Expand Down
13 changes: 8 additions & 5 deletions docs/examples/vision_renaming_screenshots.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,16 @@
"source": [
"# Create a prompt-function to return details given an image\n",
"\n",
"from magentic import OpenaiChatModel, Placeholder, UserMessage, chatprompt\n",
"from magentic.vision import UserImageMessage\n",
"from magentic import ImageBytes, OpenaiChatModel, Placeholder, UserMessage, chatprompt\n",
"\n",
"\n",
"@chatprompt(\n",
" UserMessage(\"Describe the screenshot, then provide a suitable file name.\"),\n",
" UserImageMessage(Placeholder(bytes, \"image\")),\n",
" UserMessage(\n",
" [\n",
" \"Describe the screenshot, then provide a suitable file name.\",\n",
" Placeholder(ImageBytes, \"image\"),\n",
" ]\n",
" ),\n",
" model=OpenaiChatModel(\"gpt-4-turbo\"),\n",
")\n",
"def describe_image(image: bytes) -> ScreenshotDetails: ..."
Expand Down Expand Up @@ -250,7 +253,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.10.15"
}
},
"nbformat": 4,
Expand Down
49 changes: 31 additions & 18 deletions docs/vision.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
# Vision

Image inputs can be provided to LLMs in magentic by using the `UserImageMessage` message type.
Image inputs can be provided to LLMs in magentic by using `ImageBytes` or `ImageUrl` within the `UserMessage` message type. The LLM used must support vision, for example `gpt-4o` (the default `ChatModel`). The model can be set by passing the `model` parameter to `@chatprompt`, or through the other methods of [configuration](configuration.md).

!!! note "Anthropic Image URLs"

Anthropic models currently do not support supplying an image as a url, just bytes.

For more information visit the [OpenAI Vision API documentation](https://platform.openai.com/docs/guides/vision) or the [Anthropic Vision API documentation](https://docs.anthropic.com/en/docs/build-with-claude/vision#example-multiple-images).

## UserImageMessage
!!! warning "UserImageMessage Deprecation"

The `UserImageMessage` can be used in `@chatprompt` alongside other messages. The LLM must be set to an OpenAI or Anthropic model that supports vision, for example `gpt-4o` (the default `ChatModel`). This can be done by passing the `model` parameter to `@chatprompt`, or through the other methods of [configuration](configuration.md).
Previously the `UserImageMessage` was used for vision capabilities. This is now deprecated and will be removed in a future version of Magentic. It is recommended to use `ImageBytes` or `ImageUrl` within the `UserMessage` message type instead to ensure compatibility with future updates.

## ImageUrl

As shown in [Chat Prompting](chat-prompting.md), `@chatprompt` can be used to supply a group of messages as a prompt to the LLM. `UserMessage` accepts a sequence of content blocks as input, which can be `str`, `ImageBytes`, `ImageUrl`, or other content types. `ImageUrl` is used to provide an image url to the LLM.

```python
from pydantic import BaseModel, Field

from magentic import chatprompt, UserMessage
from magentic.vision import UserImageMessage
from magentic import chatprompt, ImageUrl, UserMessage


IMAGE_URL_WOODEN_BOARDWALK = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
Expand All @@ -28,8 +31,12 @@ class ImageDetails(BaseModel):


@chatprompt(
UserMessage("Describe the following image in one sentence."),
UserImageMessage(IMAGE_URL_WOODEN_BOARDWALK),
UserMessage(
[
"Describe the following image in one sentence.",
ImageUrl(IMAGE_URL_WOODEN_BOARDWALK),
]
),
)
def describe_image() -> ImageDetails: ...

Expand All @@ -45,19 +52,22 @@ For more info on the `@chatprompt` decorator, see [Chat Prompting](chat-promptin

## Placeholder

In the previous example, the image url was tied to the function. To provide the image as a function parameter, use `Placeholder`. This substitutes a function argument into the message when the function is called.
In the previous example, the image url was tied to the function. To provide the image as a function parameter, use `Placeholder`. This substitutes a function argument into the message when the function is called. The placeholder will also automatically coerce the argument to the correct type if possible, for example `str` to `ImageUrl`.

```python hl_lines="10"
from magentic import chatprompt, Placeholder, UserMessage
from magentic.vision import UserImageMessage
from magentic import chatprompt, ImageUrl, Placeholder, UserMessage


IMAGE_URL_WOODEN_BOARDWALK = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"


@chatprompt(
UserMessage("Describe the following image in one sentence."),
UserImageMessage(Placeholder(str, "image_url")),
UserMessage(
[
"Describe the following image in one sentence.",
Placeholder(ImageUrl, "image_url"),
]
),
)
def describe_image(image_url: str) -> str: ...

Expand All @@ -66,15 +76,14 @@ describe_image(IMAGE_URL_WOODEN_BOARDWALK)
# 'A wooden boardwalk meanders through lush green wetlands under a partly cloudy blue sky.'
```

## bytes
## ImageBytes

`UserImageMessage` can also accept `bytes` as input. Like `str`, this can be passed directly or via `Placeholder`.
`UserMessage` can also accept `ImageBytes` as a content block. Like `ImageUrl`, this can be passed directly or via `Placeholder`.

```python
import requests

from magentic import chatprompt, Placeholder, UserMessage
from magentic.vision import UserImageMessage
from magentic import chatprompt, ImageBytes, Placeholder, UserMessage


IMAGE_URL_WOODEN_BOARDWALK = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
Expand All @@ -90,8 +99,12 @@ def url_to_bytes(url: str) -> bytes:


@chatprompt(
UserMessage("Describe the following image in one sentence."),
UserImageMessage(Placeholder(bytes, "image_bytes")),
UserMessage(
[
"Describe the following image in one sentence.",
Placeholder(ImageBytes, "image_bytes"),
]
),
)
def describe_image(image_bytes: bytes) -> str: ...

Expand Down
10 changes: 9 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ dependencies = [
"filetype>=1.2.0",
"logfire-api>=0.1.0",
"openai>=1.40.0",
"pydantic>=2.7.0",
"pydantic>=2.10.0",
"pydantic-settings>=2.0.0",
"typing-extensions>=4.5.0",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -52,6 +53,9 @@ examples = ["ghapi>=1.0.5", "jupyter", "pandas>=2.2.1"]
skip_covered = "true"
show_missing = "true"

[tool.logfire]
ignore_no_config = true

[tool.mypy]
check_untyped_defs = true
disable_error_code = ["empty-body"]
Expand All @@ -76,6 +80,10 @@ venv = ".venv"
addopts = "--block-network --cov-report=html --cov-report=term --cov=magentic"
asyncio_default_fixture_loop_scope = "function"
asyncio_mode = "auto"
filterwarnings = [
"ignore::DeprecationWarning::",
"default::DeprecationWarning:magentic.*:",
]
markers = [
"anthropic: Tests that query the Anthropic API. Requires the ANTHROPIC_API_KEY environment variable to be set.",
"litellm_anthropic: Tests that query the Anthropic API via litellm. Requires the ANTHROPIC_API_KEY environment variable to be set.",
Expand Down
2 changes: 2 additions & 0 deletions src/magentic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from .chat_model.message import AnyMessage as AnyMessage
from .chat_model.message import AssistantMessage as AssistantMessage
from .chat_model.message import FunctionResultMessage as FunctionResultMessage
from .chat_model.message import ImageBytes as ImageBytes
from .chat_model.message import ImageUrl as ImageUrl
from .chat_model.message import Placeholder as Placeholder
from .chat_model.message import SystemMessage as SystemMessage
from .chat_model.message import ToolResultMessage as ToolResultMessage
Expand Down
1 change: 1 addition & 0 deletions src/magentic/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from magentic.streaming import async_iter, azip

P = ParamSpec("P")
# TODO: Use `Self` from typing_extensions
Self = TypeVar("Self", bound="Chat")


Expand Down
45 changes: 33 additions & 12 deletions src/magentic/chat_model/anthropic_chat_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import base64
import json
from collections.abc import AsyncIterator, Callable, Iterable, Iterator, Sequence
from enum import Enum
from functools import singledispatch
from itertools import groupby
from typing import Any, Generic, TypeVar, cast, overload

import filetype

from magentic._parsing import contains_parallel_function_call_type, contains_string_type
from magentic.chat_model.base import ChatModel, aparse_stream, parse_stream
from magentic.chat_model.function_schema import (
Expand All @@ -20,6 +17,7 @@
)
from magentic.chat_model.message import (
AssistantMessage,
ImageBytes,
Message,
SystemMessage,
ToolResultMessage,
Expand All @@ -42,7 +40,9 @@
from anthropic.lib.streaming import MessageStreamEvent
from anthropic.lib.streaming._messages import accumulate_event
from anthropic.types import (
ImageBlockParam,
MessageParam,
TextBlockParam,
ToolChoiceParam,
ToolChoiceToolParam,
ToolParam,
Expand Down Expand Up @@ -70,29 +70,50 @@ def _(message: _RawMessage[Any]) -> MessageParam:
return message.content # type: ignore[no-any-return]


@message_to_anthropic_message.register
def _(message: UserMessage) -> MessageParam:
return {"role": AnthropicMessageRole.USER.value, "content": message.content}
@message_to_anthropic_message.register(UserMessage)
def _(message: UserMessage[Any]) -> MessageParam:
if isinstance(message.content, str):
return {"role": AnthropicMessageRole.USER.value, "content": message.content}
if isinstance(message.content, Iterable):
content: list[TextBlockParam | ImageBlockParam] = []
for block in message.content:
if isinstance(block, str):
content.append({"type": "text", "text": block})
elif isinstance(block, ImageBytes):
content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": block.mime_type,
"data": block.as_base64(),
},
}
)
else:
msg = f"Invalid content type for UserMessage: {type(block)}"
raise TypeError(msg)
return {"role": AnthropicMessageRole.USER.value, "content": content}
msg = f"Invalid content type for UserMessage: {type(message.content)}"
raise TypeError(msg)


@message_to_anthropic_message.register(UserImageMessage)
def _(message: UserImageMessage[Any]) -> MessageParam:
if isinstance(message.content, bytes):
mime_type = filetype.guess_mime(message.content)
base64_image = base64.b64encode(message.content).decode("utf-8")
else:
if not isinstance(message.content, bytes):
msg = f"Invalid content type: {type(message.content)}"
raise TypeError(msg)

image_bytes = ImageBytes(message.content)
return {
"role": AnthropicMessageRole.USER.value,
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": base64_image,
"media_type": image_bytes.mime_type,
"data": image_bytes.as_base64(),
},
}
],
Expand Down
2 changes: 1 addition & 1 deletion src/magentic/chat_model/function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _register(cls: TypeFunctionSchemaT) -> TypeFunctionSchemaT:
_async_function_schema_registry.register(type_, cls)
if issubclass(cls, FunctionSchema):
_function_schema_registry.register(type_, cls)
return cls # type: ignore[return-value]
return cls

return _register

Expand Down
Loading

0 comments on commit 0cb9e7b

Please sign in to comment.