diff --git a/cocosplit.py b/cocosplit.py index aeb6282..92aa76c 100644 --- a/cocosplit.py +++ b/cocosplit.py @@ -4,7 +4,9 @@ from sklearn.model_selection import train_test_split from skmultilearn.model_selection import iterative_train_test_split import numpy as np +from tqdm import tqdm +np.random.seed(42) def save_coco(file, info, licenses, images, annotations, categories): with open(file, 'wt', encoding='UTF-8') as coco: @@ -13,14 +15,14 @@ def save_coco(file, info, licenses, images, annotations, categories): def filter_annotations(annotations, images): image_ids = funcy.lmap(lambda i: int(i['id']), images) - return funcy.lfilter(lambda a: int(a['image_id']) in image_ids, annotations) - + filtered_annotations = funcy.lfilter(lambda a: int(a['image_id']) in image_ids, tqdm(annotations, desc='Filtering Annotations')) + return filtered_annotations def filter_images(images, annotations): - annotation_ids = funcy.lmap(lambda i: int(i['image_id']), annotations) + filtered_images = funcy.lfilter(lambda a: int(a['id']) in annotation_ids, tqdm(images, desc='Filtering Images')) + return filtered_images - return funcy.lfilter(lambda a: int(a['id']) in annotation_ids, images) parser = argparse.ArgumentParser(description='Splits COCO annotations file into training and test sets.') @@ -67,7 +69,7 @@ def main(args): annotations = funcy.lremove(lambda i: i['category_id'] not in annotation_categories , annotations) - X_train, y_train, X_test, y_test = iterative_train_test_split(np.array([annotations]).T,np.array([ annotation_categories]).T, test_size = 1-args.split) + X_train, y_train, X_test, y_test = iterative_train_test_split(np.array([annotations]).T,np.array([ annotation_categories]).T, test_size = 1-args.split, random_state=42) save_coco(args.train, info, licenses, filter_images(images, X_train.reshape(-1)), X_train.reshape(-1).tolist(), categories) save_coco(args.test, info, licenses, filter_images(images, X_test.reshape(-1)), X_test.reshape(-1).tolist(), categories) @@ -76,7 +78,7 @@ def main(args): else: - X_train, X_test = train_test_split(images, train_size=args.split) + X_train, X_test = train_test_split(images, train_size=args.split, random_state=42) anns_train = filter_annotations(annotations, X_train) anns_test=filter_annotations(annotations, X_test) diff --git a/requirements.txt b/requirements.txt index 47d6288..40b21b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ sklearn funcy argparse scikit-multilearn +tqdm \ No newline at end of file