From f7e7e37117eb1e0ba2943ca129cbc4102b4a3d74 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Sat, 30 Nov 2024 16:10:27 -0800 Subject: [PATCH] Consume LLM output stream via returned objects to allow caching (#384) * Consume stream via objects to allow caching * Update tests to check caching * Change StreamedStr consume warning to tip in docs --- docs/streaming.md | 4 +- docs/structured-outputs.md | 2 +- src/magentic/chat_model/stream.py | 58 +++++++++++-------- src/magentic/streaming.py | 2 +- tests/chat_model/test_anthropic_chat_model.py | 4 ++ tests/chat_model/test_openai_chat_model.py | 4 ++ 6 files changed, 45 insertions(+), 29 deletions(-) diff --git a/docs/streaming.md b/docs/streaming.md index 348f350..cfc2bc2 100644 --- a/docs/streaming.md +++ b/docs/streaming.md @@ -84,9 +84,9 @@ for hero in create_superhero_team("The Food Dudes"): Some LLMs have the ability to generate text output and make tool calls in the same response. This allows them to perform chain-of-thought reasoning or provide additional context to the user. In magentic, the `StreamedResponse` (or `AsyncStreamedResponse`) class can be used to request this type of output. This object is an iterable of `StreamedStr` (or `AsyncStreamedStr`) and `FunctionCall` instances. -!!! warning "Consuming StreamedStr" +!!! tip "Consuming StreamedStr" - The StreamedStr object must be iterated over before the next item in the `StreamedResponse` is processed, otherwise the string output will be lost. This is because the `StreamedResponse` and `StreamedStr` share the same underlying generator, so advancing the `StreamedResponse` iterator skips over the `StreamedStr` items. The `StreamedStr` object has internal caching so after iterating over it once the chunks will remain available. + The StreamedStr object caches its chunks internally, so it does not have to be consumed immediately. This means you can iterate over the chunks as they are received, and/or use the StreamedStr object as a whole after the LLM has finished generating the output. In the example below, we request that the LLM generates a greeting and then calls a function to get the weather for two cities. The `StreamedResponse` object is then iterated over to print the output, and the `StreamedStr` and `FunctionCall` items are processed separately. diff --git a/docs/structured-outputs.md b/docs/structured-outputs.md index 1b1e1f6..059aaba 100644 --- a/docs/structured-outputs.md +++ b/docs/structured-outputs.md @@ -150,7 +150,7 @@ print(hero_defeated) !!! warning "StreamedResponse" - It is now recommended to use `StreamedResponse` for chain-of-thought prompting, as this uses the LLM provider's native chain-of-thought capabilities. See [StreamedResponse](streaming.md#StreamedResponse) for more information. + It is now recommended to use `StreamedResponse` for chain-of-thought prompting, as this uses the LLM provider's native chain-of-thought capabilities. See [StreamedResponse](streaming.md#streamedresponse) for more information. Using a simple Python type as the return annotation might result in poor results as the LLM has no time to arrange its thoughts before answering. To allow the LLM to work through this "chain of thought" you can instead return a pydantic model with initial fields for explaining the final response. diff --git a/src/magentic/chat_model/stream.py b/src/magentic/chat_model/stream.py index da08f28..26078aa 100644 --- a/src/magentic/chat_model/stream.py +++ b/src/magentic/chat_model/stream.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Iterable, Iterator +from collections.abc import AsyncIterable, AsyncIterator, Iterable, Iterator from itertools import chain from typing import Any, Generic, NamedTuple, TypeVar @@ -81,6 +81,7 @@ def __init__( self._state = state self._iterator = self.__stream__() + self._exhausted: bool = False def __next__(self) -> StreamedStr | OutputT: return self._iterator.__next__() @@ -99,6 +100,7 @@ def _streamed_str( assert not current_item_ref # noqa: S101 current_item_ref.append(item) return + self._exhausted = True def _tool_call( self, @@ -116,6 +118,7 @@ def _tool_call( return if item.args: yield item.args + self._exhausted = True def __stream__(self) -> Iterator[StreamedStr | OutputT]: # This works similarly to `itertools.groupby` @@ -125,10 +128,12 @@ def __stream__(self) -> Iterator[StreamedStr | OutputT]: current_item = current_item_ref.pop() if self._parser.is_content(current_item): stream = chain([current_item], stream) - yield StreamedStr(self._streamed_str(stream, current_item_ref)) - if not current_item_ref: + streamed_str = StreamedStr(self._streamed_str(stream, current_item_ref)) + yield streamed_str + if not current_item_ref and not self._exhausted: # Finish the group to allow advancing to the next one - consume(self._streamed_str(stream, current_item_ref)) + # Consume stream via StreamedStr so it can cache + consume(streamed_str) elif self._parser.is_tool_call(current_item): tool_calls_stream: Iterator[FunctionCallChunk] = ( tool_call_chunk @@ -155,20 +160,18 @@ def __stream__(self) -> Iterator[StreamedStr | OutputT]: tool_calls_stream = chain( [current_tool_call_chunk], tool_calls_stream ) - yield function_schema.parse_args( + output = function_schema.parse_args( self._tool_call( tool_calls_stream, tool_call_ref, current_tool_call_id ) ) - if not tool_call_ref: + yield output + if not tool_call_ref and not self._exhausted: # Finish the group to allow advancing to the next one - consume( - self._tool_call( - tool_calls_stream, - tool_call_ref, - current_tool_call_id, - ) - ) + # Output must be Iterable if parse_args above did not consume + assert isinstance(output, Iterable), output # noqa: S101 + # Consume stream via the output type so it can cache + consume(output) except ValidationError as e: assert current_tool_call_id is not None # noqa: S101 @@ -201,6 +204,7 @@ def __init__( self._state = state self._iterator = self.__stream__() + self._exhausted: bool = False async def __anext__(self) -> AsyncStreamedStr | OutputT: return await self._iterator.__anext__() @@ -220,6 +224,7 @@ async def _streamed_str( assert not current_item_ref # noqa: S101 current_item_ref.append(item) return + self._exhausted = True async def _tool_call( self, @@ -235,6 +240,7 @@ async def _tool_call( return if item.args: yield item.args + self._exhausted = True async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | OutputT]: stream = aapply(self._state.update, self._stream) @@ -243,10 +249,14 @@ async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | OutputT]: current_item = current_item_ref.pop() if self._parser.is_content(current_item): stream = achain(async_iter([current_item]), stream) - yield AsyncStreamedStr(self._streamed_str(stream, current_item_ref)) - if not current_item_ref: + streamed_str = AsyncStreamedStr( + self._streamed_str(stream, current_item_ref) + ) + yield streamed_str + if not current_item_ref and not self._exhausted: # Finish the group to allow advancing to the next one - await aconsume(self._streamed_str(stream, current_item_ref)) + # Consume stream via AsyncStreamedStr so it can cache + await aconsume(streamed_str) elif self._parser.is_tool_call(current_item): tool_calls_stream: AsyncIterator[FunctionCallChunk] = ( tool_call_chunk @@ -273,20 +283,18 @@ async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | OutputT]: tool_calls_stream = achain( async_iter([current_tool_call_chunk]), tool_calls_stream ) - yield await function_schema.aparse_args( + output = await function_schema.aparse_args( self._tool_call( tool_calls_stream, tool_call_ref, current_tool_call_id ) ) - if not tool_call_ref: + yield output + if not tool_call_ref and not self._exhausted: # Finish the group to allow advancing to the next one - await aconsume( - self._tool_call( - tool_calls_stream, - tool_call_ref, - current_tool_call_id, - ) - ) + # Output must be AsyncIterable if aparse_args above did not consume + assert isinstance(output, AsyncIterable), output # noqa: S101 + # Consume stream via the output type so it can cache + await aconsume(output) except ValidationError as e: assert current_tool_call_id is not None # noqa: S101 raise ToolSchemaParseError( diff --git a/src/magentic/streaming.py b/src/magentic/streaming.py index fdc6881..f0911d6 100644 --- a/src/magentic/streaming.py +++ b/src/magentic/streaming.py @@ -85,7 +85,7 @@ async def atakewhile( yield item -def consume(iterator: Iterator[T]) -> None: +def consume(iterator: Iterable[T]) -> None: """Consume an iterator.""" collections.deque(iterator, maxlen=0) diff --git a/tests/chat_model/test_anthropic_chat_model.py b/tests/chat_model/test_anthropic_chat_model.py index 086f4da..8b88a6c 100644 --- a/tests/chat_model/test_anthropic_chat_model.py +++ b/tests/chat_model/test_anthropic_chat_model.py @@ -140,7 +140,9 @@ def get_weather(location: str) -> None: assert len(response_items) == 2 streamed_str, function_call = response_items assert isinstance(streamed_str, StreamedStr) + assert len(streamed_str.to_string()) > 1 # Check StreamedStr was cached assert isinstance(function_call, FunctionCall) + assert function_call() is None # Check FunctionCall is successfully called @pytest.mark.parametrize( @@ -257,4 +259,6 @@ def get_weather(location: str) -> None: assert len(response_items) == 2 streamed_str, function_call = response_items assert isinstance(streamed_str, AsyncStreamedStr) + assert len(await streamed_str.to_string()) > 1 # Check AsyncStreamedStr was cached assert isinstance(function_call, FunctionCall) + assert function_call() is None # Check FunctionCall is successfully called diff --git a/tests/chat_model/test_openai_chat_model.py b/tests/chat_model/test_openai_chat_model.py index 98d37aa..be6e539 100644 --- a/tests/chat_model/test_openai_chat_model.py +++ b/tests/chat_model/test_openai_chat_model.py @@ -188,7 +188,9 @@ def get_weather(location: str) -> None: assert len(response_items) == 2 streamed_str, function_call = response_items assert isinstance(streamed_str, StreamedStr) + assert len(streamed_str.to_string()) > 1 # Check StreamedStr was cached assert isinstance(function_call, FunctionCall) + assert function_call() is None # Check FunctionCall is successfully called @pytest.mark.openai @@ -290,7 +292,9 @@ def get_weather(location: str) -> None: assert len(response_items) == 2 streamed_str, function_call = response_items assert isinstance(streamed_str, AsyncStreamedStr) + assert len(await streamed_str.to_string()) > 1 # Check AsyncStreamedStr was cached assert isinstance(function_call, FunctionCall) + assert function_call() is None # Check FunctionCall is successfully called @pytest.mark.openai