-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtrain_test_split.py
201 lines (186 loc) · 8.03 KB
/
train_test_split.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
#
# NOTE: Some of this code was adapted from tensorflow/examples/image_retraining/retrain.py
# As such, although the repo as a whole is MIT Licensed, _this particular script_
# is licensed under the Apache 2 license that TensorFlow uses.
# HOWEVER, I have extensively altered this script to remove all TensorFlow dependencies.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
Split an image directory into train, validation, and test sets.
Note that the split is done using hashing on the filenames so that subsequent runs with additional images in the input set will wind up hashing the resulting images into the same train/test/validation sets. The split is done per-class, so it is a stratified (balanced) split.
The output directory, once finished, will contain subdirectories named 'training', 'testing', and 'validation' - this allows the subsequent train_keras and score_keras scripts to assume inputs making their commands more streamlined.
'''
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import hashlib
import re
import glob
from shutil import copyfile
import logging
from pathlib import Path
logger = logging.getLogger('train_test_split')
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())
FLAGS = None
MAX_NUM_IMAGES_PER_CLASS = 2**27 - 1 # ~134M
def create_image_lists(image_dir, testing_percentage, validation_percentage):
"""Builds a list of training images from the file system.
Analyzes the sub folders in the image directory, splits them into stable
training, testing, and validation sets, and returns a data structure
describing the lists of images for each label and their paths.
Args:
image_dir: String path to a folder containing subfolders of images.
testing_percentage: Integer percentage of the images to reserve for tests.
validation_percentage: Integer percentage of images reserved for validation.
Returns:
A dictionary containing an entry for each label subfolder, with images split
into training, testing, and validation sets within each label.
"""
image_dir_path = Path(image_dir)
if not image_dir_path.exists:
logger.error(
"Image directory '" + image_dir_path.name + "' not found.")
return None
result = {}
sub_dirs = [x for x in image_dir_path.iterdir() if x.is_dir()]
# The root directory comes first, so skip it.
for sub_dir in sub_dirs:
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
file_list = []
dir_name = sub_dir.name
logger.info("Looking for images in '{}'".format(sub_dir))
for extension in extensions:
file_list.extend(sub_dir.glob('*.' + extension))
if not file_list:
logger.warning('No files found')
continue
if len(file_list) < 20:
logger.warning(
'WARNING: Folder has less than 20 images, which may cause issues. Skipping.'
)
continue
elif len(file_list) > FLAGS.max_per_file:
logger.warning(
'WARNING: Folder {} has more than {} images. Pruning.'.format(
dir_name, FLAGS.max_per_file))
file_list = file_list[:FLAGS.max_per_file]
label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower())
training_images = []
testing_images = []
validation_images = []
for file_name in file_list:
base_name = Path(file_name).name
# We want to ignore anything after '_nohash_' in the file name when
# deciding which set to put an image in, the data set creator has a way of
# grouping photos that are close variations of each other. For example
# this is used in the plant disease data set to group multiple pictures of
# the same leaf.
hash_name = re.sub(r'_nohash_.*$', '', str(file_name))
# This looks a bit magical, but we need to decide whether this file should
# go into the training, testing, or validation sets, and we want to keep
# existing files in the same set even if more files are subsequently
# added.
# To do that, we need a stable way of deciding based on just the file name
# itself, so we do a hash of that and then use that to generate a
# probability value that we use to assign it.
hash_name_hashed = hashlib.sha1(
hash_name.encode(errors='replace')).hexdigest()
percentage_hash = ((int(hash_name_hashed, 16) %
(MAX_NUM_IMAGES_PER_CLASS + 1)) *
(100.0 / MAX_NUM_IMAGES_PER_CLASS))
if percentage_hash < validation_percentage:
validation_images.append(base_name)
elif percentage_hash < (
testing_percentage + validation_percentage):
testing_images.append(base_name)
else:
training_images.append(base_name)
result[label_name] = {
'dir': dir_name,
'training': training_images,
'testing': testing_images,
'validation': validation_images,
}
return result
def divide_images():
img_dir = FLAGS.image_dir
testing_pct = FLAGS.pct_test
validation_pct = FLAGS.pct_validation
out_dir = Path(FLAGS.output_dir)
image_lists = create_image_lists(img_dir, testing_pct, validation_pct)
class_count = len(image_lists.keys())
if class_count == 0:
logger.error('No valid folders of images found at ' + FLAGS.image_dir)
return -1
if class_count == 1:
logger.error(
'Only one valid folder of images found at ' + FLAGS.image_dir +
' - multiple classes are needed for classification.')
return -1
out_dir.mkdir()
train_dir = out_dir / 'training'
test_dir = out_dir / 'testing'
val_dir = out_dir / 'validation'
train_dir.mkdir()
test_dir.mkdir()
val_dir.mkdir()
for cl in image_lists.keys():
td_cl = train_dir / cl
te_cl = test_dir / cl
v_cl = val_dir / cl
td_cl.mkdir()
te_cl.mkdir()
v_cl.mkdir()
indir = Path(img_dir) / image_lists[cl]['dir']
for img in image_lists[cl]['training']:
copyfile(indir / img, td_cl / img)
for img in image_lists[cl]['testing']:
copyfile(indir / img, te_cl / img)
for img in image_lists[cl]['validation']:
copyfile(indir / img, v_cl / img)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument(
'--image_dir',
type=str,
required=True,
help='Path to folders of labeled images.')
parser.add_argument(
'--output_dir',
type=str,
required=True,
help='Where to save the divided images.')
parser.add_argument(
'--pct_test',
type=int,
default=10,
help='Percentage of input set to use for training.')
parser.add_argument(
'--pct_validation',
type=int,
default=20,
help='Percentage of images to use in validation.')
parser.add_argument(
'--seed', type=float, default=1337, help='Random seed.')
parser.add_argument(
'--max_per_file',
type=int,
default=MAX_NUM_IMAGES_PER_CLASS,
help='Limit the maximum number of images in a given class')
FLAGS, _ = parser.parse_known_args()
divide_images()