-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcommon.py
121 lines (97 loc) · 3.99 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
Utility functions for using TFLite Interpreter
"""
import numpy as np
from PIL import Image
import tflite_runtime.interpreter as tflite
import platform
EDGETPU_SHARED_LIB = {
'Linux': 'libedgetpu.so.1',
'Darwin': 'libedgetpu.1.dylib',
'Windows': 'edgetpu.dll'
}[platform.system()]
def make_interpreter_0(model_file):
model_file, *device = model_file.split('@')
return tflite.Interpreter(model_path=model_file)
def make_interpreter_1(model_file):
model_file, *device = model_file.split('@')
return tflite.Interpreter(
model_path=model_file,
experimental_delegates=[
tflite.load_delegate(EDGETPU_SHARED_LIB,
{'device': device[0]} if device else {})
])
def set_input(interpreter, image, resample=Image.NEAREST):
"""Copies data to input tensor."""
image = image.resize((input_image_size(interpreter)[0:2]), resample)
input_tensor(interpreter)[:, :] = image
def input_image_size(interpreter):
"""Returns input image size as (width, height, channels) tuple."""
_, height, width, channels = interpreter.get_input_details()[0]['shape']
return width, height, channels
def input_tensor(interpreter):
"""Returns input tensor view as numpy array of shape (height, width, 3)."""
tensor_index = interpreter.get_input_details()[0]['index']
return interpreter.tensor(tensor_index)()[0]
def output_tensor(interpreter, i):
"""Returns dequantized output tensor if quantized before."""
output_details = interpreter.get_output_details()[i]
output_data = np.squeeze(interpreter.tensor(output_details['index'])())
if 'quantization' not in output_details:
return output_data
scale, zero_point = output_details['quantization']
if scale == 0:
return output_data - zero_point
return scale * (output_data - zero_point)
import time
def time_elapsed(start_time,event):
time_now=time.time()
duration = (time_now - start_time)*1000
duration=round(duration,2)
print (">>> ", duration, " ms (" ,event, ")")
import os
def load_model(model_dir,model, lbl, edgetpu):
print('Loading from directory: {} '.format(model_dir))
print('Loading Model: {} '.format(model))
print('Loading Labels: {} '.format(lbl))
model_path=os.path.join(model_dir,model)
labels_path=os.path.join(model_dir,lbl)
if(edgetpu==0):
interpreter = make_interpreter_0(model_path)
else:
interpreter = make_interpreter_1(model_path)
interpreter.allocate_tensors()
labels = load_labels(labels_path)
return interpreter, labels
import re
def load_labels(path):
p = re.compile(r'\s*(\d+)(.+)')
with open(path, 'r', encoding='utf-8') as f:
lines = (p.match(line).groups() for line in f.readlines())
return {int(num): text.strip() for num, text in lines}
#----------------------------------------------------------------------
import collections
Object = collections.namedtuple('Object', ['id', 'score', 'bbox'])
class BBox(collections.namedtuple('BBox', ['xmin', 'ymin', 'xmax', 'ymax'])):
"""Bounding box.
Represents a rectangle which sides are either vertical or horizontal, parallel
to the x or y axis.
"""
__slots__ = ()
def get_output(interpreter, score_threshold, top_k, image_scale=1.0):
"""Returns list of detected objects."""
boxes = output_tensor(interpreter, 0)
class_ids = output_tensor(interpreter, 1)
scores = output_tensor(interpreter, 2)
count = int(output_tensor(interpreter, 3))
def make(i):
ymin, xmin, ymax, xmax = boxes[i]
return Object(
id=int(class_ids[i]),
score=scores[i],
bbox=BBox(xmin=np.maximum(0.0, xmin),
ymin=np.maximum(0.0, ymin),
xmax=np.minimum(1.0, xmax),
ymax=np.minimum(1.0, ymax)))
return [make(i) for i in range(top_k) if scores[i] >= score_threshold]
#--------------------------------------------------------------------