Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add basic implementation of the Connector interface and a StreamingAudioInputDevice Connector #350

Merged
merged 17 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .github/workflows/poetry-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ jobs:
- name: Install Poetry
uses: snok/install-poetry@v1

- name: Create virtual audio device
run: |
apt-get update
DEBIAN_FRONTEND=noninteractive apt-get --yes install jackd
jackd -d dummy -r 44100 &

- name: Install python dependencies
run: poetry install --with openset,nomad

Expand All @@ -63,4 +69,4 @@ jobs:
run: |
source /opt/ros/${{ matrix.ros_distro }}/setup.bash
source install/setup.bash
poetry run pytest
poetry run pytest -m "not billable"
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ profile = "black"
[tool.pytest.ini_options]
markers = [
"billable: marks test as billable (deselect with '-m \"not billable\"')",
"ci_only: marks test as cli only (deselect with '-m \"not ci_only\"')",
]
addopts = "-m 'not billable' --ignore=src"
addopts = "-m 'not billable and not ci_only' --ignore=src"
log_cli = true
log_cli_level = "DEBUG"
10 changes: 10 additions & 0 deletions src/rai/rai/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from rai.agents.conversational_agent import create_conversational_agent
from rai.agents.state_based import create_state_based_agent
from rai.agents.tool_runner import ToolRunner

__all__ = [
"ToolRunner",
"create_conversational_agent",
"create_state_based_agent",
]
23 changes: 23 additions & 0 deletions src/rai/rai/communication/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base_connector import BaseConnector, BaseMessage
from .sound_device_connector import SoundDeviceError, StreamingAudioInputDevice

__all__ = [
"BaseMessage",
"BaseConnector",
"StreamingAudioInputDevice",
"SoundDeviceError",
]
49 changes: 49 additions & 0 deletions src/rai/rai/communication/base_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Callable
from uuid import uuid4


class BaseMessage(ABC):
pass


class BaseConnector(ABC):

def _generate_handle(self) -> str:
return str(uuid4())

@abstractmethod
def send_message(self, msg: BaseMessage, target: str) -> None:
pass

@abstractmethod
def receive_message(self, source: str) -> BaseMessage:
pass

@abstractmethod
def send_and_wait(self, target: str) -> BaseMessage:
pass

@abstractmethod
def start_action(
self, target: str, on_feedback: Callable, on_finish: Callable = lambda _: None
) -> str:
pass

@abstractmethod
def terminate_action(self, action_handle: str):
pass
141 changes: 141 additions & 0 deletions src/rai/rai/communication/sound_device_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Optional, TypedDict

import numpy as np
import sounddevice as sd
from scipy.signal import resample
from sounddevice import CallbackFlags

from rai.communication.base_connector import BaseConnector, BaseMessage


class SoundDeviceError(Exception):
def __init__(self, msg: str):
super().__init__(msg)


class AudioInputDeviceConfig(TypedDict):
block_size: int
consumer_sampling_rate: int
target_sampling_rate: int
dtype: str
device_number: Optional[int]


class ConfiguredAudioInputDevice:
"""
A class to store the configuration of an audio device

Attributes
----------
sample_rate (int): Device sample rate
consumer_sampling_rate (int): The sampling rate of the consumer
window_size_samples (int): The size of the window in samples
target_sampling_rate (int): The target sampling rate
dtype (str): The data type of the audio samples
"""

def __init__(self, config: AudioInputDeviceConfig):
self.sample_rate = sd.query_devices(
device=config["device_number"], kind="input"
)[
"default_samplerate"
] # type: ignore
self.consumer_sampling_rate = config["consumer_sampling_rate"]
self.window_size_samples = int(
config["block_size"] * self.sample_rate / config["consumer_sampling_rate"]
)
self.target_sampling_rate = int(config["target_sampling_rate"])
self.dtype = config["dtype"]


class StreamingAudioInputDevice(BaseConnector):
def __init__(self):
self.streams = {}
sd.default.latency = ("low", "low")
self.configred_devices: dict[str, ConfiguredAudioInputDevice] = {}

def configure_device(self, target: str, config: AudioInputDeviceConfig):
if target.isdigit():
if config.get("device_number") is None:
config["device_number"] = int(target)
elif config["device_number"] != int(target):
raise SoundDeviceError(
"device_number in config must be the same as target"
)
self.configred_devices[target] = ConfiguredAudioInputDevice(config)
else:
raise SoundDeviceError("target must be a device number!")

def send_message(self, msg: BaseMessage, target: str) -> None:
raise SoundDeviceError(
"StreamingAudioInputDevice does not suport sending messages"
)

