diff --git a/circuitmatter/__init__.py b/circuitmatter/__init__.py index 3c8f657..be71108 100644 --- a/circuitmatter/__init__.py +++ b/circuitmatter/__init__.py @@ -12,6 +12,7 @@ from .message import Message from .protocol import InteractionModelOpcode, ProtocolId, SecureProtocolOpcode from . import session +from .subscription import Subscription from .device_types.utility.root_node import RootNode __version__ = "0.2.3" @@ -144,6 +145,7 @@ def add_device(self, device): for server in device.servers: device.descriptor.ServerList.append(server.CLUSTER_ID) + server.endpoint = self._next_endpoint self.add_cluster(self._next_endpoint, server) self.add_cluster(self._next_endpoint, device.descriptor) @@ -166,6 +168,8 @@ def process_packets(self): break self.process_packet(addr, self.packet_buffer[:nbytes]) + # Do any retransmits or subscriptions + self.manager.send_packets() def _build_attribute_error(self, path, status_code): report = interaction_model.AttributeReportIB() @@ -178,9 +182,9 @@ def _build_attribute_error(self, path, status_code): report.AttributeStatus = astatus return report - def get_report(self, context, cluster, path): + def get_report(self, context, cluster, path, subscription=None): reports = [] - datas = cluster.get_attribute_data(context, path) + datas = cluster.get_attribute_data(context, path, subscription=subscription) for data in datas: report = interaction_model.AttributeReportIB() report.AttributeData = data @@ -219,7 +223,7 @@ def invoke(self, session, cluster, path, fields, command_ref): return response - def read_attribute_path(self, context, path): + def read_attribute_path(self, context, path, subscription=None): attribute_reports = [] if path.Endpoint is None: endpoints = self._endpoints @@ -241,7 +245,11 @@ def read_attribute_path(self, context, path): clusters = [self._endpoints[endpoint][path.Cluster]] for cluster in clusters: temp_path.Cluster = cluster.CLUSTER_ID - attribute_reports.extend(self.get_report(context, cluster, temp_path)) + attribute_reports.extend( + self.get_report( + context, cluster, temp_path, subscription=subscription + ) + ) return attribute_reports def process_packet(self, address, data): @@ -250,6 +258,14 @@ def process_packet(self, address, data): message = Message() message.decode(data) message.source_ipaddress = address + session_context = self.manager.get_session(message) + if message.secure_session: + if session_context is None: + print("Failed to find session. Ignoring.") + return + secure_session_context = session_context + + session_context.receive(message) if message.secure_session: secure_session_context = None if message.session_id < len(self.manager.secure_session_contexts): @@ -429,8 +445,10 @@ def process_packet(self, address, data): self.read_attribute_path(secure_session_context, path) ) response = interaction_model.ReportDataMessage() + response.SuppressResponse = True response.AttributeReports = attribute_reports exchange.send(response) + exchange.close() elif protocol_opcode == InteractionModelOpcode.WRITE_REQUEST: print("Received Write Request") write_request = interaction_model.WriteRequestMessage.decode( @@ -447,6 +465,7 @@ def process_packet(self, address, data): response = interaction_model.WriteResponseMessage() response.WriteResponses = write_responses exchange.send(response) + exchange.close() elif protocol_opcode == InteractionModelOpcode.INVOKE_REQUEST: print("Received Invoke Request") @@ -490,6 +509,7 @@ def process_packet(self, address, data): response.SuppressResponse = False response.InvokeResponses = invoke_responses exchange.send(response) + exchange.close() elif protocol_opcode == InteractionModelOpcode.INVOKE_RESPONSE: print("Received Invoke Response") elif protocol_opcode == InteractionModelOpcode.SUBSCRIBE_REQUEST: @@ -497,18 +517,25 @@ def process_packet(self, address, data): subscribe_request = interaction_model.SubscribeRequestMessage.decode( message.application_payload ) - print(subscribe_request) + subscription = Subscription( + exchange.exchange_id, + secure_session_context, + subscribe_request.MinIntervalFloor, + subscribe_request.MaxIntervalCeiling, + ) attribute_reports = [] for path in subscribe_request.AttributeRequests: attribute_reports.extend( - self.read_attribute_path(secure_session_context, path) + self.read_attribute_path( + secure_session_context, path, subscription=subscription + ) ) response = interaction_model.ReportDataMessage() - response.SubscriptionId = exchange.exchange_id + response.SubscriptionId = subscription.id response.AttributeReports = attribute_reports exchange.send(response) final_response = interaction_model.SubscribeResponseMessage() - final_response.SubscriptionId = exchange.exchange_id + final_response.SubscriptionId = subscription.id final_response.MaxInterval = subscribe_request.MaxIntervalCeiling exchange.queue(final_response) elif protocol_opcode == InteractionModelOpcode.STATUS_RESPONSE: @@ -519,11 +546,21 @@ def process_packet(self, address, data): f"Received Status Response on {message.session_id}/{message.exchange_id} ack {message.acknowledged_message_counter}: {status_response.Status!r}" ) + # Acknowledge the message because we have no further reply. + if message.exchange_flags & session.ExchangeFlags.R: + exchange.send_standalone() + if exchange.pending_payloads: if status_response.Status == interaction_model.StatusCode.SUCCESS: exchange.send(exchange.pending_payloads.pop(0)) else: exchange.pending_payloads.clear() + # Close after an error. + exchange.close() + else: + # Close if nothing is pending. + exchange.close() + else: print(message) print("application payload", message.application_payload.hex(" ")) diff --git a/circuitmatter/clusters/device_management/node_operational_credentials.py b/circuitmatter/clusters/device_management/node_operational_credentials.py index f32816d..26f0c88 100644 --- a/circuitmatter/clusters/device_management/node_operational_credentials.py +++ b/circuitmatter/clusters/device_management/node_operational_credentials.py @@ -117,7 +117,9 @@ class AddTrustedRootCertificate(tlv.Structure): # This attribute is weird because it is fabric sensitive but not marked as such. # Cluster sets current_fabric_index for use in fabric sensitive attributes and # happens to make this work as well. - current_fabric_index = NumberAttribute(5, signed=False, bits=8, default=0) + current_fabric_index = NumberAttribute( + 5, signed=False, bits=8, default=0, C_changes_omitted=True + ) attestation_request = Command(0x00, AttestationRequest, 0x01, AttestationResponse) diff --git a/circuitmatter/data_model.py b/circuitmatter/data_model.py index 5115e2e..a32aaae 100644 --- a/circuitmatter/data_model.py +++ b/circuitmatter/data_model.py @@ -104,6 +104,7 @@ def __init__( self.feature = feature self.nullable = X_nullable self.nonvolatile = N_nonvolatile + self.omit_changes = C_changes_omitted def __get__(self, instance, cls): v = instance._attribute_values.get(self.id, None) @@ -122,6 +123,23 @@ def __set__(self, instance, value): instance._nonvolatile[ATTRIBUTES_KEY][hex(self.id)] = self.to_json(value) instance.data_version += 1 + if self.id in instance._subscriptions and not self.omit_changes: + for subscription in instance._subscriptions[self.id]: + if not subscription.active: + continue + + data = interaction_model.AttributeDataIB() + data.DataVersion = instance.data_version + attribute_path = interaction_model.AttributePathIB() + attribute_path.Endpoint = instance.endpoint + attribute_path.Cluster = instance.CLUSTER_ID + attribute_path.Attribute = self.id + data.Path = attribute_path + data.Data = self.encode(value) + report = interaction_model.AttributeReportIB() + report.AttributeData = data + subscription.append_report(report) + def to_json(self, value): return value @@ -323,6 +341,7 @@ class Cluster: def __init__(self): self._attribute_values = {} + self._subscriptions = {} # Use random since this isn't for security or replayability. self.data_version = random.randint(0, 0xFFFFFFFF) @@ -365,7 +384,7 @@ def restore(self, nonvolatile): nonvolatile[ATTRIBUTES_KEY][hex(descriptor.id)] = descriptor.default def get_attribute_data( - self, session, path + self, session, path, subscription=None ) -> typing.List[interaction_model.AttributeDataIB]: replies = [] for field_name, descriptor in self._attributes(): @@ -384,10 +403,15 @@ def get_attribute_data( "->", value, ) + if subscription is not None: + if path.Attribute not in self._subscriptions: + self._subscriptions[descriptor.id] = [] + print("new subscription") + self._subscriptions[descriptor.id].append(subscription) if value is None and descriptor.optional: continue data = interaction_model.AttributeDataIB() - data.DataVersion = 0 + data.DataVersion = self.data_version attribute_path = interaction_model.AttributePathIB() attribute_path.Endpoint = path.Endpoint attribute_path.Cluster = path.Cluster diff --git a/circuitmatter/device_types/lighting/extended_color.py b/circuitmatter/device_types/lighting/extended_color.py index 1eba034..6c878a9 100644 --- a/circuitmatter/device_types/lighting/extended_color.py +++ b/circuitmatter/device_types/lighting/extended_color.py @@ -32,6 +32,7 @@ def _move_to_hue_and_saturation(self, session, value): print(f"Error setting color: {e}") return - self._color_control.ColorMode = color_control.ColorMode.HUE_SATURATION - self._color_control.CurrentHue = value.Hue - self._color_control.CurrentSaturation = value.Saturation + print("update attributes") + self._color_control.color_mode = color_control.ColorMode.HUE_SATURATION + self._color_control.current_hue = value.Hue + self._color_control.current_saturation = value.Saturation diff --git a/circuitmatter/device_types/lighting/on_off.py b/circuitmatter/device_types/lighting/on_off.py index fc7815d..d425b9a 100644 --- a/circuitmatter/device_types/lighting/on_off.py +++ b/circuitmatter/device_types/lighting/on_off.py @@ -26,7 +26,7 @@ def _on(self, session): except Exception as e: print(f"Error turning on light: {e}") return - self._on_off.on_off = True + self._on_off.OnOff = True def _off(self, session): try: @@ -34,7 +34,7 @@ def _off(self, session): except Exception as e: print(f"Error turning off light: {e}") return - self._on_off.on_off = False + self._on_off.OnOff = False def on(self): raise NotImplementedError() diff --git a/circuitmatter/exchange.py b/circuitmatter/exchange.py index 136016a..6177310 100644 --- a/circuitmatter/exchange.py +++ b/circuitmatter/exchange.py @@ -1,3 +1,4 @@ +import random import time from .message import Message, ExchangeFlags, ProtocolId @@ -25,22 +26,33 @@ class Exchange: - def __init__(self, session, initiator: bool, exchange_id: int, protocols): + def __init__( + self, session, protocols, initiator: bool = True, exchange_id: int = -1 + ): self.initiator = initiator - self.exchange_id = exchange_id + self.exchange_id = session.next_exchange_id if exchange_id < 0 else exchange_id + print(f"\033[93mnew exchange {self.exchange_id}\033[0m") self.protocols = protocols self.session = session + if self.initiator: + self.session.initiator_exchanges[self.exchange_id] = self + else: + self.session.responder_exchanges[self.exchange_id] = self + self.pending_acknowledgement = None """Message number that is waiting for an ack from us""" self.send_standalone_time = None + self.retry_count = 0 self.next_retransmission_time = None """When to next resend the message that hasn't been acked""" self.pending_retransmission = None """Message that we've attempted to send but hasn't been acked""" self.pending_payloads = [] + self._closing = False + def send( self, application_payload=None, @@ -62,6 +74,8 @@ def send( if reliable: message.exchange_flags |= ExchangeFlags.R self.pending_retransmission = message + self.next_retransmission_time = None + self.retry_count = 0 message.source_node_id = self.session.local_node_id if protocol_id is None: protocol_id = application_payload.PROTOCOL_ID @@ -78,11 +92,35 @@ def send( message.application_payload = chunk[:offset] else: message.application_payload = application_payload - self.session.send(message) + if reliable: + self.send_pending() + else: + self.session.send(message) + + def send_pending(self, ignore_time=False) -> bool: + if self.pending_retransmission is None: + return False + if not ignore_time and self.next_retransmission_time is not None: + if time.monotonic() < self.next_retransmission_time: + return False + self.session.send(self.pending_retransmission) + self.retry_count += 1 + session_interval = ( + self.session.session_active_interval + if self.session.peer_active + else self.session.session_idle_interval + ) + difference = ( + session_interval + * (MRP_BACKOFF_BASE ** (max(0, self.retry_count - MRP_BACKOFF_THRESHOLD))) + * (1 + random.random() * MRP_BACKOFF_JITTER) + ) + self.next_retransmission_time = time.monotonic() + difference + return True def send_standalone(self): - if self.pending_retransmission is not None: - self.session.send(self.pending_retransmission) + # Resend the pending message when set. + if self.send_pending(ignore_time=True): return self.send( protocol_id=ProtocolId.SECURE_CHANNEL, @@ -109,19 +147,37 @@ def receive(self, message) -> bool: return True self.pending_retransmission = None self.next_retransmission_time = None + # Close if we're acked by a standalone packet that won't be handled higher up. + if ( + self._closing + and not self.pending_payloads + and message.protocol_id == ProtocolId.SECURE_CHANNEL + and message.protocol_opcode == SecureProtocolOpcode.MRP_STANDALONE_ACK + ): + print(f"\033[93mexchange closed after ack {self.exchange_id}\033[0m") + if self.initiator: + self.session.initiator_exchanges.pop(self.exchange_id) + else: + self.session.responder_exchanges.pop(self.exchange_id) if message.protocol_id not in self.protocols: # Drop messages that don't match the protocols we're waiting for. + # This is likely a standalone ACK to an interaction model response. return True # Section 4.12.5.2.2 # Incoming packets that are marked Reliable. if message.exchange_flags & ExchangeFlags.R: if message.duplicate: + if self.pending_acknowledgement is None: + self.pending_acknowledgement = message.message_counter # Send a standalone acknowledgement. self.send_standalone() return True - if self.pending_acknowledgement is not None: + if ( + self.pending_acknowledgement is not None + and self.pending_acknowledgement != message.message_counter + ): # Send a standalone acknowledgement with the message counter we're about to overwrite. self.send_standalone() self.pending_acknowledgement = message.message_counter @@ -132,3 +188,29 @@ def receive(self, message) -> bool: if message.duplicate: return True return False + + def close(self): + if self._closing: + print("Double+ close!") + return + self._closing = True + print(f"\033[93mclosing {self.exchange_id}\033[0m") + + if self.pending_retransmission is not None: + print(f"\033[93mpending retransmissions {self.exchange_id}\033[0m") + self.resend_pending() + return + + if self.pending_acknowledgement is not None: + print(f"\033[93mpending ack {self.exchange_id}\033[0m") + self.send_standalone() + return + + if self.initiator: + self.session.initiator_exchanges.pop(self.exchange_id) + else: + self.session.responder_exchanges.pop(self.exchange_id) + print(f"\033[93mexchange closed {self.exchange_id}\033[0m") + + def resend_pending(self): + self.send_pending() diff --git a/circuitmatter/session.py b/circuitmatter/session.py index cb5cad4..31bcbea 100644 --- a/circuitmatter/session.py +++ b/circuitmatter/session.py @@ -1,5 +1,4 @@ import enum -import json import time from . import case @@ -13,7 +12,6 @@ from cryptography.hazmat.primitives.ciphers.aead import AESCCM import ecdsa import hashlib -import pathlib import struct @@ -153,7 +151,31 @@ def __str__(self): return f"StatusReport: General Code: {self.general_code!r}, Protocol ID: {self.protocol_id!r}, Protocol Code: {self.protocol_code!r}, Protocol Data: {self.protocol_data.hex() if self.protocol_data else None}" -class UnsecuredSessionContext: +class SessionContext: + def __init__(self, socket): + self.socket = socket + self.responder_exchanges = {} + self.initiator_exchanges = {} + + self.active_timestamp = None + """A timestamp indicating the time at which the last message was received. This timestamp SHALL be initialized with the time the session was created.""" + + # In seconds + self.session_idle_interval = 0.5 + self.session_active_interval = 0.3 + self.session_active_threshold = 4 + + @property + def peer_active(self): + return ( + time.monotonic() - self.active_timestamp + ) < self.session_active_threshold + + def receive(self, message): + self.active_timestamp = time.monotonic() + + +class UnsecuredSessionContext(SessionContext): def __init__( self, socket, @@ -162,15 +184,15 @@ def __init__( ephemeral_initiator_node_id, node_ipaddress, ): - self.socket = socket + super().__init__(socket) + self.initiator = initiator self.ephemeral_initiator_node_id = ephemeral_initiator_node_id self.message_reception_state = None self.message_counter = message_counter - self.node_ipaddress = node_ipaddress - self.exchanges = {} self.local_node_id = 0 + self.node_ipaddress = node_ipaddress def send(self, message): message.flags |= 1 # DSIZ = 1 for destination node @@ -182,8 +204,9 @@ def send(self, message): self.socket.sendto(buf[:nbytes], self.node_ipaddress) -class SecureSessionContext: +class SecureSessionContext(SessionContext): def __init__(self, random_source, socket, local_session_id): + super().__init__(socket) self.session_type = None """Records whether the session was established using CASE or PASE.""" self.session_role_initiator = False @@ -210,12 +233,7 @@ def __init__(self, random_source, socket, local_session_id): """The ID used when resuming a session between the local and remote peer.""" self.session_timestamp = None """A timestamp indicating the time at which the last message was sent or received. This timestamp SHALL be initialized with the time the session was created.""" - self.active_timestamp = None - """A timestamp indicating the time at which the last message was received. This timestamp SHALL be initialized with the time the session was created.""" - self.session_idle_interval = None - self.session_active_interval = None - self.session_active_threshold = None - self.exchanges = {} + self.subscriptions = {} self.local_node_id = 0 @@ -223,12 +241,15 @@ def __init__(self, random_source, socket, local_session_id): self.socket = socket self.node_ipaddress = None + self._next_exchange_id = random_source.randbelow(0x10000) + def __str__(self): return f"Secure Session #{self.local_session_id} with {self.peer_node_id:x}" @property - def peer_active(self): - return (time.monotonic() - self.active_timestamp) < self.session_active_interval + def next_exchange_id(self): + self._next_exchange_id = (self._next_exchange_id + 1) & 0xFFFF + return self._next_exchange_id def decrypt_and_verify(self, message): cipher = self.i2r @@ -346,14 +367,11 @@ def __next__(self): class SessionManager: def __init__(self, random_source, socket, node_credentials): - persist_path = pathlib.Path("counters.json") - if persist_path.exists(): - self.nonvolatile = json.loads(persist_path.read_text()) - else: - self.nonvolatile = {} - self.nonvolatile["check_in_counter"] = None - self.nonvolatile["group_encrypted_data_message_counter"] = None - self.nonvolatile["group_encrypted_control_message_counter"] = None + # TODO: Save and restore counters + self.nonvolatile = {} + self.nonvolatile["check_in_counter"] = None + self.nonvolatile["group_encrypted_data_message_counter"] = None + self.nonvolatile["group_encrypted_control_message_counter"] = None self.unencrypted_message_counter = MessageCounter(random_source=random_source) self.group_encrypted_data_message_counter = MessageCounter( self.nonvolatile["group_encrypted_data_message_counter"], @@ -382,6 +400,8 @@ def get_session(self, message): return None # TODO: Get MRS for source node id and message type else: + if message.session_id >= len(self.secure_session_contexts): + return None session_context = self.secure_session_contexts[message.session_id] session_context.node_ipaddress = message.source_ipaddress else: @@ -398,6 +418,18 @@ def get_session(self, message): session_context = self.unsecured_session_context[message.source_node_id] return session_context + def send_packets(self): + for session in self.secure_session_contexts: + if session == "reserved": + continue + for exchange in session.responder_exchanges.values(): + exchange.resend_pending() + for exchange in session.initiator_exchanges.values(): + exchange.resend_pending() + + for subscription in session.subscriptions.values(): + subscription.send_reports() + def mark_duplicate(self, message): """Implements 4.6.7""" session_context = self.get_session(message) @@ -467,23 +499,43 @@ def process_exchange(self, message): ): # Drop illegal combination of flags. return None - if message.exchange_id not in session.exchanges: + initiator = message.exchange_flags & ExchangeFlags.I + + if initiator: + exchanges = session.responder_exchanges + else: + exchanges = session.initiator_exchanges + + if message.exchange_id not in exchanges: # Section 4.10.5.2 - initiator = message.exchange_flags & ExchangeFlags.I if initiator and not message.duplicate: - session.exchanges[message.exchange_id] = Exchange( - session, not initiator, message.exchange_id, [message.protocol_id] + # Create a new exchange if the other side is initiating one. + exchange = Exchange( + session, + [message.protocol_id], + not initiator, + message.exchange_id, ) + session.responder_exchanges[message.exchange_id] = exchange # Drop because the message isn't from an initiator. elif message.exchange_flags & ExchangeFlags.R: + ephemeral = Exchange( + session, + [message.protocol_id], + not initiator, + message.exchange_id, + ) + ephemeral.receive(message) # Send a bare acknowledgement back. - raise NotImplementedError("Send a bare acknowledgement back") + ephemeral.send_standalone() + ephemeral.close() return None else: # Just drop it. return None - exchange = session.exchanges[message.exchange_id] + exchange = exchanges[message.exchange_id] + if exchange.receive(message): # If we want to drop the message, then return None. return None @@ -544,7 +596,16 @@ def reply_to_sigma1(self, exchange, sigma1): session_context.local_fabric_index = matching_noc + 1 session_context.resumption_id = self.random.urandom(16) session_context.local_node_id = fabric.NodeID - + if sigma1.initiatorSessionParams: + session_context.session_idle_interval = ( + sigma1.initiatorSessionParams.session_idle_interval / 1000 + ) + session_context.session_active_interval = ( + sigma1.initiatorSessionParams.session_active_interval / 1000 + ) + session_context.session_active_threshold = ( + sigma1.initiatorSessionParams.session_active_threshold / 1000 + ) ephemeral_key_pair = ecdsa.keys.SigningKey.generate( curve=ecdsa.NIST256p, hashfunc=hashlib.sha256, entropy=self.random.urandom ) diff --git a/circuitmatter/subscription.py b/circuitmatter/subscription.py new file mode 100644 index 0000000..05eaf56 --- /dev/null +++ b/circuitmatter/subscription.py @@ -0,0 +1,54 @@ +import time + +from . import interaction_model +from .exchange import Exchange +from .protocol import ProtocolId + + +class Subscription: + def __init__(self, _id, session, min_interval, max_interval): + self.id = _id + session.subscriptions[self.id] = self + self.active = True + self._reports = [] + self._session = session + self._min_interval = min_interval + self._max_interval = max_interval + # Initial transmit is handled during the subscription call. + self._last_transmit = time.monotonic() + + def send_reports(self, exchange=None): + time_since = time.monotonic() - self._last_transmit + if time_since < self._min_interval: + return + if not self._reports and time_since < (self._max_interval - 1): + return + + # create a new exchange and send reports. + exchange = Exchange(self._session, [ProtocolId.INTERACTION_MODEL]) + + response = interaction_model.ReportDataMessage() + response.SubscriptionId = self.id + response.AttributeReports = self._reports + if not self._reports: + # No response on empty reports + response.SuppressResponse = True + exchange.send(response) + if not self._reports: + exchange.close() + print( + "reporting", + self._reports, + self._min_interval, + time_since, + self._max_interval, + ) + # Use a new list so we don't clear the one we're sending. + self._reports = [] + self._last_transmit = time.monotonic() + + def append_report(self, report): + self._reports.append(report) + + def ack_report(self): + pass diff --git a/circuitmatter/utility/replay.py b/circuitmatter/utility/replay.py index e3c35f7..b38a6e6 100644 --- a/circuitmatter/utility/replay.py +++ b/circuitmatter/utility/replay.py @@ -4,6 +4,7 @@ class ReplaySocket: def __init__(self, replay_data): self.replay_data = replay_data + self._last_timestamp = 0 def bind(self, address): print("bind to", address) @@ -14,13 +15,19 @@ def setblocking(self, value): def recvfrom_into(self, buffer, nbytes=None): if nbytes is None: nbytes = len(buffer) + next_timestamp = self.replay_data[0][1] + if next_timestamp - self._last_timestamp > 1000000: + self._last_timestamp = next_timestamp + raise BlockingIOError() direction = "send" while direction == "send": - direction, _, address, data_b64 = self.replay_data.pop(0) + direction, timestamp, address, data_b64 = self.replay_data.pop(0) + decoded = binascii.a2b_base64(data_b64) if len(decoded) > nbytes: raise RuntimeError("Next replay packet is larger than buffer to read into") buffer[: len(decoded)] = decoded + self._last_timestamp = timestamp return len(decoded), address def sendto(self, data, address):