diff --git a/open_images_downloader.py b/open_images_downloader.py index b692c47f..3bae6531 100644 --- a/open_images_downloader.py +++ b/open_images_downloader.py @@ -146,16 +146,15 @@ def parse_args(): if not args.include_depiction: annotations = annotations.loc[annotations['IsDepiction'] != 1, :] - # TODO MAKE IT MORE EFFICIENT - #filter by IsGroupOf filtered = [] for class_name, group_filter, percentage in zip(class_names, group_filters, percentages): sub = annotations.loc[annotations['ClassName'] == class_name, :] - if group_filter == "group": - sub = sub.loc[sub['IsGroupOf'] == 1, :] - elif group_filter == '~group': - sub = sub.loc[sub['IsGroupOf'] == 0, :] excluded_images |= set(sub['ImageID'].sample(frac=1 - percentage)) + + if group_filter == '~group': + excluded_images |= set(sub.loc[sub['IsGroupOf'] == 1, 'ImageID']) + elif group_filter == 'group': + excluded_images |= set(sub.loc[sub['IsGroupOf'] == 0, 'ImageID']) filtered.append(sub) annotations = pd.concat(filtered)