Skip to content
This repository has been archived by the owner on Feb 9, 2023. It is now read-only.

Commit

Permalink
Unify all spicomm implementations.
Browse files Browse the repository at this point in the history
Change-Id: I0f88681782bb2a9d37b221608d6b3b49d2eb2365
  • Loading branch information
dmitriykovalev committed Jul 31, 2018
1 parent 9db1e06 commit b3e7f42
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 108 deletions.
220 changes: 113 additions & 107 deletions src/aiy/_drivers/_spicomm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,21 @@
# limitations under the License.
"""Python wrapper around the VisionBonnet Spicomm device node."""

import array
import fcntl
import mmap
import multiprocessing as mp
import os
import signal
import struct
import sys
import threading

SPICOMM_DEV = '/dev/vision_spicomm'

SPICOMM_IOCTL_TRANSACT = 0xc0108903
SPICOMM_IOCTL_TRANSACT_MMAP = 0xc0108904

HEADER_SIZE = 16
DEFAULT_PAYLOAD_SIZE = 12 * 1024 * 1024 # 12 M
HEADER_SIZE = 4 * 4
DEFAULT_PAYLOAD_SIZE = 12 * 1024 * 1024 # 12M

FLAG_ERROR = 1 << 0
FLAG_TIMEOUT = 1 << 1
Expand All @@ -40,11 +39,6 @@ class SpicommError(IOError):
pass


class SpicommDevNotFoundError(SpicommError):
"""A usable Spicomm device node not found."""
pass


class SpicommOverflowError(SpicommError):
"""Transaction buffer too small for response.
Expand All @@ -68,62 +62,75 @@ def _read_header(buf):
return struct.unpack('IIII', buf[0:HEADER_SIZE])


def _read_payload(buf):
def _read_payload(buf, payload_size):
"""Returns payload bytes."""
_, _, _, payload_size = _read_header(buf)
return buf[HEADER_SIZE:HEADER_SIZE + payload_size]


def _write_header(buf, timeout, payload_size):
"""Writes data into transaction header."""
buf[0:4] = struct.pack('I', 0) # flags (used in response)
buf[4:8] = struct.pack('I', int(timeout * 1000)) # timeout (ms)
buf[8:12] = struct.pack('I', len(buf)) # buffer size
buf[12:16] = struct.pack('I', payload_size) # payload size
def _write_header(buf, timeout_ms, payload_size):
"""Writes transaction header into buffer."""
buf[0:HEADER_SIZE] = struct.pack('IIII', 0, timeout_ms, len(buf), payload_size)


def _write_payload(buf, payload):
"""Writes transaction payload into buffer."""
buf[HEADER_SIZE:HEADER_SIZE + len(payload)] = payload


def _get_timeout(payload_size):
"""Conservatively assume min 5 seconds or 3 seconds per 1MB."""
return max(3 * payload_size / 1024 / 1024, 5)
def _get_timeout_ms(timeout, payload_size):
"""Conservatively assume minimum 5 seconds or 3 seconds per 1MB."""
if timeout is not None:
return int(1000 * timeout)

return int(1000 * max(3 * payload_size / 1024 / 1024, 5))


def _get_exception(header):
flags, timeout_ms, _, payload_size = _read_header(header)
def _get_exception(flags, timeout_ms, payload_size):
if flags & FLAG_ERROR:
if flags & FLAG_TIMEOUT:
return SpicommTimeoutError(timeout_ms / 1000.0)
elif flags & FLAG_OVERFLOW:
return SpicommOverflowError(payload_size)
return SpicommError()
return SpicommError()
return None


def _async_loop(dev, pipe):
def _check_flags(flags, timeout_ms, payload_size):
e = _get_exception(flags, timeout_ms, payload_size)
if e is not None:
raise e


def _async_loop(dev, pipe, default_payload_size):
# Essentially this process can only receive SIGKILL.
signal.signal(signal.SIGINT, signal.SIG_IGN)
signal.signal(signal.SIGTERM, signal.SIG_IGN)

allocated_buf = bytearray(HEADER_SIZE + DEFAULT_PAYLOAD_SIZE)
allocated_buf = bytearray(HEADER_SIZE + default_payload_size)
while True:
size, timeout = pipe.recv()
if size <= DEFAULT_PAYLOAD_SIZE:
payload_size, timeout = pipe.recv()
use_allocated_buf = payload_size <= (len(allocated_buf) - HEADER_SIZE)

if use_allocated_buf:
buf = allocated_buf
else:
buf = bytearray(HEADER_SIZE + size)
buf = bytearray(HEADER_SIZE + payload_size)

if timeout is None:
timeout = _get_timeout(size)
timeout_ms = _get_timeout_ms(timeout, payload_size)

_write_header(buf, timeout, size)
_write_header(buf, timeout_ms, payload_size)
pipe.recv_bytes_into(buf, HEADER_SIZE)

try:
fcntl.ioctl(dev, SPICOMM_IOCTL_TRANSACT, buf)
pipe.send(_read_payload(buf))
except (IOError, OSError) as e:
pipe.send(_get_exception(buf))
flags, _, _, payload_size = _read_header(buf)
e = _get_exception(flags, timeout_ms, payload_size)
if e is not None:
pipe.send(e)
else:
pipe.send(_read_payload(buf, payload_size))
except Exception as e:
pipe.send(e)

class AsyncSpicomm(object):
"""Class for communication with VisionBonnet via kernel driver.
Expand All @@ -133,11 +140,13 @@ class AsyncSpicomm(object):
because of global interpreter lock.
"""

def __init__(self):
self._dev = open(SPICOMM_DEV, 'r+b', 0)
def __init__(self, default_payload_size=DEFAULT_PAYLOAD_SIZE):
self._dev = os.open(SPICOMM_DEV, os.O_RDWR)
self._pipe, pipe = mp.Pipe()
self._lock = threading.Lock()
ctx = mp.get_context('fork')
self._process = ctx.Process(target=_async_loop, daemon=True, args=(self._dev, pipe))
self._process = ctx.Process(target=_async_loop, daemon=True,
args=(self._dev, pipe, default_payload_size))
self._process.start()

def __enter__(self):
Expand All @@ -149,7 +158,7 @@ def __exit__(self, exc_type, exc_value, exc_tb):
def close(self):
os.kill(self._process.pid, signal.SIGKILL)
self._process.join()
self._dev.close()
os.close(self._dev)

def transact(self, request, timeout=None):
"""Execute transaction in a separate process.
Expand All @@ -168,25 +177,26 @@ def transact(self, request, timeout=None):
"""

# Setup temporary SIGINT handler
captured_args = None
def handler(*args):
nonlocal captured_args
captured_args = args
old_handler = signal.signal(signal.SIGINT, handler)
with self._lock:
captured_args = None
def handler(*args):
nonlocal captured_args
captured_args = args
old_handler = signal.signal(signal.SIGINT, handler)

# Execute communication transaction without SIGINT interruptions
self._pipe.send((len(request), timeout))
self._pipe.send_bytes(request)
response = self._pipe.recv()
# Execute communication transaction without SIGINT interruptions
self._pipe.send((len(request), timeout))
self._pipe.send_bytes(request)
response = self._pipe.recv()

# Setup old SIGINT handler or call it directly if SIGINT already happened
signal.signal(signal.SIGINT, old_handler)
if captured_args:
old_handler(*captured_args)
# Setup old SIGINT handler or call it directly if SIGINT already happened
signal.signal(signal.SIGINT, old_handler)
if captured_args:
old_handler(*captured_args)

if isinstance(response, Exception):
raise response
return response
if isinstance(response, Exception):
raise response
return response


class SyncSpicomm(object):
Expand All @@ -196,12 +206,10 @@ class SyncSpicomm(object):
process are blocked while icotl() is running because of global interpreter lock.
"""

