Skip to content

Commit

Permalink
feat: add basic implementation of the Connector interface and a `St…
Browse files Browse the repository at this point in the history
…reamingAudioInputDevice` `Connector` (#350)
  • Loading branch information
rachwalk authored Jan 3, 2025
1 parent 7f441e0 commit 2cc2176
Show file tree
Hide file tree
Showing 7 changed files with 346 additions and 2 deletions.
8 changes: 7 additions & 1 deletion .github/workflows/poetry-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ jobs:
- name: Install Poetry
uses: snok/install-poetry@v1

- name: Create virtual audio device
run: |
apt-get update
DEBIAN_FRONTEND=noninteractive apt-get --yes install jackd
jackd -d dummy -r 44100 &
- name: Install python dependencies
run: poetry install --with openset,nomad

Expand All @@ -63,4 +69,4 @@ jobs:
run: |
source /opt/ros/${{ matrix.ros_distro }}/setup.bash
source install/setup.bash
poetry run pytest
poetry run pytest -m "not billable"
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ profile = "black"
[tool.pytest.ini_options]
markers = [
"billable: marks test as billable (deselect with '-m \"not billable\"')",
"ci_only: marks test as cli only (deselect with '-m \"not ci_only\"')",
]
addopts = "-m 'not billable' --ignore=src"
addopts = "-m 'not billable and not ci_only' --ignore=src"
log_cli = true
log_cli_level = "DEBUG"
10 changes: 10 additions & 0 deletions src/rai/rai/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from rai.agents.conversational_agent import create_conversational_agent
from rai.agents.state_based import create_state_based_agent
from rai.agents.tool_runner import ToolRunner

__all__ = [
"ToolRunner",
"create_conversational_agent",
"create_state_based_agent",
]
23 changes: 23 additions & 0 deletions src/rai/rai/communication/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base_connector import BaseConnector, BaseMessage
from .sound_device_connector import SoundDeviceError, StreamingAudioInputDevice

__all__ = [
"BaseMessage",
"BaseConnector",
"StreamingAudioInputDevice",
"SoundDeviceError",
]
49 changes: 49 additions & 0 deletions src/rai/rai/communication/base_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Callable
from uuid import uuid4


class BaseMessage(ABC):
pass


class BaseConnector(ABC):

def _generate_handle(self) -> str:
return str(uuid4())

@abstractmethod
def send_message(self, msg: BaseMessage, target: str) -> None:
pass

@abstractmethod
def receive_message(self, source: str) -> BaseMessage:
pass

@abstractmethod
def send_and_wait(self, target: str) -> BaseMessage:
pass

@abstractmethod
def start_action(
self, target: str, on_feedback: Callable, on_finish: Callable = lambda _: None
) -> str:
pass

@abstractmethod
def terminate_action(self, action_handle: str):
pass
141 changes: 141 additions & 0 deletions src/rai/rai/communication/sound_device_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Optional, TypedDict

import numpy as np
import sounddevice as sd
from scipy.signal import resample
from sounddevice import CallbackFlags

from rai.communication.base_connector import BaseConnector, BaseMessage


class SoundDeviceError(Exception):
def __init__(self, msg: str):
super().__init__(msg)


class AudioInputDeviceConfig(TypedDict):
block_size: int
consumer_sampling_rate: int
target_sampling_rate: int
dtype: str
device_number: Optional[int]


class ConfiguredAudioInputDevice:
"""
A class to store the configuration of an audio device
Attributes
----------
sample_rate (int): Device sample rate
consumer_sampling_rate (int): The sampling rate of the consumer
window_size_samples (int): The size of the window in samples
target_sampling_rate (int): The target sampling rate
dtype (str): The data type of the audio samples
"""

def __init__(self, config: AudioInputDeviceConfig):
self.sample_rate = sd.query_devices(
device=config["device_number"], kind="input"
)[
"default_samplerate"
] # type: ignore
self.consumer_sampling_rate = config["consumer_sampling_rate"]
self.window_size_samples = int(
config["block_size"] * self.sample_rate / config["consumer_sampling_rate"]
)
self.target_sampling_rate = int(config["target_sampling_rate"])
self.dtype = config["dtype"]


class StreamingAudioInputDevice(BaseConnector):
def __init__(self):
self.streams = {}
sd.default.latency = ("low", "low")
self.configred_devices: dict[str, ConfiguredAudioInputDevice] = {}

def configure_device(self, target: str, config: AudioInputDeviceConfig):
if target.isdigit():
if config.get("device_number") is None:
config["device_number"] = int(target)
elif config["device_number"] != int(target):
raise SoundDeviceError(
"device_number in config must be the same as target"
)
self.configred_devices[target] = ConfiguredAudioInputDevice(config)
else:
raise SoundDeviceError("target must be a device number!")

def send_message(self, msg: BaseMessage, target: str) -> None:
raise SoundDeviceError(
"StreamingAudioInputDevice does not suport sending messages"
)

def receive_message(self, source: str) -> BaseMessage:
raise SoundDeviceError(
"StreamingAudioInputDevice does not suport receiving messages messages"
)

def send_and_wait(self, target: str) -> BaseMessage:
raise SoundDeviceError(
"StreamingAudioInputDevice does not suport sending messages"
)

def start_action(
self,
target: str,
on_feedback: Callable[[np.ndarray, dict[str, Any]], None],
on_finish: Callable = lambda _: None,
) -> str:

target_device = self.configred_devices.get(target)
if target_device is None:
raise SoundDeviceError(f"Device {target} has not been configured")

def callback(indata: np.ndarray, frames: int, _, status: CallbackFlags):
indata = indata.flatten()
sample_time_length = len(indata) / target_device.target_sampling_rate
if target_device.sample_rate != target_device.target_sampling_rate:
indata = resample(indata, int(sample_time_length * target_device.target_sampling_rate)) # type: ignore
flag_dict = {
"input_overflow": status.input_overflow,
"input_underflow": status.input_underflow,
"output_overflow": status.output_overflow,
"output_underflow": status.output_underflow,
"priming_output": status.priming_output,
}
on_feedback(indata, flag_dict)

handle = self._generate_handle()
try:
stream = sd.InputStream(
samplerate=target_device.sample_rate,
channels=1,
device=int(target),
dtype=target_device.dtype,
blocksize=target_device.window_size_samples,
callback=callback,
finished_callback=on_finish,
)
except AttributeError:
raise SoundDeviceError(f"Device {target} has not been correctly configured")
stream.start()
self.streams[handle] = stream
return handle

def terminate_action(self, action_handle: str):
self.streams[action_handle].stop()
114 changes: 114 additions & 0 deletions tests/communication/test_sound_device_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

import pytest
import sounddevice as sd

from rai.communication import SoundDeviceError, StreamingAudioInputDevice


@pytest.fixture
def setup_mock_input_stream():
with mock.patch("sounddevice.InputStream") as mock_input_stream:
yield mock_input_stream


@pytest.fixture
def device_config():
return {
"block_size": 1024,
"consumer_sampling_rate": 44100,
"target_sampling_rate": 16000,
"dtype": "float32",
}


@pytest.mark.ci_only
def test_configure(
setup_mock_input_stream,
device_config,
):
mock_input_stream = setup_mock_input_stream
mock_instance = mock.MagicMock()
mock_input_stream.return_value = mock_instance
audio_input_device = StreamingAudioInputDevice()
device = sd.query_devices(kind="input")
if type(device) is dict:
device_id = str(device["index"])
elif isinstance(device, list):
device_id = str(device[0]["index"]) # type: ignore
else:
raise AssertionError("No input device found")
audio_input_device.configure_device(device_id, device_config)
assert (
audio_input_device.configred_devices[device_id].consumer_sampling_rate == 44100
)
assert audio_input_device.configred_devices[device_id].window_size_samples == 1024
assert audio_input_device.configred_devices[device_id].target_sampling_rate == 16000
assert audio_input_device.configred_devices[device_id].dtype == "float32"


@pytest.mark.ci_only
def test_start_action_failed_init(
setup_mock_input_stream,
):
mock_input_stream = setup_mock_input_stream
mock_instance = mock.MagicMock()
mock_input_stream.return_value = mock_instance
audio_input_device = StreamingAudioInputDevice()

feedback_callback = mock.MagicMock()
finish_callback = mock.MagicMock()

recording_device = 0
with pytest.raises(SoundDeviceError, match="Device 0 has not been configured"):
_ = audio_input_device.start_action(
str(recording_device), feedback_callback, finish_callback
)


@pytest.mark.ci_only
def test_start_action(
setup_mock_input_stream,
device_config,
):
mock_input_stream = setup_mock_input_stream
mock_instance = mock.MagicMock()
mock_input_stream.return_value = mock_instance
audio_input_device = StreamingAudioInputDevice()

feedback_callback = mock.MagicMock()
finish_callback = mock.MagicMock()

device = sd.query_devices(kind="input")
if type(device) is dict:
device_id = str(device["index"])
elif isinstance(device, list):
device_id = str(device[0]["index"]) # type: ignore
else:
raise AssertionError("No input device found")
audio_input_device.configure_device(device_id, device_config)

stream_handle = audio_input_device.start_action(
device_id, feedback_callback, finish_callback
)

assert mock_input_stream.call_count == 1
init_args = mock_input_stream.call_args.kwargs
assert init_args["device"] == int(device_id)
assert init_args["finished_callback"] == finish_callback

assert audio_input_device.streams.get(stream_handle) is not None

0 comments on commit 2cc2176

Please sign in to comment.