Skip to content

Commit

Permalink
Merge pull request #39 from squeakus/master
Browse files Browse the repository at this point in the history
Adding functionality to train on your own VOC data
  • Loading branch information
qfgaohao authored Apr 16, 2019
2 parents 4079ba7 + ca7c900 commit 1e86eb3
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 14 deletions.
127 changes: 127 additions & 0 deletions vision/datasets/generate_vocdata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import glob
import sys
import os
import xml.etree.ElementTree as ET
from random import random

def main(filename):
# ratio to divide up the images
train = 0.7
val = 0.2
test = 0.1
if (train + test + val) != 1.0:
print("probabilities must equal 1")
exit()

# get the labels
labels = []
imgnames = []
annotations = {}

with open(filename, 'r') as labelfile:
label_string = ""
for line in labelfile:
label_string += line.rstrip()

labels = label_string.split(',')
labels = [elem.replace(" ", "") for elem in labels]

# get image names
for filename in os.listdir("./JPEGImages"):
if filename.endswith(".jpg"):
img = filename.rstrip('.jpg')
imgnames.append(img)

print("Labels:", labels, "imgcnt:", len(imgnames))

# initialise annotation list
for label in labels:
annotations[label] = []

# Scan the annotations for the labels
for img in imgnames:
annote = "Annotations/" + img + '.xml'
if os.path.isfile(annote):
tree = ET.parse(annote)
root = tree.getroot()
annote_labels = []
for labelname in root.findall('*/name'):
labelname = labelname.text
annote_labels.append(labelname)
if labelname in labels:
annotations[labelname].append(img)
annotations[img] = annote_labels
else:
print("Missing annotation for ", annote)
exit()

# divvy up the images to the different sets
sampler = imgnames.copy()
train_list = []
val_list = []
test_list = []

while len(sampler) > 0:
dice = random()
elem = sampler.pop()

if dice <= test:
test_list.append(elem)
elif dice <= (test + val):
val_list.append(elem)
else:
train_list.append(elem)

print("Training set:", len(train_list), "validation set:", len(val_list), "test set:", len(test_list))


# create the dataset files
create_folder("./ImageSets/Main/")
with open("./ImageSets/Main/train.txt", 'w') as outfile:
for name in train_list:
outfile.write(name + "\n")
with open("./ImageSets/Main/val.txt", 'w') as outfile:
for name in val_list:
outfile.write(name + "\n")
with open("./ImageSets/Main/trainval.txt", 'w') as outfile:
for name in train_list:
outfile.write(name + "\n")
for name in val_list:
outfile.write(name + "\n")

with open("./ImageSets/Main/test.txt", 'w') as outfile:
for name in test_list:
outfile.write(name + "\n")

# create the individiual files for each label
for label in labels:
with open("./ImageSets/Main/"+ label +"_train.txt", 'w') as outfile:
for name in train_list:
if label in annotations[name]:
outfile.write(name + " 1\n")
else:
outfile.write(name + " -1\n")
with open("./ImageSets/Main/"+ label +"_val.txt", 'w') as outfile:
for name in val_list:
if label in annotations[name]:
outfile.write(name + " 1\n")
else:
outfile.write(name + " -1\n")
with open("./ImageSets/Main/"+ label +"_test.txt", 'w') as outfile:
for name in test_list:
if label in annotations[name]:
outfile.write(name + " 1\n")
else:
outfile.write(name + " -1\n")

def create_folder(foldername):
if os.path.exists(foldername):
print('folder already exists:', foldername)
else:
os.makedirs(foldername)

if __name__=='__main__':
if len(sys.argv) < 2:
print("usage: python generate_vocdata.py <labelfile>")
exit()
main(sys.argv[1])
55 changes: 41 additions & 14 deletions vision/datasets/voc_dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import numpy as np
import logging
import pathlib
import xml.etree.ElementTree as ET
import cv2
import os


class VOCDataset:

def __init__(self, root, transform=None, target_transform=None, is_test=False, keep_difficult=False):
def __init__(self, root, transform=None, target_transform=None, is_test=False, keep_difficult=False, label_file=None):
"""Dataset for VOC data.
Args:
root: the root of the VOC2007 or VOC2012 dataset, the directory contains the following sub-directories:
Expand All @@ -22,13 +24,34 @@ def __init__(self, root, transform=None, target_transform=None, is_test=False, k
self.ids = VOCDataset._read_image_ids(image_sets_file)
self.keep_difficult = keep_difficult

self.class_names = ('BACKGROUND',
# if the labels file exists, read in the class names
label_file_name = self.root / "labels.txt"

if os.path.isfile(label_file_name):
class_string = ""
with open(label_file_name, 'r') as infile:
for line in infile:
class_string += line.rstrip()

# classes should be a comma separated list

classes = class_string.split(',')
# prepend BACKGROUND as first class
classes.insert(0, 'BACKGROUND')
classes = [ elem.replace(" ", "") for elem in classes]
self.class_names = tuple(classes)
logging.info("VOC Labels read from file: " + str(self.class_names))

else:
logging.info("No labels file, using default VOC classes.")
self.class_names = ('BACKGROUND',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor'
)
'sheep', 'sofa', 'train', 'tvmonitor')


self.class_dict = {class_name: i for i, class_name in enumerate(self.class_names)}

def __getitem__(self, index):
Expand Down Expand Up @@ -74,16 +97,20 @@ def _get_annotation(self, image_id):
is_difficult = []
for object in objects:
class_name = object.find('name').text.lower().strip()
bbox = object.find('bndbox')
# VOC dataset format follows Matlab, in which indexes start from 0
x1 = float(bbox.find('xmin').text) - 1
y1 = float(bbox.find('ymin').text) - 1
x2 = float(bbox.find('xmax').text) - 1
y2 = float(bbox.find('ymax').text) - 1
boxes.append([x1, y1, x2, y2])
labels.append(self.class_dict[class_name])
is_difficult_str = object.find('difficult').text
is_difficult.append(int(is_difficult_str) if is_difficult_str else 0)
# we're only concerned with clases in our list
if class_name in self.class_dict:
bbox = object.find('bndbox')

# VOC dataset format follows Matlab, in which indexes start from 0
x1 = float(bbox.find('xmin').text) - 1
y1 = float(bbox.find('ymin').text) - 1
x2 = float(bbox.find('xmax').text) - 1
y2 = float(bbox.find('ymax').text) - 1
boxes.append([x1, y1, x2, y2])

labels.append(self.class_dict[class_name])
is_difficult_str = object.find('difficult').text
is_difficult.append(int(is_difficult_str) if is_difficult_str else 0)

return (np.array(boxes, dtype=np.float32),
np.array(labels, dtype=np.int64),
Expand Down

0 comments on commit 1e86eb3

Please sign in to comment.