Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add nibabel-based split and merge interfaces #489

Merged
merged 22 commits into from
Apr 16, 2020
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
0de7374
[ENH] Add nibabel-based split and merge interfaces per https://github…
Mar 24, 2020
bfcd29c
Update niworkflows/interfaces/nibabel.py
dPys Mar 24, 2020
30ddf03
Update niworkflows/interfaces/nibabel.py
dPys Mar 24, 2020
5e3c93c
Update niworkflows/interfaces/nibabel.py
dPys Mar 24, 2020
40f6de1
Update niworkflows/interfaces/nibabel.py
dPys Mar 24, 2020
f3572e6
Chage naming to ConcatImages, fix in/out spec namings
Mar 24, 2020
49ff6bd
Update niworkflows/interfaces/nibabel.py
dPys Mar 31, 2020
87f1cb3
Update niworkflows/interfaces/nibabel.py
dPys Mar 31, 2020
b4fcdb7
rename ConcatImages to MergeSeries, correct typo in description of ou…
Mar 31, 2020
f51f510
Update niworkflows/interfaces/nibabel.py
dPys Apr 5, 2020
9aa0655
Update niworkflows/interfaces/nibabel.py
dPys Apr 5, 2020
5a31b01
Update niworkflows/interfaces/nibabel.py
dPys Apr 5, 2020
53db81a
Update niworkflows/interfaces/nibabel.py
dPys Apr 5, 2020
1c27f20
Update niworkflows/interfaces/nibabel.py
dPys Apr 5, 2020
05724d5
Update niworkflows/interfaces/nibabel.py
dPys Apr 5, 2020
7263d03
Apply suggestions from code review [skip ci]
oesteban Apr 8, 2020
03ebb6d
fix: added few bugfixes and regression tests
oesteban Apr 8, 2020
9446d47
fix: squeeze image with np.squeeze / change input name for consistency
oesteban Apr 8, 2020
0ae5351
sty(black): standardize formatting a bit
oesteban Apr 8, 2020
fc42351
enh: make i/o specs of SplitSeries more consistent [skip ci]
oesteban Apr 8, 2020
1d12dd3
Update niworkflows/interfaces/tests/test_nibabel.py [skip ci]
oesteban Apr 9, 2020
d657546
fix: apply review comments from @effigies, add parameterized tests
oesteban Apr 10, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 127 additions & 25 deletions niworkflows/interfaces/nibabel.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,34 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""Nibabel-based interfaces."""
from pathlib import Path
import numpy as np
import nibabel as nb
from nipype import logging
from nipype.utils.filemanip import fname_presuffix
from nipype.interfaces.base import (
traits, TraitedSpec, BaseInterfaceInputSpec, File,
SimpleInterface
traits,
TraitedSpec,
BaseInterfaceInputSpec,
File,
SimpleInterface,
OutputMultiObject,
InputMultiObject,
)

IFLOGGER = logging.getLogger('nipype.interface')
IFLOGGER = logging.getLogger("nipype.interface")


class _ApplyMaskInputSpec(BaseInterfaceInputSpec):
in_file = File(exists=True, mandatory=True, desc='an image')
in_mask = File(exists=True, mandatory=True, desc='a mask')
threshold = traits.Float(0.5, usedefault=True,
desc='a threshold to the mask, if it is nonbinary')
in_file = File(exists=True, mandatory=True, desc="an image")
in_mask = File(exists=True, mandatory=True, desc="a mask")
threshold = traits.Float(
0.5, usedefault=True, desc="a threshold to the mask, if it is nonbinary"
)


class _ApplyMaskOutputSpec(TraitedSpec):
out_file = File(exists=True, desc='masked file')
out_file = File(exists=True, desc="masked file")


class ApplyMask(SimpleInterface):
Expand All @@ -35,8 +42,9 @@ def _run_interface(self, runtime):
msknii = nb.load(self.inputs.in_mask)
msk = msknii.get_fdata() > self.inputs.threshold

self._results['out_file'] = fname_presuffix(
self.inputs.in_file, suffix='_masked', newpath=runtime.cwd)
self._results["out_file"] = fname_presuffix(
self.inputs.in_file, suffix="_masked", newpath=runtime.cwd
)

if img.dataobj.shape[:3] != msk.shape:
raise ValueError("Image and mask sizes do not match.")
Expand All @@ -48,19 +56,18 @@ def _run_interface(self, runtime):
msk = msk[..., np.newaxis]

masked = img.__class__(img.dataobj * msk, None, img.header)
masked.to_filename(self._results['out_file'])
masked.to_filename(self._results["out_file"])
return runtime


class _BinarizeInputSpec(BaseInterfaceInputSpec):
in_file = File(exists=True, mandatory=True, desc='input image')
thresh_low = traits.Float(mandatory=True,
desc='non-inclusive lower threshold')
in_file = File(exists=True, mandatory=True, desc="input image")
thresh_low = traits.Float(mandatory=True, desc="non-inclusive lower threshold")


class _BinarizeOutputSpec(TraitedSpec):
out_file = File(exists=True, desc='masked file')
out_mask = File(exists=True, desc='output mask')
out_file = File(exists=True, desc="masked file")
out_mask = File(exists=True, desc="output mask")


class Binarize(SimpleInterface):
Expand All @@ -72,20 +79,115 @@ class Binarize(SimpleInterface):
def _run_interface(self, runtime):
img = nb.load(self.inputs.in_file)

self._results['out_file'] = fname_presuffix(
self.inputs.in_file, suffix='_masked', newpath=runtime.cwd)
self._results['out_mask'] = fname_presuffix(
self.inputs.in_file, suffix='_mask', newpath=runtime.cwd)
self._results["out_file"] = fname_presuffix(
self.inputs.in_file, suffix="_masked", newpath=runtime.cwd
)
self._results["out_mask"] = fname_presuffix(
self.inputs.in_file, suffix="_mask", newpath=runtime.cwd
)

data = img.get_fdata()
mask = data > self.inputs.thresh_low
data[~mask] = 0.0
masked = img.__class__(data, img.affine, img.header)
masked.to_filename(self._results['out_file'])
masked.to_filename(self._results["out_file"])

img.header.set_data_dtype('uint8')
maskimg = img.__class__(mask.astype('uint8'), img.affine,
img.header)
maskimg.to_filename(self._results['out_mask'])
img.header.set_data_dtype("uint8")
maskimg = img.__class__(mask.astype("uint8"), img.affine, img.header)
maskimg.to_filename(self._results["out_mask"])

return runtime


class _SplitSeriesInputSpec(BaseInterfaceInputSpec):
in_file = File(exists=True, mandatory=True, desc="input 4d image")
allow_3D = traits.Bool(
False, usedefault=True, desc="do not fail if a 3D volume is passed in"
)


class _SplitSeriesOutputSpec(TraitedSpec):
out_files = OutputMultiObject(File(exists=True), desc="output list of 3d images")


class SplitSeries(SimpleInterface):
"""Split a 4D dataset along the last dimension into a series of 3D volumes."""

input_spec = _SplitSeriesInputSpec
output_spec = _SplitSeriesOutputSpec

def _run_interface(self, runtime):
filenii = nb.squeeze_image(nb.load(self.inputs.in_file))
filenii = filenii.__class__(
np.squeeze(filenii.dataobj), filenii.affine, filenii.header
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason that non-terminal 1-dimensions are left in place by squeeze_image is to preserve the meaning of each dimension. For example, if I have a time series of a single slice (64, 64, 1, 48), squeeze_image will preserve the meaning of i, j, k, n, but np.squeeze will recast n as k.

What's the use case that you're taking care of here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the use case I'm contemplating is separating out the three components of deformation fields and other model-based nonlinear transforms. There is one example of this in the tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found it. I guess I would think splitting along the fifth dimension is a different task from splitting along the fourth, but I see that it's convenient to have a single interface. I'm not sure that risking dropping meaningful dimensions is a good idea.

What about forcing to 4D like:

extra_dims = tuple(dim for dim in img.shape[3:] if dim > 1) or (1,)
if len(extra_dims) != 1:
    raise ValueError("Invalid shape")
img = img.__class__(img.dataobj.reshape(img.shape[:3] + extra_dims),
                    img.affine, img.header)

This coerces a 3D image to (x, y, z, 1) and a 4+D image to (x, y, z, n) assuming that dimensions 4-7 are all 1, n or absent.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it guaranteed that the spatial dimensions will not be affected on that reshape? If so, I'm down with this solution.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trivial dimensions have no effect on indexing, whether it's C- or Fortran ordered. If you don't change the order, it's fine.

ndim = filenii.dataobj.ndim
if ndim != 4:
if self.inputs.allow_3D and ndim == 3:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is odd, as the above will coerce a valid 4D (x, y, z, 1) image to 3D (x, y, z), requiring you to then allow_3D.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is that odd? It is indeed a 3D volume, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for instance, there are a fair number of T1w images in OpenNeuro with (x, y, z, 1) dimensions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a 4D series. You should be able to split it into one 3D volume without special casing it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the above snippet you wrote would make this particular use-case a standard one?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say checking for 3D volumes should happen before reshaping the image.

out_file = str(
Path(
fname_presuffix(self.inputs.in_file, suffix=f"_idx-000")
).absolute()
)
self._results["out_files"] = out_file
filenii.to_filename(out_file)
return runtime
raise RuntimeError(
f"Input image <{self.inputs.in_file}> is {ndim}D "
f"({'x'.join(['%d' % s for s in filenii.shape])})."
)

files_3d = nb.four_to_three(filenii)
self._results["out_files"] = []
in_file = self.inputs.in_file
for i, file_3d in enumerate(files_3d):
out_file = str(
Path(fname_presuffix(in_file, suffix=f"_idx-{i:03}")).absolute()
)
file_3d.to_filename(out_file)
self._results["out_files"].append(out_file)

return runtime


class _MergeSeriesInputSpec(BaseInterfaceInputSpec):
in_files = InputMultiObject(
File(exists=True, mandatory=True, desc="input list of 3d images")
)
allow_4D = traits.Bool(
True, usedefault=True, desc="whether 4D images are allowed to be concatenated"
)


class _MergeSeriesOutputSpec(TraitedSpec):
out_file = File(exists=True, desc="output 4d image")


class MergeSeries(SimpleInterface):
"""Merge a series of 3D volumes along the last dimension into a single 4D image."""

input_spec = _MergeSeriesInputSpec
output_spec = _MergeSeriesOutputSpec

def _run_interface(self, runtime):
nii_list = []
for f in self.inputs.in_files:
filenii = nb.squeeze_image(nb.load(f))
ndim = filenii.dataobj.ndim
if ndim == 3:
nii_list.append(filenii)
continue
elif self.inputs.allow_4D and ndim == 4:
nii_list += nb.four_to_three(filenii)
continue
else:
raise ValueError(
"Input image has an incorrect number of dimensions" f" ({ndim})."
)

img_4d = nb.concat_images(nii_list)
out_file = fname_presuffix(self.inputs.in_files[0], suffix="_merged")
img_4d.to_filename(out_file)

self._results["out_file"] = out_file
return runtime
109 changes: 97 additions & 12 deletions niworkflows/interfaces/tests/test_nibabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import nibabel as nb
import pytest

from ..nibabel import Binarize, ApplyMask
from ..nibabel import Binarize, ApplyMask, SplitSeries, MergeSeries


def test_Binarize(tmp_path):
Expand All @@ -14,10 +14,10 @@ def test_Binarize(tmp_path):
mask = np.zeros((20, 20, 20), dtype=bool)
mask[5:15, 5:15, 5:15] = bool

data = np.zeros_like(mask, dtype='float32')
data = np.zeros_like(mask, dtype="float32")
data[mask] = np.random.gamma(2, size=mask.sum())

in_file = tmp_path / 'input.nii.gz'
in_file = tmp_path / "input.nii.gz"
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))

binif = Binarize(thresh_low=0.0, in_file=str(in_file)).run()
Expand All @@ -36,28 +36,32 @@ def test_ApplyMask(tmp_path):
mask[8:11, 8:11, 8:11] = 1.0

# Test the 3D
in_file = tmp_path / 'input3D.nii.gz'
in_file = tmp_path / "input3D.nii.gz"
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))

in_mask = tmp_path / 'mask.nii.gz'
in_mask = tmp_path / "mask.nii.gz"
nb.Nifti1Image(mask, np.eye(4), None).to_filename(str(in_mask))

masked1 = ApplyMask(in_file=str(in_file), in_mask=str(in_mask), threshold=0.4).run()
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 5**3
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 5 ** 3

masked1 = ApplyMask(in_file=str(in_file), in_mask=str(in_mask), threshold=0.6).run()
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 3**3
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 3 ** 3

data4d = np.stack((data, 2 * data, 3 * data), axis=-1)
# Test the 4D case
in_file4d = tmp_path / 'input4D.nii.gz'
in_file4d = tmp_path / "input4D.nii.gz"
nb.Nifti1Image(data4d, np.eye(4), None).to_filename(str(in_file4d))

masked1 = ApplyMask(in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.4).run()
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 5**3 * 6
masked1 = ApplyMask(
in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.4
).run()
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 5 ** 3 * 6

masked1 = ApplyMask(in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.6).run()
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 3**3 * 6
masked1 = ApplyMask(
in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.6
).run()
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 3 ** 3 * 6

# Test errors
nb.Nifti1Image(mask, 2 * np.eye(4), None).to_filename(str(in_mask))
Expand All @@ -69,3 +73,84 @@ def test_ApplyMask(tmp_path):
ApplyMask(in_file=str(in_file), in_mask=str(in_mask), threshold=0.4).run()
with pytest.raises(ValueError):
ApplyMask(in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.4).run()


def test_SplitSeries(tmp_path):
"""Test 4-to-3 NIfTI split interface."""
os.chdir(tmp_path)

# Test the 4D
data = np.ones((20, 20, 20, 15), dtype=float)
in_file = tmp_path / "input4D.nii.gz"
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))

split = SplitSeries(in_file=str(in_file)).run()
assert len(split.outputs.out_files) == 15

# Test the 3D
in_file = tmp_path / "input3D.nii.gz"
nb.Nifti1Image(np.ones((20, 20, 20), dtype=float), np.eye(4), None).to_filename(
str(in_file)
)

with pytest.raises(RuntimeError):
SplitSeries(in_file=str(in_file)).run()

split = SplitSeries(in_file=str(in_file), allow_3D=True).run()
assert isinstance(split.outputs.out_files, str)

# Test the 3D
in_file = tmp_path / "input3D.nii.gz"
nb.Nifti1Image(np.ones((20, 20, 20, 1), dtype=float), np.eye(4), None).to_filename(
str(in_file)
)

with pytest.raises(RuntimeError):
SplitSeries(in_file=str(in_file)).run()

split = SplitSeries(in_file=str(in_file), allow_3D=True).run()
assert isinstance(split.outputs.out_files, str)

# Test the 5D
in_file = tmp_path / "input5D.nii.gz"
nb.Nifti1Image(
np.ones((20, 20, 20, 2, 2), dtype=float), np.eye(4), None
).to_filename(str(in_file))

with pytest.raises(RuntimeError):
SplitSeries(in_file=str(in_file)).run()

with pytest.raises(RuntimeError):
SplitSeries(in_file=str(in_file), allow_3D=True).run()

# Test splitting ANTs warpfields
data = np.ones((20, 20, 20, 1, 3), dtype=float)
in_file = tmp_path / "warpfield.nii.gz"
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))

split = SplitSeries(in_file=str(in_file)).run()
assert len(split.outputs.out_files) == 3


def test_MergeSeries(tmp_path):
"""Test 3-to-4 NIfTI concatenation interface."""
os.chdir(str(tmp_path))

in_file = tmp_path / "input3D.nii.gz"
nb.Nifti1Image(np.ones((20, 20, 20), dtype=float), np.eye(4), None).to_filename(
str(in_file)
)

merge = MergeSeries(in_files=[str(in_file)] * 5).run()
assert nb.load(merge.outputs.out_file).dataobj.shape == (20, 20, 20, 5)

in_4D = tmp_path / "input4D.nii.gz"
nb.Nifti1Image(np.ones((20, 20, 20, 4), dtype=float), np.eye(4), None).to_filename(
str(in_4D)
)

merge = MergeSeries(in_files=[str(in_file)] + [str(in_4D)]).run()
assert nb.load(merge.outputs.out_file).dataobj.shape == (20, 20, 20, 5)

with pytest.raises(ValueError):
MergeSeries(in_files=[str(in_file)] + [str(in_4D)], allow_4D=False).run()