Skip to content

Commit

Permalink
support CoCo classes, autodetect num_boxes, fix RGB order in input
Browse files Browse the repository at this point in the history
The 'yolo_detect' module now supports CoCo classes by passing the
'--yolo' option and the number of boxes per region is now detected
automatically.
Set detection threshold to 0.25 (Darknet default).
Removed RGB->BGR conversion because Yolo uses RGB
(xingwangsfu/pull/14) and fix bug in detection
(xingwangsfu/issues/12)
  • Loading branch information
Banus committed Nov 26, 2016
1 parent 51449e3 commit 3a53156
Showing 1 changed file with 48 additions and 47 deletions.
95 changes: 48 additions & 47 deletions yolo_detect.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
""" YOLO detection demo in Caffe """
from __future__ import print_function, division

import getopt
import argparse
import sys
from datetime import datetime

Expand Down Expand Up @@ -42,15 +42,18 @@ def get_boxes(output, img_size, grid_size, num_boxes):
return boxes


def parse_yolo_output(output, img_size):
def parse_yolo_output(output, img_size, num_classes):
""" convert the output of YOLO's last layer to boxes and confidence in each
class """

num_classes = 20
num_boxes = 2
n_coord_box = 4 # coordinate per bounding box
grid_size = 7

sc_offset = grid_size * grid_size * num_classes

# autodetect num_boxes
num_boxes = int((output.shape[0] - sc_offset) /
(grid_size*grid_size*(n_coord_box+1)))
box_offset = sc_offset + grid_size * grid_size * num_boxes

class_probs = np.reshape(output[0:sc_offset], (grid_size, grid_size, num_classes))
Expand All @@ -66,23 +69,40 @@ class """
return boxes, probs


def get_candidate_objects(output, img_size):
def get_candidate_objects(output, img_size, coco=False):
""" convert network output to bounding box predictions """

classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car",
"cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike",
"person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]

threshold = 0.2
classes_voc = [
"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car",
"cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike",
"person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
classes_coco = [
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
"truck", "boat", "traffic light", "fire hydrant", "stop sign",
"parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
"elephant", "bear", "zebra", "giraffe", "backpack", "umbrella",
"handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
"sports ball", "kite", "baseball bat", "baseball glove", "skateboard",
"surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork",
"knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
"broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair",
"couch", "potted plant", "bed", "dining table", "toilet", "tv",
"laptop", "mouse", "remote", "keyboard", "cell phone", "microwave",
"oven", "toaster", "sink", "refrigerator", "book", "clock", "vase",
"scissors", "teddy bear", "hair drier", "toothbrush"
]
classes = classes_coco if coco else classes_voc

threshold = 0.25
iou_threshold = 0.5

boxes, probs = parse_yolo_output(output, img_size)
boxes, probs = parse_yolo_output(output, img_size, len(classes))

filter_mat_probs = (probs >= threshold)
filter_mat_boxes = np.nonzero(filter_mat_probs)[0:3]
boxes_filtered = boxes[filter_mat_boxes]
probs_filtered = probs[filter_mat_probs]
classes_num_filtered = np.argmax(filter_mat_probs, axis=3)[filter_mat_boxes]
classes_num_filtered = np.argmax(probs, axis=3)[filter_mat_boxes]

idx = np.argsort(probs_filtered)[::-1]
boxes_filtered = boxes_filtered[idx]
Expand Down Expand Up @@ -155,60 +175,41 @@ def show_results(img, results):
cv2.imshow('YOLO detection', img)


def detect(model_filename, weight_filename, img_filename):
def detect(model_filename, weight_filename, img_filename, coco=False):
""" given a YOLO caffe model and an image, detect the objects in the image
"""
net = caffe.Net(model_filename, weight_filename, caffe.TEST)
img = caffe.io.load_image(img_filename) # load the image using caffe.io

transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
transformer.set_transpose('data', (2, 0, 1))
transformer.set_channel_swap('data', (2, 1, 0))

t_start = datetime.now()
out = net.forward_all(data=np.asarray([transformer.preprocess('data', img)]))
t_end = datetime.now()
print('total time is {:.2f} milliseconds'.format((t_end-t_start).total_seconds()*1e3))

img_cv = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
results = get_candidate_objects(out['result'][0], img.shape)
results = get_candidate_objects(out['result'][0], img.shape, coco)
show_results(img_cv, results)
cv2.waitKey()


def main(argv):
def main():
""" script entry point """
model_filename = ''
weight_filename = ''
img_filename = ''
usage_message = 'Usage: yolo_detect.py -m <model_file> -w <output_file> -i <img_file>'

try:
opts, _ = getopt.getopt(argv, "hm:w:i:")
print(opts)
except getopt.GetoptError:
print(usage_message)
return
if not opts:
print(usage_message)
return

for opt, arg in opts:
if opt == '-h':
print(usage_message)
return
elif opt == "-m":
model_filename = arg
elif opt == "-w":
weight_filename = arg
elif opt == "-i":
img_filename = arg
print('model file is {}'.format(model_filename))
print('weight file is {}'.format(weight_filename))
print('image file is {}'.format(img_filename))

detect(model_filename, weight_filename, img_filename)
parser = argparse.ArgumentParser(description='Caffe-YOLO detection test')
parser.add_argument('model', type=str, help='model prototxt')
parser.add_argument('weights', type=str, help='model weights')
parser.add_argument('image', type=str, help='input image')
parser.add_argument('--coco', action='store_true', help='use coco classes')
args = parser.parse_args()

print('model file is {}'.format(args.model))
print('weight file is {}'.format(args.weights))
print('image file is {}'.format(args.image))

detect(args.model, args.weights, args.image, args.coco)


if __name__ == '__main__':
main(sys.argv[1:])
main()

0 comments on commit 3a53156

Please sign in to comment.