diff --git a/changelog.d/20240104_184103_sirosen_attrs_payloads.rst b/changelog.d/20240104_184103_sirosen_attrs_payloads.rst new file mode 100644 index 000000000..8885f17d8 --- /dev/null +++ b/changelog.d/20240104_184103_sirosen_attrs_payloads.rst @@ -0,0 +1,29 @@ +Changed +~~~~~~~ + +- ``globus-sdk`` now depends on the ``attrs`` library. (:pr:`NUMBER`) + +Development +~~~~~~~~~~~ + +- A new component has been added for definition of payload classes, at + ``globus_sdk.payload``, based on ``attrs``. (:pr:`NUMBER`) + + - New payload classes should inherit from ``globus_sdk.payload.Payload`` + + - ``attrs``-style converter definitions are defined at + ``globus_sdk.payload.converters`` + + - ``Payload`` objects are fully supported by transport encoding, in a similar + way to ``utils.PayloadWrapper`` objects. + + - ``Payload``\s always support a field named ``extra`` which can be used to + incorporate additional data into the payload body, beyond the supported + fields. + + - ``Payload`` objects require that all of their arguments are keyword-only + + - ``Payload.extra`` assignment emits a ``RuntimeWarning`` if field names + collide with existing fields. This is the strongest signal we can give to + users that they should not do this short of emitting an error. Erroring is + not an option because it would make every field addition a breaking change. diff --git a/pyproject.toml b/pyproject.toml index ecc670f4a..7dab1b715 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ classifiers = [ ] requires-python = ">=3.7" dependencies = [ + "attrs>=23.1.0", "requests>=2.19.1,<3.0.0", "pyjwt[crypto]>=2.0.0,<3.0.0", # cryptography 3.4.0 is known-bugged, see: diff --git a/src/globus_sdk/payload/__init__.py b/src/globus_sdk/payload/__init__.py new file mode 100644 index 000000000..cbd21bcc3 --- /dev/null +++ b/src/globus_sdk/payload/__init__.py @@ -0,0 +1,4 @@ +from . import converters +from .base import Payload + +__all__ = ("Payload", "converters") diff --git a/src/globus_sdk/payload/base.py b/src/globus_sdk/payload/base.py new file mode 100644 index 000000000..44d185aaa --- /dev/null +++ b/src/globus_sdk/payload/base.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import typing as t +import warnings + +import attrs + +from globus_sdk import utils + + +def _validate_extra_against_model( + instance: Payload, + attribute: attrs.Attribute[t.Any], # pylint: disable=unused-argument + value: dict[str, t.Any] | utils.MissingType, +) -> None: + """ + Validate the 'extra' field of a Payload object against the model defined by the + Payload (sub)class. + + This is done by checking that none of the keys in the extra dict are also defined as + fields on the class. If any such fields are found, a RuntimeWarning is emitted -- + such usage is and always will be supported, but users are advised to prefer the + "real" fields whenever possible. + """ + if isinstance(value, utils.MissingType): + return + + model = instance.__class__ + model_fields = set(attrs.fields_dict(model)) + extra_fields = set(value.keys()) + + redundant_fields = model_fields & extra_fields + if redundant_fields: + warnings.warn( + f"'extra' keys overlap with defined fields for '{model.__qualname__}'. " + "'extra' will take precedence during serialization. " + f"redundant_fields={redundant_fields}", + RuntimeWarning, + stacklevel=2, + ) + + +@attrs.define(kw_only=True) +class Payload: + """ + Payload objects are used to represent the data for a request. + + The 'extra' field is always defined, and can be used to store a dict of additional + data which will be merged with the Payload object before it is sent in a request. + """ + + extra: dict[str, t.Any] | utils.MissingType = attrs.field( + default=utils.MISSING, validator=_validate_extra_against_model + ) + + def asdict(self) -> dict[str, t.Any]: + data = attrs.asdict(self) + if data["extra"] is not utils.MISSING: + data.update(data.pop("extra")) + return data diff --git a/src/globus_sdk/payload/converters.py b/src/globus_sdk/payload/converters.py new file mode 100644 index 000000000..1964d07d1 --- /dev/null +++ b/src/globus_sdk/payload/converters.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import typing as t + +import attrs + +from globus_sdk import utils + + +def str_list( + value: str | t.Iterable[t.Any] | utils.MissingType, +) -> list[str] | utils.MissingType: + if isinstance(value, utils.MissingType): + return utils.MISSING + return list(utils.safe_strseq_iter(value)) + + +nullable_str_list = attrs.converters.optional(str_list) + + +# use underscore-suffixed names for any conflicts with builtin types, following the +# convention used by sqlalchemy +def list_( + value: t.Iterable[t.Any] | utils.MissingType, +) -> list[t.Any] | utils.MissingType: + if isinstance(value, utils.MissingType): + return utils.MISSING + return list(value) + + +nullable_list = attrs.converters.optional(list_) diff --git a/src/globus_sdk/transport/encoders.py b/src/globus_sdk/transport/encoders.py index 95b331458..07d29ce79 100644 --- a/src/globus_sdk/transport/encoders.py +++ b/src/globus_sdk/transport/encoders.py @@ -6,7 +6,7 @@ import requests -from globus_sdk import utils +from globus_sdk import payload, utils class RequestEncoder: @@ -93,6 +93,9 @@ def _prepare_data(self, data: t.Any) -> t.Any: Otherwise, it is returned as-is. """ + if isinstance(data, payload.Payload): + data = data.asdict() + if isinstance(data, (dict, utils.PayloadWrapper)): return utils.filter_missing( {k: self._prepare_data(v) for k, v in data.items()} diff --git a/src/globus_sdk/transport/requests.py b/src/globus_sdk/transport/requests.py index 6455b3864..0490c6c08 100644 --- a/src/globus_sdk/transport/requests.py +++ b/src/globus_sdk/transport/requests.py @@ -9,7 +9,7 @@ import requests -from globus_sdk import config, exc, utils +from globus_sdk import config, exc, payload, utils from globus_sdk.authorizers import GlobusAuthorizer from globus_sdk.transport.encoders import ( FormRequestEncoder, @@ -217,7 +217,12 @@ def _encode( url: str, query_params: dict[str, t.Any] | None = None, data: ( - dict[str, t.Any] | list[t.Any] | utils.PayloadWrapper | str | None + dict[str, t.Any] + | list[t.Any] + | payload.Payload + | utils.PayloadWrapper + | str + | None ) = None, headers: dict[str, str] | None = None, encoding: str | None = None, @@ -267,7 +272,14 @@ def request( method: str, url: str, query_params: dict[str, t.Any] | None = None, - data: dict[str, t.Any] | list[t.Any] | utils.PayloadWrapper | str | None = None, + data: ( + dict[str, t.Any] + | list[t.Any] + | payload.Payload + | utils.PayloadWrapper + | str + | None + ) = None, headers: dict[str, str] | None = None, encoding: str | None = None, authorizer: GlobusAuthorizer | None = None, diff --git a/src/globus_sdk/utils.py b/src/globus_sdk/utils.py index fe59be44b..b23ca32be 100644 --- a/src/globus_sdk/utils.py +++ b/src/globus_sdk/utils.py @@ -134,6 +134,14 @@ class PayloadWrapper(PayloadWrapperBase): requested encoder (e.g. as a JSON request body). """ + # XXX: DEVELOPER NOTE + # + # this class is our long-standing/legacy method of defining payload helper objects + # for any new cases, the `globus_sdk.payload` module should be preferred, offering a + # `dataclasses`-based approach for building MISSING-aware payload helpers with the + # ability to set callback types + # + # use UserDict rather than subclassing dict so that our API is always consistent # e.g. `dict.pop` does not invoke `dict.__delitem__`. Overriding `__delitem__` on a # dict subclass can lead to inconsistent behavior between usages like these: diff --git a/tests/non-pytest/mypy-ignore-tests/payload_classes.py b/tests/non-pytest/mypy-ignore-tests/payload_classes.py new file mode 100644 index 000000000..a79ef8d93 --- /dev/null +++ b/tests/non-pytest/mypy-ignore-tests/payload_classes.py @@ -0,0 +1,58 @@ +# test behaviors of the globus_sdk.payload usage of dataclasses + +import typing as t + +import attrs + +from globus_sdk import payload, utils + +my_str: str +my_int: int +my_optstr: str | None + + +@attrs.define +class MyPayloadType1(payload.Payload): + foo: str + bar: int + + +doc1 = MyPayloadType1(foo="foo", bar=1) +my_str = doc1.foo +my_int = doc1.bar +my_optstr = doc1.foo +my_str = doc1.bar # type: ignore[assignment] +my_int = doc1.foo # type: ignore[assignment] + +doc1_extra = MyPayloadType1(foo="foo", bar=1, extra={"extra": "somedata"}) + + +@attrs.define +class MyPayloadType2(payload.Payload): + foo: str | utils.MissingType = attrs.field(default=utils.MISSING) + + +doc2 = MyPayloadType2() +my_str = doc2.foo # type: ignore[assignment] +my_missingstr: str | utils.MissingType = doc2.foo + + +@attrs.define +class MyPayloadType3(payload.Payload): + foo: t.Iterable[str] | utils.MissingType = attrs.field( + default=utils.MISSING, converter=payload.converters.str_list + ) + + +doc3 = MyPayloadType3(str(i) for i in range(3)) +assert not isinstance(doc3.foo, utils.MissingType) +# in spite of the application of the converter, the type is not narrowed from the +# annotated type (Iterable[str]) to the converted type (list[str]) +# +# this is a limitations in mypy; see: +# https://github.com/python/mypy/issues/3004 +# +# it *may* be resolved when `dataclasses` adds support for converters and mypy supports +# that usage, as the `attrs` plugin could use the dataclass converter support path +my_str = doc3.foo[0] # type: ignore[index] +t.assert_type(doc3.foo, t.Iterable[str]) diff --git a/tests/unit/test_payload.py b/tests/unit/test_payload.py new file mode 100644 index 000000000..8f84d3d2f --- /dev/null +++ b/tests/unit/test_payload.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +import typing as t + +import attrs +import pytest + +from globus_sdk import payload, utils +from globus_sdk.transport import JSONRequestEncoder + + +def _serialize(obj: payload.Payload): + encoder = JSONRequestEncoder() + return encoder._prepare_data(obj) + + +def test_simple_payload_class_serialization(): + @attrs.define + class MyPayloadType(payload.Payload): + foo: str + bar: int + + doc = _serialize(MyPayloadType(foo="foo", bar=1)) + assert doc == {"foo": "foo", "bar": 1} + + +def test_simple_payload_class_serialization_with_extra(): + @attrs.define + class MyPayloadType(payload.Payload): + foo: str + bar: int + + doc = _serialize(MyPayloadType(foo="foo", bar=1, extra={"baz": "baz"})) + assert doc == {"foo": "foo", "bar": 1, "baz": "baz"} + + +@pytest.mark.parametrize( + "input_value, expected", + ( + (["bar"], ["bar"]), + (("bar", "baz"), ["bar", "baz"]), + ("bar", ["bar"]), + (range(3), ["0", "1", "2"]), + ), +) +def test_payload_strlist_field(input_value, expected): + @attrs.define + class MyPayloadType(payload.Payload): + foo: t.Iterable[str] | utils.MissingType = attrs.field( + default=utils.MISSING, + converter=payload.converters.str_list, + ) + + # works via init + doc = _serialize(MyPayloadType(foo=input_value)) + assert doc == {"foo": expected} + + # works via setattr + p = MyPayloadType() + p.foo = input_value + assert _serialize(p) == {"foo": expected} + + +@pytest.mark.parametrize( + "input_value, expected", + ( + (None, None), + (("bar", "baz"), ["bar", "baz"]), + ("bar", ["bar"]), + (range(3), ["0", "1", "2"]), + ), +) +def test_payload_nullable_strlist_field(input_value, expected): + @attrs.define + class MyPayloadType(payload.Payload): + foo: t.Iterable[str] | None | utils.MissingType = attrs.field( + default=utils.MISSING, + converter=payload.converters.nullable_str_list, + ) + + # works via init + doc = _serialize(MyPayloadType(foo=input_value)) + assert doc == {"foo": expected} + + # works via setattr + p = MyPayloadType() + p.foo = input_value + assert _serialize(p) == {"foo": expected} + + +def test_encoder_recursively_serializes_payloads(): + @attrs.define + class MyPayloadType(payload.Payload): + foo: str + bar: int + + @attrs.define + class MyPayloadType2(payload.Payload): + baz: MyPayloadType + + doc = _serialize(MyPayloadType2(baz=MyPayloadType(foo="foo", bar=1))) + assert doc == {"baz": {"foo": "foo", "bar": 1}} + + +def test_extra_emits_warnings_when_aligned_with_fields_but_acts_as_override(): + @attrs.define + class MyPayloadType(payload.Payload): + foo: str + + with pytest.warns( + RuntimeWarning, + match="'extra' keys overlap with defined fields.+redundant_fields={'foo'}", + ): + x = MyPayloadType(foo="foo", extra={"foo": "bar"}) + doc = _serialize(x) + assert doc == {"foo": "bar"} + + +def test_non_nullable_list_converter_fails_on_none(): + @attrs.define + class MyPayloadType(payload.Payload): + foo: t.Iterable[dict] = attrs.field(converter=payload.converters.list_) + + with pytest.raises(TypeError): + MyPayloadType(foo=None) + + +def test_nullable_list_converter_allows_none(): + @attrs.define + class MyPayloadType(payload.Payload): + foo: t.Iterable[dict] = attrs.field(converter=payload.converters.nullable_list) + + doc = _serialize(MyPayloadType(foo=None)) + assert doc == {"foo": None}