Skip to content

Commit

Permalink
feat: state change handler
Browse files Browse the repository at this point in the history
  • Loading branch information
bj00rn committed Dec 23, 2024
1 parent 9a1000a commit 972e56b
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 56 deletions.
69 changes: 50 additions & 19 deletions src/pysaleryd/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import logging
from typing import Callable

from websockets.protocol import State

from .const import DataKey, MessageType
from .data import IncomingMessage, OutgoingMessage, ParseError
from .helpers.error_cache import ErrorCache
from .helpers.websocket import ConnectionState, ReconnectingWebsocketClient
from .helpers.websocket import ReconnectingWebsocketClient

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -35,10 +37,17 @@ def __init__(
self._data: dict[DataKey, str] = {}
self._error_cache = ErrorCache()
self._on_message_handlers: set[Callable[[dict[DataKey, str]]]] = set()
self._on_state_change_handlers: set[Callable[[dict[DataKey, str]]]] = set()
self._connect_timeout = connect_timeout
self._state = ConnectionState.NONE
self._tasks = [asyncio.create_task(self._do_call_message_handlers())]
self._websocket: ReconnectingWebsocketClient = None
self._websocket = ReconnectingWebsocketClient(
host=self._ip,
port=self._port,
connect_timeout=self._connect_timeout,
on_message=self._on_message,
on_connect=self._send_start_message,
on_state_change=self._on_state_change,
)

@property
def state(self):
Expand All @@ -47,25 +56,20 @@ def state(self):

@property
def data(self):
"""Get data from system"""
return self._data
"""Get data from system if connection is alive"""
if self.state == State.OPEN:
return self._data

return dict()

def connect(self):
"""Connect to HRV and begin receiving"""

async def send_start_message():
# server won't begin sending until message is received
await self._websocket.send("#:\r")

self._websocket = ReconnectingWebsocketClient(
host=self._ip,
port=self._port,
connect_timeout=self._connect_timeout,
on_message=self._on_message,
on_connect=send_start_message,
)
self._websocket.connect()

async def _send_start_message(self):
"""Send start message to server to begin receiving data"""
await self._websocket.send("#:\r")

async def _do_call_message_handlers(self):
"""Call message handlers with data at update_interval"""
while True:
Expand All @@ -76,9 +80,17 @@ def _call_message_handlers(self):
"""Call handlers with data"""
for handler in self._on_message_handlers:
try:
handler(self._data)
handler(self.data)
except Exception:
_LOGGER.error("Failed to call handler %s", handler, exc_info=1)

def _call_state_change_handlers(self, state):
"""Call handlers with data"""
for handler in self._on_state_change_handlers:
try:
handler(state)
except Exception:
_LOGGER.error("Failed to call handler", exc_info=1)
_LOGGER.error("Failed to call handler %s", handler, exc_info=1)

def close(self):
"""Disconnect from system"""
Expand All @@ -90,6 +102,9 @@ def close(self):
for task in self._tasks:
task.cancel()

async def _on_state_change(self, state):
self._call_state_change_handlers(state)

async def _on_message(self, msg: str):
"""Update data"""
try:
Expand All @@ -114,6 +129,22 @@ async def _on_message(self, msg: str):
except ParseError as e:
_LOGGER.warning(e, exc_info=1)

def add_state_change_handler(self, handler: Callable[[str], None]):
"""Add state change handler to be called when client state changes
:param handler: handler to be added
:type handler: Callable[[str], None]
"""
self._on_state_change_handlers.add(handler)

def remove_state_change_handler(self, handler: Callable[[str], None]):
"""Remove state change handler
:param handler: handler to be removed
:type handler: Callable[[str], None]
"""
self._on_state_change_handlers.remove(handler)

def add_message_handler(self, handler: Callable[[str], None]):
"""Add message handler
Expand Down
62 changes: 38 additions & 24 deletions src/pysaleryd/helpers/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import logging
from typing import Coroutine

from websockets.asyncio.client import ClientConnection, connect
from websockets.exceptions import ConnectionClosed, InvalidHandshake
from websockets.asyncio.client import ClientConnection, connect, process_exception
from websockets.exceptions import ConnectionClosed
from websockets.protocol import State

from ..const import ConnectionState
from .task import TaskList, task_manager

_LOGGER = logging.getLogger(__name__)
Expand All @@ -20,7 +20,7 @@ def __init__(
host: str,
port: int,
on_message: Coroutine[None, str, None],
on_state_change: Coroutine[None, ConnectionState, None] = None,
on_state_change: Coroutine[None, State, None] = None,
on_connect: Coroutine[None, None, None] = None,
connect_timeout=15,
):
Expand All @@ -32,18 +32,18 @@ def __init__(
self._on_message = on_message
self._on_state_change = on_state_change
self._on_connect = on_connect
self._state = ConnectionState.NONE
self._tasks = TaskList()
self._ws = None

@property
def state(self):
"""State of connection"""
return self._state
if self._ws:
return self._ws.protocol.state

async def _set_state(self, new_state):
self._state = new_state
async def _do_on_state_change(self):
if callable(self._on_state_change):
await self._on_state_change(new_state, self._state)
await self._on_state_change(self.state)

async def _do_on_message(self, message: str):
if callable(self._on_message):
Expand Down Expand Up @@ -110,41 +110,55 @@ async def websocket_handler(websocket: ClientConnection):
_LOGGER.info("Connecting to %s", uri)
async for websocket in connect(uri, open_timeout=self._connect_timeout):
try:
self._ws = websocket
_LOGGER.info("Connection established to %s", uri)
await self._set_state(ConnectionState.RUNNING)
await self._do_on_connect()
await self._do_on_state_change()
await websocket_handler(
websocket
) # this will return if connection is closed OK by remote
_LOGGER.warning("Connection to %s was closed, will retry", uri)
await self._set_state(ConnectionState.RETRYING)
except ConnectionClosed:
_LOGGER.warning(
"Connection to %s was closed unexpectedly, will retry", uri
"Connection to %s was closed, will reconnect", uri
)
# reconnect if connection fails
await self._set_state(ConnectionState.RETRYING)
continue

except (OSError, TimeoutError, InvalidHandshake):
_LOGGER.error("Failed to connect to %s, will retry", uri, exc_info=1)
await self._set_state(ConnectionState.RETRYING)
continue
await self._do_on_state_change()
except Exception as e: # pylint: disable=W0718
try:
await self._do_on_state_change()
processed_exception = process_exception(e)
if not process_exception:
# transient error
_LOGGER.error(
"Failed to connect to %s, will retry",
uri,
exc_info=1,
)
continue
else:
raise processed_exception # pylint: disable=W0707
except ConnectionClosed:
_LOGGER.warning(
"Connection to %s was closed unexpectedly, will retry",
uri,
)
# reconnect if connection closed
continue

except asyncio.CancelledError:
_LOGGER.info("Shutting down connection to %s", uri)
await self._set_state(ConnectionState.STOPPED)
await self._do_on_state_change()
raise

def connect(self):
"""Connect to server"""
if self._state in [ConnectionState.NONE, ConnectionState.STOPPED]:
if self._ws is None:
runner_task = asyncio.create_task(self.runner(), name="runner")
message_runner_task = asyncio.create_task(self._on_message_runner())
self._tasks.append(runner_task, message_runner_task)

def close(self):
"""Close connection and perform clean up"""
self._tasks.cancel()
self._ws = None

def __enter__(self):
"""Initiate connection"""
Expand Down
50 changes: 37 additions & 13 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
import pytest_asyncio
from websockets.protocol import State

from pysaleryd.client import Client
from pysaleryd.const import ConnectionState
Expand All @@ -30,14 +31,14 @@ async def has_state(client: Client, state: ConnectionState):
@pytest_asyncio.fixture(name="hrv_client")
async def _hrv_client(ws_server: "TestServer"):
"""HRV Client"""
with Client("localhost", 3001, 3) as client:
with Client("localhost", 3001, 3, connect_timeout=3) as client:
yield client


@pytest.mark.asyncio
async def test_client_connect(hrv_client: "Client"):
"""test connect"""
await has_state(hrv_client, ConnectionState.RUNNING)
await has_state(hrv_client, State.OPEN)


@pytest.mark.asyncio
Expand All @@ -62,6 +63,30 @@ def broken_handler(data):
assert any(data.keys())


@pytest.mark.asyncio
async def test_state_change_handler(
hrv_client: "Client", ws_server: "TestServer", mocker, caplog
):
caplog.set_level(logging.DEBUG)
"""Test state change handler callback"""

class Foo:
def handler(self, state: State):
pass

foo = Foo()
spy = mocker.spy(foo, "handler")
hrv_client.add_state_change_handler(foo.handler)
await asyncio.sleep(1)
spy.assert_called_once_with(State.OPEN)
await ws_server.close()
await asyncio.sleep(1)
spy.assert_called_with(State.CLOSED)
await ws_server.start()
await asyncio.sleep(5)
spy.assert_called_with(State.OPEN)


@pytest.mark.asyncio
async def test_get_data(hrv_client: "Client", caplog):
"""Test get data"""
Expand All @@ -76,11 +101,11 @@ async def test_reconnect(hrv_client: "Client", ws_server: "TestServer", caplog):
"""Test reconnect"""
caplog.set_level(logging.DEBUG)

await asyncio.wait_for(has_state(hrv_client, ConnectionState.RUNNING), 15)
await asyncio.wait_for(has_state(hrv_client, State.OPEN), 15)
await ws_server.close()
await asyncio.wait_for(has_state(hrv_client, ConnectionState.RETRYING), 15)
await asyncio.wait_for(has_state(hrv_client, State.CLOSED), 15)
await ws_server.start()
await asyncio.wait_for(has_state(hrv_client, ConnectionState.RUNNING), 15)
await asyncio.wait_for(has_state(hrv_client, State.OPEN), 15)


@pytest.mark.asyncio
Expand All @@ -89,13 +114,12 @@ async def test_connect_unresponsive(ws_server: "TestServer", caplog):
caplog.set_level(logging.INFO)

await ws_server.close()
await asyncio.sleep(5)
await asyncio.sleep(1)
client = Client("localhost", 3001, 3, 1)
client.connect()
await asyncio.sleep(5)
await asyncio.wait_for(has_state(client, ConnectionState.NONE), 15)
await asyncio.sleep(1)
await ws_server.start()
await asyncio.wait_for(has_state(client, ConnectionState.RUNNING), 15)
await asyncio.wait_for(has_state(client, State.OPEN), 15)


@pytest.mark.asyncio
Expand All @@ -108,8 +132,8 @@ async def test_send_command(hrv_client: "Client"):
async def test_disconnect(hrv_client: "Client", caplog):
caplog.set_level(logging.DEBUG)
"""Test disconnected client remains disconnected"""
await has_state(hrv_client, ConnectionState.RUNNING)
await asyncio.sleep(3)
await has_state(hrv_client, State.OPEN)
await asyncio.sleep(1)
hrv_client.close()
await asyncio.sleep(3)
await has_state(hrv_client, ConnectionState.STOPPED)
await asyncio.sleep(5)
await has_state(hrv_client, None)

0 comments on commit 972e56b

Please sign in to comment.