Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add basic implementation of the Connector interface and a StreamingAudioInputDevice Connector #350

Merged
merged 17 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .github/workflows/poetry-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ jobs:
- name: Install Poetry
uses: snok/install-poetry@v1

- name: Create virtual audio device
run: |
apt-get update
DEBIAN_FRONTEND=noninteractive sudo apt-get --yes install jackd
jackd -d dummy -r 44100 &

- name: Install python dependencies
run: poetry install --with openset,nomad

Expand All @@ -63,4 +69,4 @@ jobs:
run: |
source /opt/ros/${{ matrix.ros_distro }}/setup.bash
source install/setup.bash
poetry run pytest
poetry run pytest -m "not billable"
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ profile = "black"
[tool.pytest.ini_options]
markers = [
"billable: marks test as billable (deselect with '-m \"not billable\"')",
"ci_only: marks test as cli only (deselect with '-m \"not ci_only\"')",
]
addopts = "-m 'not billable' --ignore=src"
addopts = "-m 'not billable and not ci_only' --ignore=src"
log_cli = true
log_cli_level = "DEBUG"
23 changes: 23 additions & 0 deletions src/rai/rai/communication/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base_connector import BaseConnector, BaseMessage
from .sound_device_connector import SoundDeviceError, StreamingAudioInputDevice

__all__ = [
"BaseMessage",
"BaseConnector",
"StreamingAudioInputDevice",
"SoundDeviceError",
]
49 changes: 49 additions & 0 deletions src/rai/rai/communication/base_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Callable
from uuid import uuid4


class BaseMessage(ABC):
pass


class BaseConnector(ABC):

def _generate_handle(self) -> str:
return str(uuid4())

@abstractmethod
def send_message(self, msg: BaseMessage, target: str) -> None:
pass

@abstractmethod
def receive_message(self, source: str) -> BaseMessage:
pass

@abstractmethod
def send_and_wait(self, target: str) -> BaseMessage:
pass

@abstractmethod
def start_action(
self, target: str, on_feedback: Callable, on_finish: Callable = lambda _: None
) -> str:
pass

@abstractmethod
def terminate_action(self, action_handle: str):
pass
142 changes: 142 additions & 0 deletions src/rai/rai/communication/sound_device_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Optional, TypedDict

import numpy as np
import sounddevice as sd
from scipy.signal import resample
from sounddevice import CallbackFlags

from rai.communication.base_connector import BaseConnector, BaseMessage


class SoundDeviceError(Exception):
def __init__(self, msg: str):
super().__init__(msg)


class DeviceConfig(TypedDict):
kind: str
block_size: int
consumer_sampling_rate: int
target_smpling_rate: int
dtype: str
device_number: Optional[int]

maciejmajek marked this conversation as resolved.
Show resolved Hide resolved

class ConfiguredDevice:
"""
A class to store the configuration of an audio device

Attributes
----------
sample_rate (int): Device sample rate
consumer_sampling_rate (int): The sampling rate of the consumer
window_size_samples (int): The size of the window in samples
target_sampling_rate (int): The target sampling rate
dtype (str): The data type of the audio samples
"""

def __init__(self, config: DeviceConfig):
self.sample_rate = sd.query_devices(
device=config["device_number"], kind=config["kind"]
)[
"default_samplerate"
] # type: ignore
self.consumer_sampling_rate = config["consumer_sampling_rate"]
self.window_size_samples = int(
config["block_size"] * self.sample_rate / config["consumer_sampling_rate"]
)
self.target_sampling_rate = int(config["target_smpling_rate"])
self.dtype = config["dtype"]


class StreamingAudioInputDevice(BaseConnector):
def __init__(self):
self.streams = {}
sd.default.latency = ("low", "low")
self.configred_devices: dict[str, ConfiguredDevice] = {}

def configure_device(self, target: str, config: DeviceConfig):
if target.isdigit():
if config.get("device_number") is None:
config["device_number"] = int(target)
elif config["device_number"] != int(target):
raise SoundDeviceError(
"device_number in config must be the same as target"
)
self.configred_devices[target] = ConfiguredDevice(config)
else:
raise SoundDeviceError("target must be a device number!")

def send_message(self, msg: BaseMessage, target: str) -> None:
raise SoundDeviceError(
"StreamingAudioInputDevice does not suport sending messages"
)

def receive_message(self, source: str) -> BaseMessage:
raise SoundDeviceError(
"StreamingAudioInputDevice does not suport receiving messages messages"
)

def send_and_wait(self, target: str) -> BaseMessage:
raise SoundDeviceError(
"StreamingAudioInputDevice does not suport sending messages"
)

def start_action(
self,
target: str,
on_feedback: Callable[[np.ndarray, dict[str, Any]], None],
on_finish: Callable = lambda _: None,
) -> str:

