From af31a8707581174f8be4d68583d26f34fc603bc4 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Sat, 30 Nov 2024 17:23:09 -0800 Subject: [PATCH] Improve ruff format/lint rules (#385) * Add split-on-trailing-comma for imports * make format * Allow S101 assert, remove noqa comments * Reenable temporarily disabled ruff rules --- pyproject.toml | 6 ++-- .../chat_model/anthropic_chat_model.py | 24 ++++------------ src/magentic/chat_model/base.py | 4 +-- src/magentic/chat_model/litellm_chat_model.py | 18 ++++++------ src/magentic/chat_model/openai_chat_model.py | 28 +++++-------------- src/magentic/chat_model/retry_chat_model.py | 11 ++------ src/magentic/chat_model/stream.py | 28 +++++++++---------- src/magentic/chatprompt.py | 10 +------ src/magentic/function_call.py | 8 +----- src/magentic/prompt_chain.py | 7 +---- src/magentic/prompt_function.py | 10 +------ src/magentic/typing.py | 11 ++------ tests/test_backend.py | 4 +-- 13 files changed, 48 insertions(+), 121 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 84d6d8d1..8165cce1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,6 +108,7 @@ ignore = [ "TD", # flake8-todos "FIX", # flake8-fixme "PL", # Pylint + "S101", # assert # Compatibility with ruff formatter "E501", "ISC001", @@ -116,10 +117,6 @@ ignore = [ "Q002", "Q003", "W191", - "B905", # TODO: Reenable this - "UP006", # TODO: Reenable this - "UP007", # TODO: Reenable this - "UP035", # TODO: Reenable this ] [tool.ruff.lint.flake8-pytest-style] @@ -127,6 +124,7 @@ mark-parentheses = false [tool.ruff.lint.isort] known-first-party = ["magentic"] +split-on-trailing-comma = false [tool.ruff.lint.per-file-ignores] "docs/examples/*" = [ diff --git a/src/magentic/chat_model/anthropic_chat_model.py b/src/magentic/chat_model/anthropic_chat_model.py index fe60ded3..1f949a85 100644 --- a/src/magentic/chat_model/anthropic_chat_model.py +++ b/src/magentic/chat_model/anthropic_chat_model.py @@ -1,12 +1,6 @@ import base64 import json -from collections.abc import ( - AsyncIterator, - Callable, - Iterable, - Iterator, - Sequence, -) +from collections.abc import AsyncIterator, Callable, Iterable, Iterator, Sequence from enum import Enum from functools import singledispatch from itertools import groupby @@ -15,11 +9,7 @@ import filetype from magentic._parsing import contains_parallel_function_call_type, contains_string_type -from magentic.chat_model.base import ( - ChatModel, - aparse_stream, - parse_stream, -) +from magentic.chat_model.base import ChatModel, aparse_stream, parse_stream from magentic.chat_model.function_schema import ( BaseFunctionSchema, FunctionCallFunctionSchema, @@ -44,11 +34,7 @@ StreamParser, StreamState, ) -from magentic.function_call import ( - FunctionCall, - ParallelFunctionCall, - _create_unique_id, -) +from magentic.function_call import FunctionCall, ParallelFunctionCall, _create_unique_id from magentic.vision import UserImageMessage try: @@ -273,7 +259,7 @@ def update(self, item: MessageStreamEvent) -> None: current_snapshot=self._current_message_snapshot, ) if item.type == "message_stop": - assert not self.usage_ref # noqa: S101 + assert not self.usage_ref self.usage_ref.append( Usage( input_tokens=item.message.usage.input_tokens, @@ -283,7 +269,7 @@ def update(self, item: MessageStreamEvent) -> None: @property def current_message_snapshot(self) -> Message[Any]: - assert self._current_message_snapshot is not None # noqa: S101 + assert self._current_message_snapshot is not None # TODO: Possible to return AssistantMessage here? return _RawMessage(self._current_message_snapshot.model_dump()) diff --git a/src/magentic/chat_model/base.py b/src/magentic/chat_model/base.py index 46570338..1268fa9a 100644 --- a/src/magentic/chat_model/base.py +++ b/src/magentic/chat_model/base.py @@ -1,9 +1,9 @@ import types from abc import ABC, abstractmethod -from collections.abc import Callable, Iterable +from collections.abc import AsyncIterator, Callable, Iterable, Iterator from contextvars import ContextVar from itertools import chain -from typing import Any, AsyncIterator, Iterator, TypeVar, cast, get_origin, overload +from typing import Any, TypeVar, cast, get_origin, overload from pydantic import ValidationError diff --git a/src/magentic/chat_model/litellm_chat_model.py b/src/magentic/chat_model/litellm_chat_model.py index f38c9d77..ad903502 100644 --- a/src/magentic/chat_model/litellm_chat_model.py +++ b/src/magentic/chat_model/litellm_chat_model.py @@ -37,20 +37,20 @@ class LitellmStreamParser(StreamParser[ModelResponse]): def is_content(self, item: ModelResponse) -> bool: - assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 + assert isinstance(item.choices[0], StreamingChoices) return bool(item.choices[0].delta.content) def get_content(self, item: ModelResponse) -> str | None: - assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 - assert isinstance(item.choices[0].delta.content, str | None) # noqa: S101 + assert isinstance(item.choices[0], StreamingChoices) + assert isinstance(item.choices[0].delta.content, str | None) return item.choices[0].delta.content def is_tool_call(self, item: ModelResponse) -> bool: - assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 + assert isinstance(item.choices[0], StreamingChoices) return bool(item.choices[0].delta.tool_calls) def iter_tool_calls(self, item: ModelResponse) -> Iterable[FunctionCallChunk]: - assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 + assert isinstance(item.choices[0], StreamingChoices) if item.choices and item.choices[0].delta.tool_calls: for tool_call in item.choices[0].delta.tool_calls: if tool_call.function: @@ -75,13 +75,13 @@ def update(self, item: ModelResponse) -> None: # litellm requires usage is not None for its total usage calculation item.usage = litellm.Usage() # type: ignore[attr-defined] if not hasattr(item, "refusal"): - assert isinstance(item.choices[0], StreamingChoices) # noqa: S101 + assert isinstance(item.choices[0], StreamingChoices) item.choices[0].delta.refusal = None # type: ignore[attr-defined] self._chat_completion_stream_state.handle_chunk(item) # type: ignore[arg-type] usage = cast(litellm.Usage, item.usage) # type: ignore[attr-defined,name-defined] # Ignore usages with 0 tokens if usage and usage.prompt_tokens and usage.completion_tokens: - assert not self.usage_ref # noqa: S101 + assert not self.usage_ref self.usage_ref.append( Usage( input_tokens=usage.prompt_tokens, @@ -210,7 +210,7 @@ def complete( tool_schemas=tool_schemas, output_types=output_types ), # type: ignore[arg-type,unused-ignore] ) - assert not isinstance(response, ModelResponse) # noqa: S101 + assert not isinstance(response, ModelResponse) stream = OutputStream( stream=response, function_schemas=function_schemas, @@ -270,7 +270,7 @@ async def acomplete( tool_schemas=tool_schemas, output_types=output_types ), # type: ignore[arg-type,unused-ignore] ) - assert not isinstance(response, ModelResponse) # noqa: S101 + assert not isinstance(response, ModelResponse) stream = AsyncOutputStream( stream=response, function_schemas=function_schemas, diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index b1289487..1cf65376 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -1,11 +1,5 @@ import base64 -from collections.abc import ( - AsyncIterator, - Callable, - Iterable, - Iterator, - Sequence, -) +from collections.abc import AsyncIterator, Callable, Iterable, Iterator, Sequence from enum import Enum from functools import singledispatch from typing import Any, Generic, Literal, TypeVar, cast, overload @@ -24,11 +18,7 @@ from magentic._parsing import contains_parallel_function_call_type, contains_string_type from magentic._streamed_response import StreamedResponse -from magentic.chat_model.base import ( - ChatModel, - aparse_stream, - parse_stream, -) +from magentic.chat_model.base import ChatModel, aparse_stream, parse_stream from magentic.chat_model.function_schema import ( BaseFunctionSchema, FunctionCallFunctionSchema, @@ -53,11 +43,7 @@ StreamParser, StreamState, ) -from magentic.function_call import ( - FunctionCall, - ParallelFunctionCall, - _create_unique_id, -) +from magentic.function_call import FunctionCall, ParallelFunctionCall, _create_unique_id from magentic.streaming import StreamedStr from magentic.vision import UserImageMessage @@ -78,9 +64,9 @@ def message_to_openai_message(message: Message[Any]) -> ChatCompletionMessagePar @message_to_openai_message.register(_RawMessage) def _(message: _RawMessage[Any]) -> ChatCompletionMessageParam: - assert isinstance(message.content, dict) # noqa: S101 - assert "role" in message.content # noqa: S101 - assert "content" in message.content # noqa: S101 + assert isinstance(message.content, dict) + assert "role" in message.content + assert "content" in message.content return cast(ChatCompletionMessageParam, message.content) @@ -316,7 +302,7 @@ def update(self, item: ChatCompletionChunk) -> None: tool_call_chunk.index = self._current_tool_call_index self._chat_completion_stream_state.handle_chunk(item) if item.usage: - assert not self.usage_ref # noqa: S101 + assert not self.usage_ref self.usage_ref.append( Usage( input_tokens=item.usage.prompt_tokens, diff --git a/src/magentic/chat_model/retry_chat_model.py b/src/magentic/chat_model/retry_chat_model.py index a5c901b6..b7abf673 100644 --- a/src/magentic/chat_model/retry_chat_model.py +++ b/src/magentic/chat_model/retry_chat_model.py @@ -2,15 +2,8 @@ from functools import singledispatchmethod from typing import Any, TypeVar, overload -from magentic.chat_model.base import ( - ChatModel, - ToolSchemaParseError, -) -from magentic.chat_model.message import ( - AssistantMessage, - Message, - ToolResultMessage, -) +from magentic.chat_model.base import ChatModel, ToolSchemaParseError +from magentic.chat_model.message import AssistantMessage, Message, ToolResultMessage from magentic.logger import logfire R = TypeVar("R") diff --git a/src/magentic/chat_model/stream.py b/src/magentic/chat_model/stream.py index 26078aa9..c5b57c8c 100644 --- a/src/magentic/chat_model/stream.py +++ b/src/magentic/chat_model/stream.py @@ -97,7 +97,7 @@ def _streamed_str( yield content if self._parser.is_tool_call(item): # TODO: Check if output types allow for early return and raise if not - assert not current_item_ref # noqa: S101 + assert not current_item_ref current_item_ref.append(item) return self._exhausted = True @@ -113,7 +113,7 @@ def _tool_call( # so that the whole stream is consumed including stop_reason/usage chunks if item.id and item.id != current_tool_call_id: # TODO: Check if output types allow for early return and raise if not - assert not current_tool_call_ref # noqa: S101 + assert not current_tool_call_ref current_tool_call_ref.append(item) return if item.args: @@ -144,13 +144,13 @@ def __stream__(self) -> Iterator[StreamedStr | OutputT]: while tool_call_ref: current_tool_call_chunk = tool_call_ref.pop() current_tool_call_id = current_tool_call_chunk.id - assert current_tool_call_id is not None # noqa: S101 - assert current_tool_call_chunk.name is not None # noqa: S101 + assert current_tool_call_id is not None + assert current_tool_call_chunk.name is not None function_schema = select_function_schema( self._function_schemas, current_tool_call_chunk.name ) if function_schema is None: - assert current_tool_call_id is not None # noqa: S101 + assert current_tool_call_id is not None raise UnknownToolError( output_message=self._state.current_message_snapshot, tool_call_id=current_tool_call_id, @@ -169,12 +169,12 @@ def __stream__(self) -> Iterator[StreamedStr | OutputT]: if not tool_call_ref and not self._exhausted: # Finish the group to allow advancing to the next one # Output must be Iterable if parse_args above did not consume - assert isinstance(output, Iterable), output # noqa: S101 + assert isinstance(output, Iterable), output # Consume stream via the output type so it can cache consume(output) except ValidationError as e: - assert current_tool_call_id is not None # noqa: S101 + assert current_tool_call_id is not None raise ToolSchemaParseError( output_message=self._state.current_message_snapshot, tool_call_id=current_tool_call_id, @@ -221,7 +221,7 @@ async def _streamed_str( yield content if self._parser.is_tool_call(item): # TODO: Check if output types allow for early return - assert not current_item_ref # noqa: S101 + assert not current_item_ref current_item_ref.append(item) return self._exhausted = True @@ -235,7 +235,7 @@ async def _tool_call( async for item in stream: if item.id and item.id != current_tool_call_id: # TODO: Check if output types allow for early return - assert not current_tool_call_ref # noqa: S101 + assert not current_tool_call_ref current_tool_call_ref.append(item) return if item.args: @@ -267,13 +267,13 @@ async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | OutputT]: while tool_call_ref: current_tool_call_chunk = tool_call_ref.pop() current_tool_call_id = current_tool_call_chunk.id - assert current_tool_call_id is not None # noqa: S101 - assert current_tool_call_chunk.name is not None # noqa: S101 + assert current_tool_call_id is not None + assert current_tool_call_chunk.name is not None function_schema = select_function_schema( self._function_schemas, current_tool_call_chunk.name ) if function_schema is None: - assert current_tool_call_id is not None # noqa: S101 + assert current_tool_call_id is not None raise UnknownToolError( output_message=self._state.current_message_snapshot, tool_call_id=current_tool_call_id, @@ -292,11 +292,11 @@ async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | OutputT]: if not tool_call_ref and not self._exhausted: # Finish the group to allow advancing to the next one # Output must be AsyncIterable if aparse_args above did not consume - assert isinstance(output, AsyncIterable), output # noqa: S101 + assert isinstance(output, AsyncIterable), output # Consume stream via the output type so it can cache await aconsume(output) except ValidationError as e: - assert current_tool_call_id is not None # noqa: S101 + assert current_tool_call_id is not None raise ToolSchemaParseError( output_message=self._state.current_message_snapshot, tool_call_id=current_tool_call_id, diff --git a/src/magentic/chatprompt.py b/src/magentic/chatprompt.py index 906d3b38..dfdc29db 100644 --- a/src/magentic/chatprompt.py +++ b/src/magentic/chatprompt.py @@ -1,15 +1,7 @@ import inspect from collections.abc import Awaitable, Callable, Sequence from functools import update_wrapper -from typing import ( - Any, - Generic, - ParamSpec, - Protocol, - TypeVar, - cast, - overload, -) +from typing import Any, Generic, ParamSpec, Protocol, TypeVar, cast, overload from magentic.backend import get_chat_model from magentic.chat_model.base import ChatModel diff --git a/src/magentic/function_call.py b/src/magentic/function_call.py index 73998edb..6a2dd365 100644 --- a/src/magentic/function_call.py +++ b/src/magentic/function_call.py @@ -8,13 +8,7 @@ Iterable, Iterator, ) -from typing import ( - Any, - Generic, - ParamSpec, - TypeVar, - cast, -) +from typing import Any, Generic, ParamSpec, TypeVar, cast from uuid import uuid4 from magentic.logger import logfire diff --git a/src/magentic/prompt_chain.py b/src/magentic/prompt_chain.py index ab9b7fa8..0d6434bd 100644 --- a/src/magentic/prompt_chain.py +++ b/src/magentic/prompt_chain.py @@ -1,12 +1,7 @@ import inspect from collections.abc import Callable from functools import wraps -from typing import ( - Any, - ParamSpec, - TypeVar, - cast, -) +from typing import Any, ParamSpec, TypeVar, cast from magentic.chat import Chat from magentic.chat_model.base import ChatModel diff --git a/src/magentic/prompt_function.py b/src/magentic/prompt_function.py index 42905e8f..08dfdee9 100644 --- a/src/magentic/prompt_function.py +++ b/src/magentic/prompt_function.py @@ -2,15 +2,7 @@ import inspect from collections.abc import Awaitable, Callable, Sequence from functools import update_wrapper -from typing import ( - Any, - Generic, - ParamSpec, - Protocol, - TypeVar, - cast, - overload, -) +from typing import Any, Generic, ParamSpec, Protocol, TypeVar, cast, overload from magentic.backend import get_chat_model from magentic.chat_model.base import ChatModel diff --git a/src/magentic/typing.py b/src/magentic/typing.py index 21045e92..226bdad7 100644 --- a/src/magentic/typing.py +++ b/src/magentic/typing.py @@ -1,14 +1,7 @@ import inspect import types from collections.abc import Iterable, Mapping, Sequence -from typing import ( - Any, - TypeGuard, - TypeVar, - Union, - get_args, - get_origin, -) +from typing import Any, TypeGuard, TypeVar, Union, get_args, get_origin def is_union_type(type_: type) -> bool: @@ -72,7 +65,7 @@ def name_type(type_: type) -> str: return name_type(origin) + "_" + "_".join(name_type(arg) for arg in args) if name := getattr(type_, "__name__", None): - assert isinstance(name, str) # noqa: S101 + assert isinstance(name, str) if len(args) == 1: return f"{name.lower()}_of_{name_type(args[0])}" diff --git a/tests/test_backend.py b/tests/test_backend.py index a0db208a..590b53b1 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -5,9 +5,7 @@ from magentic.chat_model.litellm_chat_model import LitellmChatModel from magentic.chat_model.message import AssistantMessage, UserMessage from magentic.chat_model.mistral_chat_model import MistralChatModel -from magentic.chat_model.openai_chat_model import ( - OpenaiChatModel, -) +from magentic.chat_model.openai_chat_model import OpenaiChatModel def test_backend_anthropic_chat_model(monkeypatch):