def __init__(self):
try:
self._dev = open(SPICOMM_DEV, 'r+b', 0)
except (IOError, OSError):
raise SpicommDevNotFoundError
self._allocated_buf = bytearray(HEADER_SIZE + DEFAULT_PAYLOAD_SIZE)
def __init__(self, default_payload_size=DEFAULT_PAYLOAD_SIZE):
self._dev = os.open(SPICOMM_DEV, os.O_RDWR)
self._allocated_buf = bytearray(HEADER_SIZE + default_payload_size)
self._lock = threading.Lock()

def __enter__(self):
return self
Expand All @@ -210,7 +218,7 @@ def __exit__(self, exc_type, exc_value, exc_tb):
self.close()

def close(self):
self._dev.close()
os.close(self._dev)

def transact(self, request, timeout=None):
"""Execute transaction in the current process.
Expand All @@ -227,45 +235,45 @@ def transact(self, request, timeout=None):
SpicommTimeoutError: Transaction timed out.
SpicommError: Transaction error.
"""
size = len(request)
if size <= DEFAULT_PAYLOAD_SIZE:
buf = self._allocated_buf
else:
buf = bytearray(HEADER_SIZE + size)
with self._lock:
payload_size = len(request)
use_allocated_buf = payload_size <= (len(self._allocated_buf) - HEADER_SIZE)

if timeout is None:
timeout = _get_timeout(size)
if use_allocated_buf:
buf = self._allocated_buf
else:
buf = bytearray(HEADER_SIZE + payload_size)

