-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpreprocessing.py
132 lines (106 loc) · 4.48 KB
/
preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import glob
import utils
import cv2
import numpy as np
from tqdm import tqdm
import tensorflow as tf
ALL_CLASSES = ['NORMAL', 'CNV', 'DME', 'DRUSEN']
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def encode_label(label):
if label == 'NORMAL':
return 0
if label == 'CNV':
return 1
if label == 'DME':
return 2
if label == 'DRUSEN':
return 3
def _process_examples(example_data, filename: str, channels=3, pre_augm=True):
"""
:param example_data: takes the list of dictionaries and transform them into Tf records, this is an special format
of tensorflow data that makes your life easier in tf 1.x and 2.0 saving the data and load it in our training loop
(WARNING: You have to take care of the encoding of features to not have problems when loading the data, this means
taking into consideration that images are int or float)
:param filename: output filename
:param channels: number of channels of the image (RGB=3), grayscale=!
:return: None
"""
with tf.io.TFRecordWriter(filename) as writer:
for i, ex in enumerate(example_data):
# define pre augmentation of pre image resizing
# image = pre_augmentation(ex['image']) if pre_augm else ex['image']
image = ex['image'].astype(np.float32)
image = image.tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(ex['image'].shape[0]),
'width': _int64_feature(ex['image'].shape[1]),
'depth': _int64_feature(channels),
'image': _bytes_feature(image),
'label': _int64_feature(encode_label(ex['label']))
}))
writer.write(example.SerializeToString())
return None
def resize_image(img, size=(128, 128)):
h, w = img.shape[:2]
c = img.shape[2] if len(img.shape) > 2 else 1
if h == w:
return cv2.resize(img, size, cv2.INTER_AREA)
dif = h if h > w else w
interpolation = cv2.INTER_AREA if dif > (size[0]+size[1])//2 else cv2.INTER_CUBIC
x_pos = (dif - w)//2
y_pos = (dif - h)//2
if len(img.shape) == 2:
mask = np.zeros((dif, dif), dtype=img.dtype)
mask[y_pos:y_pos+h, x_pos:x_pos+w] = img[:h, :w]
else:
mask = np.zeros((dif, dif, c), dtype=img.dtype)
mask[y_pos:y_pos+h, x_pos:x_pos+w, :] = img[:h, :w, :]
return cv2.resize(mask, size, interpolation)
def shard_dataset(dataset, num_records=50):
chunk = len(dataset) // num_records
parts = [(k * chunk) for k in range(len(dataset)) if (k * chunk) < len(dataset)]
return chunk, parts
class Preprocessing(object):
def __init__(self, data_path):
self.data_path = data_path
utils.mdir(os.path.join(data_path, 'preprocessing'))
def load_images(self, filename_shard, type, label):
data = []
for fn in filename_shard:
img = utils.imread(fn)
# img = resize_image(img, size=(136, 136))
img = cv2.resize(img, (136, 136), interpolation=cv2.INTER_CUBIC)
assert img.shape == (136, 136, 3)
meta = {
'image': img,
'filename': fn,
'label': label,
'dataset': type
}
data.append(meta)
return data
def write_data(self, filenames, type, label):
chunk, parts = shard_dataset(filenames)
for i, j in enumerate(tqdm(parts)):
shard = filenames[j:(j + chunk)]
shard_data = self.load_images(shard, type, label)
fn = '{}_{}-{}_{:03d}-{:03d}.tfrecord'.format(type, label, 'OCT', i + 1, len(parts))
_process_examples(shard_data, os.path.join(self.data_path, 'preprocessing', fn))
return None
def create_data(self, type='train'):
data_path = os.path.join(self.data_path, type)
for cl in ALL_CLASSES:
sub_dir = os.path.join(data_path, cl)
data_fns = glob.glob('{}/*'.format(sub_dir))
self.write_data(data_fns, type, label=cl)
def generate_example_sets(self):
for type in ['train', 'test']:
self.create_data(type)
if __name__ == '__main__':
prep = Preprocessing(data_path='/media/miguel/ALICIUM/Miguel/DOWNLOADS/ZhangLabData/CellData/OCT')
prep.generate_example_sets()
pass