Skip to content

Commit

Permalink
Merge pull request #25 from tannewt/subscribe
Browse files Browse the repository at this point in the history
Support time-based (re)transmit
  • Loading branch information
tannewt authored Oct 23, 2024
2 parents 1831a90 + 9d56cd0 commit 920742c
Show file tree
Hide file tree
Showing 9 changed files with 321 additions and 53 deletions.
53 changes: 45 additions & 8 deletions circuitmatter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .message import Message
from .protocol import InteractionModelOpcode, ProtocolId, SecureProtocolOpcode
from . import session
from .subscription import Subscription
from .device_types.utility.root_node import RootNode

__version__ = "0.2.3"
Expand Down Expand Up @@ -144,6 +145,7 @@ def add_device(self, device):

for server in device.servers:
device.descriptor.ServerList.append(server.CLUSTER_ID)
server.endpoint = self._next_endpoint
self.add_cluster(self._next_endpoint, server)
self.add_cluster(self._next_endpoint, device.descriptor)

Expand All @@ -166,6 +168,8 @@ def process_packets(self):
break

self.process_packet(addr, self.packet_buffer[:nbytes])
# Do any retransmits or subscriptions
self.manager.send_packets()

def _build_attribute_error(self, path, status_code):
report = interaction_model.AttributeReportIB()
Expand All @@ -178,9 +182,9 @@ def _build_attribute_error(self, path, status_code):
report.AttributeStatus = astatus
return report

def get_report(self, context, cluster, path):
def get_report(self, context, cluster, path, subscription=None):
reports = []
datas = cluster.get_attribute_data(context, path)
datas = cluster.get_attribute_data(context, path, subscription=subscription)
for data in datas:
report = interaction_model.AttributeReportIB()
report.AttributeData = data
Expand Down Expand Up @@ -219,7 +223,7 @@ def invoke(self, session, cluster, path, fields, command_ref):

return response

def read_attribute_path(self, context, path):
def read_attribute_path(self, context, path, subscription=None):
attribute_reports = []
if path.Endpoint is None:
endpoints = self._endpoints
Expand All @@ -241,7 +245,11 @@ def read_attribute_path(self, context, path):
clusters = [self._endpoints[endpoint][path.Cluster]]
for cluster in clusters:
temp_path.Cluster = cluster.CLUSTER_ID
attribute_reports.extend(self.get_report(context, cluster, temp_path))
attribute_reports.extend(
self.get_report(
context, cluster, temp_path, subscription=subscription
)
)
return attribute_reports

def process_packet(self, address, data):
Expand All @@ -250,6 +258,14 @@ def process_packet(self, address, data):
message = Message()
message.decode(data)
message.source_ipaddress = address
session_context = self.manager.get_session(message)
if message.secure_session:
if session_context is None:
print("Failed to find session. Ignoring.")
return
secure_session_context = session_context

session_context.receive(message)
if message.secure_session:
secure_session_context = None
if message.session_id < len(self.manager.secure_session_contexts):
Expand Down Expand Up @@ -429,8 +445,10 @@ def process_packet(self, address, data):
self.read_attribute_path(secure_session_context, path)
)
response = interaction_model.ReportDataMessage()
response.SuppressResponse = True
response.AttributeReports = attribute_reports
exchange.send(response)
exchange.close()
elif protocol_opcode == InteractionModelOpcode.WRITE_REQUEST:
print("Received Write Request")
write_request = interaction_model.WriteRequestMessage.decode(
Expand All @@ -447,6 +465,7 @@ def process_packet(self, address, data):
response = interaction_model.WriteResponseMessage()
response.WriteResponses = write_responses
exchange.send(response)
exchange.close()

elif protocol_opcode == InteractionModelOpcode.INVOKE_REQUEST:
print("Received Invoke Request")
Expand Down Expand Up @@ -490,25 +509,33 @@ def process_packet(self, address, data):
response.SuppressResponse = False
response.InvokeResponses = invoke_responses
exchange.send(response)
exchange.close()
elif protocol_opcode == InteractionModelOpcode.INVOKE_RESPONSE:
print("Received Invoke Response")
elif protocol_opcode == InteractionModelOpcode.SUBSCRIBE_REQUEST:
print("Received Subscribe Request")
subscribe_request = interaction_model.SubscribeRequestMessage.decode(
message.application_payload
)
print(subscribe_request)
subscription = Subscription(
exchange.exchange_id,
secure_session_context,
subscribe_request.MinIntervalFloor,
subscribe_request.MaxIntervalCeiling,
)
attribute_reports = []
for path in subscribe_request.AttributeRequests:
attribute_reports.extend(
self.read_attribute_path(secure_session_context, path)
self.read_attribute_path(
secure_session_context, path, subscription=subscription
)
)
response = interaction_model.ReportDataMessage()
response.SubscriptionId = exchange.exchange_id
response.SubscriptionId = subscription.id
response.AttributeReports = attribute_reports
exchange.send(response)
final_response = interaction_model.SubscribeResponseMessage()
final_response.SubscriptionId = exchange.exchange_id
final_response.SubscriptionId = subscription.id
final_response.MaxInterval = subscribe_request.MaxIntervalCeiling
exchange.queue(final_response)
elif protocol_opcode == InteractionModelOpcode.STATUS_RESPONSE:
Expand All @@ -519,11 +546,21 @@ def process_packet(self, address, data):
f"Received Status Response on {message.session_id}/{message.exchange_id} ack {message.acknowledged_message_counter}: {status_response.Status!r}"
)

