diff --git a/circuitmatter/tlv.py b/circuitmatter/tlv.py index 387c298..b74ff11 100644 --- a/circuitmatter/tlv.py +++ b/circuitmatter/tlv.py @@ -1,8 +1,22 @@ +from __future__ import annotations + import enum import math import struct -from typing import Any, Optional, Type, Union -from typing import Literal +from abc import ABC, abstractmethod +from typing import ( + AnyStr, + Generic, + Iterable, + Literal, + Optional, + Type, + TypeVar, + Union, + overload, +) + +from typing_extensions import Buffer # As a byte string to save space. TAG_LENGTH = b"\x00\x01\x02\x04\x02\x04\x06\x08" @@ -26,8 +40,8 @@ class ElementType(enum.IntEnum): class TLVStructure: _max_length = None - def __init__(self, buffer=None): - self.buffer: memoryview = buffer + def __init__(self, buffer: Optional[Buffer] = None): + self.buffer = memoryview(buffer) if buffer is not None else None # These three dicts are keyed by tag. self.tag_value_offset = {} self.null_tags = set() @@ -38,20 +52,12 @@ def __init__(self, buffer=None): @classmethod def max_length(cls): if cls._max_length is None: - cls._max_length = 0 - for field in vars(cls): - descriptor_class = vars(cls)[field] - if field.startswith("_") or not isinstance(descriptor_class, Member): - continue - cls._max_length += descriptor_class.max_length + cls._max_length = sum(member.max_length for _, member in cls._members()) return cls._max_length def __str__(self): members = [] - for field in vars(type(self)): - descriptor_class = vars(type(self))[field] - if field.startswith("_") or not isinstance(descriptor_class, Member): - continue + for field, descriptor_class in self._members(): value = descriptor_class.print(self) if isinstance(descriptor_class, StructMember): value = value.replace("\n", "\n ") @@ -63,14 +69,17 @@ def encode(self) -> memoryview: end = self.encode_into(buffer) return memoryview(buffer)[:end] - def encode_into(self, buffer, offset=0): - for field in vars(type(self)): - descriptor_class = vars(type(self))[field] - if field.startswith("_") or not isinstance(descriptor_class, Member): - continue + def encode_into(self, buffer: bytearray, offset: int = 0) -> int: + for _, descriptor_class in self._members(): offset = descriptor_class.encode_into(self, buffer, offset) return offset + @classmethod + def _members(cls) -> Iterable[tuple[str, Member]]: + for field_name, descriptor in vars(cls).items(): + if not field_name.startswith("_") and isinstance(descriptor, Member): + yield field_name, descriptor + def scan_until(self, tag): if self.buffer is None: return @@ -169,8 +178,22 @@ def scan_until(self, tag): break -class Member: - def __init__(self, tag, optional=False, nullable=False): +_T = TypeVar("_T") +_NULLABLE = TypeVar("_NULLABLE", Literal[True], Literal[False]) +_OPT = TypeVar("_OPT", Literal[True], Literal[False]) + + +class Member(ABC, Generic[_T, _OPT, _NULLABLE]): + max_value_length: int = 0 + + def __init__( + self, tag, *, optional: _OPT = False, nullable: _NULLABLE = False + ) -> None: + """ + :param optional: Indicates whether the value MAY be omitted from the encoding. + Can be used for deprecation. + :param nullable: Indicates whether a TLV Null MAY be encoded in place of a value. + """ self.tag = tag self.optional = optional self.nullable = nullable @@ -185,11 +208,23 @@ def __init__(self, tag, optional=False, nullable=False): def max_length(self): return 1 + self.tag_length + self.max_value_length + @overload def __get__( - self, - obj: Optional[TLVStructure], + self: Union[ + Member[_T, Literal[True], _NULLABLE], Member[_T, _OPT, Literal[True]] + ], + obj: TLVStructure, + objtype: Optional[Type[TLVStructure]] = None, + ) -> Optional[_T]: ... + + @overload + def __get__( + self: Member[_T, Literal[False], Literal[False]], + obj: TLVStructure, objtype: Optional[Type[TLVStructure]] = None, - ) -> Any: + ) -> _T: ... + + def __get__(self, obj, objtype=None): if self.tag in obj.cached_values: return obj.cached_values[self.tag] if self.tag not in obj.tag_value_offset: @@ -205,13 +240,25 @@ def __get__( obj.cached_values[self.tag] = value return value - def __set__(self, obj: TLVStructure, value: Any) -> None: + @overload + def __set__( + self: Union[ + Member[_T, Literal[True], _NULLABLE], Member[_T, _OPT, Literal[True]] + ], + obj: TLVStructure, + value: Optional[_T], + ) -> None: ... + @overload + def __set__( + self: Member[_T, Literal[False], Literal[False]], obj: TLVStructure, value: _T + ) -> None: ... + def __set__(self, obj, value): if value is None and not self.nullable: raise ValueError("Not nullable") obj.cached_values[self.tag] = value def encode_into(self, obj: TLVStructure, buffer: bytearray, offset: int) -> int: - value = self.__get__(obj) + value = self.__get__(obj) # type: ignore # self inference issues element_type = ElementType.NULL if value is not None: element_type = self.encode_element_type(value) @@ -227,19 +274,74 @@ def encode_into(self, obj: TLVStructure, buffer: bytearray, offset: int) -> int: buffer[offset] = self.tag offset += 1 if value is not None: - new_offset = self.encode_value_into(value, buffer, offset) + new_offset = self.encode_value_into( # type: ignore # self inference issues + value, + buffer, + offset, + ) return new_offset return offset - def print(self, obj): - value = self.__get__(obj) + def print(self, obj: TLVStructure) -> str: + value = self.__get__(obj) # type: ignore # self inference issues if value is None: return "null" return self._print(value) - -class NumberMember(Member): - def __init__(self, tag, _format, optional=False): + @abstractmethod + def decode(self, buffer: memoryview, length: int, offset: int = 0) -> _T: + "Return the decoded value at `offset` in `buffer`" + ... + + @abstractmethod + def encode_element_type(self, value: _T) -> int: + "Return Element Type Field as defined in Appendix A in the spec" + ... + + @overload + @abstractmethod + def encode_value_into( + self: Union[ + Member[_T, Literal[True], _NULLABLE], Member[_T, _OPT, Literal[True]] + ], + value: Optional[_T], + buffer: bytearray, + offset: int, + ) -> int: ... + @overload + @abstractmethod + def encode_value_into( + self: Member[_T, Literal[False], Literal[False]], + value: _T, + buffer: bytearray, + offset: int, + ) -> int: ... + @abstractmethod + def encode_value_into( + self, value: Optional[_T], buffer: bytearray, offset: int + ) -> int: + "Encode `value` into `buffer` and return the new offset" + ... + + @abstractmethod + def _print(self, value: _T) -> str: + "Return string representation of `value`" + ... + + +# number type +_NT = TypeVar("_NT", float, int) + + +class NumberMember(Member[_NT, _OPT, _NULLABLE], Generic[_NT, _OPT, _NULLABLE]): + def __init__( + self, + tag, + _format: str, + optional: _OPT = False, + nullable: _NULLABLE = False, + **kwargs, + ): self.format = _format self.integer = _format[-1].upper() in INT_SIZE self.signed = self.format.islower() @@ -253,22 +355,22 @@ def __init__(self, tag, _format, optional=False): self._element_type = ElementType.FLOAT if self.max_value_length == 8: self._element_type |= 1 - super().__init__(tag, optional) + super().__init__(tag, optional=optional, nullable=nullable, **kwargs) def __set__(self, obj, value): - if self.integer: + if value is not None and self.integer: octets = 2 ** INT_SIZE.index(self.format.upper()[-1]) bits = 8 * octets - max_size = (2 ** (bits - 1) if self.signed else 2**bits) - 1 - min_size = -max_size - 1 if self.signed else 0 + max_size: int = (2 ** (bits - 1) if self.signed else 2**bits) - 1 + min_size: int = -max_size - 1 if self.signed else 0 if not min_size <= value <= max_size: raise ValueError( f"Out of bounds for {octets} octet {'' if self.signed else 'un'}signed int" ) - super().__set__(obj, value) + super().__set__(obj, value) # type: ignore # self inference issues - def decode(self, buffer, length, offset=0): + def decode(self, buffer, length, offset=0) -> _NT: if self.integer: encoded_format = INT_SIZE[int(math.log(length, 2))] if self.format.islower(): @@ -294,23 +396,29 @@ def encode_value_into(self, value, buffer, offset) -> int: return offset + self.max_value_length -IntOctetCount = Union[Literal[1], Literal[2], Literal[4], Literal[8]] - - -class IntMember(NumberMember): +class IntMember(NumberMember[int, _OPT, _NULLABLE]): def __init__( - self, tag, /, signed: bool = True, octets: IntOctetCount = 1, optional=False + self, + tag, + *, + signed: bool = True, + octets: Literal[1, 2, 4, 8] = 1, + optional: _OPT = False, + nullable: _NULLABLE = False, + **kwargs, ): uformat = INT_SIZE[int(math.log2(octets))] # little-endian self.format = f"<{uformat.lower() if signed else uformat}" - super().__init__(tag, _format=self.format, optional=optional) + super().__init__( + tag, _format=self.format, optional=optional, nullable=nullable, **kwargs + ) -class BoolMember(Member): +class BoolMember(Member[bool, _OPT, _NULLABLE]): max_value_length = 0 - def decode(self, buffer, length, offset=0) -> bool: + def decode(self, buffer, length, offset=0): octet = buffer[offset] return octet & 1 == 1 @@ -326,21 +434,24 @@ def encode_value_into(self, value, buffer, offset) -> int: return offset -class OctetStringMember(Member): - _base_element_type = ElementType.OCTET_STRING +class StringMember(Member[AnyStr, _OPT, _NULLABLE], Generic[AnyStr, _OPT, _NULLABLE]): + _base_element_type: ElementType - def __init__(self, tag, max_length, optional=False): + def __init__( + self, + tag, + max_length, + *, + optional: _OPT = False, + nullable: _NULLABLE = False, + **kwargs, + ): self.max_value_length = max_length - length_encoding = 0 - while max_length > (256 ** (length_encoding + 1)): - length_encoding += 1 + length_encoding = int(math.log(max_length, 256)) self._element_type = self._base_element_type | length_encoding self.length_format = INT_SIZE[length_encoding] self.length_length = struct.calcsize(self.length_format) - super().__init__(tag, optional) - - def decode(self, buffer, length, offset=0): - return buffer[offset : offset + length] + super().__init__(tag, optional=optional, nullable=nullable, **kwargs) def _print(self, value): return " ".join((f"{byte:02x}" for byte in value)) @@ -348,33 +459,51 @@ def _print(self, value): def encode_element_type(self, value): return self._element_type - def encode_value_into(self, value, buffer, offset) -> int: + def encode_value_into(self, value, buffer: bytearray, offset: int) -> int: struct.pack_into(self.length_format, buffer, offset, len(value)) offset += self.length_length buffer[offset : offset + len(value)] = value return offset + len(value) -class UTF8StringMember(OctetStringMember): +class OctetStringMember(StringMember[bytes, _OPT, _NULLABLE]): + _base_element_type: ElementType = ElementType.OCTET_STRING + + def decode(self, buffer, length, offset=0): + return buffer[offset : offset + length].tobytes() + + +class UTF8StringMember(StringMember[str, _OPT, _NULLABLE]): _base_element_type = ElementType.UTF8_STRING def decode(self, buffer, length, offset=0): - return super().decode(buffer, length, offset).decode("utf-8") + return buffer[offset : offset + length].tobytes().decode("utf-8") - def encode_value_into(self, value, buffer, offset) -> int: + def encode_value_into(self, value: str, buffer, offset) -> int: return super().encode_value_into(value.encode("utf-8"), buffer, offset) def _print(self, value): return f'"{value}"' -class StructMember(Member): - def __init__(self, tag, substruct_class, optional=False): +_TLVStruct = TypeVar("_TLVStruct", bound=TLVStructure) + + +class StructMember(Member[_TLVStruct, _OPT, _NULLABLE]): + def __init__( + self, + tag, + substruct_class: Type[_TLVStruct], + *, + optional: _OPT = False, + nullable: _NULLABLE = False, + **kwargs, + ): self.substruct_class = substruct_class self.max_value_length = substruct_class.max_length() + 1 - super().__init__(tag, optional) + super().__init__(tag, optional=optional, nullable=nullable, **kwargs) - def decode(self, buffer, length, offset=0) -> TLVStructure: + def decode(self, buffer, length, offset=0): return self.substruct_class(buffer[offset : offset + length]) def _print(self, value): @@ -383,7 +512,7 @@ def _print(self, value): def encode_element_type(self, value): return ElementType.STRUCTURE - def encode_value_into(self, value, buffer, offset) -> int: + def encode_value_into(self, value, buffer: bytearray, offset: int) -> int: offset = value.encode_into(buffer, offset) buffer[offset] = ElementType.END_OF_CONTAINER return offset + 1 diff --git a/pyproject.toml b/pyproject.toml index 8b913a8..3989f20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ test = [ "hypothesis", "pytest", "pytest-cov", - # "typing_extensions", + "typing_extensions", ] [tool.coverage.run] diff --git a/tests/test_tlv.py b/tests/test_tlv.py index c17b379..7c6f388 100644 --- a/tests/test_tlv.py +++ b/tests/test_tlv.py @@ -1,8 +1,12 @@ -from circuitmatter import tlv -from hypothesis import given, strategies as st +import math +from typing import Optional + import pytest +from hypothesis import given +from hypothesis import strategies as st +from typing_extensions import assert_type -import math +from circuitmatter import tlv # Test TLV encoding using examples from spec @@ -194,6 +198,21 @@ def test_roundtrip(self, v: int): assert s2.i == s.i assert str(s2) == str(s) + def test_nullability(self): + class Struct(tlv.TLVStructure): + i = tlv.IntMember(None) + ni = tlv.IntMember(None, nullable=True) + + s = Struct() + assert_type(s.i, int) + assert_type(s.ni, Optional[int]) + + s.ni = None + assert s.ni is None + + with pytest.raises(ValueError): + s.i = None + # UTF-8 String, 1-octet length, "Hello!" # 0c 06 48 65 6c 6c 6f 21 @@ -273,6 +292,11 @@ class Null(tlv.TLVStructure): n = tlv.BoolMember(None, nullable=True) +class NotNull(tlv.TLVStructure): + n = tlv.BoolMember(None, nullable=True) + b = tlv.BoolMember(None) + + class TestNull: def test_null_decode(self): s = Null(b"\x14") @@ -284,6 +308,13 @@ def test_null_encode(self): s.n = None assert s.encode().tobytes() == b"\x14" + def test_nullable(self): + s = NotNull() + + assert_type(s.b, bool) + with pytest.raises(ValueError): + s.b = None # type: ignore # testing runtime behaviour + # Single precision floating point 0.0 # 0a 00 00 00 00 @@ -451,8 +482,8 @@ def test_roundtrip(self, v: float): class InnerStruct(tlv.TLVStructure): - a = tlv.NumberMember(0, "