From 3a531565a366e774980c0c038fbc5047fa192c20 Mon Sep 17 00:00:00 2001 From: Emanuele Plebani Date: Sat, 26 Nov 2016 10:15:34 +0100 Subject: [PATCH] support CoCo classes, autodetect num_boxes, fix RGB order in input The 'yolo_detect' module now supports CoCo classes by passing the '--yolo' option and the number of boxes per region is now detected automatically. Set detection threshold to 0.25 (Darknet default). Removed RGB->BGR conversion because Yolo uses RGB (xingwangsfu/caffe-yolo/pull/14) and fix bug in detection (xingwangsfu/caffe-yolo/issues/12) --- yolo_detect.py | 95 +++++++++++++++++++++++++------------------------- 1 file changed, 48 insertions(+), 47 deletions(-) diff --git a/yolo_detect.py b/yolo_detect.py index 0b8c31f..04922ea 100644 --- a/yolo_detect.py +++ b/yolo_detect.py @@ -1,7 +1,7 @@ """ YOLO detection demo in Caffe """ from __future__ import print_function, division -import getopt +import argparse import sys from datetime import datetime @@ -42,15 +42,18 @@ def get_boxes(output, img_size, grid_size, num_boxes): return boxes -def parse_yolo_output(output, img_size): +def parse_yolo_output(output, img_size, num_classes): """ convert the output of YOLO's last layer to boxes and confidence in each class """ - num_classes = 20 - num_boxes = 2 + n_coord_box = 4 # coordinate per bounding box grid_size = 7 sc_offset = grid_size * grid_size * num_classes + + # autodetect num_boxes + num_boxes = int((output.shape[0] - sc_offset) / + (grid_size*grid_size*(n_coord_box+1))) box_offset = sc_offset + grid_size * grid_size * num_boxes class_probs = np.reshape(output[0:sc_offset], (grid_size, grid_size, num_classes)) @@ -66,23 +69,40 @@ class """ return boxes, probs -def get_candidate_objects(output, img_size): +def get_candidate_objects(output, img_size, coco=False): """ convert network output to bounding box predictions """ - classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", - "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", - "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] - - threshold = 0.2 + classes_voc = [ + "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", + "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", + "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] + classes_coco = [ + "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", + "truck", "boat", "traffic light", "fire hydrant", "stop sign", + "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", + "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", + "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", + "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", + "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", + "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", + "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", + "couch", "potted plant", "bed", "dining table", "toilet", "tv", + "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", + "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", + "scissors", "teddy bear", "hair drier", "toothbrush" + ] + classes = classes_coco if coco else classes_voc + + threshold = 0.25 iou_threshold = 0.5 - boxes, probs = parse_yolo_output(output, img_size) + boxes, probs = parse_yolo_output(output, img_size, len(classes)) filter_mat_probs = (probs >= threshold) filter_mat_boxes = np.nonzero(filter_mat_probs)[0:3] boxes_filtered = boxes[filter_mat_boxes] probs_filtered = probs[filter_mat_probs] - classes_num_filtered = np.argmax(filter_mat_probs, axis=3)[filter_mat_boxes] + classes_num_filtered = np.argmax(probs, axis=3)[filter_mat_boxes] idx = np.argsort(probs_filtered)[::-1] boxes_filtered = boxes_filtered[idx] @@ -155,7 +175,7 @@ def show_results(img, results): cv2.imshow('YOLO detection', img) -def detect(model_filename, weight_filename, img_filename): +def detect(model_filename, weight_filename, img_filename, coco=False): """ given a YOLO caffe model and an image, detect the objects in the image """ net = caffe.Net(model_filename, weight_filename, caffe.TEST) @@ -163,7 +183,6 @@ def detect(model_filename, weight_filename, img_filename): transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape}) transformer.set_transpose('data', (2, 0, 1)) - transformer.set_channel_swap('data', (2, 1, 0)) t_start = datetime.now() out = net.forward_all(data=np.asarray([transformer.preprocess('data', img)])) @@ -171,44 +190,26 @@ def detect(model_filename, weight_filename, img_filename): print('total time is {:.2f} milliseconds'.format((t_end-t_start).total_seconds()*1e3)) img_cv = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) - results = get_candidate_objects(out['result'][0], img.shape) + results = get_candidate_objects(out['result'][0], img.shape, coco) show_results(img_cv, results) cv2.waitKey() -def main(argv): +def main(): """ script entry point """ - model_filename = '' - weight_filename = '' - img_filename = '' - usage_message = 'Usage: yolo_detect.py -m -w -i ' - - try: - opts, _ = getopt.getopt(argv, "hm:w:i:") - print(opts) - except getopt.GetoptError: - print(usage_message) - return - if not opts: - print(usage_message) - return - - for opt, arg in opts: - if opt == '-h': - print(usage_message) - return - elif opt == "-m": - model_filename = arg - elif opt == "-w": - weight_filename = arg - elif opt == "-i": - img_filename = arg - print('model file is {}'.format(model_filename)) - print('weight file is {}'.format(weight_filename)) - print('image file is {}'.format(img_filename)) - - detect(model_filename, weight_filename, img_filename) + parser = argparse.ArgumentParser(description='Caffe-YOLO detection test') + parser.add_argument('model', type=str, help='model prototxt') + parser.add_argument('weights', type=str, help='model weights') + parser.add_argument('image', type=str, help='input image') + parser.add_argument('--coco', action='store_true', help='use coco classes') + args = parser.parse_args() + + print('model file is {}'.format(args.model)) + print('weight file is {}'.format(args.weights)) + print('image file is {}'.format(args.image)) + + detect(args.model, args.weights, args.image, args.coco) if __name__ == '__main__': - main(sys.argv[1:]) + main()