Skip to content

Commit

Permalink
Consume LLM output stream via returned objects to allow caching (#384)
Browse files Browse the repository at this point in the history
* Consume stream via objects to allow caching

* Update tests to check caching

* Change StreamedStr consume warning to tip in docs
  • Loading branch information
jackmpcollins authored Dec 1, 2024
1 parent 0b4b795 commit f7e7e37
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 29 deletions.
4 changes: 2 additions & 2 deletions docs/streaming.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ for hero in create_superhero_team("The Food Dudes"):

Some LLMs have the ability to generate text output and make tool calls in the same response. This allows them to perform chain-of-thought reasoning or provide additional context to the user. In magentic, the `StreamedResponse` (or `AsyncStreamedResponse`) class can be used to request this type of output. This object is an iterable of `StreamedStr` (or `AsyncStreamedStr`) and `FunctionCall` instances.

!!! warning "Consuming StreamedStr"
!!! tip "Consuming StreamedStr"

The StreamedStr object must be iterated over before the next item in the `StreamedResponse` is processed, otherwise the string output will be lost. This is because the `StreamedResponse` and `StreamedStr` share the same underlying generator, so advancing the `StreamedResponse` iterator skips over the `StreamedStr` items. The `StreamedStr` object has internal caching so after iterating over it once the chunks will remain available.
The StreamedStr object caches its chunks internally, so it does not have to be consumed immediately. This means you can iterate over the chunks as they are received, and/or use the StreamedStr object as a whole after the LLM has finished generating the output.

In the example below, we request that the LLM generates a greeting and then calls a function to get the weather for two cities. The `StreamedResponse` object is then iterated over to print the output, and the `StreamedStr` and `FunctionCall` items are processed separately.

Expand Down
2 changes: 1 addition & 1 deletion docs/structured-outputs.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ print(hero_defeated)

!!! warning "StreamedResponse"

It is now recommended to use `StreamedResponse` for chain-of-thought prompting, as this uses the LLM provider's native chain-of-thought capabilities. See [StreamedResponse](streaming.md#StreamedResponse) for more information.
It is now recommended to use `StreamedResponse` for chain-of-thought prompting, as this uses the LLM provider's native chain-of-thought capabilities. See [StreamedResponse](streaming.md#streamedresponse) for more information.

Using a simple Python type as the return annotation might result in poor results as the LLM has no time to arrange its thoughts before answering. To allow the LLM to work through this "chain of thought" you can instead return a pydantic model with initial fields for explaining the final response.

Expand Down
58 changes: 33 additions & 25 deletions src/magentic/chat_model/stream.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterable, Iterator
from collections.abc import AsyncIterable, AsyncIterator, Iterable, Iterator
from itertools import chain
from typing import Any, Generic, NamedTuple, TypeVar

Expand Down Expand Up @@ -81,6 +81,7 @@ def __init__(
self._state = state

self._iterator = self.__stream__()
self._exhausted: bool = False

def __next__(self) -> StreamedStr | OutputT:
return self._iterator.__next__()
Expand All @@ -99,6 +100,7 @@ def _streamed_str(
assert not current_item_ref # noqa: S101
current_item_ref.append(item)
return
self._exhausted = True

def _tool_call(
self,
Expand All @@ -116,6 +118,7 @@ def _tool_call(
return
if item.args:
yield item.args
self._exhausted = True

def __stream__(self) -> Iterator[StreamedStr | OutputT]:
# This works similarly to `itertools.groupby`
Expand All @@ -125,10 +128,12 @@ def __stream__(self) -> Iterator[StreamedStr | OutputT]:
current_item = current_item_ref.pop()
if self._parser.is_content(current_item):
stream = chain([current_item], stream)
yield StreamedStr(self._streamed_str(stream, current_item_ref))
if not current_item_ref:
streamed_str = StreamedStr(self._streamed_str(stream, current_item_ref))
yield streamed_str
if not current_item_ref and not self._exhausted:
# Finish the group to allow advancing to the next one
consume(self._streamed_str(stream, current_item_ref))
# Consume stream via StreamedStr so it can cache
consume(streamed_str)
elif self._parser.is_tool_call(current_item):
tool_calls_stream: Iterator[FunctionCallChunk] = (
tool_call_chunk
Expand All @@ -155,20 +160,18 @@ def __stream__(self) -> Iterator[StreamedStr | OutputT]:
tool_calls_stream = chain(
[current_tool_call_chunk], tool_calls_stream
)
yield function_schema.parse_args(
output = function_schema.parse_args(
self._tool_call(
tool_calls_stream, tool_call_ref, current_tool_call_id
)
)
if not tool_call_ref:
yield output
if not tool_call_ref and not self._exhausted:
# Finish the group to allow advancing to the next one
consume(
self._tool_call(
tool_calls_stream,
tool_call_ref,
current_tool_call_id,
)
)
# Output must be Iterable if parse_args above did not consume
assert isinstance(output, Iterable), output # noqa: S101
# Consume stream via the output type so it can cache
consume(output)

except ValidationError as e:
assert current_tool_call_id is not None # noqa: S101
Expand Down Expand Up @@ -201,6 +204,7 @@ def __init__(
self._state = state

self._iterator = self.__stream__()
self._exhausted: bool = False

async def __anext__(self) -> AsyncStreamedStr | OutputT:
return await self._iterator.__anext__()
Expand All @@ -220,6 +224,7 @@ async def _streamed_str(
assert not current_item_ref # noqa: S101
current_item_ref.append(item)
return
self._exhausted = True

async def _tool_call(
self,
Expand All @@ -235,6 +240,7 @@ async def _tool_call(
return
if item.args:
yield item.args
self._exhausted = True

async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | OutputT]:
stream = aapply(self._state.update, self._stream)
Expand All @@ -243,10 +249,14 @@ async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | OutputT]:
current_item = current_item_ref.pop()
if self._parser.is_content(current_item):
stream = achain(async_iter([current_item]), stream)
yield AsyncStreamedStr(self._streamed_str(stream, current_item_ref))
if not current_item_ref:
streamed_str = AsyncStreamedStr(
self._streamed_str(stream, current_item_ref)
)
yield streamed_str
if not current_item_ref and not self._exhausted:
# Finish the group to allow advancing to the next one
await aconsume(self._streamed_str(stream, current_item_ref))
# Consume stream via AsyncStreamedStr so it can cache
await aconsume(streamed_str)
elif self._parser.is_tool_call(current_item):
tool_calls_stream: AsyncIterator[FunctionCallChunk] = (
tool_call_chunk
Expand All @@ -273,20 +283,18 @@ async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | OutputT]:
tool_calls_stream = achain(
async_iter([current_tool_call_chunk]), tool_calls_stream
)
yield await function_schema.aparse_args(
output = await function_schema.aparse_args(
self._tool_call(
tool_calls_stream, tool_call_ref, current_tool_call_id
)
)
if not tool_call_ref:
yield output
if not tool_call_ref and not self._exhausted:
# Finish the group to allow advancing to the next one
await aconsume(
self._tool_call(
tool_calls_stream,
tool_call_ref,
current_tool_call_id,
)
)
# Output must be AsyncIterable if aparse_args above did not consume
assert isinstance(output, AsyncIterable), output # noqa: S101
# Consume stream via the output type so it can cache
await aconsume(output)
except ValidationError as e:
assert current_tool_call_id is not None # noqa: S101
raise ToolSchemaParseError(
Expand Down
2 changes: 1 addition & 1 deletion src/magentic/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ async def atakewhile(
yield item


def consume(iterator: Iterator[T]) -> None:
def consume(iterator: Iterable[T]) -> None:
"""Consume an iterator."""
collections.deque(iterator, maxlen=0)

Expand Down
4 changes: 4 additions & 0 deletions tests/chat_model/test_anthropic_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def get_weather(location: str) -> None:
assert len(response_items) == 2
streamed_str, function_call = response_items
assert isinstance(streamed_str, StreamedStr)
assert len(streamed_str.to_string()) > 1 # Check StreamedStr was cached
assert isinstance(function_call, FunctionCall)
assert function_call() is None # Check FunctionCall is successfully called


@pytest.mark.parametrize(
Expand Down Expand Up @@ -257,4 +259,6 @@ def get_weather(location: str) -> None:
assert len(response_items) == 2
streamed_str, function_call = response_items
assert isinstance(streamed_str, AsyncStreamedStr)
assert len(await streamed_str.to_string()) > 1 # Check AsyncStreamedStr was cached
assert isinstance(function_call, FunctionCall)
assert function_call() is None # Check FunctionCall is successfully called
4 changes: 4 additions & 0 deletions tests/chat_model/test_openai_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ def get_weather(location: str) -> None:
assert len(response_items) == 2
streamed_str, function_call = response_items
assert isinstance(streamed_str, StreamedStr)
assert len(streamed_str.to_string()) > 1 # Check StreamedStr was cached
assert isinstance(function_call, FunctionCall)
assert function_call() is None # Check FunctionCall is successfully called


@pytest.mark.openai
Expand Down Expand Up @@ -290,7 +292,9 @@ def get_weather(location: str) -> None:
assert len(response_items) == 2
streamed_str, function_call = response_items
assert isinstance(streamed_str, AsyncStreamedStr)
assert len(await streamed_str.to_string()) > 1 # Check AsyncStreamedStr was cached
assert isinstance(function_call, FunctionCall)
assert function_call() is None # Check FunctionCall is successfully called


@pytest.mark.openai
Expand Down

0 comments on commit f7e7e37

Please sign in to comment.