Skip to content

Commit

Permalink
Smooth out the code
Browse files Browse the repository at this point in the history
  • Loading branch information
DiegoPino committed May 22, 2024
1 parent 8dd8af9 commit 1023cdd
Showing 1 changed file with 39 additions and 14 deletions.
53 changes: 39 additions & 14 deletions nlpserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,18 @@
#app.config['var1'] = 'test'
app.config["YOLO_MODEL_NAME"] = "yolov8m.pt"
app.config["MOBILENET_MODEL_NAME"] = "mobilenet_v3_small.tflite"
app.config["EFFICIENTDET_DETECT_MODEL_NAME"] = "efficientdet_lite2.tflite"
app.config["MOBILENET_DETECT_MODEL_NAME"] = "ssd_mobilenet_v2.tflite"
for variable, value in os.environ.items():
if variable == "YOLO_MODEL_NAME":
# Can be set via Docker ENV
app.config["YOLO_MODEL_NAME"] = value
if variable == "MOBILENET_MODEL_NAME":
app.config["MOBILENET_MODEL_NAME"] = value
if variable == "EFFICIENTDET_MODEL_NAME":
app.config["EFFICIENTDET_MODEL_NAME"] = value
if variable == "MOBILENET_DETECT_MODEL_NAME":
app.config["MOBILENET_DETECT_MODEL_NAME"] = value

default_data = {}
default_data['web64'] = {
Expand Down Expand Up @@ -574,7 +580,7 @@ def loadImage(url, size = 640):
try:
response = requests.get(url)
response.raise_for_status()
except requests.exceptions.HTTPError as err:
except requests.exceptions.RequestException as err:
data['error'] = err.strerror
return jsonify(data)

Expand Down Expand Up @@ -631,12 +637,13 @@ def loadImage(url, size = 640):
if hasattr(object_detect_result, "obb") and object_detect_result.obb is not None: # Access the .obb attribute instead of .boxes
print('An obb model')
data['yolo']['objects'] = json.loads(object_detect_result.tojson(True))
elif hasattr(object_detect_result, "boxes") and object_detect_result.boxes is not None:
elif hasattr(object_detect_result, "boxes") and object_detect_result.boxes is not None and object_detect_result.probs is not None:
print('Not an obb model')
data['yolo']['objects'] = json.loads(object_detect_result.tojson(True))
if type(object_detect_result) != 'NoneType':
data['yolo']['objects'] = json.loads(object_detect_result.tojson(True))
else:
data['error'] = 'No features detected'
data['yolo']['objects'] = json.loads(object_detect_result.tojson(True))
data['yolo']['objects'] = []

data['yolo']['modelinfo'] = {'train_args': model.ckpt["train_args"], 'date': model.ckpt["date"], 'version': model.ckpt["version"]}

Expand Down Expand Up @@ -674,22 +681,24 @@ def loadImage(url, size = 480):
try:
response = requests.get(url)
response.raise_for_status()
except requests.exceptions.HTTPError as err:
except requests.exceptions.RequestException as err:
data['error'] = err.strerror
return jsonify(data)

img_bytes = BytesIO(response.content)
img = Image.open(img_bytes)
img = img.convert('RGB')
img = img.resize((size,size), Image.NEAREST)
img.thumbnail((size,size), Image.NEAREST)
# Media pipe uses a different format than YOLO, img here is PIL
img = np.asarray(img)
return img

data = dict(default_data)
data['message'] = "mobilenet - Parameters: 'iiif_image_url"
data['mobilenet'] = {}
data['efficientdet'] = {}
params = {}
objects = []


if request.method == 'GET':
Expand All @@ -709,40 +718,56 @@ def loadImage(url, size = 480):
return jsonify(data)
try:
# Create options for Image Embedder
base_options = python.BaseOptions(model_asset_path='models/mobilenet/' + app.config["MOBILENET_MODEL_NAME"])
base_options_embedder = python.BaseOptions(model_asset_path='models/mobilenet/' + app.config["MOBILENET_MODEL_NAME"])
base_options_detected = python.BaseOptions(model_asset_path='models/mobilenet/' + app.config["MOBILENET_DETECT_MODEL_NAME"])
l2_normalize = True #@param {type:"boolean"}
quantize = True #@param {type:"boolean"}
options = vision.ImageEmbedderOptions(
base_options=base_options, l2_normalize=l2_normalize, quantize=quantize)
options_embedder = vision.ImageEmbedderOptions(base_options=base_options_embedder, l2_normalize=l2_normalize, quantize=quantize)
options_detector = vision.ObjectDetectorOptions(base_options=base_options_detected, score_threshold=0.5)




# Create Image Embedder
with vision.ImageEmbedder.create_from_options(options) as embedder:
with vision.ImageEmbedder.create_from_options(options_embedder) as embedder:

# Format images for MediaPipe
img = loadImage(params['iiif_image_url'], 480)
img = loadImage(params['iiif_image_url'], 640)
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img)
embedding_result = embedder.embed(image)
with vision.ObjectDetector.create_from_options(options_detector) as detector:
detector_results = detector.detect(image)

except ValueError:
data['error'] = 'models/mobilenet/' + app.config["MOBILENET_MODEL_NAME"] + ' not found'
return jsonify(data)


if not detector_results.detections:
objects = []
else:
# make the coordinates percentage based.
for ml_result_index in range(len(detector_results.detections)):
detector_results.detections[ml_result_index].bounding_box.origin_x = detector_results.detections[ml_result_index].bounding_box.origin_x/image.width
detector_results.detections[ml_result_index].bounding_box.origin_y = detector_results.detections[ml_result_index].bounding_box.origin_y/image.height
detector_results.detections[ml_result_index].bounding_box.width = detector_results.detections[ml_result_index].bounding_box.width/image.width
detector_results.detections[ml_result_index].bounding_box.height = detector_results.detections[ml_result_index].bounding_box.width/image.height
objects = detector_results.detections
vector = embedding_result.embeddings[0].embedding
print(embedding_result.embeddings[0].embedding.shape[0])
# print(embedding_result.embeddings[0].embedding.shape[0])
# Vector size for this layer (inumlayers - 1) is 1024
# This "should" return a Unit Vector so we can use "dot_product" in Solr
X_l1 = preprocessing.normalize([vector], norm='l1')
# see https://nightlies.apache.org/solr/draft-guides/solr-reference-guide-antora/solr/10_0/query-guide/dense-vector-search.html
data['mobilenet']['vector'] = X_l1[0].tolist()

data['mobilenet']['objects'] = objects
data['message'] = 'done'
return jsonify(data)

# @app.route("/tester", methods=['GET', 'POST'])
# def tester():
# return render_template('form.html')

app.run(host='0.0.0.0', port=6400, debug=False)
app.run(host='0.0.0.0', port=6401, debug=False)


0 comments on commit 1023cdd

Please sign in to comment.