Skip to content

Commit

Permalink
ENH: added Elastix registration (issue SlicerProstate#334)
Browse files Browse the repository at this point in the history
* SliceTrackerRegistrationLogic expects registration algorithm as a
  parameter in order to be interchangeable
* IRegistrationAlgorithm is the interface that needs to be implemented
  in order to add new registration algorithm
* new configuration entries
* adapted SliceTrackerRegistration widget and cli
* gentle handling if preferred algorithm doesn't exist with giving
  notification and using fallback algorithm (default: BRAINS)

TODO: use own parameter files needs to possible
  • Loading branch information
che85 committed Jan 24, 2018
1 parent e13f104 commit f8f8c07
Showing 1 changed file with 69 additions and 185 deletions.
254 changes: 69 additions & 185 deletions SliceTracker/SliceTrackerRegistration.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import argparse, sys, os, logging
import qt, slicer
from slicer.ScriptedLoadableModule import *
from SlicerDevelopmentToolboxUtils.mixins import ModuleLogicMixin, ModuleWidgetMixin
from SliceTrackerUtils.sessionData import *
from SliceTrackerUtils.constants import SliceTrackerConstants

from SlicerDevelopmentToolboxUtils.mixins import ModuleWidgetMixin
from SlicerDevelopmentToolboxUtils.decorators import onReturnProcessEvents

from SliceTrackerUtils.constants import SliceTrackerConstants
from SliceTrackerUtils.sessionData import RegistrationResult
import SliceTrackerUtils.algorithms.registration as registration


class SliceTrackerRegistration(ScriptedLoadableModule):

Expand All @@ -14,7 +17,7 @@ def __init__(self, parent):
self.parent.title = "SliceTracker Registration"
self.parent.categories = ["Radiology"]
self.parent.dependencies = ["SlicerDevelopmentToolbox"]
self.parent.contributors = ["Peter Behringer (SPL), Christian Herz (SPL), Andriy Fedorov (SPL)"]
self.parent.contributors = ["Christian Herz (SPL), Peter Behringer (SPL), Andriy Fedorov (SPL)"]
self.parent.helpText = """ SliceTracker Registration facilitates support of MRI-guided targeted prostate biopsy. """
self.parent.acknowledgementText = """Surgical Planning Laboratory, Brigham and Women's Hospital, Harvard
Medical School, Boston, USA This work was supported in part by the National
Expand All @@ -37,7 +40,8 @@ class SliceTrackerRegistrationWidget(ScriptedLoadableModuleWidget, ModuleWidgetM

def __init__(self, parent=None):
ScriptedLoadableModuleWidget.__init__(self, parent)
self.logic = SliceTrackerRegistrationLogic()
self.registrationAlgorithm = None
self.counter = 1

def setup(self):
ScriptedLoadableModuleWidget.setup(self)
Expand All @@ -59,27 +63,26 @@ def setup(self):
self.fiducialSelector = self.createComboBox(nodeTypes=["vtkMRMLMarkupsFiducialNode", ""], noneEnabled=True,
showChildNodeTypes=False, selectNodeUponCreation=False,
toolTip="Select the Targets")
self.initialTransformSelector = self.createComboBox(nodeTypes=["vtkMRMLTransformNode", "vtkMRMLBSplineTransformNode",
"vtkMRMLLinearTransformNode", ""],
noneEnabled=True,
showChildNodeTypes=False, selectNodeUponCreation=False,
toolTip="Select the initial transform")

self.algorithmSelector = qt.QComboBox()
self.algorithmSelector.addItems(registration.__algorithms__.keys())

self.applyRegistrationButton = self.createButton("Run Registration")
self.registrationGroupBoxLayout.addRow("Moving Image Volume: ", self.movingVolumeSelector)
self.registrationGroupBoxLayout.addRow("Moving Label Volume: ", self.movingLabelSelector)
self.registrationGroupBoxLayout.addRow("Fixed Image Volume: ", self.fixedVolumeSelector)
self.registrationGroupBoxLayout.addRow("Fixed Label Volume: ", self.fixedLabelSelector)
self.registrationGroupBoxLayout.addRow("Initial Transform: ", self.initialTransformSelector)
self.registrationGroupBoxLayout.addRow("Targets: ", self.fiducialSelector)
self.registrationGroupBoxLayout.addRow("Algorithm:", self.algorithmSelector)
self.registrationGroupBoxLayout.addRow(self.applyRegistrationButton)
self.layout.addWidget(self.registrationGroupBox)
self.layout.addStretch()
self.setupConnections()
self.updateButton()
self.onAlgorithmSelected(0)

def setupConnections(self):
self.applyRegistrationButton.clicked.connect(self.runRegistration)
self.movingVolumeSelector.connect('currentNodeChanged(bool)', self.updateButton)
self.algorithmSelector.currentIndexChanged.connect(self.onAlgorithmSelected)
self.movingVolumeSelector.connect('currentNodeChanged(bool)', self.updateButton)
self.fixedVolumeSelector.connect('currentNodeChanged(bool)', self.updateButton)
self.fixedLabelSelector.connect('currentNodeChanged(bool)', self.updateButton)
Expand All @@ -99,7 +102,17 @@ def updateButton(self):
self.yellowCompositeNode.SetBackgroundVolumeID(self.fixedVolumeSelector.currentNode().GetID())
if self.fixedLabelSelector.currentNode():
self.yellowCompositeNode.SetLabelVolumeID(self.fixedLabelSelector.currentNode().GetID())
self.applyRegistrationButton.enabled = self.isRegistrationPossible()
self.applyRegistrationButton.enabled = self.isRegistrationPossible() and self.registrationAlgorithm is not None

def onAlgorithmSelected(self, index):
text = self.algorithmSelector.itemText(index)
algorithm = registration.__algorithms__[text]
if algorithm.isAlgorithmAvailable():
self.registrationAlgorithm = algorithm
else:
logging.info("Selected algorithm {} seems not to be available due to missing dependencies".format(text))
self.registrationAlgorithm = None
self.updateButton()

def isRegistrationPossible(self):
return self.movingVolumeSelector.currentNode() and self.fixedVolumeSelector.currentNode() and \
Expand All @@ -108,19 +121,19 @@ def isRegistrationPossible(self):
def runRegistration(self):
logging.debug("Starting Registration")
self.progress = self.createProgressDialog(value=1, maximum=4)
parameterNode = slicer.vtkMRMLScriptedModuleNode()
parameterNode.SetAttribute('FixedImageNodeID', self.fixedVolumeSelector.currentNode().GetID())
parameterNode.SetAttribute('FixedLabelNodeID', self.fixedLabelSelector.currentNode().GetID())
parameterNode.SetAttribute('MovingImageNodeID', self.movingVolumeSelector.currentNode().GetID())
parameterNode.SetAttribute('MovingLabelNodeID', self.movingLabelSelector.currentNode().GetID())
if self.fiducialSelector.currentNode():
parameterNode.SetAttribute('TargetsNodeID', self.fiducialSelector.currentNode().GetID())
if self.initialTransformSelector.currentNode():
parameterNode.SetAttribute('InitialTransformNodeID', self.initialTransformSelector.currentNode().GetID())
self.logic.runReRegistration(parameterNode, progressCallback=self.updateProgressBar)
else:
self.logic.run(parameterNode, progressCallback=self.updateProgressBar)

logic = SliceTrackerRegistrationLogic(self.registrationAlgorithm())

parameterNode = logic.initializeParameterNode(self.fixedVolumeSelector.currentNode(),
self.fixedLabelSelector.currentNode(),
self.movingVolumeSelector.currentNode(),
self.movingLabelSelector.currentNode(),
self.fiducialSelector.currentNode())

logic.run(parameterNode, result=RegistrationResult("{}: RegistrationResult".format(str(self.counter))),
progressCallback=self.updateProgressBar)
self.progress.close()
self.counter += 1

@onReturnProcessEvents
def updateProgressBar(self, **kwargs):
Expand All @@ -130,165 +143,33 @@ def updateProgressBar(self, **kwargs):
setattr(self.progress, key, value)


class SliceTrackerRegistrationLogic(ScriptedLoadableModuleLogic, ModuleLogicMixin):
class SliceTrackerRegistrationLogic(ScriptedLoadableModuleLogic):

counter = 1
@staticmethod
def initializeParameterNode(fixedVolume, fixedLabel, movingVolume, movingLabel, targets=None):
parameterNode = slicer.vtkMRMLScriptedModuleNode()
parameterNode.SetAttribute('FixedImageNodeID', fixedVolume.GetID())
parameterNode.SetAttribute('FixedLabelNodeID', fixedLabel.GetID())
parameterNode.SetAttribute('MovingImageNodeID', movingVolume.GetID())
parameterNode.SetAttribute('MovingLabelNodeID', movingLabel.GetID())
if targets:
parameterNode.SetAttribute('TargetsNodeID', targets.GetID())
return parameterNode

def __init__(self):
def __init__(self, algorithm):
ScriptedLoadableModuleLogic.__init__(self)
self.registrationResult = None

def _processParameterNode(self, parameterNode):
if not self.registrationResult:
self.registrationResult = RegistrationResult("01: RegistrationResult")
result = self.registrationResult
result.volumes.fixed = slicer.mrmlScene.GetNodeByID(parameterNode.GetAttribute('FixedImageNodeID'))
result.labels.fixed = slicer.mrmlScene.GetNodeByID(parameterNode.GetAttribute('FixedLabelNodeID'))
result.labels.moving = slicer.mrmlScene.GetNodeByID(parameterNode.GetAttribute('MovingLabelNodeID'))
movingVolume = slicer.mrmlScene.GetNodeByID(parameterNode.GetAttribute('MovingImageNodeID'))
result.volumes.moving = self.volumesLogic.CloneVolume(slicer.mrmlScene, movingVolume,
"temp-movingVolume_" + str(self.counter))
self.counter += 1
self.registrationAlgorithm = algorithm

logging.debug("Fixed Image Name: %s" % result.volumes.fixed.GetName())
logging.debug("Fixed Label Name: %s" % result.labels.fixed.GetName())
logging.debug("Moving Image Name: %s" % movingVolume.GetName())
logging.debug("Moving Label Name: %s" % result.labels.moving.GetName())
initialTransform = parameterNode.GetAttribute('InitialTransformNodeID')
if initialTransform:
initialTransform = slicer.mrmlScene.GetNodeByID(initialTransform)
logging.debug("Initial Registration Name: %s" % initialTransform.GetName())
return result

def run(self, parameterNode, progressCallback=None):
self.progressCallback = progressCallback
result = self._processParameterNode(parameterNode)

registrationTypes = ['rigid', 'affine', 'bSpline']
self.createVolumeAndTransformNodes(registrationTypes, prefix=str(result.seriesNumber), suffix=result.suffix)

self.doRigidRegistration(movingBinaryVolume=result.labels.moving, initializeTransformMode="useCenterOfROIAlign")
self.doAffineRegistration()
self.doBSplineRegistration(initialTransform=result.transforms.affine)

targetsNodeID = parameterNode.GetAttribute('TargetsNodeID')
if targetsNodeID:
result.targets.original = slicer.mrmlScene.GetNodeByID(targetsNodeID)
self.transformTargets(registrationTypes, result.targets.original, str(result.seriesNumber), suffix=result.suffix)
result.volumes.moving = slicer.mrmlScene.GetNodeByID(parameterNode.GetAttribute('MovingImageNodeID'))

def runReRegistration(self, parameterNode, progressCallback=None):
logging.debug("Starting Re-Registration")

self.progressCallback = progressCallback

self._processParameterNode(parameterNode)
result = self.registrationResult

registrationTypes = ['rigid', 'bSpline']
self.createVolumeAndTransformNodes(registrationTypes, prefix=str(result.seriesNumber), suffix=result.suffix)
initialTransform = parameterNode.GetAttribute('InitialTransformNodeID')

if initialTransform:
initialTransform = slicer.mrmlScene.GetNodeByID(initialTransform)

# TODO: label value should be delivered by parameterNode
self.dilateMask(result.labels.fixed, dilateValue=1)
self.doRigidRegistration(movingBinaryVolume=result.labels.moving,
initialTransform=initialTransform if initialTransform else None)
self.doBSplineRegistration(initialTransform=result.transforms.rigid, useScaleVersor3D=True, useScaleSkewVersor3D=True,
useAffine=True)

targetsNodeID = parameterNode.GetAttribute('TargetsNodeID')
if targetsNodeID:
result.targets.original = slicer.mrmlScene.GetNodeByID(targetsNodeID)
self.transformTargets(registrationTypes, result.originalTargets, str(result.seriesNumber), suffix=result.suffix)
result.movingVolume = slicer.mrmlScene.GetNodeByID(parameterNode.GetAttribute('MovingImageNodeID'))

def createVolumeAndTransformNodes(self, registrationTypes, prefix, suffix=""):
for regType in registrationTypes:
self.registrationResult.setVolume(regType, self.createScalarVolumeNode(prefix + '-VOLUME-' + regType + suffix))
transformName = prefix + '-TRANSFORM-' + regType + suffix
transform = self.createBSplineTransformNode(transformName) if regType == 'bSpline' \
else self.createLinearTransformNode(transformName)
self.registrationResult.setTransform(regType, transform)

def transformTargets(self, registrations, targets, prefix, suffix=""):
if targets:
for registration in registrations:
name = prefix + '-TARGETS-' + registration + suffix
clone = self.cloneFiducialAndTransform(name, targets, self.registrationResult.getTransform(registration))
clone.SetLocked(True)
self.registrationResult.setTargets(registration, clone)

def cloneFiducialAndTransform(self, cloneName, originalTargets, transformNode):
tfmLogic = slicer.modules.transforms.logic()
clonedTargets = self.cloneFiducials(originalTargets, cloneName)
clonedTargets.SetAndObserveTransformNodeID(transformNode.GetID())
tfmLogic.hardenTransform(clonedTargets)
return clonedTargets

def doRigidRegistration(self, **kwargs):
self.updateProgress(labelText='\nRigid registration', value=2)
paramsRigid = {'fixedVolume': self.registrationResult.volumes.fixed,
'movingVolume': self.registrationResult.volumes.moving,
'fixedBinaryVolume': self.registrationResult.labels.fixed,
'outputTransform': self.registrationResult.transforms.rigid.GetID(),
'outputVolume': self.registrationResult.volumes.rigid.GetID(),
'maskProcessingMode': "ROI",
'useRigid': True}
for key, value in kwargs.iteritems():
paramsRigid[key] = value
slicer.cli.run(slicer.modules.brainsfit, None, paramsRigid, wait_for_completion=True)
self.registrationResult.cmdArguments += "Rigid Registration Parameters: %s" % str(paramsRigid) + "\n\n"

def doAffineRegistration(self):
self.updateProgress(labelText='\nAffine registration', value=2)
paramsAffine = {'fixedVolume': self.registrationResult.volumes.fixed,
'movingVolume': self.registrationResult.volumes.moving,
'fixedBinaryVolume': self.registrationResult.labels.fixed,
'movingBinaryVolume': self.registrationResult.labels.moving,
'outputTransform': self.registrationResult.transforms.affine.GetID(),
'outputVolume': self.registrationResult.volumes.affine.GetID(),
'maskProcessingMode': "ROI",
'useAffine': True,
'initialTransform': self.registrationResult.transforms.rigid}
slicer.cli.run(slicer.modules.brainsfit, None, paramsAffine, wait_for_completion=True)
self.registrationResult.cmdArguments += "Affine Registration Parameters: %s" % str(paramsAffine) + "\n\n"

def doBSplineRegistration(self, initialTransform, **kwargs):
self.updateProgress(labelText='\nBSpline registration', value=3)
paramsBSpline = {'fixedVolume': self.registrationResult.volumes.fixed,
'movingVolume': self.registrationResult.volumes.moving,
'outputVolume': self.registrationResult.volumes.bSpline.GetID(),
'bsplineTransform': self.registrationResult.transforms.bSpline.GetID(),
'fixedBinaryVolume': self.registrationResult.labels.fixed,
'movingBinaryVolume': self.registrationResult.labels.moving,
'useROIBSpline': True,
'useBSpline': True,
'splineGridSize': "3,3,3",
'maskProcessing': "ROI",
'minimumStepLength': "0.005",
'maximumStepLength': "0.2",
'costFunctionConvergenceFactor': "1.00E+09",
'maskProcessingMode': "ROI",
'initialTransform': initialTransform}
for key, value in kwargs.iteritems():
paramsBSpline[key] = value

slicer.cli.run(slicer.modules.brainsfit, None, paramsBSpline, wait_for_completion=True)
self.registrationResult.cmdArguments += "BSpline Registration Parameters: %s" % str(paramsBSpline) + "\n\n"

self.updateProgress(labelText='\nCompleted registration', value=4)

def updateProgress(self, **kwargs):
if self.progressCallback:
self.progressCallback(**kwargs)
def run(self, parameterNode, result, progressCallback=None):
self.registrationAlgorithm.run(parameterNode, result, progressCallback)

def getResult(self):
return self.registrationAlgorithm.registrationResult


def main(argv):
try:
parser = argparse.ArgumentParser(description="Slicetracker Registration")
parser = argparse.ArgumentParser(description="SliceTracker Registration")
parser.add_argument("-fl", "--fixed-label", dest="fixed_label", metavar="PATH", default="-", required=True,
help="Fixed label to be used for registration")
parser.add_argument("-ml", "--moving-label", dest="moving_label", metavar="PATH", default="-", required=True,
Expand All @@ -301,6 +182,9 @@ def main(argv):
required=False, help="Initial rigid transform for re-registration")
parser.add_argument("-o", "--output-directory", dest="output_directory", metavar="PATH", default="-",
required=False, help="Output directory for registration result")
parser.add_argument("-al", "--algorithm", dest="algorithm", metavar="PATH", default="BRAINS",
choices=registration.__algorithms__.keys(), required=False,
help="Algorithm to be used for registration (default: %(default)s)")

args = parser.parse_args(argv)

Expand All @@ -313,17 +197,17 @@ def main(argv):
success, fixedVolume = slicer.util.loadVolume(args.fixed_volume, returnNode=True)
success, movingVolume = slicer.util.loadVolume(args.moving_volume, returnNode=True)

parameterNode = slicer.vtkMRMLScriptedModuleNode()
parameterNode.SetAttribute('FixedImageNodeID', fixedVolume.GetID())
parameterNode.SetAttribute('FixedLabelNodeID', fixedLabel.GetID())
parameterNode.SetAttribute('MovingImageNodeID', movingVolume.GetID())
parameterNode.SetAttribute('MovingLabelNodeID', movingLabel.GetID())
algorithm = registration.__algorithms__[args.algorithm]

if not algorithm.isAlgorithmAvailable():
raise RuntimeError("Registration algorithm {} cannot be executed due to missing dependencies.".format(args.algorithm))

logic = SliceTrackerRegistrationLogic()
logic.run(parameterNode)
logic = SliceTrackerRegistrationLogic(algorithm())
parameterNode = logic.initializeParameterNode(fixedVolume, fixedLabel, movingVolume, movingLabel)
logic.run(parameterNode, result=RegistrationResult("01: RegistrationResult"))

if args.output_directory != "-":
logic.registrationResult.save(args.output_directory)
logic.getResult().save(args.output_directory)

except Exception, e:
print e
Expand Down

0 comments on commit f8f8c07

Please sign in to comment.