def receive_message(self, source: str) -> BaseMessage:
raise SoundDeviceError(
"StreamingAudioInputDevice does not suport receiving messages messages"
)

def send_and_wait(self, target: str) -> BaseMessage:
raise SoundDeviceError(
"StreamingAudioInputDevice does not suport sending messages"
)

def start_action(
self,
target: str,
on_feedback: Callable[[np.ndarray, dict[str, Any]], None],
on_finish: Callable = lambda _: None,
) -> str:

target_device = self.configred_devices.get(target)
if target_device is None:
raise SoundDeviceError(f"Device {target} has not been configured")

def callback(indata: np.ndarray, frames: int, _, status: CallbackFlags):
indata = indata.flatten()
sample_time_length = len(indata) / target_device.target_sampling_rate
if target_device.sample_rate != target_device.target_sampling_rate:
indata = resample(indata, int(sample_time_length * target_device.target_sampling_rate)) # type: ignore
flag_dict = {
"input_overflow": status.input_overflow,
"input_underflow": status.input_underflow,
"output_overflow": status.output_overflow,
"output_underflow": status.output_underflow,
"priming_output": status.priming_output,
}
on_feedback(indata, flag_dict)

handle = self._generate_handle()
try:
stream = sd.InputStream(
samplerate=target_device.sample_rate,
channels=1,
device=int(target),
dtype=target_device.dtype,
blocksize=target_device.window_size_samples,
callback=callback,
finished_callback=on_finish,
)
except AttributeError:
raise SoundDeviceError(f"Device {target} has not been correctly configured")
stream.start()
self.streams[handle] = stream
return handle

def terminate_action(self, action_handle: str):
self.streams[action_handle].stop()
114 changes: 114 additions & 0 deletions tests/communication/test_sound_device_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

import pytest
import sounddevice as sd

from rai.communication import SoundDeviceError, StreamingAudioInputDevice


@pytest.fixture
def setup_mock_input_stream():
with mock.patch("sounddevice.InputStream") as mock_input_stream:
yield mock_input_stream


@pytest.fixture
def device_config():
return {
"block_size": 1024,
"consumer_sampling_rate": 44100,
"target_sampling_rate": 16000,
"dtype": "float32",
}


@pytest.mark.ci_only
def test_configure(
setup_mock_input_stream,
device_config,
):
mock_input_stream = setup_mock_input_stream
mock_instance = mock.MagicMock()
mock_input_stream.return_value = mock_instance
audio_input_device = StreamingAudioInputDevice()
device = sd.query_devices(kind="input")
if type(device) is dict:
device_id = str(device["index"])
elif isinstance(device, list):
device_id = str(device[0]["index"]) # type: ignore
else:
assert False
maciejmajek marked this conversation as resolved.
Show resolved Hide resolved
audio_input_device.configure_device(device_id, device_config)
assert (
audio_input_device.configred_devices[device_id].consumer_sampling_rate == 44100
)
assert audio_input_device.configred_devices[device_id].window_size_samples == 1024
assert audio_input_device.configred_devices[device_id].target_sampling_rate == 16000
assert audio_input_device.configred_devices[device_id].dtype == "float32"


@pytest.mark.ci_only
def test_start_action_failed_init(
setup_mock_input_stream,
):
mock_input_stream = setup_mock_input_stream
mock_instance = mock.MagicMock()
mock_input_stream.return_value = mock_instance
audio_input_device = StreamingAudioInputDevice()

feedback_callback = mock.MagicMock()
finish_callback = mock.MagicMock()

recording_device = 0
with pytest.raises(SoundDeviceError, match="Device 0 has not been configured"):
_ = audio_input_device.start_action(
str(recording_device), feedback_callback, finish_callback
)


@pytest.mark.ci_only
def test_start_action(
setup_mock_input_stream,
device_config,
):
mock_input_stream = setup_mock_input_stream
mock_instance = mock.MagicMock()
mock_input_stream.return_value = mock_instance
audio_input_device = StreamingAudioInputDevice()

feedback_callback = mock.MagicMock()
finish_callback = mock.MagicMock()

device = sd.query_devices(kind="input")
if type(device) is dict:
device_id = str(device["index"])
elif isinstance(device, list):
device_id = str(device[0]["index"]) # type: ignore
else:
assert False
audio_input_device.configure_device(device_id, device_config)

stream_handle = audio_input_device.start_action(
device_id, feedback_callback, finish_callback
)

assert mock_input_stream.call_count == 1
init_args = mock_input_stream.call_args.kwargs
assert init_args["device"] == int(device_id)
assert init_args["finished_callback"] == finish_callback

assert audio_input_device.streams.get(stream_handle) is not None