Skip to content

Commit

Permalink
fix: squeeze image with np.squeeze / change input name for consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
oesteban committed Apr 8, 2020
1 parent 03ebb6d commit 9446d47
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
11 changes: 8 additions & 3 deletions niworkflows/interfaces/nibabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _run_interface(self, runtime):

class _FourToThreeInputSpec(BaseInterfaceInputSpec):
in_file = File(exists=True, mandatory=True, desc='input 4d image')
accept_3D = traits.Bool(False, usedefault=True, desc='do not fail if a 3D volume is passed in')
allow_3D = traits.Bool(False, usedefault=True, desc='do not fail if a 3D volume is passed in')


class _FourToThreeOutputSpec(TraitedSpec):
Expand All @@ -110,17 +110,22 @@ class SplitSeries(SimpleInterface):

def _run_interface(self, runtime):
filenii = nb.squeeze_image(nb.load(self.inputs.in_file))
filenii = filenii.__class__(
np.squeeze(filenii.dataobj), filenii.affine, filenii.header
)
ndim = filenii.dataobj.ndim
if ndim != 4:
if self.inputs.accept_3D and ndim == 3:
if self.inputs.allow_3D and ndim == 3:
out_file = str(
Path(fname_presuffix(self.inputs.in_file, suffix=f"_idx-000")).absolute()
)
self._results['out_files'] = out_file
filenii.to_filename(out_file)
return runtime
raise RuntimeError(
f"Input image image is {ndim}D ({'x'.join(['%d' % s for s in filenii.shape])}).")
f"Input image <{self.inputs.in_file}> is {ndim}D "
f"({'x'.join(['%d' % s for s in filenii.shape])})."
)

files_3d = nb.four_to_three(filenii)
self._results['out_files'] = []
Expand Down
7 changes: 4 additions & 3 deletions niworkflows/interfaces/tests/test_nibabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_SplitSeries(tmp_path):
with pytest.raises(RuntimeError):
SplitSeries(in_file=str(in_file)).run()

split = SplitSeries(in_file=str(in_file), accept_3D=True).run()
split = SplitSeries(in_file=str(in_file), allow_3D=True).run()
assert isinstance(split.outputs.out_files, str)

# Test the 3D
Expand All @@ -102,7 +102,7 @@ def test_SplitSeries(tmp_path):
with pytest.raises(RuntimeError):
SplitSeries(in_file=str(in_file)).run()

split = SplitSeries(in_file=str(in_file), accept_3D=True).run()
split = SplitSeries(in_file=str(in_file), allow_3D=True).run()
assert isinstance(split.outputs.out_files, str)

# Test the 5D
Expand All @@ -114,7 +114,7 @@ def test_SplitSeries(tmp_path):
SplitSeries(in_file=str(in_file)).run()

with pytest.raises(RuntimeError):
SplitSeries(in_file=str(in_file), accept_3D=True).run()
SplitSeries(in_file=str(in_file), allow_3D=True).run()

# Test splitting ANTs warpfields
data = np.ones((20, 20, 20, 1, 3), dtype=float)
Expand All @@ -124,6 +124,7 @@ def test_SplitSeries(tmp_path):
split = SplitSeries(in_file=str(in_file)).run()
assert len(split.outputs.out_files) == 3


def test_MergeSeries(tmp_path):
"""Test 3-to-4 NIfTI concatenation interface."""
os.chdir(str(tmp_path))
Expand Down

0 comments on commit 9446d47

Please sign in to comment.