-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add basic implementation of the
Connector
interface and a `St…
…reamingAudioInputDevice` `Connector` (#350)
- Loading branch information
Showing
7 changed files
with
346 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Copyright (C) 2024 Robotec.AI | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .base_connector import BaseConnector, BaseMessage | ||
from .sound_device_connector import SoundDeviceError, StreamingAudioInputDevice | ||
|
||
__all__ = [ | ||
"BaseMessage", | ||
"BaseConnector", | ||
"StreamingAudioInputDevice", | ||
"SoundDeviceError", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Copyright (C) 2024 Robotec.AI | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Callable | ||
from uuid import uuid4 | ||
|
||
|
||
class BaseMessage(ABC): | ||
pass | ||
|
||
|
||
class BaseConnector(ABC): | ||
|
||
def _generate_handle(self) -> str: | ||
return str(uuid4()) | ||
|
||
@abstractmethod | ||
def send_message(self, msg: BaseMessage, target: str) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def receive_message(self, source: str) -> BaseMessage: | ||
pass | ||
|
||
@abstractmethod | ||
def send_and_wait(self, target: str) -> BaseMessage: | ||
pass | ||
|
||
@abstractmethod | ||
def start_action( | ||
self, target: str, on_feedback: Callable, on_finish: Callable = lambda _: None | ||
) -> str: | ||
pass | ||
|
||
@abstractmethod | ||
def terminate_action(self, action_handle: str): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# Copyright (C) 2024 Robotec.AI | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Any, Callable, Optional, TypedDict | ||
|
||
import numpy as np | ||
import sounddevice as sd | ||
from scipy.signal import resample | ||
from sounddevice import CallbackFlags | ||
|
||
from rai.communication.base_connector import BaseConnector, BaseMessage | ||
|
||
|
||
class SoundDeviceError(Exception): | ||
def __init__(self, msg: str): | ||
super().__init__(msg) | ||
|
||
|
||
class AudioInputDeviceConfig(TypedDict): | ||
block_size: int | ||
consumer_sampling_rate: int | ||
target_sampling_rate: int | ||
dtype: str | ||
device_number: Optional[int] | ||
|
||
|
||
class ConfiguredAudioInputDevice: | ||
""" | ||
A class to store the configuration of an audio device | ||
Attributes | ||
---------- | ||
sample_rate (int): Device sample rate | ||
consumer_sampling_rate (int): The sampling rate of the consumer | ||
window_size_samples (int): The size of the window in samples | ||
target_sampling_rate (int): The target sampling rate | ||
dtype (str): The data type of the audio samples | ||
""" | ||
|
||
def __init__(self, config: AudioInputDeviceConfig): | ||
self.sample_rate = sd.query_devices( | ||
device=config["device_number"], kind="input" | ||
)[ | ||
"default_samplerate" | ||
] # type: ignore | ||
self.consumer_sampling_rate = config["consumer_sampling_rate"] | ||
self.window_size_samples = int( | ||
config["block_size"] * self.sample_rate / config["consumer_sampling_rate"] | ||
) | ||
self.target_sampling_rate = int(config["target_sampling_rate"]) | ||
self.dtype = config["dtype"] | ||
|
||
|
||
class StreamingAudioInputDevice(BaseConnector): | ||
def __init__(self): | ||
self.streams = {} | ||
sd.default.latency = ("low", "low") | ||
self.configred_devices: dict[str, ConfiguredAudioInputDevice] = {} | ||
|
||
def configure_device(self, target: str, config: AudioInputDeviceConfig): | ||
if target.isdigit(): | ||
if config.get("device_number") is None: | ||
config["device_number"] = int(target) | ||
elif config["device_number"] != int(target): | ||
raise SoundDeviceError( | ||
"device_number in config must be the same as target" | ||
) | ||
self.configred_devices[target] = ConfiguredAudioInputDevice(config) | ||
else: | ||
raise SoundDeviceError("target must be a device number!") | ||
|
||
def send_message(self, msg: BaseMessage, target: str) -> None: | ||
raise SoundDeviceError( | ||
"StreamingAudioInputDevice does not suport sending messages" | ||
) | ||
|
||
def receive_message(self, source: str) -> BaseMessage: | ||
raise SoundDeviceError( | ||
"StreamingAudioInputDevice does not suport receiving messages messages" | ||
) | ||
|
||
def send_and_wait(self, target: str) -> BaseMessage: | ||
raise SoundDeviceError( | ||
"StreamingAudioInputDevice does not suport sending messages" | ||
) | ||
|
||
def start_action( | ||
self, | ||
target: str, | ||
on_feedback: Callable[[np.ndarray, dict[str, Any]], None], | ||
on_finish: Callable = lambda _: None, | ||
) -> str: | ||
|
||
target_device = self.configred_devices.get(target) | ||
if target_device is None: | ||
raise SoundDeviceError(f"Device {target} has not been configured") | ||
|
||
def callback(indata: np.ndarray, frames: int, _, status: CallbackFlags): | ||
indata = indata.flatten() | ||
sample_time_length = len(indata) / target_device.target_sampling_rate | ||
if target_device.sample_rate != target_device.target_sampling_rate: | ||
indata = resample(indata, int(sample_time_length * target_device.target_sampling_rate)) # type: ignore | ||
flag_dict = { | ||
"input_overflow": status.input_overflow, | ||
"input_underflow": status.input_underflow, | ||
"output_overflow": status.output_overflow, | ||
"output_underflow": status.output_underflow, | ||
"priming_output": status.priming_output, | ||
} | ||
on_feedback(indata, flag_dict) | ||
|
||
handle = self._generate_handle() | ||
try: | ||
stream = sd.InputStream( | ||
samplerate=target_device.sample_rate, | ||
channels=1, | ||
device=int(target), | ||
dtype=target_device.dtype, | ||
blocksize=target_device.window_size_samples, | ||
callback=callback, | ||
finished_callback=on_finish, | ||
) | ||
except AttributeError: | ||
raise SoundDeviceError(f"Device {target} has not been correctly configured") | ||
stream.start() | ||
self.streams[handle] = stream | ||
return handle | ||
|
||
def terminate_action(self, action_handle: str): | ||
self.streams[action_handle].stop() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Copyright (C) 2024 Robotec.AI | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from unittest import mock | ||
|
||
import pytest | ||
import sounddevice as sd | ||
|
||
from rai.communication import SoundDeviceError, StreamingAudioInputDevice | ||
|
||
|
||
@pytest.fixture | ||
def setup_mock_input_stream(): | ||
with mock.patch("sounddevice.InputStream") as mock_input_stream: | ||
yield mock_input_stream | ||
|
||
|
||
@pytest.fixture | ||
def device_config(): | ||
return { | ||
"block_size": 1024, | ||
"consumer_sampling_rate": 44100, | ||
"target_sampling_rate": 16000, | ||
"dtype": "float32", | ||
} | ||
|
||
|
||
@pytest.mark.ci_only | ||
def test_configure( | ||
setup_mock_input_stream, | ||
device_config, | ||
): | ||
mock_input_stream = setup_mock_input_stream | ||
mock_instance = mock.MagicMock() | ||
mock_input_stream.return_value = mock_instance | ||
audio_input_device = StreamingAudioInputDevice() | ||
device = sd.query_devices(kind="input") | ||
if type(device) is dict: | ||
device_id = str(device["index"]) | ||
elif isinstance(device, list): | ||
device_id = str(device[0]["index"]) # type: ignore | ||
else: | ||
raise AssertionError("No input device found") | ||
audio_input_device.configure_device(device_id, device_config) | ||
assert ( | ||
audio_input_device.configred_devices[device_id].consumer_sampling_rate == 44100 | ||
) | ||
assert audio_input_device.configred_devices[device_id].window_size_samples == 1024 | ||
assert audio_input_device.configred_devices[device_id].target_sampling_rate == 16000 | ||
assert audio_input_device.configred_devices[device_id].dtype == "float32" | ||
|
||
|
||
@pytest.mark.ci_only | ||
def test_start_action_failed_init( | ||
setup_mock_input_stream, | ||
): | ||
mock_input_stream = setup_mock_input_stream | ||
mock_instance = mock.MagicMock() | ||
mock_input_stream.return_value = mock_instance | ||
audio_input_device = StreamingAudioInputDevice() | ||
|
||
feedback_callback = mock.MagicMock() | ||
finish_callback = mock.MagicMock() | ||
|
||
recording_device = 0 | ||
with pytest.raises(SoundDeviceError, match="Device 0 has not been configured"): | ||
_ = audio_input_device.start_action( | ||
str(recording_device), feedback_callback, finish_callback | ||
) | ||
|
||
|
||
@pytest.mark.ci_only | ||
def test_start_action( | ||
setup_mock_input_stream, | ||
device_config, | ||
): | ||
mock_input_stream = setup_mock_input_stream | ||
mock_instance = mock.MagicMock() | ||
mock_input_stream.return_value = mock_instance | ||
audio_input_device = StreamingAudioInputDevice() | ||
|
||
feedback_callback = mock.MagicMock() | ||
finish_callback = mock.MagicMock() | ||
|
||
device = sd.query_devices(kind="input") | ||
if type(device) is dict: | ||
device_id = str(device["index"]) | ||
elif isinstance(device, list): | ||
device_id = str(device[0]["index"]) # type: ignore | ||
else: | ||
raise AssertionError("No input device found") | ||
audio_input_device.configure_device(device_id, device_config) | ||
|
||
stream_handle = audio_input_device.start_action( | ||
device_id, feedback_callback, finish_callback | ||
) | ||
|
||
assert mock_input_stream.call_count == 1 | ||
init_args = mock_input_stream.call_args.kwargs | ||
assert init_args["device"] == int(device_id) | ||
assert init_args["finished_callback"] == finish_callback | ||
|
||
assert audio_input_device.streams.get(stream_handle) is not None |