_write_header(buf, timeout, size)
_write_payload(buf, request)
timeout_ms = _get_timeout_ms(timeout, payload_size)

_write_header(buf, timeout_ms, payload_size)
_write_payload(buf, request)

try:
fcntl.ioctl(self._dev, SPICOMM_IOCTL_TRANSACT, buf)
return _read_payload(buf)
except (IOError, OSError):
raise _get_exception(buf)
flags, _, _, payload_size = _read_header(buf)
_check_flags(flags, timeout_ms, payload_size)

if use_allocated_buf:
return bytearray(_read_payload(buf, payload_size))
else:
return _read_payload(buf, payload_size)


def _transact_mmap(dev, mm, offset, request, timeout):
payload_size = len(request)
if timeout is None:
timeout = _get_timeout(payload_size)
timeout_ms = _get_timeout_ms(timeout, payload_size)
flags = 0

mm[0:payload_size] = request

buf = bytearray(HEADER_SIZE)
buf[0:4] = struct.pack('I', 0) # flags (used in response)
buf[4:8] = struct.pack('I', int(timeout * 1000)) # timeout (ms)
buf[8:12] = struct.pack('I', offset) # page offset
buf[12:16] = struct.pack('I', payload_size) # payload size
buf = bytearray(struct.pack('IIII', flags, timeout_ms, offset, payload_size))
assert(len(buf) == HEADER_SIZE)

try:
# Buffer size is small (< 1024 bytes), so ioctl call doesn't block other threads.
fcntl.ioctl(dev, SPICOMM_IOCTL_TRANSACT_MMAP, buf)
_, _, _, payload_size = _read_header(buf)
return bytearray(mm[0:payload_size])
except (IOError, OSError):
raise _get_exception(buf)
# Buffer size is small (< 1024 bytes), so ioctl call doesn't block other threads.
fcntl.ioctl(dev, SPICOMM_IOCTL_TRANSACT_MMAP, buf)
flags, _, _, payload_size = _read_header(buf)
_check_flags(flags, timeout_ms, payload_size)
return bytearray(mm[0:payload_size])


class SyncSpicommMmap(object):
Expand All @@ -275,13 +283,10 @@ class SyncSpicommMmap(object):
process are *not* blocked while icotl() is running.
"""

def __init__(self):
try:
self._dev = os.open(SPICOMM_DEV, os.O_RDWR)
except (IOError, OSError):
raise SpicommDevNotFoundError

self._mm = mmap.mmap(self._dev, length=DEFAULT_PAYLOAD_SIZE, offset=0)
def __init__(self, default_payload_size=DEFAULT_PAYLOAD_SIZE):
self._dev = os.open(SPICOMM_DEV, os.O_RDWR)
self._mm = mmap.mmap(self._dev, length=default_payload_size, offset=0)
self._lock = threading.Lock()

def __enter__(self):
return self
Expand All @@ -294,24 +299,25 @@ def close(self):
os.close(self._dev)

def transact(self, request, timeout=None):
if len(request) < len(self._mm):
return _transact_mmap(self._dev, self._mm, 0, request, timeout)
else:
offset = (len(self._mm) + (mmap.PAGESIZE - 1)) // mmap.PAGESIZE
with mmap.mmap(self._dev, length=len(request),
offset=mmap.PAGESIZE * offset) as mm:
return _transact_mmap(self._dev, mm, offset, request, timeout)
with self._lock:
if len(request) < len(self._mm):
# Default buffer
return _transact_mmap(self._dev, self._mm, 0, request, timeout)
else:
# Temporary bigger buffer
offset = (len(self._mm) + (mmap.PAGESIZE - 1)) // mmap.PAGESIZE
with mmap.mmap(self._dev, length=len(request), offset=mmap.PAGESIZE * offset) as mm:
return _transact_mmap(self._dev, mm, offset, request, timeout)


_spicomm_type = os.environ.get('VISION_BONNET_SPICOMM', None)
_spicomm_types = {'sync': SyncSpicomm,
'sync_mmap': SyncSpicommMmap,
'async': AsyncSpicomm}

# Scicomm class provides the ability to send and receive data as a transaction.
# This means that every call to transact consists of a combined
# send and receive step that's atomic from the calling application's
# point of view. Multiple threads and processes can access the device
# node concurrently using one Spicomm instance per thread.
# Transactions are serialized in the underlying kernel driver.
_spicomm_type = os.environ.get('VISION_BONNET_SPICOMM', None)
_spicomm_types = {'sync': SyncSpicomm,
'sync_mmap': SyncSpicommMmap,
'async': AsyncSpicomm}
Spicomm = _spicomm_types.get(_spicomm_type, AsyncSpicomm)
Loading

0 comments on commit b3e7f42

Please sign in to comment.