From aa993ef2390a3806109ac173ecc176d6bb3d428d Mon Sep 17 00:00:00 2001 From: Emanuele Plebani Date: Wed, 19 Apr 2017 21:54:37 +0200 Subject: [PATCH] convert avgpool and softmax layers, support classification mode in demo The conversion script supports the global average pooling (avgpool) and softmax layers required to convert the darknet and tiny networks. In the demo the --coco option has been replaced by the more generic 'mode'; the 'darknet' mode implements classification on the 1k Imagenet classes. --- create_yolo_prototxt.py | 12 +- imagenet.shortnames | 1000 +++++++++++++++++++++++++++++++++++++++ yolo_detect.py | 112 +++-- 3 files changed, 1087 insertions(+), 37 deletions(-) create mode 100644 imagenet.shortnames diff --git a/create_yolo_prototxt.py b/create_yolo_prototxt.py index 4f4fab7..fdd33ff 100644 --- a/create_yolo_prototxt.py +++ b/create_yolo_prototxt.py @@ -130,6 +130,12 @@ def max_pooling_layer(previous, name, params): kernel_size=int(params["size"]), stride=int(params["stride"])) +def global_pooling_layer(previous, name, mode="avg"): + """ create a Global Pooling Layer """ + pool = cp.Pooling.AVE if mode == "avg" else cp.Pooling.MAX + return cl.Pooling(previous, name=name, pool=pool, global_pooling=True) + + def dense_layer(previous, name, params, train=False): """ create a densse layer """ fields = dict(num_output=int(params["output"])) @@ -191,6 +197,10 @@ def convert_configuration(config, train=False, loc_layer=False): elif section == "maxpool": layers.append(max_pooling_layer(layers[-1], "pool{0}".format(count), params)) + elif section == "avgpool": + layers.append(global_pooling_layer(layers[-1], "pool{0}".format(count))) + elif section == "softmax": + layers.append(cl.Softmax(layers[-1], name="softmax{0}".format(count))) elif section == "connected": count += 1 add_dense_layer(layers, count, params, train) @@ -207,7 +217,7 @@ def convert_configuration(config, train=False, loc_layer=False): model = caffe.NetSpec() for layer in layers: setattr(model, layer.fn.params["name"], layer) - model.result = layers[-1] + model.result = layers[-1] return model diff --git a/imagenet.shortnames b/imagenet.shortnames new file mode 100644 index 0000000..119d6fe --- /dev/null +++ b/imagenet.shortnames @@ -0,0 +1,1000 @@ +kit fox +English setter +Siberian husky +Australian terrier +English springer +grey whale +lesser panda +Egyptian cat +ibex +Persian cat +cougar +gazelle +porcupine +sea lion +malamute +badger +Great Dane +Walker hound +Welsh springer spaniel +whippet +Scottish deerhound +killer whale +mink +African elephant +Weimaraner +soft-coated wheaten terrier +Dandie Dinmont +red wolf +Old English sheepdog +jaguar +otterhound +bloodhound +Airedale +hyena +meerkat +giant schnauzer +titi +three-toed sloth +sorrel +black-footed ferret +dalmatian +black-and-tan coonhound +papillon +skunk +Staffordshire bullterrier +Mexican hairless +Bouvier des Flandres +weasel +miniature poodle +Cardigan +malinois +bighorn +fox squirrel +colobus +tiger cat +Lhasa +impala +coyote +Yorkshire terrier +Newfoundland +brown bear +red fox +Norwegian elkhound +Rottweiler +hartebeest +Saluki +grey fox +schipperke +Pekinese +Brabancon griffon +West Highland white terrier +Sealyham terrier +guenon +mongoose +indri +tiger +Irish wolfhound +wild boar +EntleBucher +zebra +ram +French bulldog +orangutan +basenji +leopard +Bernese mountain dog +Maltese dog +Norfolk terrier +toy terrier +vizsla +cairn +squirrel monkey +groenendael +clumber +Siamese cat +chimpanzee +komondor +Afghan hound +Japanese spaniel +proboscis monkey +guinea pig +white wolf +ice bear +gorilla +borzoi +toy poodle +Kerry blue terrier +ox +Scotch terrier +Tibetan mastiff +spider monkey +Doberman +Boston bull +Greater Swiss Mountain dog +Appenzeller +Shih-Tzu +Irish water spaniel +Pomeranian +Bedlington terrier +warthog +Arabian camel +siamang +miniature schnauzer +collie +golden retriever +Irish terrier +affenpinscher +Border collie +hare +boxer +silky terrier +beagle +Leonberg +German short-haired pointer +patas +dhole +baboon +macaque +Chesapeake Bay retriever +bull mastiff +kuvasz +capuchin +pug +curly-coated retriever +Norwich terrier +flat-coated retriever +hog +keeshond +Eskimo dog +Brittany spaniel +standard poodle +Lakeland terrier +snow leopard +Gordon setter +dingo +standard schnauzer +hamster +Tibetan terrier +Arctic fox +wire-haired fox terrier +basset +water buffalo +American black bear +Angora +bison +howler monkey +hippopotamus +chow +giant panda +American Staffordshire terrier +Shetland sheepdog +Great Pyrenees +Chihuahua +tabby +marmoset +Labrador retriever +Saint Bernard +armadillo +Samoyed +bluetick +redbone +polecat +marmot +kelpie +gibbon +llama +miniature pinscher +wood rabbit +Italian greyhound +lion +cocker spaniel +Irish setter +dugong +Indian elephant +beaver +Sussex spaniel +Pembroke +Blenheim spaniel +Madagascar cat +Rhodesian ridgeback +lynx +African hunting dog +langur +Ibizan hound +timber wolf +cheetah +English foxhound +briard +sloth bear +Border terrier +German shepherd +otter +koala +tusker +echidna +wallaby +platypus +wombat +revolver +umbrella +schooner +soccer ball +accordion +ant +starfish +chambered nautilus +grand piano +laptop +strawberry +airliner +warplane +airship +balloon +space shuttle +fireboat +gondola +speedboat +lifeboat +canoe +yawl +catamaran +trimaran +container ship +liner +pirate +aircraft carrier +submarine +wreck +half track +tank +missile +bobsled +dogsled +bicycle-built-for-two +mountain bike +freight car +passenger car +barrow +shopping cart +motor scooter +forklift +electric locomotive +steam locomotive +amphibian +ambulance +beach wagon +cab +convertible +jeep +limousine +minivan +Model T +racer +sports car +go-kart +golfcart +moped +snowplow +fire engine +garbage truck +pickup +tow truck +trailer truck +moving van +police van +recreational vehicle +streetcar +snowmobile +tractor +mobile home +tricycle +unicycle +horse cart +jinrikisha +oxcart +bassinet +cradle +crib +four-poster +bookcase +china cabinet +medicine chest +chiffonier +table lamp +file +park bench +barber chair +throne +folding chair +rocking chair +studio couch +toilet seat +desk +pool table +dining table +entertainment center +wardrobe +Granny Smith +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +acorn +hip +ear +rapeseed +corn +buckeye +organ +upright +chime +drum +gong +maraca +marimba +steel drum +banjo +cello +violin +harp +acoustic guitar +electric guitar +cornet +French horn +trombone +harmonica +ocarina +panpipe +bassoon +oboe +sax +flute +daisy +yellow lady's slipper +cliff +valley +alp +volcano +promontory +sandbar +coral reef +lakeside +seashore +geyser +hatchet +cleaver +letter opener +plane +power drill +lawn mower +hammer +corkscrew +can opener +plunger +screwdriver +shovel +plow +chain saw +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +white stork +black stork +spoonbill +flamingo +American egret +little blue heron +bittern +crane +limpkin +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +European gallinule +pelican +king penguin +albatross +great white shark +tiger shark +hammerhead +electric ray +stingray +barracouta +coho +tench +goldfish +eel +rock beauty +anemone fish +lionfish +puffer +sturgeon +gar +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +triceratops +African crocodile +American alligator +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +whistle +wing +paintbrush +hand blower +oxygen mask +snorkel +loudspeaker +microphone +screen +mouse +electric fan +oil filter +strainer +space heater +stove +guillotine +barometer +rule +odometer +scale +analog clock +digital clock +wall clock +hourglass +sundial +parking meter +stopwatch +digital watch +stethoscope +syringe +magnetic compass +binoculars +projector +sunglasses +loupe +radio telescope +bow +cannon +assault rifle +rifle +projectile +computer keyboard +typewriter keyboard +crane +lighter +abacus +cash machine +slide rule +desktop computer +hand-held computer +notebook +web site +harvester +thresher +printer +slot +vending machine +sewing machine +joystick +switch +hook +car wheel +paddlewheel +pinwheel +potter's wheel +gas pump +carousel +swing +reel +radiator +puck +hard disc +sunglass +pick +car mirror +solar dish +remote control +disk brake +buckle +hair slide +knot +combination lock +padlock +nail +safety pin +screw +muzzle +seat belt +ski +candle +jack-o'-lantern +spotlight +torch +neck brace +pier +tripod +maypole +mousetrap +spider web +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +isopod +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +sea urchin +sea cucumber +iron +espresso maker +microwave +Dutch oven +rotisserie +toaster +waffle iron +vacuum +dishwasher +refrigerator +washer +Crock Pot +frying pan +wok +caldron +coffeepot +teapot +spatula +altar +triumphal arch +patio +steel arch bridge +suspension bridge +viaduct +barn +greenhouse +palace +monastery +library +apiary +boathouse +church +mosque +stupa +planetarium +restaurant +cinema +home theater +lumbermill +coil +obelisk +totem pole +castle +prison +grocery store +bakery +barbershop +bookshop +butcher shop +confectionery +shoe shop +tobacco shop +toyshop +fountain +cliff dwelling +yurt +dock +brass +megalith +bannister +breakwater +dam +chainlink fence +picket fence +worm fence +stone wall +grille +sliding door +turnstile +mountain tent +scoreboard +honeycomb +plate rack +pedestal +beacon +mashed potato +bell pepper +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +cardoon +mushroom +shower curtain +jean +carton +handkerchief +sandal +ashcan +safe +plate +necklace +croquet ball +fur coat +thimble +pajama +running shoe +cocktail shaker +chest +manhole cover +modem +tub +tray +balance beam +bagel +prayer rug +kimono +hot pot +whiskey jug +knee pad +book jacket +spindle +ski mask +beer bottle +crash helmet +bottlecap +tile roof +mask +maillot +Petri dish +football helmet +bathing cap +teddy +holster +pop bottle +photocopier +vestment +crossword puzzle +golf ball +trifle +suit +water tower +feather boa +cloak +red wine +drumstick +shield +Christmas stocking +hoopskirt +menu +stage +bonnet +meat loaf +baseball +face powder +scabbard +sunscreen +beer glass +hen-of-the-woods +guacamole +lampshade +wool +hay +bow tie +mailbag +water jug +bucket +dishrag +soup bowl +eggnog +mortar +trench coat +paddle +chain +swab +mixing bowl +potpie +wine bottle +shoji +bulletproof vest +drilling platform +binder +cardigan +sweatshirt +pot +birdhouse +hamper +ping-pong ball +pencil box +pay-phone +consomme +apron +punching bag +backpack +groom +bearskin +pencil sharpener +broom +mosquito net +abaya +mortarboard +poncho +crutch +Polaroid camera +space bar +cup +racket +traffic light +quill +radio +dough +cuirass +military uniform +lipstick +shower cap +monitor +oscilloscope +mitten +brassiere +French loaf +vase +milk can +rugby ball +paper towel +earthstar +envelope +miniskirt +cowboy hat +trolleybus +perfume +bathtub +hotdog +coral fungus +bullet train +pillow +toilet tissue +cassette +carpenter's kit +ladle +stinkhorn +lotion +hair spray +academic gown +dome +crate +wig +burrito +pill bottle +chain mail +theater curtain +window shade +barrel +washbasin +ballpoint +basketball +bath towel +cowboy boot +gown +window screen +agaric +cellular telephone +nipple +barbell +mailbox +lab coat +fire screen +minibus +packet +maze +pole +horizontal bar +sombrero +pickelhaube +rain barrel +wallet +cassette player +comic book +piggy bank +street sign +bell cote +fountain pen +Windsor tie +volleyball +overskirt +sarong +purse +bolo tie +bib +parachute +sleeping bag +television +swimming trunks +measuring cup +espresso +pizza +breastplate +shopping basket +wooden spoon +saltshaker +chocolate sauce +ballplayer +goblet +gyromitra +stretcher +water bottle +dial telephone +soap dispenser +jersey +school bus +jigsaw puzzle +plastic bag +reflex camera +diaper +Band Aid +ice lolly +velvet +tennis ball +gasmask +doormat +Loafer +ice cream +pretzel +quilt +maillot +tape player +clog +iPod +bolete +scuba diver +pitcher +matchstick +bikini +sock +CD player +lens cap +thatch +vault +beaker +bubble +cheeseburger +parallel bars +flagpole +coffee mug +rubber eraser +stole +carbonara +dumbbell \ No newline at end of file diff --git a/yolo_detect.py b/yolo_detect.py index b3051ea..e00a7c9 100644 --- a/yolo_detect.py +++ b/yolo_detect.py @@ -20,6 +20,42 @@ caffe.set_mode_cpu() +def load_names(filename): + """ load names from a text file (one per line) """ + with open(filename, 'r') as fid: + names = [l.strip() for l in fid] + return names + + +PRESETS = { + 'coco': { 'classes': [ + "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", + "truck", "boat", "traffic light", "fire hydrant", "stop sign", + "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", + "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", + "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", + "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", + "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", + "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", + "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", + "couch", "potted plant", "bed", "dining table", "toilet", "tv", + "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", + "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", + "scissors", "teddy bear", "hair drier", "toothbrush" + ], 'anchors': [[0.738768, 2.42204, 4.30971, 10.246, 12.6868], + [0.874946, 2.65704, 7.04493, 4.59428, 11.8741]] + }, + 'voc': { 'classes': [ + "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", + "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", + "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"], + 'anchors': [[1.08, 3.42, 6.63, 9.42, 16.62], + [1.19, 4.41, 11.38, 5.11, 10.52]] + }, + 'darknet': { 'classes': load_names('imagenet.shortnames'), 'anchors': []} +} + + def get_boxes(output, img_size, grid_size, num_boxes): """ extract bounding boxes from the last layer """ @@ -128,40 +164,15 @@ class """ raise ValueError(" output format not recognized") -def get_candidate_objects(output, img_size, coco=False): +def get_candidate_objects(output, img_size, mode): """ convert network output to bounding box predictions """ - classes_voc = [ - "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", - "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", - "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] - classes_coco = [ - "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", - "truck", "boat", "traffic light", "fire hydrant", "stop sign", - "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", - "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", - "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", - "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", - "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", - "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", - "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", - "couch", "potted plant", "bed", "dining table", "toilet", "tv", - "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", - "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", - "scissors", "teddy bear", "hair drier", "toothbrush" - ] - if coco: - classes = classes_coco - anchors = [[0.738768, 2.42204, 4.30971, 10.246, 12.6868], - [0.874946, 2.65704, 7.04493, 4.59428, 11.8741]] - else: - classes = classes_voc - anchors = [[1.08, 3.42, 6.63, 9.42, 16.62], - [1.19, 4.41, 11.38, 5.11, 10.52]] - threshold = 0.2 iou_threshold = 0.4 + classes = PRESETS[mode]['classes'] + anchors = PRESETS[mode]['anchors'] + boxes, probs = parse_yolo_output(output, img_size, len(classes), anchors) filter_mat_probs = (probs >= threshold) @@ -258,11 +269,27 @@ def show_results(img, results): cv2.imshow('YOLO detection', img) -def detect(model_filename, weight_filename, img_filename, coco=False): +def crop_max(img, shape): + """ crop the largest dimension to avoid stretching """ + net_h, net_w = shape + height, width = img.shape[:2] + aratio = net_w / net_h + + if width > height * aratio: + diff = int((width - height * aratio) / 2) + return img[:, diff:-diff, :] + else: + diff = int((height - width / aratio) / 2) + return img[diff:-diff, :, :] + + +def detect(model_filename, weight_filename, img_filename, mode): """ given a YOLO caffe model and an image, detect the objects in the image """ net = caffe.Net(model_filename, weight_filename, caffe.TEST) img = caffe.io.load_image(img_filename) # load the image using caffe.io + if mode == 'darknet': + img = crop_max(img, net.blobs['data'].data.shape[-2:]) transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape}) transformer.set_transpose('data', (2, 0, 1)) @@ -272,10 +299,20 @@ def detect(model_filename, weight_filename, img_filename, coco=False): t_end = datetime.now() print('total time is {:.2f} milliseconds'.format((t_end-t_start).total_seconds()*1e3)) - img_cv = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) - results = get_candidate_objects(out['result'][0], img.shape, coco) - show_results(img_cv, results) - cv2.waitKey() + if mode == 'darknet': + net_output = out[out.keys()[0]] # get first out layer + if len(net_output.shape) > 2: + net_output = np.squeeze(net_output)[np.newaxis, :] + + ids = np.argsort(net_output[0])[-1:-6:-1] + print('predicted classes: {}'.format( + [(PRESETS[mode]['classes'][cls_id], net_output[0][cls_id]) + for cls_id in ids])) + else: + img_cv = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + results = get_candidate_objects(out['result'][0], img.shape, mode) + show_results(img_cv, results) + cv2.waitKey() def main(): @@ -284,14 +321,17 @@ def main(): parser.add_argument('model', type=str, help='model prototxt') parser.add_argument('weights', type=str, help='model weights') parser.add_argument('image', type=str, help='input image') - parser.add_argument('--coco', action='store_true', help='use coco classes') + parser.add_argument('--mode', type=str, help='preset to use', default='coco') args = parser.parse_args() + if args.mode not in PRESETS.keys(): + raise ValueError(" Preset not supported: {}".format(args.mode)) + print('model file is {}'.format(args.model)) print('weight file is {}'.format(args.weights)) print('image file is {}'.format(args.image)) - detect(args.model, args.weights, args.image, args.coco) + detect(args.model, args.weights, args.image, args.mode) if __name__ == '__main__':