Skip to content

Commit

Permalink
chunking broken
Browse files Browse the repository at this point in the history
  • Loading branch information
tannewt committed Oct 11, 2024
1 parent aa514ac commit c12fc55
Show file tree
Hide file tree
Showing 19 changed files with 503 additions and 216 deletions.
54 changes: 32 additions & 22 deletions circuitmatter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import time

from . import case
from . import data_model
from . import interaction_model
from .message import Message
from .protocol import InteractionModelOpcode, ProtocolId, SecureProtocolOpcode
Expand Down Expand Up @@ -115,14 +114,7 @@ def add_device(self, device):
if self._next_endpoint > 0:
self.root_node.descriptor.PartsList.append(self._next_endpoint)

device.descriptor = data_model.DescriptorCluster()
device_type = data_model.DescriptorCluster.DeviceTypeStruct()
device_type.DeviceType = device.DEVICE_TYPE_ID
device_type.Revision = device.REVISION
device.descriptor.DeviceTypeList = [device_type]
device.descriptor.PartsList = [self._next_endpoint]
device.descriptor.ServerList = []
device.descriptor.ClientList = []
device.descriptor.PartsList.append(self._next_endpoint)

for server in device.servers:
device.descriptor.ServerList.append(server.CLUSTER_ID)
Expand Down Expand Up @@ -154,15 +146,16 @@ def _build_attribute_error(self, path, status_code):
report.AttributeStatus = astatus
return report

def get_report(self, cluster, path):
def get_report(self, context, cluster, path):
reports = []
datas = cluster.get_attribute_data(path)
datas = cluster.get_attribute_data(context, path)
for data in datas:
report = interaction_model.AttributeReportIB()
report.AttributeData = data
reports.append(report)
# Only add status if an error occurs
if not datas:
print("Unsupported attribute", cluster, path)
report = self._build_attribute_error(
path, interaction_model.StatusCode.UNSUPPORTED_ATTRIBUTE
)
Expand Down Expand Up @@ -194,7 +187,7 @@ def invoke(self, session, cluster, path, fields, command_ref):

return response

def read_attribute_path(self, path):
def read_attribute_path(self, context, path):
attribute_reports = []
if path.Endpoint is None:
endpoints = self._endpoints
Expand All @@ -203,6 +196,8 @@ def read_attribute_path(self, path):

# Wildcard so we get it from every endpoint.
for endpoint in endpoints:
temp_path = path.copy()
temp_path.Endpoint = endpoint
if path.Cluster is None:
clusters = self._endpoints[endpoint].values()
else:
Expand All @@ -213,11 +208,8 @@ def read_attribute_path(self, path):
continue
clusters = [self._endpoints[endpoint][path.Cluster]]
for cluster in clusters:
# TODO: The path object probably needs to be cloned. Otherwise we'll
# change the endpoint for all uses.
path.Endpoint = endpoint
path.Cluster = cluster.CLUSTER_ID
attribute_reports.extend(self.get_report(cluster, path))
temp_path.Cluster = cluster.CLUSTER_ID
attribute_reports.extend(self.get_report(context, cluster, temp_path))
return attribute_reports

def process_packet(self, address, data):
Expand All @@ -241,6 +233,10 @@ def process_packet(self, address, data):
if exchange is None:
print(f"Dropping message {message.message_counter}")
return
else:
print(
f"Processing message {message.message_counter} for exchange {exchange.exchange_id}"
)

protocol_id = message.protocol_id
protocol_opcode = message.protocol_opcode
Expand Down Expand Up @@ -403,7 +399,9 @@ def process_packet(self, address, data):
attribute_reports = []
for path in read_request.AttributeRequests:
print("read", path)
attribute_reports.extend(self.read_attribute_path(path))
attribute_reports.extend(
self.read_attribute_path(secure_session_context, path)
)
response = interaction_model.ReportDataMessage()
response.AttributeReports = attribute_reports
exchange.send(response)
Expand All @@ -412,14 +410,17 @@ def process_packet(self, address, data):
write_request, _ = interaction_model.WriteRequestMessage.decode(
message.application_payload[0], message.application_payload[1:]
)
print(write_request)
write_responses = []
for request in write_request.WriteRequests:
path = request.Path
if path.Cluster in self._endpoints[path.Endpoint]:
cluster = self._endpoints[path.Endpoint][path.Cluster]
print(cluster)
write_responses.append(cluster.set_attribute(request))
write_responses.append(
cluster.set_attribute(secure_session_context, request)
)
response = interaction_model.WriteResponseMessage()
response.WriteResponses = write_responses
exchange.send(response)

elif protocol_opcode == InteractionModelOpcode.INVOKE_REQUEST:
print("Received Invoke Request")
Expand Down Expand Up @@ -474,8 +475,11 @@ def process_packet(self, address, data):
print(subscribe_request)
attribute_reports = []
for path in subscribe_request.AttributeRequests:
attribute_reports.extend(self.read_attribute_path(path))
attribute_reports.extend(
self.read_attribute_path(secure_session_context, path)
)
response = interaction_model.ReportDataMessage()
response.SubscriptionId = exchange.exchange_id
response.AttributeReports = attribute_reports
exchange.send(response)
final_response = interaction_model.SubscribeResponseMessage()
Expand All @@ -489,6 +493,12 @@ def process_packet(self, address, data):
print(
f"Received Status Response on {message.session_id}/{message.exchange_id} ack {message.acknowledged_message_counter}: {status_response.Status!r}"
)