target_device = self.configred_devices.get(target)
if target_device is None:
raise SoundDeviceError(f"Device {target} has not been configured")

def callback(indata: np.ndarray, frames: int, _, status: CallbackFlags):
indata = indata.flatten()
sample_time_length = len(indata) / target_device.target_sampling_rate
if target_device.sample_rate != target_device.target_sampling_rate:
indata = resample(indata, int(sample_time_length * target_device.target_samping_rate)) # type: ignore
flag_dict = {
"input_overflow": status.input_overflow,
"input_underflow": status.input_underflow,
"output_overflow": status.output_overflow,
"output_underflow": status.output_underflow,
"priming_output": status.priming_output,
}
on_feedback(indata, flag_dict)

handle = self._generate_handle()
try:
stream = sd.InputStream(
samplerate=target_device.sample_rate,
channels=1,
device=int(target),
dtype=target_device.dtype,
blocksize=target_device.window_size_samples,
callback=callback,
finished_callback=on_finish,
)
except AttributeError:
raise SoundDeviceError(f"Device {target} has not been correctly configured")
stream.start()
self.streams[handle] = stream
return handle

def terminate_action(self, action_handle: str):
self.streams[action_handle].stop()
Copy link
Contributor

@coderabbitai coderabbitai bot Jan 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix typos and enhance exception chaining

  1. The resample call references the incorrect key target_samping_rate; it should read target_sampling_rate.
  2. In the except AttributeError block, you can leverage exception chaining by using raise <Exception> from exc or raise SoundDeviceError(...) from None to distinguish the new exception from the original cause.
-indata = resample(indata, int(sample_time_length * target_device.target_samping_rate)) 
+indata = resample(indata, int(sample_time_length * target_device.target_sampling_rate)) 

...
except AttributeError as exc:
-    raise SoundDeviceError(f"Device {target} has not been correctly configured")
+    raise SoundDeviceError(f"Device {target} has not been correctly configured") from exc
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def start_action(
self,
target: str,
on_feedback: Callable[[np.ndarray, dict[str, Any]], None],
on_finish: Callable = lambda _: None,
) -> str:
target_device = self.configred_devices.get(target)
if target_device is None:
raise SoundDeviceError(f"Device {target} has not been configured")
def callback(indata: np.ndarray, frames: int, _, status: CallbackFlags):
indata = indata.flatten()
sample_time_length = len(indata) / target_device.target_sampling_rate
if target_device.sample_rate != target_device.target_sampling_rate:
indata = resample(indata, int(sample_time_length * target_device.target_samping_rate)) # type: ignore
flag_dict = {
"input_overflow": status.input_overflow,
"input_underflow": status.input_underflow,
"output_overflow": status.output_overflow,
"output_underflow": status.output_underflow,
"priming_output": status.priming_output,
}
on_feedback(indata, flag_dict)
handle = self._generate_handle()
try:
stream = sd.InputStream(
samplerate=target_device.sample_rate,
channels=1,
device=int(target),
dtype=target_device.dtype,
blocksize=target_device.window_size_samples,
callback=callback,
finished_callback=on_finish,
)
except AttributeError:
raise SoundDeviceError(f"Device {target} has not been correctly configured")
stream.start()
self.streams[handle] = stream
return handle
def terminate_action(self, action_handle: str):
self.streams[action_handle].stop()
def start_action(
self,
target: str,
on_feedback: Callable[[np.ndarray, dict[str, Any]], None],
on_finish: Callable = lambda _: None,
) -> str:
target_device = self.configred_devices.get(target)
if target_device is None:
raise SoundDeviceError(f"Device {target} has not been configured")
def callback(indata: np.ndarray, frames: int, _, status: CallbackFlags):
indata = indata.flatten()
sample_time_length = len(indata) / target_device.target_sampling_rate
if target_device.sample_rate != target_device.target_sampling_rate:
indata = resample(indata, int(sample_time_length * target_device.target_sampling_rate)) # type: ignore
flag_dict = {
"input_overflow": status.input_overflow,
"input_underflow": status.input_underflow,
"output_overflow": status.output_overflow,
"output_underflow": status.output_underflow,
"priming_output": status.priming_output,
}
on_feedback(indata, flag_dict)
handle = self._generate_handle()
try:
stream = sd.InputStream(
samplerate=target_device.sample_rate,
channels=1,
device=int(target),
dtype=target_device.dtype,
blocksize=target_device.window_size_samples,
callback=callback,
finished_callback=on_finish,
)
except AttributeError as exc:
raise SoundDeviceError(f"Device {target} has not been correctly configured") from exc
stream.start()
self.streams[handle] = stream
return handle
def terminate_action(self, action_handle: str):
self.streams[action_handle].stop()
🧰 Tools
🪛 Ruff (0.8.2)

