-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpreprocessing.py
471 lines (424 loc) · 17.8 KB
/
preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
"""
Preprocess the input images
"""
import copy
import logging
import warnings
from pathlib import Path
from typing import Callable, Collection, Dict, List, Optional, Tuple
import numpy as np
import SimpleITK as sitk
import yaml
from tqdm import tqdm
from . import config as cfg
# configure logger
logger = logging.getLogger(__name__)
def load_image(path: Path):
if not path.exists():
raise FileNotFoundError(f"{path} not found")
return sitk.ReadImage(str(path))
def save_image(image: sitk.Image, path: Path):
sitk.WriteImage(image, str(path))
def combine_images(
images: List[sitk.Image],
labels: Optional[sitk.Image],
resample=True,
target_spacing=(1, 1, 3),
cut_to_overlap=True,
) -> Tuple[sitk.Image, Optional[sitk.Image]]:
"""Combine images by calculating the best overlap and the combining the images
into one composed image. The images will also be resampled if specified.
The first image will be used as reference image for the direction.
Parameters
----------
images : List[sitk.Image]
List of images
labels : Optional[sitk.Image]
Labels, which will be resampled to the same coordinates as the images
resample : bool, optional
if resampling should be done, by default True
target_spacing : tuple, optional
The spacing for the resampling, by default (1, 1, 3), if any values are None,
the original spacing is used
cut_to_overlap : bool, optional
If the sample should be cut to the overlapping region of all images, by default true
Returns
-------
Tuple[sitk.Image, Optional[sitk.Image]]
The resulting images, if no labels were given, the second image will be None.
"""
# take the first image as reference image
if cut_to_overlap:
overlap = get_overlap(images)
reference_image = images[0][overlap]
else:
reference_image = images[0]
if labels is not None:
num_labels = sitk.GetArrayFromImage(labels).astype(bool).sum()
if num_labels == 0:
logger.warning("No labels found in the image.")
# resample reference image to the new spacing if wanted
if resample:
assert target_spacing is not None
# calculate spacing
orig_size = np.array(reference_image.GetSize())
orig_spacing = np.array(reference_image.GetSpacing())
# see if any values in target spacing are none
if np.any([ts is None for ts in target_spacing]):
ts_list = []
for num, spc in enumerate(target_spacing):
if spc is None:
ts_list.append(orig_spacing[num])
else:
ts_list.append(spc)
target_spacing = tuple(ts_list)
# set new sizes
physical_size = orig_size * orig_spacing
new_size = (physical_size / target_spacing).astype(int)
new_physical_size = new_size * target_spacing
# and origin
orig_origin = np.array(reference_image.GetOrigin())
size_diff = physical_size - new_physical_size
direction = np.array(reference_image.GetDirection()).reshape((3, 3)).T
shift = np.dot((size_diff / 2), direction)
new_origin = orig_origin + shift
# see how much the number of labels is reduced
if labels is not None:
orig_label_spacing = np.array(labels.GetSpacing())
# see how much the number of voxels in the same are is reduced
reduction_factor = np.prod(target_spacing / orig_label_spacing)
resample_method = sitk.ResampleImageFilter()
resample_method.SetDefaultPixelValue(0)
resample_method.SetInterpolator(sitk.sitkBSplineResampler)
resample_method.SetOutputDirection(reference_image.GetDirection())
# for some reason, there is an error otherwise
resample_method.SetSize([int(n) for n in new_size])
resample_method.SetOutputOrigin(list(new_origin))
resample_method.SetOutputPixelType(sitk.sitkFloat32)
resample_method.SetOutputSpacing(target_spacing)
reference_image_resized = resample_method.Execute(reference_image)
if labels is not None:
# also resample labels, but change interpolator
resample_method.SetInterpolator(sitk.sitkNearestNeighbor)
resample_method.SetOutputPixelType(sitk.sitkUInt8)
labels_resampled = resample_method.Execute(labels)
else:
labels_resampled = None
else:
# just resample the labels
reference_image_resized = sitk.Cast(reference_image, sitk.sitkFloat32)
if labels is not None:
resample_method = sitk.ResampleImageFilter()
resample_method.SetDefaultPixelValue(0)
resample_method.SetInterpolator(sitk.sitkNearestNeighbor)
resample_method.SetOutputDirection(reference_image_resized.GetDirection())
# for some reason, there is an error otherwise
resample_method.SetSize(reference_image_resized.GetSize())
resample_method.SetOutputOrigin(reference_image_resized.GetOrigin())
resample_method.SetOutputPixelType(sitk.sitkUInt8)
resample_method.SetOutputSpacing(reference_image_resized.GetSpacing())
labels_resampled = resample_method.Execute(labels)
reduction_factor = 1
else:
labels_resampled = None
# sample all images to the reference image
resample_method = sitk.ResampleImageFilter()
resample_method.SetReferenceImage(reference_image_resized)
resample_method.SetOutputPixelType(sitk.sitkFloat32)
resample_method.SetInterpolator(sitk.sitkBSplineResampler)
images_resampled = [reference_image_resized] + [
resample_method.Execute(img) for img in images[1:]
]
image_combine = sitk.Compose(images_resampled)
# check labels
if labels is not None:
num_labels_res = sitk.GetArrayFromImage(labels_resampled).astype(bool).sum()
if num_labels != 0 and num_labels_res < 10:
warnings.warn("Labels were not in the overlapping region")
# account for the differences in the spacing when seeing if labels are missing
if num_labels_res < num_labels / reduction_factor * 0.85:
logger.warning("Less labelled voxels in the resampled image.")
return image_combine, labels_resampled
def generate_mask(image: sitk.Image, lower_quantile=0.25) -> sitk.Image:
"""Generate a binary mask by only considering pixels between the lower quantile
and the maximum value
Parameters
----------
image : sitk.Image
The image
lower_quantile : float, optional
The quantile to use for the low value, by default 0.25
Returns
-------
sitk.Image
The mask
"""
image_np = sitk.GetArrayFromImage(image)
upper = np.max(image_np)
lower = np.quantile(image_np, lower_quantile)
mask_image = sitk.BinaryThreshold(
image1=image,
lowerThreshold=lower + 1e-9, # add small value to exclude the limit
upperThreshold=upper + 1e-9, # for the higher values, include the limit
)
mask_closed = sitk.BinaryClosingByReconstruction(
image1=mask_image, kernelRadius=[5, 5, 5]
)
mask_opened = sitk.BinaryOpeningByReconstruction(
image1=mask_closed, kernelRadius=[2, 2, 2]
)
return mask_opened
def get_overlap(images: List[sitk.Image]) -> list:
"""Calculate the area of overlap between all three images
Parameters
----------
images : List[sitk.Image]
The images, the coordinate system of the first image will be used
Returns
-------
list
List with start and stop of valid indices for all images
"""
images_mask = [generate_mask(img) for img in images]
# resample to the first image
masks_resampled = [images_mask[0]]
for img in images_mask[1:]:
img_resampled = sitk.Resample(
image1=img,
referenceImage=images_mask[0],
transform=sitk.Euler3DTransform(),
interpolator=sitk.sitkNearestNeighbor,
)
masks_resampled.append(img_resampled)
# turn into numpy arrays
mask_np = np.array([sitk.GetArrayFromImage(img) for img in masks_resampled]).astype(
bool
)
all_mask = np.all(mask_np, axis=0)
valid_indices = []
# use opposite order as SimpleITK does
axes_names = ["z", "y", "x"]
for i in range(2, -1, -1):
avg_axes = tuple(set(range(3)) - set([i]))
dir_name = axes_names[i]
valid = all_mask.mean(axis=avg_axes) > 0.2
# get the first good value (start will include it)
start = np.argmax(valid)
# get the last good value (the end will exclude the limit, but size is off by 1)
end = valid.size - np.argmax(np.flip(valid))
valid_indices.append(slice(start, end))
frac_removed = (valid.size - (end - start)) / valid.size
if frac_removed > 0.5:
warnings.warn(
f"More that 50% ({frac_removed} %) removed in direction {dir_name}."
)
elif frac_removed > 0.3:
logger.warning(
"More than 30 %% (%i %%) removed in direction %s.",
int(frac_removed * 100),
dir_name,
)
return valid_indices
def preprocess_dataset(
data_set: dict,
num_channels: int,
base_dir: Path,
data_dir: Path,
preprocessed_dir: Path,
train_dataset: Collection,
preprocessing_parameters: dict,
pass_modality=False,
cut_to_overlap=True,
):
"""Preprocess the images by applying the normalization and then combining
them into one image.
Parameters
----------
data_set : dict
The images to process, should contain "images" key with a list of images
and can contain a "labels" key with the path to the labels file
num_channels : int
The number of channels (has to be the length of the "images" list)
base_dir : Path
The dir all other directories are relative to, usually the experiment dir
data_dir : Path
The path everything in the dataset is relative to
preprocessed_dir : Path
The directory to save the preprocessed data to (relative to base dir)
train_dataset : Collection
If the normalization is trained (like histogram matching), use these
images for training
preprocessing_parameters : dict
The parameters for the normalization, "normalizing_method" is NORMALIZING
enumerator to use, available classes are in Normalization, the dict
is passed to the class.
pass_modality : bool, optional
If the number of the modality should be passed to the normalization
initializer as mod_num, by default False
cut_to_overlap : bool, optional
If the sample should be cut to the overlapping region of all images, by default true
Returns
-------
dict
A new dictionary containing the processed images. All keys in the dict
for each individual patient are added to the new dict as well. The keys are
kept the same also.
"""
preprocessed_dir_abs = base_dir / preprocessed_dir
if not preprocessed_dir_abs.exists():
preprocessed_dir_abs.mkdir(parents=True)
# load normalization
normalization_method = preprocessing_parameters["normalizing_method"]
normalization_class = normalization_method.get_class()
# make lists to train the normalization
norm_train_set: List[List[Path]] = [[] for n in range(num_channels)]
for name in train_dataset:
images = data_set[name]["images"]
assert len(images) == num_channels, "Number of modalities inconsistent"
for num, img in enumerate(images):
norm_train_set[num].append(data_dir / img)
# train the normalization
normalizations = []
norm_params = preprocessing_parameters["normalization_parameters"]
for num in range(num_channels):
if isinstance(norm_params, dict):
norm_params_channel = norm_params
elif isinstance(norm_params, list):
norm_params_channel = norm_params[num]
else:
raise TypeError("normalization_parameters should be dict or list")
norm_file = preprocessed_dir_abs / f"normalization_mod{num}.yaml"
if not norm_file.parent.exists():
norm_file.parent.mkdir(parents=True)
if norm_file.exists():
norm = normalization_class.from_file(norm_file)
# make sure the parameters are correct
parameters_from_file = norm.get_parameters()
for key, value in parameters_from_file.items():
if not norm_params_channel.get(key) == value:
if not hasattr(parameters_from_file, key):
continue
raise ValueError(
f"Normalization of preprocessed images has different parameters for {key}."
)
else:
if pass_modality:
norm = normalization_class(mod_num=num, **norm_params_channel)
else:
norm = normalization_class(**norm_params_channel)
norm.train_normalization(norm_train_set[num])
# and save it
norm.to_file(norm_file)
normalizations.append(norm)
# remember preprocessed images
preprocessed_dict = {}
# resample and apply normalization
for name, data in tqdm(data_set.items(), unit="image", desc="preprocess images"):
# define paths
image_paths = [data_dir / img for img in data["images"]]
image_rel_path = preprocessed_dir / str(
cfg.sample_file_name_prefix + name + cfg.file_suffix
)
image_processed_path = base_dir / image_rel_path
labels_exist = True
if "labels" in data and data.get("labels", None) is not None:
labels_path = data_dir / data["labels"]
label_rel_path = preprocessed_dir / str(
cfg.label_file_name_prefix + name + cfg.file_suffix
)
labels_processed_path = base_dir / label_rel_path
else:
labels_path = None
labels_processed_path = None
labels_exist = False
labels_path = None
labels_processed_path = None
# preprocess images
preprocess_image(
image_paths=image_paths,
image_processed_path=image_processed_path,
labels_path=labels_path,
labels_processed_path=labels_processed_path,
normalizations=normalizations,
preprocessing_parameters=preprocessing_parameters,
cut_to_overlap=cut_to_overlap,
)
if image_processed_path.exists():
preprocessed_dict[str(name)] = {
k: v for k, v in data.items() if k not in ("images", "labels")
}
preprocessed_dict[str(name)]["image"] = image_rel_path
else:
raise FileNotFoundError(f"{image_processed_path} not found after preprocessing")
if labels_exist:
if labels_processed_path.exists():
preprocessed_dict[str(name)]["labels"] = label_rel_path
else:
raise FileNotFoundError(
f"{labels_processed_path} not found after preprocessing"
)
else:
# set labels to None if it previously was None
if "labels" in data:
preprocessed_dict[str(name)]["labels"] = None
# add additional keys
additional_keys = set(data.keys()) - set(("labels", "images"))
for key in additional_keys:
preprocessed_dict[str(name)][key] = copy.deepcopy(data[key])
logger.info("Preprocessing finished.")
dataset_file = preprocessed_dir_abs / "preprocessing_dataset.yaml"
with open(dataset_file, "w", encoding="utf8") as f:
yaml.dump(preprocessed_dict, f, sort_keys=False)
return preprocessed_dict
def preprocess_image(
image_paths: List[Path],
image_processed_path: Path,
labels_path: Optional[Path],
labels_processed_path: Optional[Path],
normalizations: List[Callable],
preprocessing_parameters: Dict,
cut_to_overlap=True,
):
"""Preprocess a single image, it will be normalized and then all images
will be combined into one
Parameters
----------
name : str
The name (used in the filename)
image_paths : List[Path]
The path of the image
labels_path : Optional[Path]
The path of the labels (or None, then no labels are exported)
normalizations : List
A list of the normalizations used to process each image, should be the
same length as the images
cut_to_overlap : bool, optional
If the sample should be cut to the overlapping region of all images, by default true
"""
# preprocess if it does not exist yet
if labels_path is None:
already_processed = image_processed_path.exists()
else:
already_processed = image_processed_path.exists() and labels_processed_path.exists()
if not already_processed:
# load and normalize images
images = [load_image(img) for img in image_paths]
images_normalized = [norm(img) for img, norm in zip(images, normalizations)]
if labels_path is not None:
labels = load_image(labels_path)
else:
labels = None
logger.info("Processing image %s", image_processed_path.name)
res_image, res_labels = combine_images(
images=images_normalized,
labels=labels,
resample=preprocessing_parameters["resample"],
target_spacing=preprocessing_parameters["target_spacing"],
cut_to_overlap=cut_to_overlap,
)
save_image(res_image, image_processed_path)
if labels is not None:
assert (
labels_processed_path is not None
), "if there are labels, also provide a path."
save_image(res_labels, labels_processed_path)