if exchange.pending_payloads:
if status_response.Status == interaction_model.StatusCode.SUCCESS:
exchange.send(exchange.pending_payloads.pop(0))
else:
exchange.pending_payloads.clear()
else:
print(message)
print("application payload", message.application_payload.hex(" "))
Expand Down
3 changes: 2 additions & 1 deletion circuitmatter/clusters/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ def add_noc(
# Store the NOC.
noc_struct = data_model.NodeOperationalCredentialsCluster.NOCStruct()
noc_struct.NOC = args.NOCValue
noc_struct.ICAC = args.ICACValue
if args.ICACValue:
noc_struct.ICAC = args.ICACValue
self.nocs.append(noc_struct)

# Get the root cert public key so we can create the compressed fabric id.
Expand Down
1 change: 1 addition & 0 deletions circuitmatter/clusters/device_management/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Clusters defined in the Spec (not cluster spec) in Chapter 11
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from circuitmatter import data_model


class GeneralDiagnosticsCluster(data_model.Cluster):
CLUSTER_ID = 0x0033
REVISION = 2
4 changes: 3 additions & 1 deletion circuitmatter/clusters/general/on_off.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ class OnOff(data_model.Cluster):
GlobalSceneControl = data_model.BoolAttribute(0x4000, default=True)
OnTime = data_model.NumberAttribute(0x4001, signed=False, bits=16, default=0)
OffWaitTime = data_model.NumberAttribute(0x4002, signed=False, bits=16, default=0)
StartUpOnOff = data_model.EnumAttribute(0x4003, StartUpOnOffEnum)
StartUpOnOff = data_model.EnumAttribute(
0x4003, StartUpOnOffEnum, N_nonvolatile=True, X_nullable=True
)

off = data_model.Command(0x00, None)
on = data_model.Command(0x01, None)
Expand Down
1 change: 1 addition & 0 deletions circuitmatter/clusters/system_model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Clusters defined in the Spec (not cluster spec) in Chapter 9
17 changes: 17 additions & 0 deletions circuitmatter/clusters/system_model/binding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from circuitmatter import data_model
from circuitmatter import tlv


class BindingCluster(data_model.Cluster):
CLUSTER_ID = 0x001E
REVISION = 1

class TargetStruct(tlv.Structure):
Node = data_model.NodeId(1, optional=True)
Group = data_model.GroupId(2, optional=True)
Endpoint = data_model.EndpointNumber(3, optional=True)
Cluster = data_model.ClusterId(4, optional=True)

Binding = data_model.ListAttribute(
0x0000, TargetStruct, default=[], N_nonvolatile=True
)
15 changes: 15 additions & 0 deletions circuitmatter/clusters/system_model/user_label.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from circuitmatter import data_model
from circuitmatter import tlv


class UserLabelCluster(data_model.Cluster):
CLUSTER_ID = 0x0041
REVISION = 1

class LabelStruct(tlv.Structure):
Label = tlv.UTF8StringMember(0, 16, default="")
Value = tlv.UTF8StringMember(1, 16, default="")

LabelList = data_model.ListAttribute(
0x0000, LabelStruct, default=[], N_nonvolatile=True
)
54 changes: 37 additions & 17 deletions circuitmatter/data_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import enum
import inspect
import random
import struct
import traceback
import typing
from typing import Iterable, Union

Expand All @@ -17,24 +19,29 @@ class Enum16(enum.IntEnum):


class Uint16(tlv.IntMember):
def __init__(self, _id=None, minimum=0):
super().__init__(_id, signed=False, octets=2, minimum=minimum)
def __init__(self, _id=None, minimum=0, **kwargs):
super().__init__(_id, signed=False, octets=2, minimum=minimum, **kwargs)


class Uint32(tlv.IntMember):
def __init__(self, _id=None, minimum=0):
super().__init__(_id, signed=False, octets=4, minimum=minimum)
def __init__(self, _id=None, minimum=0, **kwargs):
super().__init__(_id, signed=False, octets=4, minimum=minimum, **kwargs)


class Uint64(tlv.IntMember):
def __init__(self, _id=None, minimum=0):
super().__init__(_id, signed=False, octets=8, minimum=minimum)
def __init__(self, _id=None, minimum=0, **kwargs):
super().__init__(_id, signed=False, octets=8, minimum=minimum, **kwargs)


class GroupId(Uint16):
class NodeId(Uint64):
pass


class GroupId(Uint16):
def __init__(self, _id=None, **kwargs):
super().__init__(_id, minimum=1, **kwargs)


class ClusterId(Uint16):
pass

Expand All @@ -44,8 +51,8 @@ class DeviceTypeId(Uint32):


class EndpointNumber(Uint16):
def __init__(self, _id=None):
super().__init__(_id, minimum=1)
def __init__(self, _id=None, **kwargs):
super().__init__(_id, minimum=1, **kwargs)


# Data model "lists" are encoded as tlv arrays. 🙄
Expand Down Expand Up @@ -82,12 +89,10 @@ def __get__(self, instance, cls):

def __set__(self, instance, value):
old_value = instance._attribute_values.get(self.id, None)
print("set old_value", old_value)
if old_value == value:
return
instance._attribute_values[self.id] = value
instance.data_version += 1
print("set new version", instance.data_version)

def encode(self, value) -> bytes:
if value is None and self.nullable:
Expand Down Expand Up @@ -142,6 +147,8 @@ def __init__(self, _id, enum_type, **kwargs):

class ListAttribute(Attribute):
def __init__(self, _id, element_type, **kwargs):
if inspect.isclass(element_type) and issubclass(element_type, enum.Enum):
element_type = tlv.EnumMember(None, element_type)
self.tlv_type = tlv.ArrayMember(None, element_type)
self._element_type = element_type
# Copy the default list so we don't accidentally share it with another
Expand All @@ -156,6 +163,8 @@ def _encode(self, value) -> bytes:
def element_from_value(self, value):
if issubclass(self._element_type, tlv.Container):
return self._element_type.from_value(value)
if issubclass(self._element_type, enum.Enum):
return self._element_type(value)
return value


Expand Down Expand Up @@ -235,15 +244,17 @@ def _attributes(cls) -> Iterable[tuple[str, Attribute]]:
yield field_name, descriptor

def get_attribute_data(
self, path
self, session, path
) -> typing.List[interaction_model.AttributeDataIB]:
replies = []
for field_name, descriptor in self._attributes():
if path.Attribute is not None and descriptor.id != path.Attribute:
continue
if descriptor.feature and not (self.feature_map & descriptor.feature):
continue
self.current_fabric_index = session.local_fabric_index
value = getattr(self, field_name)
self.current_fabric_index = None
print(
"reading",
f"EP{path.Endpoint}",
Expand All @@ -262,14 +273,19 @@ def get_attribute_data(
attribute_path.Attribute = descriptor.id
data.Path = attribute_path
data.Data = descriptor.encode(value)
print(
f"{path.Endpoint}/{path.Cluster:x}/{descriptor.id:x} -> {data.Data.hex()}"
)
replies.append(data)
if path.Attribute is not None:
break
if not replies:
print("not found", path.Attribute)
return replies

def set_attribute(self, attribute_data) -> interaction_model.AttributeStatusIB:
def set_attribute(
self, context, attribute_data
) -> interaction_model.AttributeStatusIB:
status_code = interaction_model.StatusCode.SUCCESS
for field_name, descriptor in self._attributes():
path = attribute_data.Path
Expand Down Expand Up @@ -321,7 +337,7 @@ def invoke(
if descriptor.command_id != path.Command:
continue

print("invoke", self, field_name, descriptor)
print("invoke", type(self).__name__, field_name)
command = getattr(self, field_name)
if callable(command):
if descriptor.request_type is not None:
Expand All @@ -330,9 +346,10 @@ def invoke(
except ValueError:
return interaction_model.StatusCode.INVALID_COMMAND
try:
print(arg)
result = command(session, arg)
except Exception as e:
print(e)
traceback.print_exception(e)
return interaction_model.StatusCode.FAILURE
else:
try:
Expand Down Expand Up @@ -415,7 +432,7 @@ class AccessControlEntryStruct(tlv.Structure):
class AccessControlExtensionStruct(tlv.Structure):
Data = tlv.OctetStringMember(1, max_length=128)

ACL = ListAttribute(0x0000, AccessControlEntryStruct)
ACL = ListAttribute(0x0000, AccessControlEntryStruct, default=[])
Extension = ListAttribute(0x0001, AccessControlExtensionStruct, optional=True)
SubjectsPerAccessControlEntry = NumberAttribute(
0x0002, signed=False, bits=16, default=4
Expand Down Expand Up @@ -754,7 +771,7 @@ class NodeOperationalCredentialsCluster(Cluster):

class NOCStruct(tlv.Structure):
NOC = tlv.OctetStringMember(0, 400)
ICAC = tlv.OctetStringMember(1, 400)
ICAC = tlv.OctetStringMember(1, 400, nullable=True)

class FabricDescriptorStruct(tlv.Structure):
RootPublicKey = tlv.OctetStringMember(1, 65)
Expand Down Expand Up @@ -816,6 +833,9 @@ class AddTrustedRootCertificate(tlv.Structure):
trusted_root_certificates = ListAttribute(
4, tlv.OctetStringMember(None, 400), N_nonvolatile=True, C_changes_omitted=True
)
# This attribute is weird because it is fabric sensitive but not marked as such.
# Cluster sets current_fabric_index for use in fabric sensitive attributes and
# happens to make this work as well.
current_fabric_index = NumberAttribute(5, signed=False, bits=8, default=0)

attestation_request = Command(0x00, AttestationRequest, 0x01, AttestationResponse)
Expand Down
Loading

0 comments on commit c12fc55

Please sign in to comment.