From 756e10c75e9af7aa3912a603372bef21dd842191 Mon Sep 17 00:00:00 2001 From: Peter Malkin Date: Thu, 11 Jan 2018 14:39:24 -0800 Subject: [PATCH] Run autopep8 on public python codebase to make travis happy pip3 install --user autopep8 find . -name *.py | xargs -n1 ~/.local/bin/autopep8 --aggressive --in-place --max-line-length 100 Change-Id: I7afbf5621b987255d31823fd8f1ca35e2406fe92 --- .travis.yml | 4 +- checkpoints/check_audio.py | 3 +- checkpoints/check_cloud.py | 2 +- checkpoints/check_wifi.py | 2 +- src/aiy/_apis/_speech.py | 3 +- src/aiy/_drivers/_buzzer.py | 12 +- src/aiy/_drivers/_hat.py | 40 +- src/aiy/_drivers/_rgbled.py | 470 ++++---- src/aiy/_drivers/_spicomm.py | 178 +-- src/aiy/_drivers/_transport.py | 82 +- src/aiy/toneplayer.py | 12 +- src/aiy/trackplayer.py | 47 +- src/aiy/vision/inference.py | 444 +++---- src/aiy/vision/models/dish_classifier.py | 62 +- .../vision/models/dish_classifier_classes.py | 2 +- src/aiy/vision/models/face_detection.py | 74 +- src/aiy/vision/models/image_classification.py | 60 +- src/aiy/vision/models/object_detection.py | 344 +++--- src/aiy/vision/models/utils.py | 6 +- src/aiy/vision/pins.py | 1074 ++++++++--------- src/examples/vision/annotator.py | 256 ++-- src/examples/vision/buzzer/buzzer_demo.py | 2 +- src/examples/vision/dish_classifier.py | 20 +- src/examples/vision/face_camera_trigger.py | 27 +- src/examples/vision/face_detection.py | 32 +- src/examples/vision/face_detection_camera.py | 94 +- src/examples/vision/gpiozero/led_example.py | 8 +- src/examples/vision/gpiozero/servo_example.py | 18 +- .../vision/gpiozero/simple_button_example.py | 8 +- src/examples/vision/image_classification.py | 36 +- src/examples/vision/joy/joy_detection_demo.py | 234 ++-- src/examples/vision/object_detection.py | 46 +- .../vision/object_detection_camera.py | 92 +- .../assistant_library_with_button_demo.py | 1 + 34 files changed, 1908 insertions(+), 1887 deletions(-) diff --git a/.travis.yml b/.travis.yml index 3b30dc49..373991c3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,11 +7,11 @@ before_install: - sudo apt-get install -y python3-numpy python3-scipy install: -- pip3 install pep8 pyflakes coverage +- pip3 install pyflakes coverage pycodestyle - pip3 install -r requirements.txt script: -- pep8 --max-line-length=100 . +- pycodestyle --max-line-length=100 --exclude=*_pb2.py . - nosetests --with-coverage after_success: diff --git a/checkpoints/check_audio.py b/checkpoints/check_audio.py index acb6fedf..08081310 100755 --- a/checkpoints/check_audio.py +++ b/checkpoints/check_audio.py @@ -39,6 +39,7 @@ RECORD_DURATION_SECONDS = 3 + def get_sound_cards(): """Read a dictionary of ALSA cards from /proc, indexed by number.""" cards = {} @@ -159,4 +160,4 @@ def main(): input('Press Enter to close...') except Exception: # pylint: disable=W0703 traceback.print_exc() - input('Press Enter to close...') \ No newline at end of file + input('Press Enter to close...') diff --git a/checkpoints/check_cloud.py b/checkpoints/check_cloud.py index 906f1a5c..6e301e96 100755 --- a/checkpoints/check_cloud.py +++ b/checkpoints/check_cloud.py @@ -102,6 +102,6 @@ def main(): try: main() input('Press Enter to close...') - except: # pylint: disable=bare-except + except Exception: traceback.print_exc() input('Press Enter to close...') diff --git a/checkpoints/check_wifi.py b/checkpoints/check_wifi.py index 0ae83313..3ec89702 100755 --- a/checkpoints/check_wifi.py +++ b/checkpoints/check_wifi.py @@ -75,6 +75,6 @@ def main(): try: main() input('Press Enter to close...') - except: # pylint: disable=bare-except + except Exception: traceback.print_exc() input('Press Enter to close...') diff --git a/src/aiy/_apis/_speech.py b/src/aiy/_apis/_speech.py index 35a2d8d2..f18a5d44 100644 --- a/src/aiy/_apis/_speech.py +++ b/src/aiy/_apis/_speech.py @@ -341,7 +341,8 @@ def _stop_sending_audio(self, resp): resp.speech_event_type) logger.info('endpointer_type: %s', speech_event_type) - END_OF_SINGLE_UTTERANCE = types.StreamingRecognizeResponse.SpeechEventType.Value('END_OF_SINGLE_UTTERANCE') + END_OF_SINGLE_UTTERANCE = types.StreamingRecognizeResponse.SpeechEventType.Value( + 'END_OF_SINGLE_UTTERANCE') return resp.speech_event_type == END_OF_SINGLE_UTTERANCE def _handle_response(self, resp): diff --git a/src/aiy/_drivers/_buzzer.py b/src/aiy/_drivers/_buzzer.py index 8f5d32a4..3f09c2c8 100644 --- a/src/aiy/_drivers/_buzzer.py +++ b/src/aiy/_drivers/_buzzer.py @@ -97,7 +97,6 @@ def _wait_for_access(self, path): if not os.access(path, os.W_OK): raise IOError('Could not open %s' % path) - def _pwrite_int(self, path, data): """Helper method to quickly write a value to a sysfs node. @@ -131,7 +130,7 @@ def _export_pwm(self): """ try: self._pwrite_int(self.PWM_SOFT_EXPORT_PATH, self.gpio) - except: + except BaseException: self._exported = False raise @@ -141,7 +140,7 @@ def _export_pwm(self): try: self._wait_for_access(period_path) self._period_fh = open(period_path, 'w') - except: + except BaseException: self._unexport_pwm() raise @@ -149,11 +148,10 @@ def _export_pwm(self): try: self._wait_for_access(pulse_path) self._pulse_fh = open(pulse_path, 'w') - except: + except BaseException: self._unexport_pwm() raise - def _unexport_pwm(self): """Unexports the given GPIO from the pwm-soft driver. @@ -161,10 +159,10 @@ def _unexport_pwm(self): previously opened, and then unexporting the given gpio. """ if self._exported: - if self._period_fh != None: + if self._period_fh is not None: self._period_fh.close() - if self._pulse_fh != None: + if self._pulse_fh is not None: self._pulse_fh.close() self._pwrite_int(self.PWM_SOFT_UNEXPORT_PATH, self.gpio) diff --git a/src/aiy/_drivers/_hat.py b/src/aiy/_drivers/_hat.py index 7da9059f..1037862f 100644 --- a/src/aiy/_drivers/_hat.py +++ b/src/aiy/_drivers/_hat.py @@ -25,28 +25,32 @@ 3: 'Voice Bonnet', } + def _is_hat_attached(): - return os.path.exists(HAT_PATH) + return os.path.exists(HAT_PATH) + def _get_hat_product(): - with open(os.path.join(HAT_PATH, 'product')) as f: - return f.readline().strip() + with open(os.path.join(HAT_PATH, 'product')) as f: + return f.readline().strip() + def _get_hat_product_id(): - with open(os.path.join(HAT_PATH, 'product_id')) as f: - matches = HAT_PRODUCT_ID_RE.match(f.readline().strip()) - if matches: - return int(matches.group(0), 16) + with open(os.path.join(HAT_PATH, 'product_id')) as f: + matches = HAT_PRODUCT_ID_RE.match(f.readline().strip()) + if matches: + return int(matches.group(0), 16) + def get_aiy_device_name(): - if not _is_hat_attached(): - return None - product = _get_hat_product() - if not 'AIY' in product: - return None - product_id = _get_hat_product_id() - if not product_id: - return None - if not product_id in AIY_HATS: - return None - return AIY_HATS[product_id] + if not _is_hat_attached(): + return None + product = _get_hat_product() + if 'AIY' not in product: + return None + product_id = _get_hat_product_id() + if not product_id: + return None + if product_id not in AIY_HATS: + return None + return AIY_HATS[product_id] diff --git a/src/aiy/_drivers/_rgbled.py b/src/aiy/_drivers/_rgbled.py index 37f667ed..29c9c671 100644 --- a/src/aiy/_drivers/_rgbled.py +++ b/src/aiy/_drivers/_rgbled.py @@ -15,250 +15,250 @@ class RGBLED(object): - """Sets the KTD2026 driver chip to show patterns with the attached RGB LED. - - Simple usage: - from aiy._drivers._rgbled import RGBLED - rgbled = RGBLED() - rgbled.SetAnimation(color=RGBLED.BLUE, pattern=RGBLED.BLINK, rate_hz=4) - """ - - OFF = 0 - ON = 1 - BLINK = 2 - BREATHE = 3 - - # These values use a mapping of red, green, blue. Changing the channel map - # will affect this. - RED = (0xFF, 0x00, 0x00) - GREEN = (0x00, 0xFF, 0x00) - YELLOW = (0xFF, 0xFF, 0x00) - BLUE = (0x00, 0x00, 0xFF) - PURPLE = (0xFF, 0x00, 0xFF) - CYAN = (0x00, 0xFF, 0xFF) - WHITE = (0xFF, 0xFF, 0xFF) - - ENABLE_OFF = 0 - ENABLE_ON = 1 - ENABLE_PWM1 = 2 - ENABLE_PWM2 = 3 - - DEFAULT_CHANNEL_MAP = {'red': 1, 'green': 2, 'blue': 3, 'privacy': 4} - - def __init__(self, channel_map=None, debug=False): - """Initializes the RGB LED driver. - - Args: - channel_map: a dictionary of name -> channel number. Determined - experimentally. Typically this will be red, gree, blue, with the - values 1, 2, 3, respectively. Defaults to the aforementioned map. - debug: whether or not to output what is being written raw to the - various sysfs nodes. - """ - if channel_map is None: - self._channel_map = self.DEFAULT_CHANNEL_MAP - else: - self._channel_map = channel_map - self._debug = debug - self.Reset() - - def __del__(self): - self.Reset() - - def _MakeChannelPath(self, channel): - """Generates a ktd202x sysfs node path from a given channel name. - - Args: - channel: a string naming the channel to select. - Returns: - A string containing the base path to the channel's LED class device - sysfs path. - """ - return '/sys/class/leds/ktd202x:led%d/' % self._channel_map[channel] - - def _PWriteInt(self, channel, filename, data): - """Helper method to quickly write a value to a channel sysfs node. - - This is functionally equivalent to the pwrite system call, though does - not have the same OS semantics. - - Args: - channel: string, the name of the channel to write to. - filename: string, the name of the file in the channel's sysfs directory to - write to. - data: integer, the value to write to the file. - """ - path = self._MakeChannelPath(channel) + filename - if self._debug: - print('_PWriteInt(channel=%s, file=%s, data=%s)' % (channel, filename, - data)) - with open(path, 'w') as output: - output.write('%d\n' % data) - - def SetChannelMapping(self, mapping): - """Set the channel mapping from color to channel number. - - Args: - mapping: dictionary of channel name (red, green, blue) to channel - number (1, 2, 3). - """ - self._channel_map = mapping - - def EnableChannel(self, channel, enable_state=ENABLE_ON): - """Sets the enable value for a given channel name. - - Args: - channel: string, the name of the channel to set the enable bits for. - enable_state: integer, one of ENABLE_OFF, ENABLE_ON, ENABLE_PWM1, or - ENABLE_PWM2. - """ - channel_num = self._channel_map[channel] - self._PWriteInt(channel, 'device/ch%d_enable' % channel_num, enable_state) - - def SetBrightness(self, channel, brightness): - """Sets a given channel's brightness value. - - Args: - channel: string, the name of the channel to set the brightness for. - brightness: integer, the brightness to set. 255 is brightest. - """ - self._PWriteInt(channel, 'brightness', brightness) - - def SetFlashPeriod(self, times_per_second): - """Sets the flash period in Hz for the whole device. - - Args: - times_per_second: float, the frequency in Hz of how frequently to - flash the LEDs. - """ - seconds_per_time = 1 / times_per_second - period = seconds_per_time * (126 / 16.38) - if period > 126: - period = 126 - self._PWriteInt('red', 'device/tflash', period) - - def SetRiseTime(self, time): - """Sets the rising time for the LED flashing. - - Args: - time: the amount of time to take to do a rise. Max 15. - """ - self._PWriteInt('red', 'device/trise', time) + """Sets the KTD2026 driver chip to show patterns with the attached RGB LED. - def SetFallTime(self, time): - """Sets the falling time for the LED flashing. - - Args: - time: the amount of time to take to do a fall. Max 15. - """ - self._PWriteInt('red', 'device/tfall', time) - - def SetPWM1Percentage(self, percentage=1): - """Sets the percentage of the flash period for PWM1 channels to be on. - - Args: - percentage: float, from 0.0 to 1.0, percentage of the flash time to - keep the channels on. - """ - self._PWriteInt('red', 'device/pwm1', int(255 * percentage)) - - def SetPWM2Percentage(self, percentage=1): - """Sets the percentage of the flash period for PWM2 channels to be on. - - Args: - percentage: float, from 0.0 to 1.0, percentage of the flash time to - keep the channels on. + Simple usage: + from aiy._drivers._rgbled import RGBLED + rgbled = RGBLED() + rgbled.SetAnimation(color=RGBLED.BLUE, pattern=RGBLED.BLINK, rate_hz=4) """ - self._PWriteInt('red', 'device/pwm2', int(255 * percentage)) - def SetColorMix(self, red=0, green=0, blue=0): - """Sets the solid color mix to display on all three channels. - - Note: this will reset the chip and force it into solid color mode. - - Args: - red: integer, max 255, the brightness of the red channel. - green: integer, max 255, the brightness of the green channel. - blue: integer, max 255, the brightness of the blue channel. - """ - # self.Reset() - colors = {'red': red, 'green': green, 'blue': blue} - for channel_name, color in colors.items(): - self.SetBrightness(channel_name, color) - - def Reset(self): - """Forces a KTD202x chip reset. - """ - self._PWriteInt('red', 'device/reset', 1) - - def _SetAnimationPattern(self, pattern=BLINK, rate_hz=1): - """Helper function to setup the given blink pattern with the given rate. - - Note: resets the chip. - - Args: - pattern: integer, one of OFF, ON, BLINK, or BREATHE. ON is solid on, - BLINK is a 50% duty cycle hard blink with no ramps enabled, BREATHE - is a 30% duty cycle soft blink with ramps set to 5. - rate_hz: float, the rate in Hz of how often to blink the given - pattern. Irrelevant for OFF or ON. - """ - self.Reset() - if pattern == self.ON: - self.SetPWM1Percentage(1) - elif pattern == self.BLINK: - self.SetFlashPeriod(rate_hz) - self.SetPWM1Percentage(0.5) - elif pattern == self.BREATHE: - self.SetFlashPeriod(rate_hz) - self.SetRiseTime(5) - self.SetFallTime(5) - self.SetPWM1Percentage(0.3) - - def SetAnimation(self, color=RED, pattern=BLINK, rate_hz=1): - """Sets the given animation for the given color at the given rate. - - Note: resets the chip. - - Args: - color: tuple, one of RED, GREEN, YELLOW, BLUE, PURPLE, CYAN, WHITE, or - a tuple of three values signifying which channels to enable (1 - enables, 0 disables) for the given flashing sequence. - pattern: integer, one of OFF, ON, BLINK, or BREATHE. ON is solid on, - BLINK is a 50% duty cycle hard blink with no ramps enabled, BREATHE - is a 30% duty cycle soft blink with ramps set to 5. - rate_hz: float, the rate in Hz of how often to blink the given - pattern. Irrelevant for OFF or ON. - """ - self._SetAnimationPattern(pattern, rate_hz) - for (color, value) in zip(('red', 'green', 'blue'), color): - print((color, value)) - state = self.ENABLE_OFF - if value > 0: - state = self.ENABLE_PWM1 - self.EnableChannel(color, state) + OFF = 0 + ON = 1 + BLINK = 2 + BREATHE = 3 + + # These values use a mapping of red, green, blue. Changing the channel map + # will affect this. + RED = (0xFF, 0x00, 0x00) + GREEN = (0x00, 0xFF, 0x00) + YELLOW = (0xFF, 0xFF, 0x00) + BLUE = (0x00, 0x00, 0xFF) + PURPLE = (0xFF, 0x00, 0xFF) + CYAN = (0x00, 0xFF, 0xFF) + WHITE = (0xFF, 0xFF, 0xFF) + + ENABLE_OFF = 0 + ENABLE_ON = 1 + ENABLE_PWM1 = 2 + ENABLE_PWM2 = 3 + + DEFAULT_CHANNEL_MAP = {'red': 1, 'green': 2, 'blue': 3, 'privacy': 4} + + def __init__(self, channel_map=None, debug=False): + """Initializes the RGB LED driver. + + Args: + channel_map: a dictionary of name -> channel number. Determined + experimentally. Typically this will be red, gree, blue, with the + values 1, 2, 3, respectively. Defaults to the aforementioned map. + debug: whether or not to output what is being written raw to the + various sysfs nodes. + """ + if channel_map is None: + self._channel_map = self.DEFAULT_CHANNEL_MAP + else: + self._channel_map = channel_map + self._debug = debug + self.Reset() + + def __del__(self): + self.Reset() + + def _MakeChannelPath(self, channel): + """Generates a ktd202x sysfs node path from a given channel name. + + Args: + channel: a string naming the channel to select. + Returns: + A string containing the base path to the channel's LED class device + sysfs path. + """ + return '/sys/class/leds/ktd202x:led%d/' % self._channel_map[channel] + + def _PWriteInt(self, channel, filename, data): + """Helper method to quickly write a value to a channel sysfs node. + + This is functionally equivalent to the pwrite system call, though does + not have the same OS semantics. + + Args: + channel: string, the name of the channel to write to. + filename: string, the name of the file in the channel's sysfs directory to + write to. + data: integer, the value to write to the file. + """ + path = self._MakeChannelPath(channel) + filename + if self._debug: + print('_PWriteInt(channel=%s, file=%s, data=%s)' % (channel, filename, + data)) + with open(path, 'w') as output: + output.write('%d\n' % data) + + def SetChannelMapping(self, mapping): + """Set the channel mapping from color to channel number. + + Args: + mapping: dictionary of channel name (red, green, blue) to channel + number (1, 2, 3). + """ + self._channel_map = mapping + + def EnableChannel(self, channel, enable_state=ENABLE_ON): + """Sets the enable value for a given channel name. + + Args: + channel: string, the name of the channel to set the enable bits for. + enable_state: integer, one of ENABLE_OFF, ENABLE_ON, ENABLE_PWM1, or + ENABLE_PWM2. + """ + channel_num = self._channel_map[channel] + self._PWriteInt(channel, 'device/ch%d_enable' % channel_num, enable_state) + + def SetBrightness(self, channel, brightness): + """Sets a given channel's brightness value. + + Args: + channel: string, the name of the channel to set the brightness for. + brightness: integer, the brightness to set. 255 is brightest. + """ + self._PWriteInt(channel, 'brightness', brightness) + + def SetFlashPeriod(self, times_per_second): + """Sets the flash period in Hz for the whole device. + + Args: + times_per_second: float, the frequency in Hz of how frequently to + flash the LEDs. + """ + seconds_per_time = 1 / times_per_second + period = seconds_per_time * (126 / 16.38) + if period > 126: + period = 126 + self._PWriteInt('red', 'device/tflash', period) + + def SetRiseTime(self, time): + """Sets the rising time for the LED flashing. + + Args: + time: the amount of time to take to do a rise. Max 15. + """ + self._PWriteInt('red', 'device/trise', time) + + def SetFallTime(self, time): + """Sets the falling time for the LED flashing. + + Args: + time: the amount of time to take to do a fall. Max 15. + """ + self._PWriteInt('red', 'device/tfall', time) + + def SetPWM1Percentage(self, percentage=1): + """Sets the percentage of the flash period for PWM1 channels to be on. + + Args: + percentage: float, from 0.0 to 1.0, percentage of the flash time to + keep the channels on. + """ + self._PWriteInt('red', 'device/pwm1', int(255 * percentage)) + + def SetPWM2Percentage(self, percentage=1): + """Sets the percentage of the flash period for PWM2 channels to be on. + + Args: + percentage: float, from 0.0 to 1.0, percentage of the flash time to + keep the channels on. + """ + self._PWriteInt('red', 'device/pwm2', int(255 * percentage)) + + def SetColorMix(self, red=0, green=0, blue=0): + """Sets the solid color mix to display on all three channels. + + Note: this will reset the chip and force it into solid color mode. + + Args: + red: integer, max 255, the brightness of the red channel. + green: integer, max 255, the brightness of the green channel. + blue: integer, max 255, the brightness of the blue channel. + """ + # self.Reset() + colors = {'red': red, 'green': green, 'blue': blue} + for channel_name, color in colors.items(): + self.SetBrightness(channel_name, color) + + def Reset(self): + """Forces a KTD202x chip reset. + """ + self._PWriteInt('red', 'device/reset', 1) + + def _SetAnimationPattern(self, pattern=BLINK, rate_hz=1): + """Helper function to setup the given blink pattern with the given rate. + + Note: resets the chip. + + Args: + pattern: integer, one of OFF, ON, BLINK, or BREATHE. ON is solid on, + BLINK is a 50% duty cycle hard blink with no ramps enabled, BREATHE + is a 30% duty cycle soft blink with ramps set to 5. + rate_hz: float, the rate in Hz of how often to blink the given + pattern. Irrelevant for OFF or ON. + """ + self.Reset() + if pattern == self.ON: + self.SetPWM1Percentage(1) + elif pattern == self.BLINK: + self.SetFlashPeriod(rate_hz) + self.SetPWM1Percentage(0.5) + elif pattern == self.BREATHE: + self.SetFlashPeriod(rate_hz) + self.SetRiseTime(5) + self.SetFallTime(5) + self.SetPWM1Percentage(0.3) + + def SetAnimation(self, color=RED, pattern=BLINK, rate_hz=1): + """Sets the given animation for the given color at the given rate. + + Note: resets the chip. + + Args: + color: tuple, one of RED, GREEN, YELLOW, BLUE, PURPLE, CYAN, WHITE, or + a tuple of three values signifying which channels to enable (1 + enables, 0 disables) for the given flashing sequence. + pattern: integer, one of OFF, ON, BLINK, or BREATHE. ON is solid on, + BLINK is a 50% duty cycle hard blink with no ramps enabled, BREATHE + is a 30% duty cycle soft blink with ramps set to 5. + rate_hz: float, the rate in Hz of how often to blink the given + pattern. Irrelevant for OFF or ON. + """ + self._SetAnimationPattern(pattern, rate_hz) + for (color, value) in zip(('red', 'green', 'blue'), color): + print((color, value)) + state = self.ENABLE_OFF + if value > 0: + state = self.ENABLE_PWM1 + self.EnableChannel(color, state) class PrivacyLED(RGBLED): - """Wrapper for LED driver to enable/disable privacy LED + """Wrapper for LED driver to enable/disable privacy LED - Simple usage: - from aiy._drivers._rgbled import PrivacyLED - with PrivacyLED() # Illuminated on entry. - """ + Simple usage: + from aiy._drivers._rgbled import PrivacyLED + with PrivacyLED() # Illuminated on entry. + """ - def __init__(self): - """Initializes the parent LED driver. + def __init__(self): + """Initializes the parent LED driver. - Configures PWM2 to breathe on the privacy channel. + Configures PWM2 to breathe on the privacy channel. - """ - super().__init__() + """ + super().__init__() - def __enter__(self): - """Configures the privacy channel to be fully illuminated.""" - super().EnableChannel(channel='privacy', enable_state=RGBLED.ENABLE_ON) + def __enter__(self): + """Configures the privacy channel to be fully illuminated.""" + super().EnableChannel(channel='privacy', enable_state=RGBLED.ENABLE_ON) - def __exit__(self, exc_type, exc_value, exc_tb): - """On exit, turn off the LED.""" - super().EnableChannel(channel='privacy', enable_state=RGBLED.ENABLE_OFF) + def __exit__(self, exc_type, exc_value, exc_tb): + """On exit, turn off the LED.""" + super().EnableChannel(channel='privacy', enable_state=RGBLED.ENABLE_OFF) diff --git a/src/aiy/_drivers/_spicomm.py b/src/aiy/_drivers/_spicomm.py index c4876bf1..38557d8a 100644 --- a/src/aiy/_drivers/_spicomm.py +++ b/src/aiy/_drivers/_spicomm.py @@ -33,111 +33,111 @@ class SpicommError(IOError): - """Base class for Spicomm errors.""" - pass + """Base class for Spicomm errors.""" + pass class SpicommDevNotFoundError(SpicommError): - """A usable Spicomm device node not found.""" - pass + """A usable Spicomm device node not found.""" + pass class SpicommOverflowError(SpicommError): - """Transaction buffer too small for response. + """Transaction buffer too small for response. - Attributes: - size: Number of bytes needed for the response. - """ + Attributes: + size: Number of bytes needed for the response. + """ - def __init__(self, size): - self.size = size - super(SpicommOverflowError, self).__init__() + def __init__(self, size): + self.size = size + super(SpicommOverflowError, self).__init__() class SpicommTimeoutError(SpicommError): - """Transaction timed out.""" - pass + """Transaction timed out.""" + pass class SpicommInternalError(SpicommError): - """Internal unexpected error.""" - pass + """Internal unexpected error.""" + pass class Spicomm(object): - """VisionBonnet Spicomm wrapper. - - Provides the ability to send and receive data as a transaction. - This means that every call to transact consists of a combined - send and receive step that's atomic from the calling application's - point of view. Multiple threads and processes can access the device - node concurrently using one Spicomm instance per thread. - Transactions are serialized in the underlying kernel driver. - """ - - def __init__(self): - try: - self._dev = open(SPICOMM_DEV, 'r+b', 0) - except (IOError, OSError): - raise SpicommDevNotFoundError - self._tbuf = bytearray(HEADER_SIZE + PAYLOAD_SIZE) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, exc_tb): - self.close() - - def close(self): - if self._dev: - self._dev.close() - - def transact(self, request, timeout=15): - """Execute a Spicomm transaction. - - The bytes in request are sent, a response is waited for and returned. - If the request or response is too large SpicommOverflowError is raised. - - Args: - request: Request bytes to send. - timeout: How long a response will be waited for, in seconds. - - Returns: - Bytes-like object with response data. - - Raises: - SpicommOverflowError: Transaction buffer was too small for response. - The 'size' attribute contains the required size. - SpicommTimeoutError : Transaction timed out. - SpicommInternalError: Unexpected error interacting with kernel driver. + """VisionBonnet Spicomm wrapper. + + Provides the ability to send and receive data as a transaction. + This means that every call to transact consists of a combined + send and receive step that's atomic from the calling application's + point of view. Multiple threads and processes can access the device + node concurrently using one Spicomm instance per thread. + Transactions are serialized in the underlying kernel driver. """ - payload_len = len(request) - if payload_len > PAYLOAD_SIZE: - raise SpicommOverflowError(PAYLOAD_SIZE) - - # Fill in transaction buffer. - self._tbuf[0:4] = struct.pack('I', 0) # flags, not currently used. - self._tbuf[4:8] = struct.pack('I', int(timeout * 1000)) # timeout, ms. - self._tbuf[8:12] = struct.pack('I', len(self._tbuf)) # total buffer size. - self._tbuf[12:16] = struct.pack('I', payload_len) # filled range of buffer. - self._tbuf[16:16 + payload_len] = request - - try: - # Send transaction to kernel driver. - fcntl.ioctl(self._dev, SPICOMM_IOCTL_TRANSACT, self._tbuf) - - # No exception means errno 0 and self._tbuf is now mutated. - _, _, _, payload_len = struct.unpack('IIII', self._tbuf[0:16]) - return self._tbuf[16:16 + payload_len] - except (IOError, OSError): - # FLAG_ERROR is set if we actually talked to the kernel. - flags, _, _, payload_len = struct.unpack('IIII', self._tbuf[0:16]) - if flags & FLAG_ERROR: - if flags & FLAG_TIMEOUT: - raise SpicommTimeoutError - elif flags & FLAG_OVERFLOW: - raise SpicommOverflowError(payload_len) - - # This is unexpected. - raise SpicommInternalError + def __init__(self): + try: + self._dev = open(SPICOMM_DEV, 'r+b', 0) + except (IOError, OSError): + raise SpicommDevNotFoundError + self._tbuf = bytearray(HEADER_SIZE + PAYLOAD_SIZE) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + self.close() + + def close(self): + if self._dev: + self._dev.close() + + def transact(self, request, timeout=15): + """Execute a Spicomm transaction. + + The bytes in request are sent, a response is waited for and returned. + If the request or response is too large SpicommOverflowError is raised. + + Args: + request: Request bytes to send. + timeout: How long a response will be waited for, in seconds. + + Returns: + Bytes-like object with response data. + + Raises: + SpicommOverflowError: Transaction buffer was too small for response. + The 'size' attribute contains the required size. + SpicommTimeoutError : Transaction timed out. + SpicommInternalError: Unexpected error interacting with kernel driver. + """ + + payload_len = len(request) + if payload_len > PAYLOAD_SIZE: + raise SpicommOverflowError(PAYLOAD_SIZE) + + # Fill in transaction buffer. + self._tbuf[0:4] = struct.pack('I', 0) # flags, not currently used. + self._tbuf[4:8] = struct.pack('I', int(timeout * 1000)) # timeout, ms. + self._tbuf[8:12] = struct.pack('I', len(self._tbuf)) # total buffer size. + self._tbuf[12:16] = struct.pack('I', payload_len) # filled range of buffer. + self._tbuf[16:16 + payload_len] = request + + try: + # Send transaction to kernel driver. + fcntl.ioctl(self._dev, SPICOMM_IOCTL_TRANSACT, self._tbuf) + + # No exception means errno 0 and self._tbuf is now mutated. + _, _, _, payload_len = struct.unpack('IIII', self._tbuf[0:16]) + return self._tbuf[16:16 + payload_len] + except (IOError, OSError): + # FLAG_ERROR is set if we actually talked to the kernel. + flags, _, _, payload_len = struct.unpack('IIII', self._tbuf[0:16]) + if flags & FLAG_ERROR: + if flags & FLAG_TIMEOUT: + raise SpicommTimeoutError + elif flags & FLAG_OVERFLOW: + raise SpicommOverflowError(payload_len) + + # This is unexpected. + raise SpicommInternalError diff --git a/src/aiy/_drivers/_transport.py b/src/aiy/_drivers/_transport.py index 6fb1f92a..a0251da4 100644 --- a/src/aiy/_drivers/_transport.py +++ b/src/aiy/_drivers/_transport.py @@ -23,69 +23,69 @@ class _SpiTransport(object): - """Communicate with VisionBonnet over SPI bus.""" + """Communicate with VisionBonnet over SPI bus.""" - def __init__(self): - self._spicomm = _spicomm.Spicomm() + def __init__(self): + self._spicomm = _spicomm.Spicomm() - # TODO(dkovalev): add timeout when implemented in Spicomm - def send(self, request): - return self._spicomm.transact(request) + # TODO(dkovalev): add timeout when implemented in Spicomm + def send(self, request): + return self._spicomm.transact(request) - def close(self): - self._spicomm.close() + def close(self): + self._spicomm.close() def _socket_recvall(s, size): - buf = b'' - while size: - newbuf = s.recv(size) - if not newbuf: - return None - buf += newbuf - size -= len(newbuf) - return buf + buf = b'' + while size: + newbuf = s.recv(size) + if not newbuf: + return None + buf += newbuf + size -= len(newbuf) + return buf def _socket_receive_message(s): - buf = _socket_recvall(s, 4) # 4 bytes - if not buf: - return None - size = struct.unpack('!I', buf)[0] - return _socket_recvall(s, size) + buf = _socket_recvall(s, 4) # 4 bytes + if not buf: + return None + size = struct.unpack('!I', buf)[0] + return _socket_recvall(s, size) def _socket_send_message(s, msg): - s.sendall(struct.pack('!I', len(msg))) # 4 bytes - s.sendall(msg) # len(msg) bytes + s.sendall(struct.pack('!I', len(msg))) # 4 bytes + s.sendall(msg) # len(msg) bytes class _SocketTransport(object): - """Communicate with VisionBonnet over socket.""" + """Communicate with VisionBonnet over socket.""" - def __init__(self): - """Open connection to the bonnet.""" - self._client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + def __init__(self): + """Open connection to the bonnet.""" + self._client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - host = os.environ.get('VISION_BONNET_HOST', '172.28.28.10') - port = int(os.environ.get('VISION_BONNET_PORT', '35000')) - self._client.connect((host, port)) + host = os.environ.get('VISION_BONNET_HOST', '172.28.28.10') + port = int(os.environ.get('VISION_BONNET_PORT', '35000')) + self._client.connect((host, port)) - # TODO(dkovalev,weiranzhao): add timeout parameter - def send(self, request): - _socket_send_message(self._client, request) - return _socket_receive_message(self._client) + # TODO(dkovalev,weiranzhao): add timeout parameter + def send(self, request): + _socket_send_message(self._client, request) + return _socket_receive_message(self._client) - def close(self): - self._client.close() + def close(self): + self._client.close() def _is_arm(): - return os.uname()[4].startswith('arm') + return os.uname()[4].startswith('arm') def make_transport(): - if _is_arm(): - return _SpiTransport() - else: - return _SocketTransport() + if _is_arm(): + return _SpiTransport() + else: + return _SocketTransport() diff --git a/src/aiy/toneplayer.py b/src/aiy/toneplayer.py index 0dc700b4..0a1f2cdc 100644 --- a/src/aiy/toneplayer.py +++ b/src/aiy/toneplayer.py @@ -157,16 +157,16 @@ def _parse(self, array): def _parse_note(self, note_str): """Parses a single note/rest string into its given class instance.""" result = TonePlayer.REST_RE.match(note_str) - if result != None: + if result is not None: length = TonePlayer.PERIOD_MAP[result.group('length')] return Rest(self.bpm, length) result = TonePlayer.NOTE_RE.match(note_str) - if result != None: + if result is not None: name = result.group('name') octave = 4 - if result.group('octave') != None: + if result.group('octave') is not None: octave = int(result.group('octave')) if octave > 8: octave = 8 @@ -174,7 +174,7 @@ def _parse_note(self, note_str): octave = 1 length = Rest.QUARTER - if result.group('length') != None: + if result.group('length') is not None: length = TonePlayer.PERIOD_MAP[result.group('length')] return Note(name, octave, self.bpm, length) @@ -188,7 +188,9 @@ def play(self, *args): for note in parsed_notes: if isinstance(note, Note): if self.debug: - print(note.name + str(note.octave), '(' + str(note.to_frequency()) + ')', str(note.to_length_secs()) + 's') + print(note.name + str(note.octave), + '(' + str(note.to_frequency()) + ')', + str(note.to_length_secs()) + 's') controller.set_frequency(note.to_frequency()) else: controller.set_frequency(0) diff --git a/src/aiy/trackplayer.py b/src/aiy/trackplayer.py index 2031559c..0ea588d6 100644 --- a/src/aiy/trackplayer.py +++ b/src/aiy/trackplayer.py @@ -25,6 +25,7 @@ class Command(object): """Base class for all commands.""" + def apply(self, player, controller, note, tick_delta): """Applies the effect of this command.""" pass @@ -39,8 +40,10 @@ def parse(klass, *args): """ pass + class Glissando(Command): """Pitchbends a note up or down by the given rate.""" + def __init__(self, direction, hz_per_tick): self.direction = direction self.hz_per_tick = hz_per_tick @@ -61,6 +64,7 @@ def parse(klass, *args): class PulseChange(Command): """Changes the pulse width of a note up or down by the given rate.""" + def __init__(self, direction, usec_per_tick): self.direction = direction self.usec_per_tick = usec_per_tick @@ -81,6 +85,7 @@ def parse(klass, *args): class SetPulseWidth(Command): """Changes the pulse width of a note up or down by the given rate.""" + def __init__(self, pulse_width_usec): self.pulse = pulse_width_usec @@ -98,6 +103,7 @@ def parse(klass, *args): class Arpeggio(Command): """Plays an arpeggiated chord.""" + def __init__(self, *args): self.chord = args @@ -125,12 +131,13 @@ def parse(klass, *args): class Vibrato(Command): """Vibrates the frequency by the given amount.""" + def __init__(self, depth_hz, speed): self.depth_hz = depth_hz self.speed = speed def apply(self, player, controller, note, tick_delta): - freq_delta = round(math.sin(tick_delta * (1/self.speed))) + freq_delta = round(math.sin(tick_delta * (1 / self.speed))) freq = note.to_frequency() freq += freq_delta * self.depth_hz controller.set_frequency(int(freq)) @@ -144,8 +151,10 @@ def parse(klass, *args): speed = int(args[1]) return klass(depth_hz, speed), 2 + class Retrigger(Command): """Retriggers a note a consecutive number of times.""" + def __init__(self, times): self.times = times @@ -168,6 +177,7 @@ def parse(klass, *args): class NoteOff(Command): """Stops a given note from playing.""" + def apply(self, player, controller, note, tick_delta): if tick_delta == 0: controller.set_frequency(0) @@ -182,6 +192,7 @@ def parse(klass, *args): class SetSpeed(Command): """Changes the speed of the given song.""" + def __init__(self, speed): self.speed = speed @@ -200,6 +211,7 @@ def parse(klass, *args): class JumpToPosition(Command): """Jumps to the given position in a song.""" + def __init__(self, position): self.position = position @@ -218,6 +230,7 @@ def parse(klass, *args): class StopPlaying(Command): """Stops the TrackPlayer from playing.""" + def apply(self, player, controller, note, tick_delta): if tick_delta == 0: controller.set_frequency(0) @@ -258,31 +271,31 @@ def add_pattern(self, pattern): The new pattern index. """ self.patterns.append(pattern) - if self.debug == True: + if self.debug: print('Added new pattern %d' % (len(self.patterns) - 1)) return len(self.patterns) - 1 def add_order(self, pattern_number): """Adds a pattern index to the order.""" - if self.debug == True: + if self.debug: print('Adding order[%d] == %d' % (len(self.order), pattern_number)) self.order.append(pattern_number) def set_order(self, position, pattern_number): """Changes a pattern index in the order.""" - if self.debug == True: + if self.debug: print('Setting order[%d] == %d' % (position, pattern_number)) self.order[position] = pattern_number def set_speed(self, new_speed): """Sets the playing speed in ticks/row.""" - if self.debug == True: + if self.debug: print('Setting speed to %d' % (new_speed)) self.speed = new_speed def set_position(self, new_position): """Sets the position inside of the current pattern.""" - if self.debug == True: + if self.debug: print('Jumping position to %d' % (new_position)) self.current_position = new_Position @@ -298,7 +311,7 @@ def play(self): self.playing = True with PWMController(self.gpio) as controller: - while self.playing == True: + while self.playing: if self.current_order >= len(self.order): self.current_order = 0 @@ -321,21 +334,22 @@ def play(self): if isinstance(note_command, Command): last_command = note_command note_command.apply(self, controller, last_note, t) - if self.playing == False: + if self.playing: print() return self.tick += 1 time.sleep(0.01) - if self.debug == True: + if self.debug: print(' ' * 70 + '\r', end='') - print('pos: %03d pattern: %02d' % (self.current_position, self.current_pattern), end='') - if last_note != None: + print('pos: %03d pattern: %02d' % + (self.current_position, self.current_pattern), end='') + if last_note is not None: print(' note: %s' % (str(last_note)), end='') else: print(' ', end='') - if last_command != None: + if last_command is not None: print(' command: %s' % (str(last_command)), end='') self.current_position += 1 @@ -344,6 +358,7 @@ def play(self): controller.set_frequency(0) + class TrackLoader(object): """Simple track module loader. @@ -484,13 +499,13 @@ def _parse_pattern_line(self, line): while word_idx < len(line): word = line[word_idx] result = TrackLoader.NOTE_RE.match(word) - if result != None: + if result is not None: name = result.group('name') octave = result.group('octave') row.append(Note(result.group('name'), int(result.group('octave')))) result = TrackLoader.COMMAND_RE.match(word) - if result != None: + if result is not None: name = result.group('name') args = line[word_idx + 1:] klass = TrackLoader.COMMANDS[name] @@ -504,7 +519,7 @@ def _parse_pattern_line(self, line): def _debug(self, str, *args): """Helper method to print out a line only if debug is on.""" - if self.debug == True: + if self.debug: print(str % args) def load(self): @@ -529,7 +544,7 @@ def load(self): for line in lines: line = line.split() - if header_finished == True: + if header_finished: if len(line) == 0: if not between_patterns: current_pattern.append([]) diff --git a/src/aiy/vision/inference.py b/src/aiy/vision/inference.py index e6112d41..96b8f16b 100644 --- a/src/aiy/vision/inference.py +++ b/src/aiy/vision/inference.py @@ -27,258 +27,258 @@ def _tobytes(img): - try: - return img.tobytes() - except AttributeError: - return img.tostring() + try: + return img.tobytes() + except AttributeError: + return img.tostring() class CameraInference(object): - """Helper class to run camera inference.""" + """Helper class to run camera inference.""" - def __init__(self, descriptor, params=None): - self._engine = InferenceEngine() - self._key = self._engine.load_model(descriptor) - self._engine.start_camera_inference(self._key, params) + def __init__(self, descriptor, params=None): + self._engine = InferenceEngine() + self._key = self._engine.load_model(descriptor) + self._engine.start_camera_inference(self._key, params) - def camera_state(self): - return self._engine.get_camera_state() + def camera_state(self): + return self._engine.get_camera_state() - def run(self): - while True: - yield self._engine.camera_inference() + def run(self): + while True: + yield self._engine.camera_inference() - def close(self): - self._engine.stop_camera_inference() - self._engine.unload_model(self._key) + def close(self): + self._engine.stop_camera_inference() + self._engine.unload_model(self._key) - def __enter__(self): - return self + def __enter__(self): + return self - def __exit__(self, exc_type, exc_value, exc_tb): - self.close() + def __exit__(self, exc_type, exc_value, exc_tb): + self.close() class ImageInference(object): - """Helper class to run image inference.""" + """Helper class to run image inference.""" - def __init__(self, descriptor): - self._engine = InferenceEngine() - self._key = self._engine.load_model(descriptor) + def __init__(self, descriptor): + self._engine = InferenceEngine() + self._key = self._engine.load_model(descriptor) - def run(self, image, params=None): - return self._engine.image_inference(self._key, image, params) + def run(self, image, params=None): + return self._engine.image_inference(self._key, image, params) - def close(self): - self._engine.unload_model(self._key) + def close(self): + self._engine.unload_model(self._key) - def __enter__(self): - return self + def __enter__(self): + return self - def __exit__(self, exc_type, exc_value, exc_tb): - self.close() + def __exit__(self, exc_type, exc_value, exc_tb): + self.close() class ModelDescriptor(object): - """Info used by VisionBonnet to load model.""" - - def __init__(self, name, input_shape, input_normalizer, compute_graph): - """Initialzes ModelDescriptor. - - Args: - name: string, a name used to refer the model, should not conflict - with existing model names. - input_shape: (batch, height, width, depth). For now, only batch=1 and - depth=3 are supported. - input_normalizer: (mean, stddev) to convert input image (for analysis) to - the same range model is - trained. For example, if the model is trained with [-1, 1] input. To - analyze an RGB image (input range [0, 255]), one needs to specify the - input normalizer as (128.0, 128.0). - compute_graph: string, converted model proto - """ - self.name = name - self.input_shape = input_shape - self.input_normalizer = input_normalizer - self.compute_graph = compute_graph + """Info used by VisionBonnet to load model.""" + + def __init__(self, name, input_shape, input_normalizer, compute_graph): + """Initialzes ModelDescriptor. + + Args: + name: string, a name used to refer the model, should not conflict + with existing model names. + input_shape: (batch, height, width, depth). For now, only batch=1 and + depth=3 are supported. + input_normalizer: (mean, stddev) to convert input image (for analysis) to + the same range model is + trained. For example, if the model is trained with [-1, 1] input. To + analyze an RGB image (input range [0, 255]), one needs to specify the + input normalizer as (128.0, 128.0). + compute_graph: string, converted model proto + """ + self.name = name + self.input_shape = input_shape + self.input_normalizer = input_normalizer + self.compute_graph = compute_graph class InferenceException(Exception): - def __init__(self, *args, **kwargs): - Exception.__init__(self, *args, **kwargs) + def __init__(self, *args, **kwargs): + Exception.__init__(self, *args, **kwargs) class InferenceEngine(object): - """Class to access InferenceEngine on VisionBonnet board. - - Inference result has the following format: - - message InferenceResult { - string model_name; // Name of the model to run inference on. - int32 width; // Input image/frame width. - int32 height; // Input image/frame height. - Rectangle window; // Window inside width x height image/frame. - int32 duration_ms; // Inference duration. - map tensors; // Output tensors. - - message Frame { - int32 index; // Frame number. - int64 timestamp_us; // Frame timestamp. - } - - Frame frame; // Frame-specific inference data. - } - """ - - def __init__(self): - self._transport = make_transport() - logging.info('InferenceEngine transport: %s', - self._transport.__class__.__name__) - - def close(self): - self._transport.close() + """Class to access InferenceEngine on VisionBonnet board. - def __enter__(self): - return self + Inference result has the following format: - def __exit__(self, exc_type, exc_value, exc_tb): - self.close() + message InferenceResult { + string model_name; // Name of the model to run inference on. + int32 width; // Input image/frame width. + int32 height; // Input image/frame height. + Rectangle window; // Window inside width x height image/frame. + int32 duration_ms; // Inference duration. + map tensors; // Output tensors. - def _communicate(self, request, debug=False): - """Gets response and logs messages if need to. + message Frame { + int32 index; // Frame number. + int64 timestamp_us; // Frame timestamp. + } - Args: - request: protocol_pb2.Request - debug: boolean, if True and response contains inference result, print up - to the first 10 elements of each returned vector. - - Returns: - protocol_pb2.Response - """ - response = protocol_pb2.Response() - response.ParseFromString(self._transport.send(request.SerializeToString())) - if response.status.code != protocol_pb2.Response.Status.OK: - raise InferenceException(response.status.message) - - if debug: - # Print up to the first 10 elements of each returned vector. - for name, tensor in response.result.tensors.items(): - data = tensor.data[0:10] - logging.info('First %d elements of output tensor:', len(data)) - for index, value in enumerate(data): - logging.info(' %s[%d] = %f', name, index, value) - - return response - - def load_model(self, descriptor): - """Loads model on VisionBonnet. - - Args: - descriptor: ModelDescriptor, meta info that defines model name, - where to get the model and etc. - Returns: - Model identifier. - """ - logging.info('Loading model "%s"...', descriptor.name) - - batch, height, width, depth = descriptor.input_shape - assert batch == 1, 'Only batch == 1 is currently supported' - assert depth == 3, 'Only depth == 3 is currently supported' - mean, stddev = descriptor.input_normalizer - - request = protocol_pb2.Request() - request.load_model.model_name = descriptor.name - request.load_model.input_shape.batch = batch - request.load_model.input_shape.height = height - request.load_model.input_shape.width = width - request.load_model.input_shape.depth = depth - request.load_model.input_normalizer.mean = mean - request.load_model.input_normalizer.stddev = stddev - if descriptor.compute_graph: - request.load_model.compute_graph = descriptor.compute_graph - - try: - self._communicate(request) - except InferenceException as e: - logging.warning(str(e)) - - return descriptor.name - - def unload_model(self, model_name): - """Deletes model on VisionBonnet. - - Args: - model_name: string, unique identifier used to refer a model. - """ - logging.info('Unloading model "%s"...', model_name) - - request = protocol_pb2.Request() - request.unload_model.model_name = model_name - self._communicate(request) - - def start_camera_inference(self, model_name, params=None): - """Starts inference running on VisionBonnet.""" - request = protocol_pb2.Request() - request.start_camera_inference.model_name = model_name - - for key, value in (params or {}).items(): - request.start_camera_inference.params[key] = str(value) - - self._communicate(request) - - def camera_inference(self): - """Returns the latest inference result from VisionBonnet.""" - request = protocol_pb2.Request() - request.camera_inference.SetInParent() - return self._communicate(request).inference_result - - def stop_camera_inference(self): - """Stops inference running on VisionBonnet.""" - request = protocol_pb2.Request() - request.stop_camera_inference.SetInParent() - self._communicate(request) - - def get_camera_state(self): - request = protocol_pb2.Request() - request.get_camera_state.SetInParent() - return self._communicate(request).camera_state - - def image_inference(self, model_name, image, params=None): - """Runs inference on image using model (identified by model_name). - - Args: - model_name: string, unique identifier used to refer a model. - image: PIL.Image, - params: dict, additional parameters to run inference - - Returns: - protocol_pb2.Response + Frame frame; // Frame-specific inference data. + } """ - assert model_name, 'model_name must not be empty' - - logging.info('Image inference with model "%s"...', model_name) - - width, height = image.size - - request = protocol_pb2.Request() - request.image_inference.model_name = model_name - request.image_inference.tensor.shape.height = height - request.image_inference.tensor.shape.width = width - - if image.mode == 'RGB': - r, g, b = image.split() - request.image_inference.tensor.shape.depth = 3 - request.image_inference.tensor.data = _tobytes(r) + _tobytes( - g) + _tobytes(b) - elif image.mode == 'L': - request.image_inference.tensor.shape.depth = 1 - request.image_inference.tensor.data = _tobytes(image) - else: - assert False, 'Only RGB and L modes are supported.' - - for key, value in (params or {}).items(): - request.image_inference.params[key] = str(value) - - return self._communicate(request).inference_result + def __init__(self): + self._transport = make_transport() + logging.info('InferenceEngine transport: %s', + self._transport.__class__.__name__) + + def close(self): + self._transport.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + self.close() + + def _communicate(self, request, debug=False): + """Gets response and logs messages if need to. + + Args: + request: protocol_pb2.Request + debug: boolean, if True and response contains inference result, print up + to the first 10 elements of each returned vector. + + Returns: + protocol_pb2.Response + """ + response = protocol_pb2.Response() + response.ParseFromString(self._transport.send(request.SerializeToString())) + if response.status.code != protocol_pb2.Response.Status.OK: + raise InferenceException(response.status.message) + + if debug: + # Print up to the first 10 elements of each returned vector. + for name, tensor in response.result.tensors.items(): + data = tensor.data[0:10] + logging.info('First %d elements of output tensor:', len(data)) + for index, value in enumerate(data): + logging.info(' %s[%d] = %f', name, index, value) + + return response + + def load_model(self, descriptor): + """Loads model on VisionBonnet. + + Args: + descriptor: ModelDescriptor, meta info that defines model name, + where to get the model and etc. + Returns: + Model identifier. + """ + logging.info('Loading model "%s"...', descriptor.name) + + batch, height, width, depth = descriptor.input_shape + assert batch == 1, 'Only batch == 1 is currently supported' + assert depth == 3, 'Only depth == 3 is currently supported' + mean, stddev = descriptor.input_normalizer + + request = protocol_pb2.Request() + request.load_model.model_name = descriptor.name + request.load_model.input_shape.batch = batch + request.load_model.input_shape.height = height + request.load_model.input_shape.width = width + request.load_model.input_shape.depth = depth + request.load_model.input_normalizer.mean = mean + request.load_model.input_normalizer.stddev = stddev + if descriptor.compute_graph: + request.load_model.compute_graph = descriptor.compute_graph + + try: + self._communicate(request) + except InferenceException as e: + logging.warning(str(e)) + + return descriptor.name + + def unload_model(self, model_name): + """Deletes model on VisionBonnet. + + Args: + model_name: string, unique identifier used to refer a model. + """ + logging.info('Unloading model "%s"...', model_name) + + request = protocol_pb2.Request() + request.unload_model.model_name = model_name + self._communicate(request) + + def start_camera_inference(self, model_name, params=None): + """Starts inference running on VisionBonnet.""" + request = protocol_pb2.Request() + request.start_camera_inference.model_name = model_name + + for key, value in (params or {}).items(): + request.start_camera_inference.params[key] = str(value) + + self._communicate(request) + + def camera_inference(self): + """Returns the latest inference result from VisionBonnet.""" + request = protocol_pb2.Request() + request.camera_inference.SetInParent() + return self._communicate(request).inference_result + + def stop_camera_inference(self): + """Stops inference running on VisionBonnet.""" + request = protocol_pb2.Request() + request.stop_camera_inference.SetInParent() + self._communicate(request) + + def get_camera_state(self): + request = protocol_pb2.Request() + request.get_camera_state.SetInParent() + return self._communicate(request).camera_state + + def image_inference(self, model_name, image, params=None): + """Runs inference on image using model (identified by model_name). + + Args: + model_name: string, unique identifier used to refer a model. + image: PIL.Image, + params: dict, additional parameters to run inference + + Returns: + protocol_pb2.Response + """ + + assert model_name, 'model_name must not be empty' + + logging.info('Image inference with model "%s"...', model_name) + + width, height = image.size + + request = protocol_pb2.Request() + request.image_inference.model_name = model_name + request.image_inference.tensor.shape.height = height + request.image_inference.tensor.shape.width = width + + if image.mode == 'RGB': + r, g, b = image.split() + request.image_inference.tensor.shape.depth = 3 + request.image_inference.tensor.data = _tobytes(r) + _tobytes( + g) + _tobytes(b) + elif image.mode == 'L': + request.image_inference.tensor.shape.depth = 1 + request.image_inference.tensor.data = _tobytes(image) + else: + assert False, 'Only RGB and L modes are supported.' + + for key, value in (params or {}).items(): + request.image_inference.params[key] = str(value) + + return self._communicate(request).inference_result diff --git a/src/aiy/vision/models/dish_classifier.py b/src/aiy/vision/models/dish_classifier.py index c4f8669e..0be5c53c 100644 --- a/src/aiy/vision/models/dish_classifier.py +++ b/src/aiy/vision/models/dish_classifier.py @@ -21,37 +21,37 @@ def model(): - return ModelDescriptor( - name='dish_classifier', - input_shape=(1, 192, 192, 3), - input_normalizer=(128.0, 128.0), - compute_graph=utils.load_compute_graph(_COMPUTE_GRAPH_NAME)) + return ModelDescriptor( + name='dish_classifier', + input_shape=(1, 192, 192, 3), + input_normalizer=(128.0, 128.0), + compute_graph=utils.load_compute_graph(_COMPUTE_GRAPH_NAME)) def get_classes(result, max_num_objects=None, object_prob_threshold=0.0): - """Converts dish classifier model output to list of detected objects. - - Args: - result: output tensor from dish classifier model. - max_num_objects: int; max number of objects to return. - object_prob_threshold: float; min probability of each returned object. - - Returns: - A list of (class_name: string, probability: float) pairs ordered by - probability from highest to lowest. The number of pairs is not greater than - max_num_objects. Each probability is greater than object_prob_threshold. For - example: - - [('Ramen', 0.981934) - ('Yaka mein, 0.005497)] - """ - assert len(result.tensors) == 1 - tensor = result.tensors['MobilenetV1/Predictions/Softmax'] - probs, shape = tensor.data, tensor.shape - assert (shape.batch, shape.height, shape.width, shape.depth) == (1, 1, 1, - 2024) - - pairs = [pair for pair in enumerate(probs) if pair[1] > object_prob_threshold] - pairs = sorted(pairs, key=lambda pair: pair[1], reverse=True) - pairs = pairs[0:max_num_objects] - return [('/'.join(CLASSES[index]), prob) for index, prob in pairs] + """Converts dish classifier model output to list of detected objects. + + Args: + result: output tensor from dish classifier model. + max_num_objects: int; max number of objects to return. + object_prob_threshold: float; min probability of each returned object. + + Returns: + A list of (class_name: string, probability: float) pairs ordered by + probability from highest to lowest. The number of pairs is not greater than + max_num_objects. Each probability is greater than object_prob_threshold. For + example: + + [('Ramen', 0.981934) + ('Yaka mein, 0.005497)] + """ + assert len(result.tensors) == 1 + tensor = result.tensors['MobilenetV1/Predictions/Softmax'] + probs, shape = tensor.data, tensor.shape + assert (shape.batch, shape.height, shape.width, shape.depth) == (1, 1, 1, + 2024) + + pairs = [pair for pair in enumerate(probs) if pair[1] > object_prob_threshold] + pairs = sorted(pairs, key=lambda pair: pair[1], reverse=True) + pairs = pairs[0:max_num_objects] + return [('/'.join(CLASSES[index]), prob) for index, prob in pairs] diff --git a/src/aiy/vision/models/dish_classifier_classes.py b/src/aiy/vision/models/dish_classifier_classes.py index 5fd45cd9..70c711ea 100644 --- a/src/aiy/vision/models/dish_classifier_classes.py +++ b/src/aiy/vision/models/dish_classifier_classes.py @@ -1242,7 +1242,7 @@ ('Beef bourguignon',), # index=1225 ('Truffade',), # index=1226 ('B\xc3\xb2 n\xc6\xb0\xe1\xbb\x9bng l\xc3\xa1 l\xe1\xbb\x91t', - ), # index=1227 + ), # index=1227 ('Ful medames',), # index=1228 ('Aligot',), # index=1229 ('Kolach',), # index=1230 diff --git a/src/aiy/vision/models/face_detection.py b/src/aiy/vision/models/face_detection.py index d5365fdf..078cdeab 100644 --- a/src/aiy/vision/models/face_detection.py +++ b/src/aiy/vision/models/face_detection.py @@ -22,52 +22,52 @@ def _reshape(array, width): - assert len(array) % width == 0 - height = len(array) // width - return [array[i * width:(i + 1) * width] for i in range(height)] + assert len(array) % width == 0 + height = len(array) // width + return [array[i * width:(i + 1) * width] for i in range(height)] class Face(object): - """Face detection result.""" + """Face detection result.""" - def __init__(self, bounding_box, face_score, joy_score): - """Creates a new Face instance. + def __init__(self, bounding_box, face_score, joy_score): + """Creates a new Face instance. - Args: - bounding_box: (x, y, width, height). - face_score: float, face confidence score. - joy_score: float, face joy score. - """ - self.bounding_box = bounding_box - self.face_score = face_score - self.joy_score = joy_score + Args: + bounding_box: (x, y, width, height). + face_score: float, face confidence score. + joy_score: float, face joy score. + """ + self.bounding_box = bounding_box + self.face_score = face_score + self.joy_score = joy_score - def __str__(self): - return 'face_score=%f, joy_score=%f, bbox=%s' % (self.face_score, - self.joy_score, - str(self.bounding_box)) + def __str__(self): + return 'face_score=%f, joy_score=%f, bbox=%s' % (self.face_score, + self.joy_score, + str(self.bounding_box)) def model(): - # Face detection model has special implementation in VisionBonnet firmware. - # input_shape, input_normalizer, and computate_graph params have on effect. - return ModelDescriptor( - name='FaceDetection', - input_shape=(1, 0, 0, 3), - input_normalizer=(0, 0), - compute_graph=utils.load_compute_graph(_COMPUTE_GRAPH_NAME)) + # Face detection model has special implementation in VisionBonnet firmware. + # input_shape, input_normalizer, and computate_graph params have on effect. + return ModelDescriptor( + name='FaceDetection', + input_shape=(1, 0, 0, 3), + input_normalizer=(0, 0), + compute_graph=utils.load_compute_graph(_COMPUTE_GRAPH_NAME)) def get_faces(result): - """Retunrs list of Face objects decoded from the inference result.""" - assert len(result.tensors) == 3 - # TODO(dkovalev): check tensor shapes - bboxes = _reshape(result.tensors['bounding_boxes'].data, 4) - face_scores = result.tensors['face_scores'].data - joy_scores = result.tensors['joy_scores'].data - assert len(bboxes) == len(joy_scores) - assert len(bboxes) == len(face_scores) - return [ - Face(tuple(bbox), face_score, joy_score) - for bbox, face_score, joy_score in zip(bboxes, face_scores, joy_scores) - ] + """Retunrs list of Face objects decoded from the inference result.""" + assert len(result.tensors) == 3 + # TODO(dkovalev): check tensor shapes + bboxes = _reshape(result.tensors['bounding_boxes'].data, 4) + face_scores = result.tensors['face_scores'].data + joy_scores = result.tensors['joy_scores'].data + assert len(bboxes) == len(joy_scores) + assert len(bboxes) == len(face_scores) + return [ + Face(tuple(bbox), face_score, joy_score) + for bbox, face_score, joy_score in zip(bboxes, face_scores, joy_scores) + ] diff --git a/src/aiy/vision/models/image_classification.py b/src/aiy/vision/models/image_classification.py index 22667fc1..20bd89dd 100644 --- a/src/aiy/vision/models/image_classification.py +++ b/src/aiy/vision/models/image_classification.py @@ -36,40 +36,40 @@ def model(model_type=MOBILENET): - return ModelDescriptor( - name=model_type, - input_shape=(1, 160, 160, 3), - input_normalizer=(128.0, 128.0), - compute_graph=utils.load_compute_graph( - _COMPUTE_GRAPH_NAME_MAP[model_type])) + return ModelDescriptor( + name=model_type, + input_shape=(1, 160, 160, 3), + input_normalizer=(128.0, 128.0), + compute_graph=utils.load_compute_graph( + _COMPUTE_GRAPH_NAME_MAP[model_type])) def get_classes(result, max_num_objects=None, object_prob_threshold=0.0): - """Converts image classification model output to list of detected objects. + """Converts image classification model output to list of detected objects. - Args: - result: output tensor from image classification model. - max_num_objects: int; max number of objects to return. - object_prob_threshold: float; min probability of each returned object. + Args: + result: output tensor from image classification model. + max_num_objects: int; max number of objects to return. + object_prob_threshold: float; min probability of each returned object. - Returns: - A list of (class_name: string, probability: float) pairs ordered by - probability from highest to lowest. The number of pairs is not greater than - max_num_objects. Each probability is greater than object_prob_threshold. For - example: + Returns: + A list of (class_name: string, probability: float) pairs ordered by + probability from highest to lowest. The number of pairs is not greater than + max_num_objects. Each probability is greater than object_prob_threshold. For + example: - [('Egyptian cat', 0.767578) - ('tiger cat, 0.163574) - ('lynx/catamount', 0.039795)] - """ - assert len(result.tensors) == 1 - tensor_name = _OUTPUT_TENSOR_NAME_MAP[result.model_name] - tensor = result.tensors[tensor_name] - probs, shape = tensor.data, tensor.shape - assert (shape.batch, shape.height, shape.width, shape.depth) == (1, 1, 1, - 1001) + [('Egyptian cat', 0.767578) + ('tiger cat, 0.163574) + ('lynx/catamount', 0.039795)] + """ + assert len(result.tensors) == 1 + tensor_name = _OUTPUT_TENSOR_NAME_MAP[result.model_name] + tensor = result.tensors[tensor_name] + probs, shape = tensor.data, tensor.shape + assert (shape.batch, shape.height, shape.width, shape.depth) == (1, 1, 1, + 1001) - pairs = [pair for pair in enumerate(probs) if pair[1] > object_prob_threshold] - pairs = sorted(pairs, key=lambda pair: pair[1], reverse=True) - pairs = pairs[0:max_num_objects] - return [('/'.join(CLASSES[index]), prob) for index, prob in pairs] + pairs = [pair for pair in enumerate(probs) if pair[1] > object_prob_threshold] + pairs = sorted(pairs, key=lambda pair: pair[1], reverse=True) + pairs = pairs[0:max_num_objects] + return [('/'.join(CLASSES[index]), prob) for index, prob in pairs] diff --git a/src/aiy/vision/models/object_detection.py b/src/aiy/vision/models/object_detection.py index 4d3af7c0..9fbc104b 100644 --- a/src/aiy/vision/models/object_detection.py +++ b/src/aiy/vision/models/object_detection.py @@ -22,201 +22,201 @@ class Object(object): - """Object detection result.""" - BACKGROUND = 0 - PERSON = 1 - CAT = 2 - DOG = 3 - - _LABELS = { - BACKGROUND: 'BACKGROUND', - PERSON: 'PERSON', - CAT: 'CAT', - DOG: 'DOG', - } - - def __init__(self, bounding_box, kind, score): - """Initialization. - - Args: - bounding_box: a tuple of 4 ints, (x, y, width, height) order. - kind: int, tells what object is in the bounding box. - score: float, confidence score. - """ - self.bounding_box = bounding_box - self.kind = kind - self.score = score - - def __str__(self): - return 'kind=%s(%d), score=%f, bbox=%s' % (self._LABELS[self.kind], - self.kind, self.score, - str(self.bounding_box)) + """Object detection result.""" + BACKGROUND = 0 + PERSON = 1 + CAT = 2 + DOG = 3 + + _LABELS = { + BACKGROUND: 'BACKGROUND', + PERSON: 'PERSON', + CAT: 'CAT', + DOG: 'DOG', + } + + def __init__(self, bounding_box, kind, score): + """Initialization. + + Args: + bounding_box: a tuple of 4 ints, (x, y, width, height) order. + kind: int, tells what object is in the bounding box. + score: float, confidence score. + """ + self.bounding_box = bounding_box + self.kind = kind + self.score = score + + def __str__(self): + return 'kind=%s(%d), score=%f, bbox=%s' % (self._LABELS[self.kind], + self.kind, self.score, + str(self.bounding_box)) def _reshape(array, height, width): - assert len(array) == height * width - return [array[i * width:(i + 1) * width] for i in range(height)] + assert len(array) == height * width + return [array[i * width:(i + 1) * width] for i in range(height)] def _decode_and_nms_detection_result(logit_scores, box_encodings, anchors, score_threshold, image_size, offset): - """Decodes result as bounding boxes and runs Non-Maximum Suppression. - - Args: - logit_scores: list of scores - box_encodings: list of bounding boxes - anchors: list of anchors - score_threshold: float, bounding box candidates below this threshold will - be rejected. - image_size: (width, height) - offset: (x, y) - Returns: - A list of ObjectDetection.Result. - """ - - assert len(box_encodings) == len(anchors) - assert len(logit_scores) == len(anchors) - - x0, y0 = offset - results = [] - for logit_score, box_encoding, anchor in zip(logit_scores, box_encodings, - anchors): - scores = _logit_score_to_score(logit_score) - max_score_index, max_score = max(enumerate(scores), key=lambda x: x[1]) - # Skip if max score is below threshold or max score is 'background'. - if max_score <= score_threshold or max_score_index == 0: - continue - - x, y, w, h = _decode_box_encoding(box_encoding, anchor, image_size) - results.append(Object((x0 + x, y0 + y, w, h), max_score_index, max_score)) - - return _non_maximum_suppression(results) + """Decodes result as bounding boxes and runs Non-Maximum Suppression. + + Args: + logit_scores: list of scores + box_encodings: list of bounding boxes + anchors: list of anchors + score_threshold: float, bounding box candidates below this threshold will + be rejected. + image_size: (width, height) + offset: (x, y) + Returns: + A list of ObjectDetection.Result. + """ + + assert len(box_encodings) == len(anchors) + assert len(logit_scores) == len(anchors) + + x0, y0 = offset + results = [] + for logit_score, box_encoding, anchor in zip(logit_scores, box_encodings, + anchors): + scores = _logit_score_to_score(logit_score) + max_score_index, max_score = max(enumerate(scores), key=lambda x: x[1]) + # Skip if max score is below threshold or max score is 'background'. + if max_score <= score_threshold or max_score_index == 0: + continue + + x, y, w, h = _decode_box_encoding(box_encoding, anchor, image_size) + results.append(Object((x0 + x, y0 + y, w, h), max_score_index, max_score)) + + return _non_maximum_suppression(results) def _logit_score_to_score(logit_score): - return [1.0 / (1.0 + math.exp(-val)) for val in logit_score] + return [1.0 / (1.0 + math.exp(-val)) for val in logit_score] def _decode_box_encoding(box_encoding, anchor, image_size): - """Decodes bounding box encoding. - - Args: - box_encoding: a tuple of 4 floats. - anchor: a tuple of 4 floats. - image_size: a tuple of 2 ints, (width, height) - Returns: - A tuple of 4 integer, in the order of (left, upper, right, lower). - """ - assert len(box_encoding) == 4 - assert len(anchor) == 4 - y_scale = 10.0 - x_scale = 10.0 - height_scale = 5.0 - width_scale = 5.0 - - rel_y_translation = box_encoding[0] / y_scale - rel_x_translation = box_encoding[1] / x_scale - rel_height_dilation = box_encoding[2] / height_scale - rel_width_dilation = box_encoding[3] / width_scale - - anchor_ymin, anchor_xmin, anchor_ymax, anchor_xmax = anchor - anchor_ycenter = (anchor_ymax + anchor_ymin) / 2 - anchor_xcenter = (anchor_xmax + anchor_xmin) / 2 - anchor_height = anchor_ymax - anchor_ymin - anchor_width = anchor_xmax - anchor_xmin - - ycenter = anchor_ycenter + anchor_height * rel_y_translation - xcenter = anchor_xcenter + anchor_width * rel_x_translation - height = math.exp(rel_height_dilation) * anchor_height - width = math.exp(rel_width_dilation) * anchor_width - - image_width, image_height = image_size - x0 = int(max(0.0, xcenter - width / 2) * image_width) - y0 = int(max(0.0, ycenter - height / 2) * image_height) - x1 = int(min(1.0, xcenter + width / 2) * image_width) - y1 = int(min(1.0, ycenter + height / 2) * image_height) - return (x0, y0, x1 - x0, y1 - y0) + """Decodes bounding box encoding. + + Args: + box_encoding: a tuple of 4 floats. + anchor: a tuple of 4 floats. + image_size: a tuple of 2 ints, (width, height) + Returns: + A tuple of 4 integer, in the order of (left, upper, right, lower). + """ + assert len(box_encoding) == 4 + assert len(anchor) == 4 + y_scale = 10.0 + x_scale = 10.0 + height_scale = 5.0 + width_scale = 5.0 + + rel_y_translation = box_encoding[0] / y_scale + rel_x_translation = box_encoding[1] / x_scale + rel_height_dilation = box_encoding[2] / height_scale + rel_width_dilation = box_encoding[3] / width_scale + + anchor_ymin, anchor_xmin, anchor_ymax, anchor_xmax = anchor + anchor_ycenter = (anchor_ymax + anchor_ymin) / 2 + anchor_xcenter = (anchor_xmax + anchor_xmin) / 2 + anchor_height = anchor_ymax - anchor_ymin + anchor_width = anchor_xmax - anchor_xmin + + ycenter = anchor_ycenter + anchor_height * rel_y_translation + xcenter = anchor_xcenter + anchor_width * rel_x_translation + height = math.exp(rel_height_dilation) * anchor_height + width = math.exp(rel_width_dilation) * anchor_width + + image_width, image_height = image_size + x0 = int(max(0.0, xcenter - width / 2) * image_width) + y0 = int(max(0.0, ycenter - height / 2) * image_height) + x1 = int(min(1.0, xcenter + width / 2) * image_width) + y1 = int(min(1.0, ycenter + height / 2) * image_height) + return (x0, y0, x1 - x0, y1 - y0) def _overlap_ratio(box1, box2): - """Computes overlap ratio of two bounding boxes. - - Args: - box1: (x, y, width, height). - box2: (x, y, width, height). - - Returns: - float, represents overlap ratio between given boxes. - """ - - def _area(box): - _, _, width, height = box - area = width * height - assert area >= 0 - return area - - def _intersection_area(box1, box2): - x1, y1, width1, height1 = box1 - x2, y2, width2, height2 = box2 - x = max(x1, x2) - y = max(y1, y2) - width = max(min(x1 + width1, x2 + width2) - x, 0) - height = max(min(y1 + height1, y2 + height2) - y, 0) - area = width * height - assert area >= 0 - return area - - intersection_area = _intersection_area(box1, box2) - union_area = _area(box1) + _area(box2) - intersection_area - assert union_area >= 0 - if union_area > 0: - return float(intersection_area) / float(union_area) - return 1.0 + """Computes overlap ratio of two bounding boxes. + + Args: + box1: (x, y, width, height). + box2: (x, y, width, height). + + Returns: + float, represents overlap ratio between given boxes. + """ + + def _area(box): + _, _, width, height = box + area = width * height + assert area >= 0 + return area + + def _intersection_area(box1, box2): + x1, y1, width1, height1 = box1 + x2, y2, width2, height2 = box2 + x = max(x1, x2) + y = max(y1, y2) + width = max(min(x1 + width1, x2 + width2) - x, 0) + height = max(min(y1 + height1, y2 + height2) - y, 0) + area = width * height + assert area >= 0 + return area + + intersection_area = _intersection_area(box1, box2) + union_area = _area(box1) + _area(box2) - intersection_area + assert union_area >= 0 + if union_area > 0: + return float(intersection_area) / float(union_area) + return 1.0 def _non_maximum_suppression(boxes, overlap_threshold=0.5): - """Runs Non Maximum Suppression. - - Removes box candidate that overlaps with existing candidate who has higher - score. - - Args: - boxes: list of Object - overlap_threshold: float - Returns: - A list of Object - """ - boxes = sorted(boxes, key=lambda x: x.score, reverse=True) - for i in range(len(boxes)): - if boxes[i].score < 0.0: - continue - # Suppress any nearby bounding boxes having lower score than boxes[i] - for j in range(i + 1, len(boxes)): - if boxes[j].score < 0.0: - continue - if _overlap_ratio(boxes[i].bounding_box, - boxes[j].bounding_box) > overlap_threshold: - boxes[j].score = -1.0 # Suppress box - - return [box for box in boxes if box.score >= 0.0] # Exclude suppressed boxes + """Runs Non Maximum Suppression. + + Removes box candidate that overlaps with existing candidate who has higher + score. + + Args: + boxes: list of Object + overlap_threshold: float + Returns: + A list of Object + """ + boxes = sorted(boxes, key=lambda x: x.score, reverse=True) + for i in range(len(boxes)): + if boxes[i].score < 0.0: + continue + # Suppress any nearby bounding boxes having lower score than boxes[i] + for j in range(i + 1, len(boxes)): + if boxes[j].score < 0.0: + continue + if _overlap_ratio(boxes[i].bounding_box, + boxes[j].bounding_box) > overlap_threshold: + boxes[j].score = -1.0 # Suppress box + + return [box for box in boxes if box.score >= 0.0] # Exclude suppressed boxes def model(): - return ModelDescriptor( - name='object_detection', - input_shape=(1, 256, 256, 3), - input_normalizer=(128.0, 128.0), - compute_graph=utils.load_compute_graph(_COMPUTE_GRAPH_NAME)) + return ModelDescriptor( + name='object_detection', + input_shape=(1, 256, 256, 3), + input_normalizer=(128.0, 128.0), + compute_graph=utils.load_compute_graph(_COMPUTE_GRAPH_NAME)) # TODO: check all tensor shapes def get_objects(result, score_threshold=0.3, offset=(0, 0)): - assert len(result.tensors) == 2 - logit_scores = result.tensors['concat_1'].data - logit_scores = _reshape(logit_scores, len(ANCHORS), 4) - box_encodings = result.tensors['concat'].data - box_encodings = _reshape(box_encodings, len(ANCHORS), 4) - - size = (result.window.width, result.window.height) - return _decode_and_nms_detection_result(logit_scores, box_encodings, ANCHORS, - score_threshold, size, offset) + assert len(result.tensors) == 2 + logit_scores = result.tensors['concat_1'].data + logit_scores = _reshape(logit_scores, len(ANCHORS), 4) + box_encodings = result.tensors['concat'].data + box_encodings = _reshape(box_encodings, len(ANCHORS), 4) + + size = (result.window.width, result.window.height) + return _decode_and_nms_detection_result(logit_scores, box_encodings, ANCHORS, + score_threshold, size, offset) diff --git a/src/aiy/vision/models/utils.py b/src/aiy/vision/models/utils.py index c448012e..007705c5 100644 --- a/src/aiy/vision/models/utils.py +++ b/src/aiy/vision/models/utils.py @@ -4,6 +4,6 @@ def load_compute_graph(name): - path = os.environ.get('VISION_BONNET_MODELS_PATH', '/opt/aiy/models') - with open(os.path.join(path, name), 'rb') as f: - return f.read() + path = os.environ.get('VISION_BONNET_MODELS_PATH', '/opt/aiy/models') + with open(os.path.join(path, name), 'rb') as f: + return f.read() diff --git a/src/aiy/vision/pins.py b/src/aiy/vision/pins.py index ee580673..cdcae21d 100644 --- a/src/aiy/vision/pins.py +++ b/src/aiy/vision/pins.py @@ -26,33 +26,33 @@ def _detect_gpio_offset(module_path): - for folder in listdir(module_path): - try: - with open('%s/%s/base' % (module_path, folder), 'r') as offset: - return int(offset.read()) - except IOError: - pass - return None + for folder in listdir(module_path): + try: + with open('%s/%s/base' % (module_path, folder), 'r') as offset: + return int(offset.read()) + except IOError: + pass + return None _FsNodeSpec = namedtuple('_FsNodeSpec', ['pin', 'name']) class GpioSpec(_FsNodeSpec): - _MODULE_PATH = '/sys/bus/i2c/drivers/aiy-io-i2c/1-0051/gpio-aiy-io/gpio' - _PIN_OFFSET = _detect_gpio_offset(_MODULE_PATH) + _MODULE_PATH = '/sys/bus/i2c/drivers/aiy-io-i2c/1-0051/gpio-aiy-io/gpio' + _PIN_OFFSET = _detect_gpio_offset(_MODULE_PATH) - def __new__(cls, pin, name): - return super(GpioSpec, cls).__new__(cls, GpioSpec._PIN_OFFSET + pin, name) + def __new__(cls, pin, name): + return super(GpioSpec, cls).__new__(cls, GpioSpec._PIN_OFFSET + pin, name) - def __str__(self): - return 'gpio %s (%d)' % (self.name, self.pin - self._PIN_OFFSET) + def __str__(self): + return 'gpio %s (%d)' % (self.name, self.pin - self._PIN_OFFSET) class PwmSpec(_FsNodeSpec): - def __str__(self): - return 'pwm %d' % self.pin + def __str__(self): + return 'pwm %d' % self.pin AIYPinSpec = namedtuple('AIYPinSpec', ['gpio_spec', 'pwm_spec']) @@ -71,571 +71,571 @@ def __str__(self): class SysFsPin(object): - """Generic SysFsPin which implements generic SysFs driver functionality.""" + """Generic SysFsPin which implements generic SysFs driver functionality.""" - def __init__(self, spec, fs_root): - self._pin = spec.pin - self._name = spec.name - self._fs_root = fs_root - # Ensure things start out unexported. - try: - self.unexport() - except IOError: - pass + def __init__(self, spec, fs_root): + self._pin = spec.pin + self._name = spec.name + self._fs_root = fs_root + # Ensure things start out unexported. + try: + self.unexport() + except IOError: + pass - def set_function(self, function): - raise NotImplementedError('Setting function not supported') + def set_function(self, function): + raise NotImplementedError('Setting function not supported') - def get_function(self): - raise NotImplementedError('Getting function not supported') + def get_function(self): + raise NotImplementedError('Getting function not supported') - def export(self): - try: - with open(self.root_path('export'), 'w') as export: - export.write('%d' % self._pin) - except IOError: - raise GPIOPinInUse('Pin already in use') + def export(self): + try: + with open(self.root_path('export'), 'w') as export: + export.write('%d' % self._pin) + except IOError: + raise GPIOPinInUse('Pin already in use') - def unexport(self): - with open(self.root_path('unexport'), 'w') as unexport: - unexport.write('%d' % self._pin) + def unexport(self): + with open(self.root_path('unexport'), 'w') as unexport: + unexport.write('%d' % self._pin) - def open(self): - self.export() + def open(self): + self.export() - def close(self): - self.unexport() + def close(self): + self.unexport() - def wait_for_permissions(self, prop): - """Wait for write permissions on the given property. + def wait_for_permissions(self, prop): + """Wait for write permissions on the given property. - We must wait because the the file system needs to grant permissions for the - newly created node.""" - while True: - try: - with open(self.property_path(prop), 'w'): - pass - return - except IOError: - time.sleep(.01) + We must wait because the the file system needs to grant permissions for the + newly created node.""" + while True: + try: + with open(self.property_path(prop), 'w'): + pass + return + except IOError: + time.sleep(.01) - def get_value(self): - raise NotImplementedError('Value getting not implemented') + def get_value(self): + raise NotImplementedError('Value getting not implemented') - def set_value(self, value): - raise NotImplementedError('Value setting not implemented') + def set_value(self, value): + raise NotImplementedError('Value setting not implemented') - def write_property(self, prop, value): - """Writes the given sysfs node property to the pin.""" - with open(self.property_path(prop), 'w') as node: - node.write(value) + def write_property(self, prop, value): + """Writes the given sysfs node property to the pin.""" + with open(self.property_path(prop), 'w') as node: + node.write(value) - def read_property(self, prop): - """Reads the given sysfs node property from the pin.""" - with open(self.property_path(prop), 'r') as node: - return node.read() + def read_property(self, prop): + """Reads the given sysfs node property from the pin.""" + with open(self.property_path(prop), 'r') as node: + return node.read() - def root_path(self, node): - return '%s/%s' % (self._fs_root, node) + def root_path(self, node): + return '%s/%s' % (self._fs_root, node) - def property_path(self, prop): - return '%s/%s/%s' % (self._fs_root, self._name, prop) + def property_path(self, prop): + return '%s/%s/%s' % (self._fs_root, self._name, prop) class SysFsGpioPin(SysFsPin): - """SysFs support for GPIO pins. - - Supports the SysFs node for GPIO control. - """ - _FS_ROOT = '/sys/class/gpio' - - def __init__(self, spec): - super(SysFsGpioPin, self).__init__(spec, self._FS_ROOT) - if not isinstance(spec, GpioSpec): - raise TypeError('Pin specification not compatible with SysFS GPIO') - self._spec = spec - self._out = False - self._value = None - - def _get_direction(self): - return self.read_property('direction') - - def _set_direction(self, direction): - if direction not in ('in', 'out'): - raise ValueError('Direction must be either in or out') - self.write_property('direction', direction) - - def _get_value(self): - return self.read_property('value') - - def _set_value(self, value): - self.write_property('value', value) - - def _get_active_low(self): - return self.read_property('active_low') - - def _set_active_low(self, active_low): - self.write_property('active_low', '1' if active_low else '0') - - def set_function(self, function): - if function == 'input': - self._set_direction('in') - self._out = False - elif function == 'output': - self._set_direction('out') - self._out = True - else: - raise ValueError('pin function must be either input or output') - - def get_function(self): - direction = self._get_direction() - if direction == 'input': - return 'in' - if direction == 'output': - return 'out' - - def set_value(self, value): - if not self._out: - raise PinSetInput('Pin is not open for output') - self._set_value('1' if value else '0') - self._value = value - - def get_value(self): - if self._out: - return self._value - return bool(int(self._get_value())) - - def open(self): - super(SysFsGpioPin, self).open() - self.wait_for_permissions('active_low') - self.wait_for_permissions('direction') - # GPIO pins on the hat seem to be inverted by default. - self._set_active_low(True) - - def close(self): - # Restore the default direction (turns off LED) before closing. - self._set_direction('in') - super(SysFsGpioPin, self).close() - + """SysFs support for GPIO pins. -class SysFsPwmPin(SysFsPin): - """SysFs support for PWM pins. + Supports the SysFs node for GPIO control. + """ + _FS_ROOT = '/sys/class/gpio' + + def __init__(self, spec): + super(SysFsGpioPin, self).__init__(spec, self._FS_ROOT) + if not isinstance(spec, GpioSpec): + raise TypeError('Pin specification not compatible with SysFS GPIO') + self._spec = spec + self._out = False + self._value = None + + def _get_direction(self): + return self.read_property('direction') + + def _set_direction(self, direction): + if direction not in ('in', 'out'): + raise ValueError('Direction must be either in or out') + self.write_property('direction', direction) + + def _get_value(self): + return self.read_property('value') + + def _set_value(self, value): + self.write_property('value', value) + + def _get_active_low(self): + return self.read_property('active_low') + + def _set_active_low(self, active_low): + self.write_property('active_low', '1' if active_low else '0') + + def set_function(self, function): + if function == 'input': + self._set_direction('in') + self._out = False + elif function == 'output': + self._set_direction('out') + self._out = True + else: + raise ValueError('pin function must be either input or output') + + def get_function(self): + direction = self._get_direction() + if direction == 'input': + return 'in' + if direction == 'output': + return 'out' + + def set_value(self, value): + if not self._out: + raise PinSetInput('Pin is not open for output') + self._set_value('1' if value else '0') + self._value = value + + def get_value(self): + if self._out: + return self._value + return bool(int(self._get_value())) + + def open(self): + super(SysFsGpioPin, self).open() + self.wait_for_permissions('active_low') + self.wait_for_permissions('direction') + # GPIO pins on the hat seem to be inverted by default. + self._set_active_low(True) + + def close(self): + # Restore the default direction (turns off LED) before closing. + self._set_direction('in') + super(SysFsGpioPin, self).close() - Supports the SysFs node for pwm control. - """ - _FS_ROOT = '/sys/class/pwm/pwmchip0' - class PwmState(object): - """Container for the state of the pwm. +class SysFsPwmPin(SysFsPin): + """SysFs support for PWM pins. - Used to recover after disable/enable and ensure consistency. + Supports the SysFs node for pwm control. """ - - def __init__(self): - self.duty_cycle = 0 - self.period_ns = _NS_PER_SECOND / 50 - self.enabled = False - self.function = None - - def __init__(self, spec): - super(SysFsPwmPin, self).__init__(spec, self._FS_ROOT) - if not isinstance(spec, PwmSpec): - raise TypeError('Pin specification not compatible with SysFS PWM') - if spec.pin < 0 or spec.pin > 3: - raise ValueError('Pin must be between 0 and 3 (inclusive)') - self._spec = spec - self._state = SysFsPwmPin.PwmState() - - def _set_enabled(self, enabled): - self.write_property('enable', '1' if enabled else '0') - self._state.enabled = enabled - - def _get_enabled(self): - return int(self.read_property('enable')) != 0 - - def _set_period_ns(self, period_ns): - self.write_property('period', '%d' % period_ns) - self._state.period_ns = int(period_ns) - - def _get_period_ns(self): - return int(self.read_property('period')) - - def _set_duty_cycle(self, duty_cycle): - self.write_property('duty_cycle', '%d' % duty_cycle) - self._state.duty_cycle = duty_cycle - - def _get_duty_cycle(self): - return int(self.read_property('duty_cycle')) - - def _update_state(self, new_state): - # Each time we enable, we need to first re-set the period and duty cycle (in - # that order). - if new_state.period_ns != self._state.period_ns or (not self._state.enabled - and new_state.enabled): - self._set_period_ns(new_state.period_ns) - if new_state.duty_cycle != self._state.duty_cycle or ( - not self._state.enabled and new_state.enabled): - self._set_duty_cycle(new_state.duty_cycle) - if new_state.enabled != self._state.enabled: - self._set_enabled(new_state.enabled) - - def _read_state(self): - self._state.period_ns = self._get_period_ns() - self._state.enabled = self._get_enabled() - self._state.duty_cycle = self._get_duty_cycle() - - def set_function(self, function): - if function != 'pwm' and function != 'output': - raise ValueError('PWM pins only support pwm and output functionality') - self._state.function = function - - def get_function(self): - return self._state.function - - def get_value(self): - return self._state.duty_cycle / self._state.period_ns - - def set_value(self, value): - new_state = deepcopy(self._state) - if value is None: - new_state.enabled = False - else: - new_state.enabled = True - new_state.duty_cycle = value * self._state.period_ns - self._update_state(new_state) - - def set_period_ns(self, period_ns): - new_state = deepcopy(self._state) - new_state.period_ns = period_ns - self._update_state(new_state) - - def get_period_ns(self): - return self._state.period_ns - - def open(self): - super(SysFsPwmPin, self).open() - self.wait_for_permissions('period') - self.wait_for_permissions('enable') - self._read_state() - new_state = deepcopy(self._state) - new_state.period_ns = _NS_PER_SECOND / 50 - new_state.enabled = True - self._update_state(new_state) - - def close(self): - self._set_enabled(False) - super(SysFsPwmPin, self).close() + _FS_ROOT = '/sys/class/pwm/pwmchip0' + + class PwmState(object): + """Container for the state of the pwm. + + Used to recover after disable/enable and ensure consistency. + """ + + def __init__(self): + self.duty_cycle = 0 + self.period_ns = _NS_PER_SECOND / 50 + self.enabled = False + self.function = None + + def __init__(self, spec): + super(SysFsPwmPin, self).__init__(spec, self._FS_ROOT) + if not isinstance(spec, PwmSpec): + raise TypeError('Pin specification not compatible with SysFS PWM') + if spec.pin < 0 or spec.pin > 3: + raise ValueError('Pin must be between 0 and 3 (inclusive)') + self._spec = spec + self._state = SysFsPwmPin.PwmState() + + def _set_enabled(self, enabled): + self.write_property('enable', '1' if enabled else '0') + self._state.enabled = enabled + + def _get_enabled(self): + return int(self.read_property('enable')) != 0 + + def _set_period_ns(self, period_ns): + self.write_property('period', '%d' % period_ns) + self._state.period_ns = int(period_ns) + + def _get_period_ns(self): + return int(self.read_property('period')) + + def _set_duty_cycle(self, duty_cycle): + self.write_property('duty_cycle', '%d' % duty_cycle) + self._state.duty_cycle = duty_cycle + + def _get_duty_cycle(self): + return int(self.read_property('duty_cycle')) + + def _update_state(self, new_state): + # Each time we enable, we need to first re-set the period and duty cycle (in + # that order). + if new_state.period_ns != self._state.period_ns or (not self._state.enabled + and new_state.enabled): + self._set_period_ns(new_state.period_ns) + if new_state.duty_cycle != self._state.duty_cycle or ( + not self._state.enabled and new_state.enabled): + self._set_duty_cycle(new_state.duty_cycle) + if new_state.enabled != self._state.enabled: + self._set_enabled(new_state.enabled) + + def _read_state(self): + self._state.period_ns = self._get_period_ns() + self._state.enabled = self._get_enabled() + self._state.duty_cycle = self._get_duty_cycle() + + def set_function(self, function): + if function != 'pwm' and function != 'output': + raise ValueError('PWM pins only support pwm and output functionality') + self._state.function = function + + def get_function(self): + return self._state.function + + def get_value(self): + return self._state.duty_cycle / self._state.period_ns + + def set_value(self, value): + new_state = deepcopy(self._state) + if value is None: + new_state.enabled = False + else: + new_state.enabled = True + new_state.duty_cycle = value * self._state.period_ns + self._update_state(new_state) + + def set_period_ns(self, period_ns): + new_state = deepcopy(self._state) + new_state.period_ns = period_ns + self._update_state(new_state) + + def get_period_ns(self): + return self._state.period_ns + + def open(self): + super(SysFsPwmPin, self).open() + self.wait_for_permissions('period') + self.wait_for_permissions('enable') + self._read_state() + new_state = deepcopy(self._state) + new_state.period_ns = _NS_PER_SECOND / 50 + new_state.enabled = True + self._update_state(new_state) + + def close(self): + self._set_enabled(False) + super(SysFsPwmPin, self).close() # Debounce by making sure the last change wasn't less than d_time in the past -> # should be agnostic to direction. class DebouncingPoller(object): - """Manages debouncing and polling a function periodically in the background. - - Calls a given getter periodically and when the debounced value changes such - that detector(old, new) returns true, the callback is called. Only runs while - detector, getter, and callback are set. - """ - _MIN_POLL_INTERVAL = .0001 - - def __init__(self, value_getter, callback, detector=lambda old, new: True): - self._poll_thread = None - self._debounce_time = .001 - self._poll_interval = .00051 - self._getter = value_getter - self._detector = detector - self._callback = callback - - @property - def poll_interval(self): - return self._poll_interval - - @poll_interval.setter - def poll_interval(self, interval): - self._poll_interval = max(interval, self._MIN_POLL_INTERVAL) - self.restart_polling() - - @property - def debounce_time(self): - return self._debounce_time - - @debounce_time.setter - def debounce_time(self, debounce_time): - self._debounce_time = debounce_time - self.restart_polling() - - @property - def callback(self): - return self._callback - - @callback.setter - def callback(self, callback): - self.stop_polling() - self._callback = callback - self.try_start_polling() - - @property - def detector(self): - return self._detector - - @detector.setter - def detector(self, detector): - self._detector = detector - self.restart_polling() - - def try_start_polling(self): - if (not self._poll_thread and self._getter and self._callback and - self._detector): - self._poll_thread = GPIOThread( - target=self._poll, - args=(self._poll_interval, self._debounce_time, self._getter, - self._detector, self._callback)) - self._poll_thread.start() - - def stop_polling(self): - if self._poll_thread: - self._poll_thread.stop() - self._poll_thread = None - - def restart_polling(self): - self.stop_polling() - self.try_start_polling() - - # Only called from the polling thread. - def _poll(self, poll_interval, debounce_interval, getter, detector, callback): - """Debounces and monitors the value retrieved by _getter. - - Triggers callback if detector(old_value, new_value) returns true. - Args: - poll_interval: positive float, time in seconds between polling the getter. - debounce_interval: positive float, time in seconds to wait after a change - to allow a future change to the value to trigger the callback. - getter: function() -> value, gets the value. This will be called - periodically and the value type will be the same type passed to the - detector function. - detector: function(old, new) -> bool, filters changes to determine when - the callback should be called. Can be used for edge detection - callback: function() to be invoked when detector conditions are met. + """Manages debouncing and polling a function periodically in the background. + + Calls a given getter periodically and when the debounced value changes such + that detector(old, new) returns true, the callback is called. Only runs while + detector, getter, and callback are set. """ - last_time = time.time() - last_value = getter() - while not self._poll_thread.stopping.wait(poll_interval): - value = getter() - new_time = time.time() - if not debounce_interval or (new_time - last_time) > debounce_interval: - if detector(last_value, value): - callback() - last_value = value - last_time = new_time + _MIN_POLL_INTERVAL = .0001 + + def __init__(self, value_getter, callback, detector=lambda old, new: True): + self._poll_thread = None + self._debounce_time = .001 + self._poll_interval = .00051 + self._getter = value_getter + self._detector = detector + self._callback = callback + + @property + def poll_interval(self): + return self._poll_interval + + @poll_interval.setter + def poll_interval(self, interval): + self._poll_interval = max(interval, self._MIN_POLL_INTERVAL) + self.restart_polling() + + @property + def debounce_time(self): + return self._debounce_time + + @debounce_time.setter + def debounce_time(self, debounce_time): + self._debounce_time = debounce_time + self.restart_polling() + + @property + def callback(self): + return self._callback + + @callback.setter + def callback(self, callback): + self.stop_polling() + self._callback = callback + self.try_start_polling() + + @property + def detector(self): + return self._detector + + @detector.setter + def detector(self, detector): + self._detector = detector + self.restart_polling() + + def try_start_polling(self): + if (not self._poll_thread and self._getter and self._callback and + self._detector): + self._poll_thread = GPIOThread( + target=self._poll, + args=(self._poll_interval, self._debounce_time, self._getter, + self._detector, self._callback)) + self._poll_thread.start() + + def stop_polling(self): + if self._poll_thread: + self._poll_thread.stop() + self._poll_thread = None + + def restart_polling(self): + self.stop_polling() + self.try_start_polling() + + # Only called from the polling thread. + def _poll(self, poll_interval, debounce_interval, getter, detector, callback): + """Debounces and monitors the value retrieved by _getter. + + Triggers callback if detector(old_value, new_value) returns true. + Args: + poll_interval: positive float, time in seconds between polling the getter. + debounce_interval: positive float, time in seconds to wait after a change + to allow a future change to the value to trigger the callback. + getter: function() -> value, gets the value. This will be called + periodically and the value type will be the same type passed to the + detector function. + detector: function(old, new) -> bool, filters changes to determine when + the callback should be called. Can be used for edge detection + callback: function() to be invoked when detector conditions are met. + """ + last_time = time.time() + last_value = getter() + while not self._poll_thread.stopping.wait(poll_interval): + value = getter() + new_time = time.time() + if not debounce_interval or (new_time - last_time) > debounce_interval: + if detector(last_value, value): + callback() + last_value = value + last_time = new_time class HatPin(Pin): - """A Pin implemenation that supports pins controlled by the hat's MCU. - - Only one HatPin should exist at a given time for a given pin system wide. - Behavior is completely unpredictable if more than one pin exists concurrently. - If the factory is used for construction there are protections in place to - prevent this, however if multiple programs are running simultaneously the - protections do not limit cross program duplication. - """ - _EDGE_DETECTORS = { - 'both': lambda old, new: old != new, - 'rising': lambda old, new: not old and new, - 'falling': lambda old, new: old and not new, - None: None, - } - - def __init__(self, spec, pwm=False): - super(HatPin, self).__init__() - self.gpio_pin = None - self.pwm_pin = None - self.pwm_active = False - self.gpio_active = False - if spec.gpio_spec is not None: - self.gpio_pin = SysFsGpioPin(spec.gpio_spec) - - if spec.pwm_spec is not None: - self.pwm_pin = SysFsPwmPin(spec.pwm_spec) - - self._closed = False - self._poller = DebouncingPoller(self._get_state, None) - self._edges = None - self._set_bounce(.001) - # Start out with gpio enabled for compatibility. - self._enable_gpio() - - def _enable_pwm(self): - if self._closed: - return - if self.pwm_pin is None: - raise PinPWMUnsupported( - 'PWM was enabled, but is not supported on pin %r' % self.pwm_pin) - self._disable_gpio() - if not self.pwm_active: - self.pwm_pin.open() - self.pwm_active = True - - def _disable_pwm(self): - if self.pwm_active and self.pwm_pin is not None: - self.pwm_pin.close() - self.pwm_active = False - - def _enable_gpio(self): - if self._closed: - return - if self.gpio_pin is None: - raise PinUnsupported( - 'GPIO was enabled, but is not supported on pin %r' % self.gpio_pin) - self._disable_pwm() - if not self.gpio_active: - self.gpio_pin.open() - self.gpio_active = True - - def _disable_gpio(self): - if self.gpio_active and self.gpio_pin is not None: - self.gpio_pin.close() - self.gpio_active = False - - def close(self): - self._closed = True - self._poller.stop_polling() - self._disable_pwm() - self._disable_gpio() - - def _active_pin(self): - if self.pwm_active: - return self.pwm_pin - if self.gpio_active: - return self.gpio_pin - return None + """A Pin implemenation that supports pins controlled by the hat's MCU. - def _get_function(self): - return self._active_pin().get_function() - - def _set_function(self, value): - if value == 'input': - if self.pwm_active: - raise InputDeviceError('PWM Pin cannot be set to input') - self._enable_gpio() - elif value == 'pwm': - if self.gpio_active: - raise PinPWMUnsupported('GPIO Pin cannot be set to pwm') - self._enable_pwm() - elif self._active_pin() is None: - self._enable_gpio() - - if value != 'input': - self._poller.stop_polling() - self._active_pin().set_function(value) - - def _get_state(self): - return self._active_pin().get_value() - - def _set_state(self, state): - self._active_pin().set_value(state) - - def _get_frequency(self): - if self.pwm_pin is None or not self.pwm_active: - return None - return _NS_PER_SECOND / self.pwm_pin.get_period_ns() - - def _set_frequency(self, frequency): - if frequency is None: - self._enable_gpio() - else: - self._enable_pwm() - self.pwm_pin.set_period_ns(_NS_PER_SECOND / frequency) - - def _set_pull(self, pull): - if pull != 'up': - raise PinFixedPull('Only pull up is supported right now (%s)' % pull) - - def _get_pull(self): - return 'up' - - def _set_edges(self, edges): - if edges not in HatPin._EDGE_DETECTORS.keys(): - raise PinInvalidEdges('Edge must be "both", "falling", "rising", or None') - self._poller.detector = HatPin._EDGE_DETECTORS[edges] - self._edges = edges - - def _get_edges(self): - return self._edges - - def _set_when_changed(self, callback): - self._poller.callback = callback - - def _get_when_changed(self): - return self._poller.callback - - def set_poll_interval(self, poll_interval): - """Sets the time between polling the pin value. - - If a debounce time is set, this will be set to .51 * the debounce time. - There is a natural minimum value of _MIN_POLL_INTERVAL to which all smaller - values will be clipped. - Args: - poll_interval: positve float, time in seconds between polling the pin. + Only one HatPin should exist at a given time for a given pin system wide. + Behavior is completely unpredictable if more than one pin exists concurrently. + If the factory is used for construction there are protections in place to + prevent this, however if multiple programs are running simultaneously the + protections do not limit cross program duplication. """ - self._poller.poll_interval = poll_interval - - def _set_bounce(self, debounce_time): - if debounce_time is None: - self._poller.debounce_time = debounce_time - elif debounce_time < 0: - raise PinInvalidBounce('Bounce must be positive.') - else: - self._poller.debounce_time = debounce_time - self.set_poll_interval(debounce_time * .51) - - def _get_bounce(self): - return self._poller.debounce_time + _EDGE_DETECTORS = { + 'both': lambda old, new: old != new, + 'rising': lambda old, new: not old and new, + 'falling': lambda old, new: old and not new, + None: None, + } + + def __init__(self, spec, pwm=False): + super(HatPin, self).__init__() + self.gpio_pin = None + self.pwm_pin = None + self.pwm_active = False + self.gpio_active = False + if spec.gpio_spec is not None: + self.gpio_pin = SysFsGpioPin(spec.gpio_spec) + + if spec.pwm_spec is not None: + self.pwm_pin = SysFsPwmPin(spec.pwm_spec) + + self._closed = False + self._poller = DebouncingPoller(self._get_state, None) + self._edges = None + self._set_bounce(.001) + # Start out with gpio enabled for compatibility. + self._enable_gpio() + + def _enable_pwm(self): + if self._closed: + return + if self.pwm_pin is None: + raise PinPWMUnsupported( + 'PWM was enabled, but is not supported on pin %r' % self.pwm_pin) + self._disable_gpio() + if not self.pwm_active: + self.pwm_pin.open() + self.pwm_active = True + + def _disable_pwm(self): + if self.pwm_active and self.pwm_pin is not None: + self.pwm_pin.close() + self.pwm_active = False + + def _enable_gpio(self): + if self._closed: + return + if self.gpio_pin is None: + raise PinUnsupported( + 'GPIO was enabled, but is not supported on pin %r' % self.gpio_pin) + self._disable_pwm() + if not self.gpio_active: + self.gpio_pin.open() + self.gpio_active = True + + def _disable_gpio(self): + if self.gpio_active and self.gpio_pin is not None: + self.gpio_pin.close() + self.gpio_active = False + + def close(self): + self._closed = True + self._poller.stop_polling() + self._disable_pwm() + self._disable_gpio() + + def _active_pin(self): + if self.pwm_active: + return self.pwm_pin + if self.gpio_active: + return self.gpio_pin + return None + + def _get_function(self): + return self._active_pin().get_function() + + def _set_function(self, value): + if value == 'input': + if self.pwm_active: + raise InputDeviceError('PWM Pin cannot be set to input') + self._enable_gpio() + elif value == 'pwm': + if self.gpio_active: + raise PinPWMUnsupported('GPIO Pin cannot be set to pwm') + self._enable_pwm() + elif self._active_pin() is None: + self._enable_gpio() + + if value != 'input': + self._poller.stop_polling() + self._active_pin().set_function(value) + + def _get_state(self): + return self._active_pin().get_value() + + def _set_state(self, state): + self._active_pin().set_value(state) + + def _get_frequency(self): + if self.pwm_pin is None or not self.pwm_active: + return None + return _NS_PER_SECOND / self.pwm_pin.get_period_ns() + + def _set_frequency(self, frequency): + if frequency is None: + self._enable_gpio() + else: + self._enable_pwm() + self.pwm_pin.set_period_ns(_NS_PER_SECOND / frequency) + + def _set_pull(self, pull): + if pull != 'up': + raise PinFixedPull('Only pull up is supported right now (%s)' % pull) + + def _get_pull(self): + return 'up' + + def _set_edges(self, edges): + if edges not in HatPin._EDGE_DETECTORS.keys(): + raise PinInvalidEdges('Edge must be "both", "falling", "rising", or None') + self._poller.detector = HatPin._EDGE_DETECTORS[edges] + self._edges = edges + + def _get_edges(self): + return self._edges + + def _set_when_changed(self, callback): + self._poller.callback = callback + + def _get_when_changed(self): + return self._poller.callback + + def set_poll_interval(self, poll_interval): + """Sets the time between polling the pin value. + + If a debounce time is set, this will be set to .51 * the debounce time. + There is a natural minimum value of _MIN_POLL_INTERVAL to which all smaller + values will be clipped. + Args: + poll_interval: positve float, time in seconds between polling the pin. + """ + self._poller.poll_interval = poll_interval + + def _set_bounce(self, debounce_time): + if debounce_time is None: + self._poller.debounce_time = debounce_time + elif debounce_time < 0: + raise PinInvalidBounce('Bounce must be positive.') + else: + self._poller.debounce_time = debounce_time + self.set_poll_interval(debounce_time * .51) + + def _get_bounce(self): + return self._poller.debounce_time class HybridFactory(Factory): - """Factory for selecting between other factories based on priority/success.""" + """Factory for selecting between other factories based on priority/success.""" - def __init__(self, *factories): - super(HybridFactory, self).__init__() - self.factories = factories + def __init__(self, *factories): + super(HybridFactory, self).__init__() + self.factories = factories - def close(self): - for factory in self.factories: - factory.close() + def close(self): + for factory in self.factories: + factory.close() - def pin(self, spec): - for factory in self.factories: - try: - # Try to make the pin from each factory (in order), until one works. - return factory.pin(spec) - except (TypeError, ValueError): - pass - raise TypeError( - 'No registered factory was able to construct a pin for the given ' - 'specification') + def pin(self, spec): + for factory in self.factories: + try: + # Try to make the pin from each factory (in order), until one works. + return factory.pin(spec) + except (TypeError, ValueError): + pass + raise TypeError( + 'No registered factory was able to construct a pin for the given ' + 'specification') class HatFactory(Factory): - """Factory for pins accessed through the hat's MCU.""" - pins = {} - - def __init__(self): - super(HatFactory, self).__init__() - - self.pins = HatFactory.pins - - def close(self): - for pin in self.pins.values(): - pin.close() - - def pin(self, spec): - if spec in self.pins: - return self.pins.get(spec) - if isinstance(spec, AIYPinSpec): - pin = HatPin(spec) - self.pins[spec] = pin - return pin - raise TypeError('Hat factory invoked on non-hat pin') + """Factory for pins accessed through the hat's MCU.""" + pins = {} + + def __init__(self): + super(HatFactory, self).__init__() + + self.pins = HatFactory.pins + + def close(self): + for pin in self.pins.values(): + pin.close() + + def pin(self, spec): + if spec in self.pins: + return self.pins.get(spec) + if isinstance(spec, AIYPinSpec): + pin = HatPin(spec) + self.pins[spec] = pin + return pin + raise TypeError('Hat factory invoked on non-hat pin') # This overrides the default factory being used by all gpiozero devices. It will diff --git a/src/examples/vision/annotator.py b/src/examples/vision/annotator.py index a4f76fb2..db53581f 100644 --- a/src/examples/vision/annotator.py +++ b/src/examples/vision/annotator.py @@ -31,150 +31,150 @@ def _round_to_bit(value, power): - """Rounds the given value to the next multiple of 2^power. + """Rounds the given value to the next multiple of 2^power. - Args: - value: int to be rounded. - power: power of two which the value should be rounded up to. - Returns: - the result of value rounded to the next multiple 2^power. - """ - return (((value - 1) >> power) + 1) << power + Args: + value: int to be rounded. + power: power of two which the value should be rounded up to. + Returns: + the result of value rounded to the next multiple 2^power. + """ + return (((value - 1) >> power) + 1) << power def _round_buffer_dims(dims): - """Appropriately rounds the given dimensions for image overlaying. + """Appropriately rounds the given dimensions for image overlaying. - The overlay buffer must be rounded the next multiple of 32 for the hight, and - the next multiple of 16 for the width.""" - return (_round_to_bit(dims[0], 5), _round_to_bit(dims[1], 4)) + The overlay buffer must be rounded the next multiple of 32 for the hight, and + the next multiple of 16 for the width.""" + return (_round_to_bit(dims[0], 5), _round_to_bit(dims[1], 4)) # TODO(namiller): Add an annotator for images. class Annotator(object): - """Utility for managing annotations on the camera preview. - - Args: - camera: picamera.PiCamera camera object to overlay on top of. - bg_color: PIL.ImageColor (with alpha) for the background of the overlays. - default_color: PIL.ImageColor (with alpha) default for the drawn content. - """ - - def __init__(self, camera, bg_color=None, default_color=None, - dimensions=None): - self._dims = dimensions if dimensions else camera.resolution - self._buffer_dims = _round_buffer_dims(self._dims) - self._buffer = Image.new('RGBA', self._buffer_dims) - self._overlay = camera.add_overlay( - self._buffer.tobytes(), format='rgba', layer=3, size=self._buffer_dims) - self._draw = ImageDraw.Draw(self._buffer) - self._bg_color = bg_color if bg_color else (0, 0, 0, 0xA0) - self._default_color = default_color if default_color else (0xFF, 0, 0, 0xFF) - - # MMALPort has a bug in enable.wrapper, where it always calls - # self._pool.send_buffer(block=False) regardless of the port direction. - # This is in contrast to setup time when it only calls - # self._pool.send_all_buffers(block=False) - # if self._port[0].type == mmal.MMAL_PORT_TYPE_OUTPUT. - # Because of this bug updating an overlay once will log a MMAL_EAGAIN - # error every update. This is safe to ignore as we the user is driving - # the renderer input port with calls to update() that dequeue buffers - # and sends them to the input port (so queue is empty on when - # send_all_buffers(block=False) is called from wrapper). - # As a workaround, monkey patch MMALPortPool.send_buffer and - # silence the "error" if thrown by our overlay instance. - original_send_buffer = picamera.mmalobj.MMALPortPool.send_buffer - - def silent_send_buffer(zelf, **kwargs): - try: - original_send_buffer(zelf, **kwargs) - except picamera.exc.PiCameraMMALError as error: - # Only silence MMAL_EAGAIN for our target instance. - our_target = self._overlay.renderer.inputs[0].pool == zelf - if not our_target or error.status != 14: - raise error - - picamera.mmalobj.MMALPortPool.send_buffer = silent_send_buffer - - def update(self): - """Updates the contents of the overlay.""" - self._overlay.update(self._buffer.tobytes()) - - def stop(self): - """Removes the overlay from the screen.""" - self._draw.rectangle((0, 0) + self._dims, fill=0) - self.update() - - def clear(self): - """Clears the contents of the overlay - leaving only the plain background. - """ - self._draw.rectangle((0, 0) + self._dims, fill=self._bg_color) - - def bounding_box(self, rect, outline=None, fill=None): - """Draws a bounding box around the specified rectangle. + """Utility for managing annotations on the camera preview. Args: - rect: (x1, y1, x2, y2) rectangle to be drawn - where (x1,y1) and (x2, y2) - are opposite corners of the desired rectangle. - outline: PIL.ImageColor with which to draw the outline (defaults to the - configured default_color). - fill: PIL.ImageColor with which to fill the rectangel (defaults to None - which will not cover up drawings under the region. + camera: picamera.PiCamera camera object to overlay on top of. + bg_color: PIL.ImageColor (with alpha) for the background of the overlays. + default_color: PIL.ImageColor (with alpha) default for the drawn content. """ - outline = self._default_color if outline is None else outline - self._draw.rectangle(rect, fill=fill, outline=outline) - - #TODO(namiller): Add a font size parameter and load a truetype font. - def text(self, location, text, color=None): - """Draws the given text at the given location. - Args: - location: (x,y) point at which to draw the text (upper left corner). - text: string to be drawn. - color: PIL.ImageColor to draw the string in (defaults to default_color). - """ - color = self._default_color if color is None else color - self._draw.text(location, text, fill=color) - - def point(self, location, radius=1, color=None): - """Draws a point of the given size at the given location. - - Args: - location: (x,y) center of the point to be drawn. - radius: the radius of the point to be drawn. - color: The color to draw the point in (defaults to default_color). - """ - color = self._default_color if color is None else color - self._draw.ellipse( - (location[0] - radius, location[1] - radius, location[0] + radius, - location[1] + radius), - fill=color) + def __init__(self, camera, bg_color=None, default_color=None, + dimensions=None): + self._dims = dimensions if dimensions else camera.resolution + self._buffer_dims = _round_buffer_dims(self._dims) + self._buffer = Image.new('RGBA', self._buffer_dims) + self._overlay = camera.add_overlay( + self._buffer.tobytes(), format='rgba', layer=3, size=self._buffer_dims) + self._draw = ImageDraw.Draw(self._buffer) + self._bg_color = bg_color if bg_color else (0, 0, 0, 0xA0) + self._default_color = default_color if default_color else (0xFF, 0, 0, 0xFF) + + # MMALPort has a bug in enable.wrapper, where it always calls + # self._pool.send_buffer(block=False) regardless of the port direction. + # This is in contrast to setup time when it only calls + # self._pool.send_all_buffers(block=False) + # if self._port[0].type == mmal.MMAL_PORT_TYPE_OUTPUT. + # Because of this bug updating an overlay once will log a MMAL_EAGAIN + # error every update. This is safe to ignore as we the user is driving + # the renderer input port with calls to update() that dequeue buffers + # and sends them to the input port (so queue is empty on when + # send_all_buffers(block=False) is called from wrapper). + # As a workaround, monkey patch MMALPortPool.send_buffer and + # silence the "error" if thrown by our overlay instance. + original_send_buffer = picamera.mmalobj.MMALPortPool.send_buffer + + def silent_send_buffer(zelf, **kwargs): + try: + original_send_buffer(zelf, **kwargs) + except picamera.exc.PiCameraMMALError as error: + # Only silence MMAL_EAGAIN for our target instance. + our_target = self._overlay.renderer.inputs[0].pool == zelf + if not our_target or error.status != 14: + raise error + + picamera.mmalobj.MMALPortPool.send_buffer = silent_send_buffer + + def update(self): + """Updates the contents of the overlay.""" + self._overlay.update(self._buffer.tobytes()) + + def stop(self): + """Removes the overlay from the screen.""" + self._draw.rectangle((0, 0) + self._dims, fill=0) + self.update() + + def clear(self): + """Clears the contents of the overlay - leaving only the plain background. + """ + self._draw.rectangle((0, 0) + self._dims, fill=self._bg_color) + + def bounding_box(self, rect, outline=None, fill=None): + """Draws a bounding box around the specified rectangle. + + Args: + rect: (x1, y1, x2, y2) rectangle to be drawn - where (x1,y1) and (x2, y2) + are opposite corners of the desired rectangle. + outline: PIL.ImageColor with which to draw the outline (defaults to the + configured default_color). + fill: PIL.ImageColor with which to fill the rectangel (defaults to None + which will not cover up drawings under the region. + """ + outline = self._default_color if outline is None else outline + self._draw.rectangle(rect, fill=fill, outline=outline) + + # TODO(namiller): Add a font size parameter and load a truetype font. + def text(self, location, text, color=None): + """Draws the given text at the given location. + + Args: + location: (x,y) point at which to draw the text (upper left corner). + text: string to be drawn. + color: PIL.ImageColor to draw the string in (defaults to default_color). + """ + color = self._default_color if color is None else color + self._draw.text(location, text, fill=color) + + def point(self, location, radius=1, color=None): + """Draws a point of the given size at the given location. + + Args: + location: (x,y) center of the point to be drawn. + radius: the radius of the point to be drawn. + color: The color to draw the point in (defaults to default_color). + """ + color = self._default_color if color is None else color + self._draw.ellipse( + (location[0] - radius, location[1] - radius, location[0] + radius, + location[1] + radius), + fill=color) def _main(): - """Example usage of the annotator utility. - - Demonstrates setting up a camera preview, drawing slowly moving/intersecting - animations over it, and clearing the overlays.""" - with picamera.PiCamera() as camera: - # Resolution can be arbitrary. - camera.resolution = (351, 561) - camera.start_preview() - annotator = Annotator(camera) - for i in range(10): - annotator.clear() - annotator.bounding_box( - (20, 20, 70, 70), outline=(0, 0xFF, 0, 0xFF), fill=0) - annotator.bounding_box((10 * i, 10, 10 * i + 50, 60)) - annotator.bounding_box( - (80, 0, 130, 50), outline=(0, 0, 0xFF, 0xFF), fill=0) - annotator.text((100, 100), 'Hello World') - annotator.point((10, 100), radius=5) - annotator.update() - time.sleep(1) - annotator.stop() - time.sleep(10) + """Example usage of the annotator utility. + + Demonstrates setting up a camera preview, drawing slowly moving/intersecting + animations over it, and clearing the overlays.""" + with picamera.PiCamera() as camera: + # Resolution can be arbitrary. + camera.resolution = (351, 561) + camera.start_preview() + annotator = Annotator(camera) + for i in range(10): + annotator.clear() + annotator.bounding_box( + (20, 20, 70, 70), outline=(0, 0xFF, 0, 0xFF), fill=0) + annotator.bounding_box((10 * i, 10, 10 * i + 50, 60)) + annotator.bounding_box( + (80, 0, 130, 50), outline=(0, 0, 0xFF, 0xFF), fill=0) + annotator.text((100, 100), 'Hello World') + annotator.point((10, 100), radius=5) + annotator.update() + time.sleep(1) + annotator.stop() + time.sleep(10) if __name__ == '__main__': - _main() + _main() diff --git a/src/examples/vision/buzzer/buzzer_demo.py b/src/examples/vision/buzzer/buzzer_demo.py index 151b1c30..0d9e9cf5 100755 --- a/src/examples/vision/buzzer/buzzer_demo.py +++ b/src/examples/vision/buzzer/buzzer_demo.py @@ -46,7 +46,7 @@ def main(): ] player = aiy.toneplayer.TonePlayer(22) - player.play(*tetris_theme); + player.play(*tetris_theme) if __name__ == '__main__': diff --git a/src/examples/vision/dish_classifier.py b/src/examples/vision/dish_classifier.py index 5c55d65d..8b077309 100644 --- a/src/examples/vision/dish_classifier.py +++ b/src/examples/vision/dish_classifier.py @@ -22,17 +22,17 @@ def main(): - parser = argparse.ArgumentParser() - parser.add_argument('--input', '-i', dest='input', required=True) - args = parser.parse_args() + parser = argparse.ArgumentParser() + parser.add_argument('--input', '-i', dest='input', required=True) + args = parser.parse_args() - with ImageInference(dish_classifier.model()) as inference: - image = Image.open(args.input) - classes = dish_classifier.get_classes( - inference.run(image), max_num_objects=5, object_prob_threshold=0.1) - for i, (label, score) in enumerate(classes): - print('Result %d: %s (prob=%f)' % (i, label, score)) + with ImageInference(dish_classifier.model()) as inference: + image = Image.open(args.input) + classes = dish_classifier.get_classes( + inference.run(image), max_num_objects=5, object_prob_threshold=0.1) + for i, (label, score) in enumerate(classes): + print('Result %d: %s (prob=%f)' % (i, label, score)) if __name__ == '__main__': - main() + main() diff --git a/src/examples/vision/face_camera_trigger.py b/src/examples/vision/face_camera_trigger.py index 29054f10..9c32ac37 100755 --- a/src/examples/vision/face_camera_trigger.py +++ b/src/examples/vision/face_camera_trigger.py @@ -20,22 +20,21 @@ def main(): - with PiCamera() as camera: - # Configure camera - camera.resolution = (1640, 922) # Full Frame, 16:9 (Camera v2) - camera.start_preview() + with PiCamera() as camera: + # Configure camera + camera.resolution = (1640, 922) # Full Frame, 16:9 (Camera v2) + camera.start_preview() - # Do inference on VisionBonnet - with CameraInference(face_detection.model()) as inference: - for result in inference.run(): - if len(face_detection.get_faces(result)) >= 1: - camera.capture('faces.jpg') - break + # Do inference on VisionBonnet + with CameraInference(face_detection.model()) as inference: + for result in inference.run(): + if len(face_detection.get_faces(result)) >= 1: + camera.capture('faces.jpg') + break - # Stop preview - camera.stop_preview() + # Stop preview + camera.stop_preview() if __name__ == '__main__': - main() - + main() diff --git a/src/examples/vision/face_detection.py b/src/examples/vision/face_detection.py index 705daebc..9c8eb234 100755 --- a/src/examples/vision/face_detection.py +++ b/src/examples/vision/face_detection.py @@ -29,23 +29,23 @@ def main(): - parser = argparse.ArgumentParser() - parser.add_argument('--input', '-i', dest='input', required=True) - parser.add_argument('--output', '-o', dest='output') - args = parser.parse_args() + parser = argparse.ArgumentParser() + parser.add_argument('--input', '-i', dest='input', required=True) + parser.add_argument('--output', '-o', dest='output') + args = parser.parse_args() - with ImageInference(face_detection.model()) as inference: - image = Image.open( - io.BytesIO(sys.stdin.buffer.read()) - if args.input == '-' else args.input) - draw = ImageDraw.Draw(image) - for i, face in enumerate(face_detection.get_faces(inference.run(image))): - print('Face #%d: %s' % (i, str(face))) - x, y, width, height = face.bounding_box - draw.rectangle((x, y, x + width, y + height), outline='red') - if args.output: - image.save(args.output) + with ImageInference(face_detection.model()) as inference: + image = Image.open( + io.BytesIO(sys.stdin.buffer.read()) + if args.input == '-' else args.input) + draw = ImageDraw.Draw(image) + for i, face in enumerate(face_detection.get_faces(inference.run(image))): + print('Face #%d: %s' % (i, str(face))) + x, y, width, height = face.bounding_box + draw.rectangle((x, y, x + width, y + height), outline='red') + if args.output: + image.save(args.output) if __name__ == '__main__': - main() + main() diff --git a/src/examples/vision/face_detection_camera.py b/src/examples/vision/face_detection_camera.py index b4200f05..d7ba3f26 100755 --- a/src/examples/vision/face_detection_camera.py +++ b/src/examples/vision/face_detection_camera.py @@ -29,60 +29,60 @@ def main(): - """Face detection camera inference example.""" - parser = argparse.ArgumentParser() - parser.add_argument( - '--num_frames', - '-n', - type=int, - dest='num_frames', - default=-1, - help='Sets the number of frames to run for, otherwise runs forever.') - args = parser.parse_args() + """Face detection camera inference example.""" + parser = argparse.ArgumentParser() + parser.add_argument( + '--num_frames', + '-n', + type=int, + dest='num_frames', + default=-1, + help='Sets the number of frames to run for, otherwise runs forever.') + args = parser.parse_args() - with PiCamera() as camera: - # Forced sensor mode, 1640x1232, full FoV. See: - # https://picamera.readthedocs.io/en/release-1.13/fov.html#sensor-modes - # This is the resolution inference run on. - camera.sensor_mode = 4 + with PiCamera() as camera: + # Forced sensor mode, 1640x1232, full FoV. See: + # https://picamera.readthedocs.io/en/release-1.13/fov.html#sensor-modes + # This is the resolution inference run on. + camera.sensor_mode = 4 - # Scaled and cropped resolution. If different from sensor mode implied - # resolution, inference results must be adjusted accordingly. This is - # true in particular when camera.start_recording is used to record an - # encoded h264 video stream as the Pi encoder can't encode all native - # sensor resolutions, or a standard one like 1080p may be desired. - camera.resolution = (1640, 1232) + # Scaled and cropped resolution. If different from sensor mode implied + # resolution, inference results must be adjusted accordingly. This is + # true in particular when camera.start_recording is used to record an + # encoded h264 video stream as the Pi encoder can't encode all native + # sensor resolutions, or a standard one like 1080p may be desired. + camera.resolution = (1640, 1232) - # Start the camera stream. - camera.framerate = 30 - camera.start_preview() + # Start the camera stream. + camera.framerate = 30 + camera.start_preview() - # Annotator renders in software so use a smaller size and scale results - # for increased performace. - annotator = Annotator(camera, dimensions=(320, 240)) - scale_x = 320 / 1640 - scale_y = 240 / 1232 + # Annotator renders in software so use a smaller size and scale results + # for increased performace. + annotator = Annotator(camera, dimensions=(320, 240)) + scale_x = 320 / 1640 + scale_y = 240 / 1232 - # Incoming boxes are of the form (x, y, width, height). Scale and - # transform to the form (x1, y1, x2, y2). - def transform(bounding_box): - x, y, width, height = bounding_box - return (scale_x * x, scale_y * y, scale_x * (x + width), - scale_y * (y + height)) + # Incoming boxes are of the form (x, y, width, height). Scale and + # transform to the form (x1, y1, x2, y2). + def transform(bounding_box): + x, y, width, height = bounding_box + return (scale_x * x, scale_y * y, scale_x * (x + width), + scale_y * (y + height)) - with CameraInference(face_detection.model()) as inference: - for i, result in enumerate(inference.run()): - if i == args.num_frames: - break - faces = face_detection.get_faces(result) - annotator.clear() - for face in faces: - annotator.bounding_box(transform(face.bounding_box), fill=0) - annotator.update() - print('Iteration #%d: num_faces=%d' % (i, len(faces))) + with CameraInference(face_detection.model()) as inference: + for i, result in enumerate(inference.run()): + if i == args.num_frames: + break + faces = face_detection.get_faces(result) + annotator.clear() + for face in faces: + annotator.bounding_box(transform(face.bounding_box), fill=0) + annotator.update() + print('Iteration #%d: num_faces=%d' % (i, len(faces))) - camera.stop_preview() + camera.stop_preview() if __name__ == '__main__': - main() + main() diff --git a/src/examples/vision/gpiozero/led_example.py b/src/examples/vision/gpiozero/led_example.py index 4d02cbf5..9a287afe 100644 --- a/src/examples/vision/gpiozero/led_example.py +++ b/src/examples/vision/gpiozero/led_example.py @@ -10,7 +10,7 @@ led = LED(LED_1) # Alternate turning the LED off and on until the user terminates the example. while True: - led.on() - sleep(1) - led.off() - sleep(1) + led.on() + sleep(1) + led.off() + sleep(1) diff --git a/src/examples/vision/gpiozero/servo_example.py b/src/examples/vision/gpiozero/servo_example.py index 024f0231..7195bcf9 100644 --- a/src/examples/vision/gpiozero/servo_example.py +++ b/src/examples/vision/gpiozero/servo_example.py @@ -16,12 +16,12 @@ # Move the Servos back and forth until the user terminates the example. while True: - simple_servo.min() - tuned_servo.max() - sleep(1) - simple_servo.mid() - tuned_servo.mid() - sleep(1) - simple_servo.max() - tuned_servo.min() - sleep(1) + simple_servo.min() + tuned_servo.max() + sleep(1) + simple_servo.mid() + tuned_servo.mid() + sleep(1) + simple_servo.max() + tuned_servo.min() + sleep(1) diff --git a/src/examples/vision/gpiozero/simple_button_example.py b/src/examples/vision/gpiozero/simple_button_example.py index 5ef147b0..cca59112 100644 --- a/src/examples/vision/gpiozero/simple_button_example.py +++ b/src/examples/vision/gpiozero/simple_button_example.py @@ -17,7 +17,7 @@ button = Button(BUTTON_GPIO_PIN) while True: - if button.is_pressed: - led.on() - else: - led.off() + if button.is_pressed: + led.on() + else: + led.off() diff --git a/src/examples/vision/image_classification.py b/src/examples/vision/image_classification.py index 3de83db2..07e764d2 100644 --- a/src/examples/vision/image_classification.py +++ b/src/examples/vision/image_classification.py @@ -24,25 +24,25 @@ def main(): - parser = argparse.ArgumentParser() - parser.add_argument('--input', '-i', dest='input', required=True) - args = parser.parse_args() + parser = argparse.ArgumentParser() + parser.add_argument('--input', '-i', dest='input', required=True) + args = parser.parse_args() - # There are two models available for image classification task: - # 1) MobileNet based (image_classification.MOBILENET), which has 59.9% top-1 - # accuracy on ImageNet; - # 2) SqueezeNet based (image_classification.SQUEEZENET), which has 45.3% top-1 - # accuracy on ImageNet; - model_type = image_classification.MOBILENET - with ImageInference(image_classification.model(model_type)) as inference: - image = Image.open( - io.BytesIO(sys.stdin.buffer.read()) - if args.input == '-' else args.input) - classes = image_classification.get_classes( - inference.run(image), max_num_objects=5, object_prob_threshold=0.1) - for i, (label, score) in enumerate(classes): - print('Result %d: %s (prob=%f)' % (i, label, score)) + # There are two models available for image classification task: + # 1) MobileNet based (image_classification.MOBILENET), which has 59.9% top-1 + # accuracy on ImageNet; + # 2) SqueezeNet based (image_classification.SQUEEZENET), which has 45.3% top-1 + # accuracy on ImageNet; + model_type = image_classification.MOBILENET + with ImageInference(image_classification.model(model_type)) as inference: + image = Image.open( + io.BytesIO(sys.stdin.buffer.read()) + if args.input == '-' else args.input) + classes = image_classification.get_classes( + inference.run(image), max_num_objects=5, object_prob_threshold=0.1) + for i, (label, score) in enumerate(classes): + print('Result %d: %s (prob=%f)' % (i, label, score)) if __name__ == '__main__': - main() + main() diff --git a/src/examples/vision/joy/joy_detection_demo.py b/src/examples/vision/joy/joy_detection_demo.py index 91d09427..7b56e25f 100755 --- a/src/examples/vision/joy/joy_detection_demo.py +++ b/src/examples/vision/joy/joy_detection_demo.py @@ -44,131 +44,131 @@ def blend(color_a, color_b, alpha): - return tuple([ - math.ceil(a * alpha + b * (1.0 - alpha)) - for a, b in zip(color_a, color_b) - ]) + return tuple([ + math.ceil(a * alpha + b * (1.0 - alpha)) + for a, b in zip(color_a, color_b) + ]) class JoyDetector(object): - def __init__(self, num_frames, preview_alpha): - self._rgbled = RGBLED(debug=False) - self._num_frames = num_frames - self._preview_alpha = preview_alpha - self._toneplayer = TonePlayer(22, bpm=10) - self._sound_played = False - self._detector = threading.Thread(target=self._run_detector) - self._animator = threading.Thread(target=self._run_animator) - self._joy_score_lock = threading.Lock() - self._joy_score = 0.0 - self._joy_score_window = collections.deque(maxlen=WINDOW_SIZE) - self._run_event = threading.Event() - signal.signal(signal.SIGINT, lambda signal, frame: self.stop()) - signal.signal(signal.SIGTERM, lambda signal, frame: self.stop()) - - @property - def joy_score(self): - with self._joy_score_lock: - return self._joy_score - - @joy_score.setter - def joy_score(self, value): - with self._joy_score_lock: - self._joy_score = value - - def start(self): - print('Starting JoyDetector...') - self._run_event.set() - self._detector.start() - - def join(self): - self._detector.join() - self._animator.join() - - def stop(self): - print('Stopping JoyDetector...') - self._run_event.clear() - - def _play_sound(self, sound): - if not self._sound_played: - self._sound_played = True - self._sound = threading.Thread(target=self._toneplayer.play, args=(*sound,)) - self._sound.start() - - def _run_animator(self): - while self._run_event.is_set(): - joy_score = self.joy_score - if joy_score > JOY_SCORE_PEAK: - self._play_sound(JOY_SOUND) - elif joy_score < JOY_SCORE_MIN: - self._play_sound(SAD_SOUND) - else: + def __init__(self, num_frames, preview_alpha): + self._rgbled = RGBLED(debug=False) + self._num_frames = num_frames + self._preview_alpha = preview_alpha + self._toneplayer = TonePlayer(22, bpm=10) self._sound_played = False - - if joy_score > 0: - self._rgbled.SetColorMix(*blend(JOY_COLOR, SAD_COLOR, joy_score)) - else: - self._rgbled.SetColorMix(*NONE_COLOR) - time.sleep(0.1) - - def _run_detector(self): - with PiCamera() as camera, PrivacyLED(): - # Forced sensor mode, 1640x1232, full FoV. See: - # https://picamera.readthedocs.io/en/release-1.13/fov.html#sensor-modes - # This is the resolution inference run on. - camera.sensor_mode = 4 - camera.resolution = (1640, 1232) - camera.framerate = 15 - # Blend the preview layer with the alpha value from the flags. - camera.start_preview(alpha=self._preview_alpha) - with CameraInference(face_detection.model()) as inference: - self._play_sound(MODEL_LOAD_SOUND) - self._animator.start() - for i, result in enumerate(inference.run()): - faces = face_detection.get_faces(result) - # Calculate joy score as an average for all detected faces. - joy_score = 0.0 - if faces: - joy_score = sum([face.joy_score for face in faces]) / len(faces) - - # Append new joy score to the window and calculate mean value. - self._joy_score_window.append(joy_score) - self.joy_score = sum(self._joy_score_window) / len( - self._joy_score_window) - if self._num_frames == i or not self._run_event.is_set(): - break + self._detector = threading.Thread(target=self._run_detector) + self._animator = threading.Thread(target=self._run_animator) + self._joy_score_lock = threading.Lock() + self._joy_score = 0.0 + self._joy_score_window = collections.deque(maxlen=WINDOW_SIZE) + self._run_event = threading.Event() + signal.signal(signal.SIGINT, lambda signal, frame: self.stop()) + signal.signal(signal.SIGTERM, lambda signal, frame: self.stop()) + + @property + def joy_score(self): + with self._joy_score_lock: + return self._joy_score + + @joy_score.setter + def joy_score(self, value): + with self._joy_score_lock: + self._joy_score = value + + def start(self): + print('Starting JoyDetector...') + self._run_event.set() + self._detector.start() + + def join(self): + self._detector.join() + self._animator.join() + + def stop(self): + print('Stopping JoyDetector...') + self._run_event.clear() + + def _play_sound(self, sound): + if not self._sound_played: + self._sound_played = True + self._sound = threading.Thread(target=self._toneplayer.play, args=(*sound,)) + self._sound.start() + + def _run_animator(self): + while self._run_event.is_set(): + joy_score = self.joy_score + if joy_score > JOY_SCORE_PEAK: + self._play_sound(JOY_SOUND) + elif joy_score < JOY_SCORE_MIN: + self._play_sound(SAD_SOUND) + else: + self._sound_played = False + + if joy_score > 0: + self._rgbled.SetColorMix(*blend(JOY_COLOR, SAD_COLOR, joy_score)) + else: + self._rgbled.SetColorMix(*NONE_COLOR) + time.sleep(0.1) + + def _run_detector(self): + with PiCamera() as camera, PrivacyLED(): + # Forced sensor mode, 1640x1232, full FoV. See: + # https://picamera.readthedocs.io/en/release-1.13/fov.html#sensor-modes + # This is the resolution inference run on. + camera.sensor_mode = 4 + camera.resolution = (1640, 1232) + camera.framerate = 15 + # Blend the preview layer with the alpha value from the flags. + camera.start_preview(alpha=self._preview_alpha) + with CameraInference(face_detection.model()) as inference: + self._play_sound(MODEL_LOAD_SOUND) + self._animator.start() + for i, result in enumerate(inference.run()): + faces = face_detection.get_faces(result) + # Calculate joy score as an average for all detected faces. + joy_score = 0.0 + if faces: + joy_score = sum([face.joy_score for face in faces]) / len(faces) + + # Append new joy score to the window and calculate mean value. + self._joy_score_window.append(joy_score) + self.joy_score = sum(self._joy_score_window) / len( + self._joy_score_window) + if self._num_frames == i or not self._run_event.is_set(): + break def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - '--num_frames', - '-n', - type=int, - dest='num_frames', - default=-1, - help='Sets the number of frames to run for. ' - 'Setting this parameter to -1 will ' - 'cause the demo to not automatically terminate.') - parser.add_argument( - '--preview_alpha', - '-pa', - type=int, - dest='preview_alpha', - default=0, - help='Sets the transparency value of the preview overlay (0-255).') - args = parser.parse_args() - - device = get_aiy_device_name() - if not device or not 'Vision' in device: - print('Do you have an AIY Vision bonnet installed? Exiting.') - sys.exit(0) - - detector = JoyDetector(args.num_frames, args.preview_alpha) - detector.start() - detector.join() + parser = argparse.ArgumentParser() + parser.add_argument( + '--num_frames', + '-n', + type=int, + dest='num_frames', + default=-1, + help='Sets the number of frames to run for. ' + 'Setting this parameter to -1 will ' + 'cause the demo to not automatically terminate.') + parser.add_argument( + '--preview_alpha', + '-pa', + type=int, + dest='preview_alpha', + default=0, + help='Sets the transparency value of the preview overlay (0-255).') + args = parser.parse_args() + + device = get_aiy_device_name() + if not device or not 'Vision' in device: + print('Do you have an AIY Vision bonnet installed? Exiting.') + sys.exit(0) + + detector = JoyDetector(args.num_frames, args.preview_alpha) + detector.start() + detector.join() if __name__ == '__main__': - main() + main() diff --git a/src/examples/vision/object_detection.py b/src/examples/vision/object_detection.py index e21c0441..a5f3896a 100755 --- a/src/examples/vision/object_detection.py +++ b/src/examples/vision/object_detection.py @@ -29,32 +29,32 @@ def _crop_center(image): - width, height = image.size - size = min(width, height) - x, y = (width - size) / 2, (height - size) / 2 - return image.crop((x, y, x + size, y + size)), (x, y) + width, height = image.size + size = min(width, height) + x, y = (width - size) / 2, (height - size) / 2 + return image.crop((x, y, x + size, y + size)), (x, y) def main(): - parser = argparse.ArgumentParser() - parser.add_argument('--input', '-i', dest='input', required=True) - parser.add_argument('--output', '-o', dest='output') - args = parser.parse_args() - - with ImageInference(object_detection.model()) as inference: - image = Image.open( - io.BytesIO(sys.stdin.buffer.read()) - if args.input == '-' else args.input) - image_center, offset = _crop_center(image) - draw = ImageDraw.Draw(image) - result = inference.run(image_center) - for i, obj in enumerate(object_detection.get_objects(result, 0.3, offset)): - print('Object #%d: %s' % (i, str(obj))) - x, y, width, height = obj.bounding_box - draw.rectangle((x, y, x + width, y + height), outline='red') - if args.output: - image.save(args.output) + parser = argparse.ArgumentParser() + parser.add_argument('--input', '-i', dest='input', required=True) + parser.add_argument('--output', '-o', dest='output') + args = parser.parse_args() + + with ImageInference(object_detection.model()) as inference: + image = Image.open( + io.BytesIO(sys.stdin.buffer.read()) + if args.input == '-' else args.input) + image_center, offset = _crop_center(image) + draw = ImageDraw.Draw(image) + result = inference.run(image_center) + for i, obj in enumerate(object_detection.get_objects(result, 0.3, offset)): + print('Object #%d: %s' % (i, str(obj))) + x, y, width, height = obj.bounding_box + draw.rectangle((x, y, x + width, y + height), outline='red') + if args.output: + image.save(args.output) if __name__ == '__main__': - main() + main() diff --git a/src/examples/vision/object_detection_camera.py b/src/examples/vision/object_detection_camera.py index 7bd6e344..17a798a5 100644 --- a/src/examples/vision/object_detection_camera.py +++ b/src/examples/vision/object_detection_camera.py @@ -28,60 +28,60 @@ def main(): - """Object detection camera inference example.""" - parser = argparse.ArgumentParser() - parser.add_argument( - '--num_frames', - '-n', - type=int, - dest='num_frames', - default=-1, - help='Sets the number of frames to run for, otherwise runs forever.') + """Object detection camera inference example.""" + parser = argparse.ArgumentParser() + parser.add_argument( + '--num_frames', + '-n', + type=int, + dest='num_frames', + default=-1, + help='Sets the number of frames to run for, otherwise runs forever.') - parser.add_argument( - '--num_objects', - '-c', - type=int, - dest='num_objects', - default=3, - help='Sets the number of object interences to print.') + parser.add_argument( + '--num_objects', + '-c', + type=int, + dest='num_objects', + default=3, + help='Sets the number of object interences to print.') - args = parser.parse_args() + args = parser.parse_args() - def print_classes(classes, object_count): - s = '' - for index, (obj, prob) in enumerate(classes): - if index > object_count - 1: - break - s+='%s=%1.2f\t|\t' % (obj, prob) - print('%s\r' % s) + def print_classes(classes, object_count): + s = '' + for index, (obj, prob) in enumerate(classes): + if index > object_count - 1: + break + s += '%s=%1.2f\t|\t' % (obj, prob) + print('%s\r' % s) - with PiCamera() as camera: - # Forced sensor mode, 1640x1232, full FoV. See: - # https://picamera.readthedocs.io/en/release-1.13/fov.html#sensor-modes - # This is the resolution inference run on. - camera.sensor_mode = 4 + with PiCamera() as camera: + # Forced sensor mode, 1640x1232, full FoV. See: + # https://picamera.readthedocs.io/en/release-1.13/fov.html#sensor-modes + # This is the resolution inference run on. + camera.sensor_mode = 4 - # Scaled and cropped resolution. If different from sensor mode implied - # resolution, inference results must be adjusted accordingly. This is - # true in particular when camera.start_recording is used to record an - # encoded h264 video stream as the Pi encoder can't encode all native - # sensor resolutions, or a standard one like 1080p may be desired. - camera.resolution = (1640, 1232) + # Scaled and cropped resolution. If different from sensor mode implied + # resolution, inference results must be adjusted accordingly. This is + # true in particular when camera.start_recording is used to record an + # encoded h264 video stream as the Pi encoder can't encode all native + # sensor resolutions, or a standard one like 1080p may be desired. + camera.resolution = (1640, 1232) - # Start the camera stream. - camera.framerate = 30 - camera.start_preview() + # Start the camera stream. + camera.framerate = 30 + camera.start_preview() - with CameraInference(image_classification.model()) as inference: - for i, result in enumerate(inference.run()): - if i == args.num_frames: - break - classes = image_classification.get_classes(result) - print_classes(classes, args.num_objects) + with CameraInference(image_classification.model()) as inference: + for i, result in enumerate(inference.run()): + if i == args.num_frames: + break + classes = image_classification.get_classes(result) + print_classes(classes, args.num_objects) - camera.stop_preview() + camera.stop_preview() if __name__ == '__main__': - main() + main() diff --git a/src/examples/voice/assistant_library_with_button_demo.py b/src/examples/voice/assistant_library_with_button_demo.py index 24f2156f..437abb67 100755 --- a/src/examples/voice/assistant_library_with_button_demo.py +++ b/src/examples/voice/assistant_library_with_button_demo.py @@ -47,6 +47,7 @@ class MyAssistant(object): thread. Otherwise, the on_button_pressed() method will never get a chance to be invoked. """ + def __init__(self): self._task = threading.Thread(target=self._run_task) self._can_start_conversation = False