Skip to content

Commit

Permalink
Merge pull request #489 from dPys/enh/nibabel_splitmerge_interfaces
Browse files Browse the repository at this point in the history
ENH: Add nibabel-based split and merge interfaces
  • Loading branch information
oesteban authored Apr 16, 2020
2 parents 602a7ca + d657546 commit 7b28b4f
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 37 deletions.
135 changes: 110 additions & 25 deletions niworkflows/interfaces/nibabel.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,34 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""Nibabel-based interfaces."""
from pathlib import Path
import numpy as np
import nibabel as nb
from nipype import logging
from nipype.utils.filemanip import fname_presuffix
from nipype.interfaces.base import (
traits, TraitedSpec, BaseInterfaceInputSpec, File,
SimpleInterface
traits,
TraitedSpec,
BaseInterfaceInputSpec,
File,
SimpleInterface,
OutputMultiObject,
InputMultiObject,
)

IFLOGGER = logging.getLogger('nipype.interface')
IFLOGGER = logging.getLogger("nipype.interface")


class _ApplyMaskInputSpec(BaseInterfaceInputSpec):
in_file = File(exists=True, mandatory=True, desc='an image')
in_mask = File(exists=True, mandatory=True, desc='a mask')
threshold = traits.Float(0.5, usedefault=True,
desc='a threshold to the mask, if it is nonbinary')
in_file = File(exists=True, mandatory=True, desc="an image")
in_mask = File(exists=True, mandatory=True, desc="a mask")
threshold = traits.Float(
0.5, usedefault=True, desc="a threshold to the mask, if it is nonbinary"
)


class _ApplyMaskOutputSpec(TraitedSpec):
out_file = File(exists=True, desc='masked file')
out_file = File(exists=True, desc="masked file")


class ApplyMask(SimpleInterface):
Expand All @@ -35,8 +42,9 @@ def _run_interface(self, runtime):
msknii = nb.load(self.inputs.in_mask)
msk = msknii.get_fdata() > self.inputs.threshold

self._results['out_file'] = fname_presuffix(
self.inputs.in_file, suffix='_masked', newpath=runtime.cwd)
self._results["out_file"] = fname_presuffix(
self.inputs.in_file, suffix="_masked", newpath=runtime.cwd
)

if img.dataobj.shape[:3] != msk.shape:
raise ValueError("Image and mask sizes do not match.")
Expand All @@ -48,19 +56,18 @@ def _run_interface(self, runtime):
msk = msk[..., np.newaxis]

masked = img.__class__(img.dataobj * msk, None, img.header)
masked.to_filename(self._results['out_file'])
masked.to_filename(self._results["out_file"])
return runtime


class _BinarizeInputSpec(BaseInterfaceInputSpec):
in_file = File(exists=True, mandatory=True, desc='input image')
thresh_low = traits.Float(mandatory=True,
desc='non-inclusive lower threshold')
in_file = File(exists=True, mandatory=True, desc="input image")
thresh_low = traits.Float(mandatory=True, desc="non-inclusive lower threshold")


class _BinarizeOutputSpec(TraitedSpec):
out_file = File(exists=True, desc='masked file')
out_mask = File(exists=True, desc='output mask')
out_file = File(exists=True, desc="masked file")
out_mask = File(exists=True, desc="output mask")


class Binarize(SimpleInterface):
Expand All @@ -72,20 +79,98 @@ class Binarize(SimpleInterface):
def _run_interface(self, runtime):
img = nb.load(self.inputs.in_file)

self._results['out_file'] = fname_presuffix(
self.inputs.in_file, suffix='_masked', newpath=runtime.cwd)
self._results['out_mask'] = fname_presuffix(
self.inputs.in_file, suffix='_mask', newpath=runtime.cwd)
self._results["out_file"] = fname_presuffix(
self.inputs.in_file, suffix="_masked", newpath=runtime.cwd
)
self._results["out_mask"] = fname_presuffix(
self.inputs.in_file, suffix="_mask", newpath=runtime.cwd
)

data = img.get_fdata()
mask = data > self.inputs.thresh_low
data[~mask] = 0.0
masked = img.__class__(data, img.affine, img.header)
masked.to_filename(self._results['out_file'])
masked.to_filename(self._results["out_file"])

img.header.set_data_dtype('uint8')
maskimg = img.__class__(mask.astype('uint8'), img.affine,
img.header)
maskimg.to_filename(self._results['out_mask'])
img.header.set_data_dtype("uint8")
maskimg = img.__class__(mask.astype("uint8"), img.affine, img.header)
maskimg.to_filename(self._results["out_mask"])

return runtime


class _SplitSeriesInputSpec(BaseInterfaceInputSpec):
in_file = File(exists=True, mandatory=True, desc="input 4d image")


class _SplitSeriesOutputSpec(TraitedSpec):
out_files = OutputMultiObject(File(exists=True), desc="output list of 3d images")


class SplitSeries(SimpleInterface):
"""Split a 4D dataset along the last dimension into a series of 3D volumes."""

input_spec = _SplitSeriesInputSpec
output_spec = _SplitSeriesOutputSpec

def _run_interface(self, runtime):
in_file = self.inputs.in_file
img = nb.load(in_file)
extra_dims = tuple(dim for dim in img.shape[3:] if dim > 1) or (1,)
if len(extra_dims) != 1:
raise ValueError(f"Invalid shape {'x'.join(str(s) for s in img.shape)}")
img = img.__class__(img.dataobj.reshape(img.shape[:3] + extra_dims),
img.affine, img.header)

self._results["out_files"] = []
for i, img_3d in enumerate(nb.four_to_three(img)):
out_file = str(
Path(fname_presuffix(in_file, suffix=f"_idx-{i:03}")).absolute()
)
img_3d.to_filename(out_file)
self._results["out_files"].append(out_file)

return runtime


class _MergeSeriesInputSpec(BaseInterfaceInputSpec):
in_files = InputMultiObject(
File(exists=True, mandatory=True, desc="input list of 3d images")
)
allow_4D = traits.Bool(
True, usedefault=True, desc="whether 4D images are allowed to be concatenated"
)


class _MergeSeriesOutputSpec(TraitedSpec):
out_file = File(exists=True, desc="output 4d image")


class MergeSeries(SimpleInterface):
"""Merge a series of 3D volumes along the last dimension into a single 4D image."""

input_spec = _MergeSeriesInputSpec
output_spec = _MergeSeriesOutputSpec

def _run_interface(self, runtime):
nii_list = []
for f in self.inputs.in_files:
filenii = nb.squeeze_image(nb.load(f))
ndim = filenii.dataobj.ndim
if ndim == 3:
nii_list.append(filenii)
continue
elif self.inputs.allow_4D and ndim == 4:
nii_list += nb.four_to_three(filenii)
continue
else:
raise ValueError(
"Input image has an incorrect number of dimensions" f" ({ndim})."
)

img_4d = nb.concat_images(nii_list)
out_file = fname_presuffix(self.inputs.in_files[0], suffix="_merged")
img_4d.to_filename(out_file)

self._results["out_file"] = out_file
return runtime
83 changes: 71 additions & 12 deletions niworkflows/interfaces/tests/test_nibabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import nibabel as nb
import pytest

from ..nibabel import Binarize, ApplyMask
from ..nibabel import Binarize, ApplyMask, SplitSeries, MergeSeries


def test_Binarize(tmp_path):
Expand All @@ -14,10 +14,10 @@ def test_Binarize(tmp_path):
mask = np.zeros((20, 20, 20), dtype=bool)
mask[5:15, 5:15, 5:15] = bool

data = np.zeros_like(mask, dtype='float32')
data = np.zeros_like(mask, dtype="float32")
data[mask] = np.random.gamma(2, size=mask.sum())

in_file = tmp_path / 'input.nii.gz'
in_file = tmp_path / "input.nii.gz"
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))

binif = Binarize(thresh_low=0.0, in_file=str(in_file)).run()
Expand All @@ -36,28 +36,32 @@ def test_ApplyMask(tmp_path):
mask[8:11, 8:11, 8:11] = 1.0

# Test the 3D
in_file = tmp_path / 'input3D.nii.gz'
in_file = tmp_path / "input3D.nii.gz"
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))

in_mask = tmp_path / 'mask.nii.gz'
in_mask = tmp_path / "mask.nii.gz"
nb.Nifti1Image(mask, np.eye(4), None).to_filename(str(in_mask))

masked1 = ApplyMask(in_file=str(in_file), in_mask=str(in_mask), threshold=0.4).run()
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 5**3
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 5 ** 3

masked1 = ApplyMask(in_file=str(in_file), in_mask=str(in_mask), threshold=0.6).run()
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 3**3
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 3 ** 3

data4d = np.stack((data, 2 * data, 3 * data), axis=-1)
# Test the 4D case
in_file4d = tmp_path / 'input4D.nii.gz'
in_file4d = tmp_path / "input4D.nii.gz"
nb.Nifti1Image(data4d, np.eye(4), None).to_filename(str(in_file4d))

masked1 = ApplyMask(in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.4).run()
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 5**3 * 6
masked1 = ApplyMask(
in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.4
).run()
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 5 ** 3 * 6

masked1 = ApplyMask(in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.6).run()
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 3**3 * 6
masked1 = ApplyMask(
in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.6
).run()
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 3 ** 3 * 6

# Test errors
nb.Nifti1Image(mask, 2 * np.eye(4), None).to_filename(str(in_mask))
Expand All @@ -69,3 +73,58 @@ def test_ApplyMask(tmp_path):
ApplyMask(in_file=str(in_file), in_mask=str(in_mask), threshold=0.4).run()
with pytest.raises(ValueError):
ApplyMask(in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.4).run()


@pytest.mark.parametrize("shape,exp_n", [
((20, 20, 20, 15), 15),
((20, 20, 20), 1),
((20, 20, 20, 1), 1),
((20, 20, 20, 1, 3), 3),
((20, 20, 20, 3, 1), 3),
((20, 20, 20, 1, 3, 3), -1),
((20, 1, 20, 15), 15),
((20, 1, 20), 1),
((20, 1, 20, 1), 1),
((20, 1, 20, 1, 3), 3),
((20, 1, 20, 3, 1), 3),
((20, 1, 20, 1, 3, 3), -1),
])
def test_SplitSeries(tmp_path, shape, exp_n):
"""Test 4-to-3 NIfTI split interface."""
os.chdir(tmp_path)

in_file = str(tmp_path / "input.nii.gz")
nb.Nifti1Image(np.ones(shape, dtype=float), np.eye(4), None).to_filename(in_file)

_interface = SplitSeries(in_file=in_file)
if exp_n > 0:
split = _interface.run()
n = int(isinstance(split.outputs.out_files, str)) or len(split.outputs.out_files)
assert n == exp_n
else:
with pytest.raises(ValueError):
_interface.run()


def test_MergeSeries(tmp_path):
"""Test 3-to-4 NIfTI concatenation interface."""
os.chdir(str(tmp_path))

in_file = tmp_path / "input3D.nii.gz"
nb.Nifti1Image(np.ones((20, 20, 20), dtype=float), np.eye(4), None).to_filename(
str(in_file)
)

merge = MergeSeries(in_files=[str(in_file)] * 5).run()
assert nb.load(merge.outputs.out_file).dataobj.shape == (20, 20, 20, 5)

in_4D = tmp_path / "input4D.nii.gz"
nb.Nifti1Image(np.ones((20, 20, 20, 4), dtype=float), np.eye(4), None).to_filename(
str(in_4D)
)

merge = MergeSeries(in_files=[str(in_file)] + [str(in_4D)]).run()
assert nb.load(merge.outputs.out_file).dataobj.shape == (20, 20, 20, 5)

with pytest.raises(ValueError):
MergeSeries(in_files=[str(in_file)] + [str(in_4D)], allow_4D=False).run()

0 comments on commit 7b28b4f

Please sign in to comment.