136-136: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maciejmajek, here’s what we have from the latest script output:

  1. Typo correction:

    • The string "target_samping_rate" no longer appears in the repository, confirming the mismatch is presumably corrected to "target_sampling_rate".
  2. Exception chaining:

    • We still don’t see "from exc" usage or except AttributeError as exc in the file. The script found an except AttributeError: but not with exception chaining (from exc).

Thus, it appears that the typo fix has been integrated, but the recommended exception chaining is not in place yet. If you wish to apply exception chaining, please modify that block to use, for example:

except AttributeError as exc:
    raise SoundDeviceError(f"Device {target} has not been correctly configured") from exc

That should address both points from your earlier feedback. Let me know if you have any other questions!

🔗 Analysis chain
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Let's verify that the typo "target_samping_rate" is removed and replaced by "target_sampling_rate".
echo "Searching for 'target_samping_rate' occurrences..."
rg "target_samping_rate"

# Let's also verify that we're now using exception chaining with 'from exc' in the except block.
echo ""
echo "Searching for 'except AttributeError as exc' lines and the following lines..."
rg -A 3 "except AttributeError as exc" src/rai/rai/communication/sound_device_connector.py

Length of output: 416


Script:

#!/bin/bash
# Let's broaden our search efforts to see if the fix was applied anywhere, or if "from exc" occurs in the codebase.
# We will search for "from exc" usage, references to raising SoundDeviceError, and references to "AttributeError".

echo "Searching for 'from exc' usage..."
rg "from exc"

echo ""
echo "Searching for 'raise SoundDeviceError' lines to see if chaining might have been manually inserted..."
rg -A 3 "raise SoundDeviceError"

echo ""
echo "Searching for 'AttributeError' references in entire codebase..."
rg "AttributeError" -A 5

Length of output: 3554

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rachwalk please apply exception chaining suggestion

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

115 changes: 115 additions & 0 deletions tests/communication/test_sound_device_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

import pytest
import sounddevice as sd

from rai.communication import SoundDeviceError, StreamingAudioInputDevice


@pytest.fixture
def setup_mock_input_stream():
with mock.patch("sounddevice.InputStream") as mock_input_stream:
yield mock_input_stream


@pytest.fixture
def device_config():
return {
"kind": "input",
"block_size": 1024,
"consumer_sampling_rate": 44100,
"target_smpling_rate": 16000,
"dtype": "float32",
}


@pytest.mark.ci_only
def test_configure(
setup_mock_input_stream,
device_config,
):
mock_input_stream = setup_mock_input_stream
mock_instance = mock.MagicMock()
mock_input_stream.return_value = mock_instance
audio_input_device = StreamingAudioInputDevice()
device = sd.query_devices(kind="input")
if type(device) is dict:
device_id = str(device["index"])
elif isinstance(device, list):
device_id = str(device[0]["index"]) # type: ignore
else:
assert False
maciejmajek marked this conversation as resolved.
Show resolved Hide resolved
audio_input_device.configure_device(device_id, device_config)
assert (
audio_input_device.configred_devices[device_id].consumer_sampling_rate == 44100
)
assert audio_input_device.configred_devices[device_id].window_size_samples == 1024
assert audio_input_device.configred_devices[device_id].target_sampling_rate == 16000
assert audio_input_device.configred_devices[device_id].dtype == "float32"


@pytest.mark.ci_only
def test_start_action_failed_init(
setup_mock_input_stream,
):
mock_input_stream = setup_mock_input_stream
mock_instance = mock.MagicMock()
mock_input_stream.return_value = mock_instance
audio_input_device = StreamingAudioInputDevice()

feedback_callback = mock.MagicMock()
finish_callback = mock.MagicMock()

recording_device = 0
with pytest.raises(SoundDeviceError, match="Device 0 has not been configured"):
_ = audio_input_device.start_action(
str(recording_device), feedback_callback, finish_callback
)


@pytest.mark.ci_only
def test_start_action(
setup_mock_input_stream,
device_config,
):
mock_input_stream = setup_mock_input_stream
mock_instance = mock.MagicMock()
mock_input_stream.return_value = mock_instance
audio_input_device = StreamingAudioInputDevice()

feedback_callback = mock.MagicMock()
finish_callback = mock.MagicMock()

device = sd.query_devices(kind="input")
if type(device) is dict:
device_id = str(device["index"])
elif isinstance(device, list):
device_id = str(device[0]["index"]) # type: ignore
else:
assert False
audio_input_device.configure_device(device_id, device_config)

stream_handle = audio_input_device.start_action(
device_id, feedback_callback, finish_callback
)

assert mock_input_stream.call_count == 1
init_args = mock_input_stream.call_args.kwargs
assert init_args["device"] == int(device_id)
assert init_args["finished_callback"] == finish_callback

assert audio_input_device.streams.get(stream_handle) is not None
Loading