# Acknowledge the message because we have no further reply.
if message.exchange_flags & session.ExchangeFlags.R:
exchange.send_standalone()

if exchange.pending_payloads:
if status_response.Status == interaction_model.StatusCode.SUCCESS:
exchange.send(exchange.pending_payloads.pop(0))
else:
exchange.pending_payloads.clear()
# Close after an error.
exchange.close()
else:
# Close if nothing is pending.
exchange.close()

else:
print(message)
print("application payload", message.application_payload.hex(" "))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ class AddTrustedRootCertificate(tlv.Structure):
# This attribute is weird because it is fabric sensitive but not marked as such.
# Cluster sets current_fabric_index for use in fabric sensitive attributes and
# happens to make this work as well.
current_fabric_index = NumberAttribute(5, signed=False, bits=8, default=0)
current_fabric_index = NumberAttribute(
5, signed=False, bits=8, default=0, C_changes_omitted=True
)

attestation_request = Command(0x00, AttestationRequest, 0x01, AttestationResponse)

Expand Down
28 changes: 26 additions & 2 deletions circuitmatter/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(
self.feature = feature
self.nullable = X_nullable
self.nonvolatile = N_nonvolatile
self.omit_changes = C_changes_omitted

def __get__(self, instance, cls):
v = instance._attribute_values.get(self.id, None)
Expand All @@ -122,6 +123,23 @@ def __set__(self, instance, value):
instance._nonvolatile[ATTRIBUTES_KEY][hex(self.id)] = self.to_json(value)
instance.data_version += 1

if self.id in instance._subscriptions and not self.omit_changes:
for subscription in instance._subscriptions[self.id]:
if not subscription.active:
continue

data = interaction_model.AttributeDataIB()
data.DataVersion = instance.data_version
attribute_path = interaction_model.AttributePathIB()
attribute_path.Endpoint = instance.endpoint
attribute_path.Cluster = instance.CLUSTER_ID
attribute_path.Attribute = self.id
data.Path = attribute_path
data.Data = self.encode(value)
report = interaction_model.AttributeReportIB()
report.AttributeData = data
subscription.append_report(report)

def to_json(self, value):
return value

Expand Down Expand Up @@ -323,6 +341,7 @@ class Cluster:

def __init__(self):
self._attribute_values = {}
self._subscriptions = {}
# Use random since this isn't for security or replayability.
self.data_version = random.randint(0, 0xFFFFFFFF)

Expand Down Expand Up @@ -365,7 +384,7 @@ def restore(self, nonvolatile):
nonvolatile[ATTRIBUTES_KEY][hex(descriptor.id)] = descriptor.default

def get_attribute_data(
self, session, path
self, session, path, subscription=None
) -> typing.List[interaction_model.AttributeDataIB]:
replies = []
for field_name, descriptor in self._attributes():
Expand All @@ -384,10 +403,15 @@ def get_attribute_data(
"->",
value,
)
if subscription is not None:
if path.Attribute not in self._subscriptions:
self._subscriptions[descriptor.id] = []
print("new subscription")
self._subscriptions[descriptor.id].append(subscription)
if value is None and descriptor.optional:
continue
data = interaction_model.AttributeDataIB()
data.DataVersion = 0
data.DataVersion = self.data_version
attribute_path = interaction_model.AttributePathIB()
attribute_path.Endpoint = path.Endpoint
attribute_path.Cluster = path.Cluster
Expand Down
7 changes: 4 additions & 3 deletions circuitmatter/device_types/lighting/extended_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def _move_to_hue_and_saturation(self, session, value):
print(f"Error setting color: {e}")
return

self._color_control.ColorMode = color_control.ColorMode.HUE_SATURATION
self._color_control.CurrentHue = value.Hue
self._color_control.CurrentSaturation = value.Saturation
print("update attributes")
self._color_control.color_mode = color_control.ColorMode.HUE_SATURATION
self._color_control.current_hue = value.Hue
self._color_control.current_saturation = value.Saturation
4 changes: 2 additions & 2 deletions circuitmatter/device_types/lighting/on_off.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ def _on(self, session):
except Exception as e:
print(f"Error turning on light: {e}")
return
self._on_off.on_off = True
self._on_off.OnOff = True

def _off(self, session):
try:
self.off()
except Exception as e:
print(f"Error turning off light: {e}")
return
self._on_off.on_off = False
self._on_off.OnOff = False

def on(self):
raise NotImplementedError()
Expand Down
Loading

0 comments on commit 920742c

Please sign in to comment.