Skip to content

Commit

Permalink
Add basic implementation of the Connector interface and a StreamingAu…
Browse files Browse the repository at this point in the history
…dioInputDevice Connector
  • Loading branch information
rachwalk committed Dec 20, 2024
1 parent 412824b commit 10a04c8
Show file tree
Hide file tree
Showing 4 changed files with 298 additions and 0 deletions.
23 changes: 23 additions & 0 deletions src/rai/rai/communication/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base_connector import BaseConnector, BaseMessage
from .sound_device_connector import SoundDeviceError, StreamingAudioInputDevice

__all__ = [
"BaseMessage",
"BaseConnector",
"StreamingAudioInputDevice",
"SoundDeviceError",
]
46 changes: 46 additions & 0 deletions src/rai/rai/communication/base_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from abc import ABC, abstractmethod
from typing import Callable
from uuid import uuid4

class BaseMessage(ABC):
...

class BaseConnector(ABC):

def _generate_handle(self) -> str:
return str(uuid4())

@abstractmethod
def send_message(self, msg: BaseMessage, target: str) -> None:
...

@abstractmethod
def receive_message(self, source: str) -> BaseMessage:
...

@abstractmethod
def send_and_wait(self, target: str) -> BaseMessage:
...

@abstractmethod
def start_action(self, target: str, on_feedback: Callable, on_finish: Callable = lambda _: None) -> str:
...

@abstractmethod
def terminate_action(self, action_handle: str):
...
129 changes: 129 additions & 0 deletions src/rai/rai/communication/sound_device_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Optional, TypedDict

import numpy as np
import sounddevice as sd
from scipy.signal import resample
from sounddevice import CallbackFlags

from rai.communication.base_connector import BaseConnector, BaseMessage


class SoundDeviceError(Exception):
def __init__(self, msg: str):
super().__init__(msg)


class DeviceConfig(TypedDict):
kind: str
block_size: int
sampling_rate: int
target_smpling_rate: int
dtype: str
device_number: Optional[int]


class ConfiguredDevice:
def __init__(self, config: DeviceConfig):
self.sample_rate = sd.query_devices(
device=config["device_number"], kind=config["kind"]
)[
"default_samplerate"
] # type: ignore
self.window_size_samples = int(
config["block_size"] * self.sample_rate / config["sampling_rate"]
)
self.target_samping_rate = int(config["target_smpling_rate"])
self.dtype = config["dtype"]


class StreamingAudioInputDevice(BaseConnector):
def __init__(self):
self.streams = {}
sd.default.latency = ("low", "low")
self.configred_devices: dict[str, ConfiguredDevice] = {}

def configure_device(self, target: str, config: DeviceConfig):
if target.isdigit():
if config.get("device_number") is None:
config["device_number"] = int(target)
elif config["device_number"] != int(target):
raise SoundDeviceError(
"device_number in config must be the same as target"
)
self.configred_devices[target] = ConfiguredDevice(config)
else:
raise SoundDeviceError("target must be a device number!")

def send_message(self, msg: BaseMessage, target: str) -> None:
raise SoundDeviceError(
"StreamingAudioInputDevice does not suport sending messages"
)

def receive_message(self, source: str) -> BaseMessage:
raise SoundDeviceError(
"StreamingAudioInputDevice does not suport receiving messages messages"
)

def send_and_wait(self, target: str) -> BaseMessage:
raise SoundDeviceError(
"StreamingAudioInputDevice does not suport sending messages"
)

def start_action(
self,
target: str,
on_feedback: Callable[[np.ndarray, dict[str, Any]], None],
on_finish: Callable = lambda _: None,
) -> str:

target_device = self.configred_devices.get(target)
if target_device is None:
raise SoundDeviceError(f"Device {target} has not been configured")

def callback(indata: np.ndarray, frames: int, _, status: CallbackFlags):
indata = indata.flatten()
sample_time_length = len(indata) / target_device.target_samping_rate
if target_device.sample_rate != target_device.target_samping_rate:
indata = resample(indata, int(sample_time_length * target_device.target_samping_rate)) # type: ignore
flag_dict = {
"input_overflow": status.input_overflow,
"input_underflow": status.input_underflow,
"output_overflow": status.output_overflow,
"output_underflow": status.output_underflow,
"priming_output": status.priming_output,
}
on_feedback(indata, flag_dict)

handle = self._generate_handle()
try:
stream = sd.InputStream(
samplerate=target_device.sample_rate,
channels=1,
device=int(target),
dtype=target_device.dtype,
blocksize=target_device.window_size_samples,
callback=callback,
finished_callback=on_finish,
)
except AttributeError:
raise SoundDeviceError(f"Device {target} has not been correctly configured")
stream.start()
self.streams[handle] = stream
return handle

def terminate_action(self, action_handle: str):
self.streams[action_handle].stop()
100 changes: 100 additions & 0 deletions tests/communication/test_sound_device_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

import pytest

from rai.communication import SoundDeviceError, StreamingAudioInputDevice


@pytest.fixture
def setup_mock_input_stream():
with mock.patch("sounddevice.InputStream") as mock_input_stream:
yield mock_input_stream


@pytest.fixture
def device_config():
return {
"kind": "input",
"block_size": 1024,
"sampling_rate": 44100,
"target_smpling_rate": 16000,
"dtype": "float32",
}


def test_configure(
setup_mock_input_stream,
device_config,
):
mock_input_stream = setup_mock_input_stream
mock_instance = mock.MagicMock()
mock_input_stream.return_value = mock_instance
audio_input_device = StreamingAudioInputDevice()
audio_input_device.configure_device("0", device_config)
assert audio_input_device.configred_devices["0"].sample_rate == 44100
assert audio_input_device.configred_devices["0"].window_size_samples == 1024
assert audio_input_device.configred_devices["0"].target_samping_rate == 16000
assert audio_input_device.configred_devices["0"].dtype == "float32"


def test_start_action_failed_init(
setup_mock_input_stream,
):
mock_input_stream = setup_mock_input_stream
mock_instance = mock.MagicMock()
mock_input_stream.return_value = mock_instance
audio_input_device = StreamingAudioInputDevice()

feedback_callback = mock.MagicMock()
finish_callback = mock.MagicMock()

recording_device = 0
with pytest.raises(SoundDeviceError, match="Device 0 has not been configured"):
stream_handle = audio_input_device.start_action(
str(recording_device), feedback_callback, finish_callback
)


def test_start_action(
setup_mock_input_stream,
device_config,
):
mock_input_stream = setup_mock_input_stream
mock_instance = mock.MagicMock()
mock_input_stream.return_value = mock_instance
audio_input_device = StreamingAudioInputDevice()

feedback_callback = mock.MagicMock()
finish_callback = mock.MagicMock()

recording_device = "0"
audio_input_device.configure_device(recording_device, device_config)

stream_handle = audio_input_device.start_action(
str(recording_device), feedback_callback, finish_callback
)

assert mock_input_stream.call_count == 1
init_args = mock_input_stream.call_args.kwargs
assert init_args["samplerate"] == 44100.0
assert init_args["channels"] == 1
assert init_args["device"] == int(recording_device)
assert init_args["dtype"] == "float32"
assert init_args["blocksize"] == 1024
assert init_args["finished_callback"] == finish_callback

assert audio_input_device.streams.get(stream_handle) is not None

0 comments on commit 10a04c8

Please sign in to comment.