diff --git a/python/lsst/drp/tasks/assemble_coadd.py b/python/lsst/drp/tasks/assemble_coadd.py index f56da360..1eab9ad9 100644 --- a/python/lsst/drp/tasks/assemble_coadd.py +++ b/python/lsst/drp/tasks/assemble_coadd.py @@ -19,37 +19,30 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -__all__ = [ - "AssembleCoaddTask", - "AssembleCoaddConnections", +__all__ = ( "AssembleCoaddConfig", - "CompareWarpAssembleCoaddTask", - "CompareWarpAssembleCoaddConfig", -] + "AssembleCoaddConnections", + "AssembleCoaddTask", +) -import copy import logging import warnings import lsst.afw.geom as afwGeom import lsst.afw.image as afwImage import lsst.afw.math as afwMath -import lsst.afw.table as afwTable import lsst.coadd.utils as coaddUtils import lsst.geom as geom import lsst.meas.algorithms as measAlg import lsst.pex.config as pexConfig import lsst.pex.exceptions as pexExceptions import lsst.pipe.base as pipeBase -import lsst.utils as utils -import lsstDebug import numpy from deprecated.sphinx import deprecated -from lsst.meas.algorithms import AccumulatorMeanStack, ScaleVarianceTask, SourceDetectionTask -from lsst.pipe.tasks.coaddBase import CoaddBaseTask, makeSkyInfo, reorderAndPadList, subBBoxIter +from lsst.meas.algorithms import AccumulatorMeanStack +from lsst.pipe.tasks.coaddBase import CoaddBaseTask, makeSkyInfo, subBBoxIter from lsst.pipe.tasks.healSparseMapping import HealSparseInputMapTask from lsst.pipe.tasks.interpImage import InterpImageTask -from lsst.pipe.tasks.maskStreaks import MaskStreaksTask from lsst.pipe.tasks.scaleZeroPoint import ScaleZeroPointTask from lsst.skymap import BaseSkyMap from lsst.utils.timer import timeMethod @@ -57,7 +50,7 @@ log = logging.getLogger(__name__) -class AssembleCoaddConnections( +class BaseAssembleCoaddConnections( pipeBase.PipelineTaskConnections, dimensions=("tract", "patch", "band", "skymap"), defaultTemplates={ @@ -67,6 +60,8 @@ class AssembleCoaddConnections( "warpTypeSuffix": "", }, ): + """Connections to define common input connections for coaddition tasks.""" + inputWarps = pipeBase.connectionTypes.Input( doc=( "Input list of warps to be assemebled i.e. stacked." @@ -101,54 +96,20 @@ class AssembleCoaddConnections( dimensions=("tract", "patch", "skymap", "band"), minimum=0, ) - coaddExposure = pipeBase.connectionTypes.Output( - doc="Output coadded exposure, produced by stacking input warps", - name="{outputCoaddName}Coadd{warpTypeSuffix}", - storageClass="ExposureF", - dimensions=("tract", "patch", "skymap", "band"), - ) - nImage = pipeBase.connectionTypes.Output( - doc="Output image of number of input images per pixel", - name="{outputCoaddName}Coadd_nImage", - storageClass="ImageU", - dimensions=("tract", "patch", "skymap", "band"), - ) - inputMap = pipeBase.connectionTypes.Output( - doc="Output healsparse map of input images", - name="{outputCoaddName}Coadd_inputMap", - storageClass="HealSparseMap", - dimensions=("tract", "patch", "skymap", "band"), - ) - - def __init__(self, *, config=None): - super().__init__(config=config) - - if not config.doMaskBrightObjects: - self.prerequisiteInputs.remove("brightObjectMask") - - if not config.doSelectVisits: - self.inputs.remove("selectedVisits") - if not config.doNImage: - self.outputs.remove("nImage") - if not self.config.doInputMap: - self.outputs.remove("inputMap") - - -class AssembleCoaddConfig( - CoaddBaseTask.ConfigClass, pipeBase.PipelineTaskConfig, pipelineConnections=AssembleCoaddConnections +class BaseAssembleCoaddConfig( + CoaddBaseTask.ConfigClass, pipeBase.PipelineTaskConfig, pipelineConnections=BaseAssembleCoaddConnections ): warpType = pexConfig.Field( doc="Warp name: one of 'direct' or 'psfMatched'", dtype=str, default="direct", ) - subregionSize = pexConfig.ListField( - dtype=int, + subregionSize = pexConfig.ListField[int]( doc="Width, height of stack subregion size; " "make small enough that a full stack of images will fit into memory " - " at once.", + " at once. Relevant only if `doOnlineForMean` is False.", length=2, default=(2000, 2000), ) @@ -205,8 +166,8 @@ class AssembleCoaddConfig( ) doNImage = pexConfig.Field( doc="Create image of number of contributing exposures for each pixel", - dtype=bool, default=False, + dtype=bool, ) doUsePsfMatchedPolygons = pexConfig.Field( doc="Use ValidPolygons from shrunk Psf-Matched Calexps? Should be set " @@ -303,7 +264,7 @@ def validate(self): ) -class AssembleCoaddTask(CoaddBaseTask, pipeBase.PipelineTask): +class BaseAssembleCoaddTask(CoaddBaseTask, pipeBase.PipelineTask): """Assemble a coadded image from a set of warps. Each Warp that goes into a coadd will typically have an independent @@ -339,8 +300,8 @@ class AssembleCoaddTask(CoaddBaseTask, pipeBase.PipelineTask): documentation for the subtasks for further information. """ - ConfigClass = AssembleCoaddConfig - _DefaultName = "assembleCoadd" + ConfigClass = BaseAssembleCoaddConfig + _DefaultName = "baseAssembleCoadd" def __init__(self, *args, **kwargs): # TODO: DM-17415 better way to handle previously allowed passed args @@ -471,7 +432,7 @@ def _makeSupplementaryData(self, butlerQC, inputRefs, outputRefs): def makeSupplementaryDataGen3(self, butlerQC, inputRefs, outputRefs): return self._makeSupplementaryData(butlerQC, inputRefs, outputRefs) - def prepareInputs(self, refList): + def prepareInputs(self, refList, bbox=None): """Prepare the input warps for coaddition by measuring the weight for each warp and the scaling for the photometric zero point. @@ -482,8 +443,10 @@ def prepareInputs(self, refList): Parameters ---------- - refList : `list` + refList : `list` [`~lsst.daf.butler.DeferredDatasetHandle`] List of data references to tempExp. + bbox : `lsst.geom.Box2I`, optional + Bounding box to use for each warp. Returns ------- @@ -511,7 +474,10 @@ def prepareInputs(self, refList): imageScalerList = [] tempExpName = self.getTempExpDatasetName(self.warpType) for tempExpRef in refList: - tempExp = tempExpRef.get() + if bbox: + tempExp = tempExpRef.get(parameters={"bbox": bbox}) + else: + tempExp = tempExpRef.get() # Ignore any input warp that is empty of data if numpy.isnan(tempExp.image.array).all(): continue @@ -580,154 +546,96 @@ def prepareStats(self, mask=None): statsFlags = afwMath.stringToStatisticsProperty(self.config.statistic) return pipeBase.Struct(ctrl=statsCtrl, flags=statsFlags) - @timeMethod - def run( + def assembleSubregion( self, - skyInfo, + coaddExposure, + bbox, tempExpRefList, imageScalerList, weightList, - altMaskList=None, - mask=None, - supplementaryData=None, + altMaskList, + statsFlags, + statsCtrl, + nImage=None, ): - """Assemble a coadd from input warps. + """Assemble the coadd for a sub-region. - Assemble the coadd using the provided list of coaddTempExps. Since - the full coadd covers a patch (a large area), the assembly is - performed over small areas on the image at a time in order to - conserve memory usage. Iterate over subregions within the outer - bbox of the patch using `assembleSubregion` to stack the corresponding - subregions from the coaddTempExps with the statistic specified. - Set the edge bits the coadd mask based on the weight map. + For each coaddTempExp, check for (and swap in) an alternative mask + if one is passed. Remove mask planes listed in + `config.removeMaskPlanes`. Finally, stack the actual exposures using + `lsst.afw.math.statisticsStack` with the statistic specified by + statsFlags. Typically, the statsFlag will be one of lsst.afw.math.MEAN + for a mean-stack or `lsst.afw.math.MEANCLIP` for outlier rejection + using an N-sigma clipped mean where N and iterations are specified by + statsCtrl. Assign the stacked subregion back to the coadd. Parameters ---------- - skyInfo : `~lsst.pipe.base.Struct` - Struct with geometric information about the patch. + coaddExposure : `lsst.afw.image.Exposure` + The target exposure for the coadd. + bbox : `lsst.geom.Box` + Sub-region to coadd. tempExpRefList : `list` - List of data references to Warps (previously called CoaddTempExps). + List of data reference to tempExp. imageScalerList : `list` List of image scalers. weightList : `list` List of weights. - altMaskList : `list`, optional + altMaskList : `list` List of alternate masks to use rather than those stored with - tempExp. - mask : `int`, optional - Bit mask value to exclude from coaddition. - supplementaryData : `~lsst.pipe.base.Struct`, optional - Struct with additional data products needed to assemble coadd. - Only used by subclasses that implement ``_makeSupplementaryData`` - and override `run`. - - Returns - ------- - result : `~lsst.pipe.base.Struct` - Results as a struct with attributes: - - ``coaddExposure`` - Coadded exposure (`~lsst.afw.image.Exposure`). - ``nImage`` - Exposure count image (`~lsst.afw.image.Image`), if requested. - ``inputMap`` - Bit-wise map of inputs, if requested. - ``warpRefList`` - Input list of refs to the warps - (`~lsst.daf.butler.DeferredDatasetHandle`) (unmodified). - ``imageScalerList`` - Input list of image scalers (`list`) (unmodified). - ``weightList`` - Input list of weights (`list`) (unmodified). - - Raises - ------ - lsst.pipe.base.NoWorkFound - Raised if no data references are provided. + tempExp, or None. Each element is dict with keys = mask plane + name to which to add the spans. + statsFlags : `lsst.afw.math.Property` + Property object for statistic for coadd. + statsCtrl : `lsst.afw.math.StatisticsControl` + Statistics control object for coadd. + nImage : `lsst.afw.image.ImageU`, optional + Keeps track of exposure count for each pixel. """ - tempExpName = self.getTempExpDatasetName(self.warpType) - self.log.info("Assembling %s %s", len(tempExpRefList), tempExpName) - if not tempExpRefList: - raise pipeBase.NoWorkFound("No exposures provided for co-addition.") - - stats = self.prepareStats(mask=mask) + self.log.debug("Computing coadd over %s", bbox) - if altMaskList is None: - altMaskList = [None] * len(tempExpRefList) + coaddExposure.mask.addMaskPlane("REJECTED") + coaddExposure.mask.addMaskPlane("CLIPPED") + coaddExposure.mask.addMaskPlane("SENSOR_EDGE") + maskMap = self.setRejectedMaskMapping(statsCtrl) + clipped = afwImage.Mask.getPlaneBitMask("CLIPPED") + maskedImageList = [] + if nImage is not None: + subNImage = afwImage.ImageU(bbox.getWidth(), bbox.getHeight()) + for tempExpRef, imageScaler, altMask in zip(tempExpRefList, imageScalerList, altMaskList): + exposure = tempExpRef.get(parameters={"bbox": bbox}) - coaddExposure = afwImage.ExposureF(skyInfo.bbox, skyInfo.wcs) - coaddExposure.setPhotoCalib(self.scaleZeroPoint.getPhotoCalib()) - coaddExposure.getInfo().setCoaddInputs(self.inputRecorder.makeCoaddInputs()) - self.assembleMetadata(coaddExposure, tempExpRefList, weightList) - coaddMaskedImage = coaddExposure.getMaskedImage() - subregionSizeArr = self.config.subregionSize - subregionSize = geom.Extent2I(subregionSizeArr[0], subregionSizeArr[1]) - # if nImage is requested, create a zero one which can be passed to - # assembleSubregion. - if self.config.doNImage: - nImage = afwImage.ImageU(skyInfo.bbox) - else: - nImage = None - # If inputMap is requested, create the initial version that can be - # masked in assembleSubregion. - if self.config.doInputMap: - self.inputMapper.build_ccd_input_map( - skyInfo.bbox, skyInfo.wcs, coaddExposure.getInfo().getCoaddInputs().ccds - ) + maskedImage = exposure.getMaskedImage() + mask = maskedImage.getMask() + if altMask is not None: + self.applyAltMaskPlanes(mask, altMask) + imageScaler.scaleMaskedImage(maskedImage) - if self.config.doOnlineForMean and self.config.statistic == "MEAN": - try: - self.assembleOnlineMeanCoadd( - coaddExposure, - tempExpRefList, - imageScalerList, - weightList, - altMaskList, - stats.ctrl, - nImage=nImage, - ) - except Exception as e: - self.log.exception("Cannot compute online coadd %s", e) - raise - else: - for subBBox in subBBoxIter(skyInfo.bbox, subregionSize): - try: - self.assembleSubregion( - coaddExposure, - subBBox, - tempExpRefList, - imageScalerList, - weightList, - altMaskList, - stats.flags, - stats.ctrl, - nImage=nImage, - ) - except Exception as e: - self.log.exception("Cannot compute coadd %s: %s", subBBox, e) - raise + # Add 1 for each pixel which is not excluded by the exclude mask. + # In legacyCoadd, pixels may also be excluded by + # afwMath.statisticsStack. + if nImage is not None: + subNImage.getArray()[maskedImage.getMask().getArray() & statsCtrl.getAndMask() == 0] += 1 + if self.config.removeMaskPlanes: + self.removeMaskPlanes(maskedImage) + maskedImageList.append(maskedImage) - # If inputMap is requested, we must finalize the map after the - # accumulation. - if self.config.doInputMap: - self.inputMapper.finalize_ccd_input_map_mask() - inputMap = self.inputMapper.ccd_input_map - else: - inputMap = None + if self.config.doInputMap: + visit = exposure.getInfo().getCoaddInputs().visits[0].getId() + self.inputMapper.mask_warp_bbox(bbox, visit, mask, statsCtrl.getAndMask()) - self.setInexactPsf(coaddMaskedImage.getMask()) - # Despite the name, the following doesn't really deal with "EDGE" - # pixels: it identifies pixels that didn't receive any unmasked inputs - # (as occurs around the edge of the field). - coaddUtils.setCoaddEdgeBits(coaddMaskedImage.getMask(), coaddMaskedImage.getVariance()) - return pipeBase.Struct( - coaddExposure=coaddExposure, - nImage=nImage, - warpRefList=tempExpRefList, - imageScalerList=imageScalerList, - weightList=weightList, - inputMap=inputMap, - ) + with self.timer("stack"): + coaddSubregion = afwMath.statisticsStack( + maskedImageList, + statsFlags, + statsCtrl, + weightList, + clipped, # also set output to CLIPPED if sigma-clipped + maskMap, + ) + coaddExposure.maskedImage.assign(coaddSubregion, bbox) + if nImage is not None: + nImage.assign(subNImage, bbox) def assembleMetadata(self, coaddExposure, tempExpRefList, weightList): """Set the metadata for the coadd. @@ -736,7 +644,7 @@ def assembleMetadata(self, coaddExposure, tempExpRefList, weightList): Parameters ---------- - coaddExposure : `lsst.afw.image.Exposure` + coaddExposure : `~lsst.afw.image.Exposure` The target exposure for the coadd. tempExpRefList : `list` List of data references to tempExp. @@ -799,169 +707,8 @@ def assembleMetadata(self, coaddExposure, tempExpRefList, weightList): transmissionCurve = measAlg.makeCoaddTransmissionCurve(coaddExposure.getWcs(), coaddInputs.ccds) coaddExposure.getInfo().setTransmissionCurve(transmissionCurve) - def assembleSubregion( - self, - coaddExposure, - bbox, - tempExpRefList, - imageScalerList, - weightList, - altMaskList, - statsFlags, - statsCtrl, - nImage=None, - ): - """Assemble the coadd for a sub-region. - - For each coaddTempExp, check for (and swap in) an alternative mask - if one is passed. Remove mask planes listed in - `config.removeMaskPlanes`. Finally, stack the actual exposures using - `lsst.afw.math.statisticsStack` with the statistic specified by - statsFlags. Typically, the statsFlag will be one of lsst.afw.math.MEAN - for a mean-stack or `lsst.afw.math.MEANCLIP` for outlier rejection - using an N-sigma clipped mean where N and iterations are specified by - statsCtrl. Assign the stacked subregion back to the coadd. - - Parameters - ---------- - coaddExposure : `lsst.afw.image.Exposure` - The target exposure for the coadd. - bbox : `lsst.geom.Box` - Sub-region to coadd. - tempExpRefList : `list` - List of data reference to tempExp. - imageScalerList : `list` - List of image scalers. - weightList : `list` - List of weights. - altMaskList : `list` - List of alternate masks to use rather than those stored with - tempExp, or None. Each element is dict with keys = mask plane - name to which to add the spans. - statsFlags : `lsst.afw.math.Property` - Property object for statistic for coadd. - statsCtrl : `lsst.afw.math.StatisticsControl` - Statistics control object for coadd. - nImage : `lsst.afw.image.ImageU`, optional - Keeps track of exposure count for each pixel. - """ - self.log.debug("Computing coadd over %s", bbox) - - coaddExposure.mask.addMaskPlane("REJECTED") - coaddExposure.mask.addMaskPlane("CLIPPED") - coaddExposure.mask.addMaskPlane("SENSOR_EDGE") - maskMap = self.setRejectedMaskMapping(statsCtrl) - clipped = afwImage.Mask.getPlaneBitMask("CLIPPED") - maskedImageList = [] - if nImage is not None: - subNImage = afwImage.ImageU(bbox.getWidth(), bbox.getHeight()) - for tempExpRef, imageScaler, altMask in zip(tempExpRefList, imageScalerList, altMaskList): - exposure = tempExpRef.get(parameters={"bbox": bbox}) - - maskedImage = exposure.getMaskedImage() - mask = maskedImage.getMask() - if altMask is not None: - self.applyAltMaskPlanes(mask, altMask) - imageScaler.scaleMaskedImage(maskedImage) - - # Add 1 for each pixel which is not excluded by the exclude mask. - # In legacyCoadd, pixels may also be excluded by - # afwMath.statisticsStack. - if nImage is not None: - subNImage.getArray()[maskedImage.getMask().getArray() & statsCtrl.getAndMask() == 0] += 1 - if self.config.removeMaskPlanes: - self.removeMaskPlanes(maskedImage) - maskedImageList.append(maskedImage) - - if self.config.doInputMap: - visit = exposure.getInfo().getCoaddInputs().visits[0].getId() - self.inputMapper.mask_warp_bbox(bbox, visit, mask, statsCtrl.getAndMask()) - - with self.timer("stack"): - coaddSubregion = afwMath.statisticsStack( - maskedImageList, - statsFlags, - statsCtrl, - weightList, - clipped, # also set output to CLIPPED if sigma-clipped - maskMap, - ) - coaddExposure.maskedImage.assign(coaddSubregion, bbox) - if nImage is not None: - nImage.assign(subNImage, bbox) - - def assembleOnlineMeanCoadd( - self, coaddExposure, tempExpRefList, imageScalerList, weightList, altMaskList, statsCtrl, nImage=None - ): - """Assemble the coadd using the "online" method. - - This method takes a running sum of images and weights to save memory. - It only works for MEAN statistics. - - Parameters - ---------- - coaddExposure : `lsst.afw.image.Exposure` - The target exposure for the coadd. - tempExpRefList : `list` - List of data reference to tempExp. - imageScalerList : `list` - List of image scalers. - weightList : `list` - List of weights. - altMaskList : `list` - List of alternate masks to use rather than those stored with - tempExp, or None. Each element is dict with keys = mask plane - name to which to add the spans. - statsCtrl : `lsst.afw.math.StatisticsControl` - Statistics control object for coadd. - nImage : `lsst.afw.image.ImageU`, optional - Keeps track of exposure count for each pixel. - """ - self.log.debug("Computing online coadd.") - - coaddExposure.mask.addMaskPlane("REJECTED") - coaddExposure.mask.addMaskPlane("CLIPPED") - coaddExposure.mask.addMaskPlane("SENSOR_EDGE") - maskMap = self.setRejectedMaskMapping(statsCtrl) - thresholdDict = AccumulatorMeanStack.stats_ctrl_to_threshold_dict(statsCtrl) - - bbox = coaddExposure.maskedImage.getBBox() - - stacker = AccumulatorMeanStack( - coaddExposure.image.array.shape, - statsCtrl.getAndMask(), - mask_threshold_dict=thresholdDict, - mask_map=maskMap, - no_good_pixels_mask=statsCtrl.getNoGoodPixelsMask(), - calc_error_from_input_variance=self.config.calcErrorFromInputVariance, - compute_n_image=(nImage is not None), - ) - - for tempExpRef, imageScaler, altMask, weight in zip( - tempExpRefList, imageScalerList, altMaskList, weightList - ): - exposure = tempExpRef.get() - maskedImage = exposure.getMaskedImage() - mask = maskedImage.getMask() - if altMask is not None: - self.applyAltMaskPlanes(mask, altMask) - imageScaler.scaleMaskedImage(maskedImage) - if self.config.removeMaskPlanes: - self.removeMaskPlanes(maskedImage) - - stacker.add_masked_image(maskedImage, weight=weight) - - if self.config.doInputMap: - visit = exposure.getInfo().getCoaddInputs().visits[0].getId() - self.inputMapper.mask_warp_bbox(bbox, visit, mask, statsCtrl.getAndMask()) - - stacker.fill_stacked_masked_image(coaddExposure.maskedImage) - - if nImage is not None: - nImage.array[:, :] = stacker.n_image - - def removeMaskPlanes(self, maskedImage): - """Unset the mask of an image for mask planes specified in the config. + def removeMaskPlanes(self, maskedImage): + """Unset the mask of an image for mask planes specified in the config. Parameters ---------- @@ -1111,27 +858,6 @@ def setBrightObjectMasks(self, exposure, brightObjectMasks, dataId=None): continue spans.clippedTo(mask.getBBox()).setMask(mask, self.brightObjectBitmask) - def setInexactPsf(self, mask): - """Set INEXACT_PSF mask plane. - - If any of the input images isn't represented in the coadd (due to - clipped pixels or chip gaps), the `CoaddPsf` will be inexact. Flag - these pixels. - - Parameters - ---------- - mask : `lsst.afw.image.Mask` - Coadded exposure's mask, modified in-place. - """ - mask.addMaskPlane("INEXACT_PSF") - inexactPsf = mask.getPlaneBitMask("INEXACT_PSF") - sensorEdge = mask.getPlaneBitMask("SENSOR_EDGE") # chip edges (so PSF is discontinuous) - clipped = mask.getPlaneBitMask("CLIPPED") # pixels clipped from coadd - rejected = mask.getPlaneBitMask("REJECTED") # pixels rejected from coadd due to masks - array = mask.getArray() - selected = array & (sensorEdge | clipped | rejected) > 0 - array[selected] |= inexactPsf - def filterWarps(self, inputs, goodVisits): """Return list of only inputRefs with visitId in goodVisits ordered by goodVisit. @@ -1158,672 +884,335 @@ def filterWarps(self, inputs, goodVisits): return filteredInputs -def countMaskFromFootprint(mask, footprint, bitmask, ignoreMask): - """Function to count the number of pixels with a specific mask in a - footprint. - - Find the intersection of mask & footprint. Count all pixels in the mask - that are in the intersection that have bitmask set but do not have - ignoreMask set. Return the count. - - Parameters - ---------- - mask : `lsst.afw.image.Mask` - Mask to define intersection region by. - footprint : `lsst.afw.detection.Footprint` - Footprint to define the intersection region by. - bitmask : `Unknown` - Specific mask that we wish to count the number of occurances of. - ignoreMask : `Unknown` - Pixels to not consider. - - Returns - ------- - result : `int` - Number of pixels in footprint with specified mask. +class AssembleCoaddConnections(BaseAssembleCoaddConnections): + """Connections to define input and output connections for coaddition + tasks. """ - bbox = footprint.getBBox() - bbox.clip(mask.getBBox(afwImage.PARENT)) - fp = afwImage.Mask(bbox) - subMask = mask.Factory(mask, bbox, afwImage.PARENT) - footprint.spans.setMask(fp, bitmask) - return numpy.logical_and( - (subMask.getArray() & fp.getArray()) > 0, (subMask.getArray() & ignoreMask) == 0 - ).sum() - -class CompareWarpAssembleCoaddConnections(AssembleCoaddConnections): - psfMatchedWarps = pipeBase.connectionTypes.Input( - doc=( - "PSF-Matched Warps are required by CompareWarp regardless of the coadd type requested. " - "Only PSF-Matched Warps make sense for image subtraction. " - "Therefore, they must be an additional declared input." - ), - name="{inputCoaddName}Coadd_psfMatchedWarp", + coaddExposure = pipeBase.connectionTypes.Output( + doc="Output coadded exposure, produced by stacking input warps", + name="{outputCoaddName}Coadd{warpTypeSuffix}", storageClass="ExposureF", - dimensions=("tract", "patch", "skymap", "visit"), - deferLoad=True, - multiple=True, + dimensions=("tract", "patch", "skymap", "band"), ) - templateCoadd = pipeBase.connectionTypes.Output( - doc=( - "Model of the static sky, used to find temporal artifacts. Typically a PSF-Matched, " - "sigma-clipped coadd. Written if and only if assembleStaticSkyModel.doWrite=True" - ), - name="{outputCoaddName}CoaddPsfMatched", - storageClass="ExposureF", + nImage = pipeBase.connectionTypes.Output( + doc="Output image of number of input images per pixel", + name="{outputCoaddName}Coadd_nImage", + storageClass="ImageU", + dimensions=("tract", "patch", "skymap", "band"), + ) + inputMap = pipeBase.connectionTypes.Output( + doc="Output healsparse map of input images", + name="{outputCoaddName}Coadd_inputMap", + storageClass="HealSparseMap", dimensions=("tract", "patch", "skymap", "band"), ) def __init__(self, *, config=None): super().__init__(config=config) - if not config.assembleStaticSkyModel.doWrite: - self.outputs.remove("templateCoadd") - config.validate() + if config: + if not config.doMaskBrightObjects: + self.prerequisiteInputs.remove("brightObjectMask") -class CompareWarpAssembleCoaddConfig( - AssembleCoaddConfig, pipelineConnections=CompareWarpAssembleCoaddConnections -): - assembleStaticSkyModel = pexConfig.ConfigurableField( - target=AssembleCoaddTask, - doc="Task to assemble an artifact-free, PSF-matched Coadd to serve as " - "a naive/first-iteration model of the static sky.", - ) - detect = pexConfig.ConfigurableField( - target=SourceDetectionTask, - doc="Detect outlier sources on difference between each psfMatched warp and static sky model", - ) - detectTemplate = pexConfig.ConfigurableField( - target=SourceDetectionTask, - doc="Detect sources on static sky model. Only used if doPreserveContainedBySource is True", - ) - maskStreaks = pexConfig.ConfigurableField( - target=MaskStreaksTask, - doc="Detect streaks on difference between each psfMatched warp and static sky model. Only used if " - "doFilterMorphological is True. Adds a mask plane to an exposure, with the mask plane name set by" - "streakMaskName", - ) - streakMaskName = pexConfig.Field(dtype=str, default="STREAK", doc="Name of mask bit used for streaks") - maxNumEpochs = pexConfig.Field( - doc="Charactistic maximum local number of epochs/visits in which an artifact candidate can appear " - "and still be masked. The effective maxNumEpochs is a broken linear function of local " - "number of epochs (N): min(maxFractionEpochsLow*N, maxNumEpochs + maxFractionEpochsHigh*N). " - "For each footprint detected on the image difference between the psfMatched warp and static sky " - "model, if a significant fraction of pixels (defined by spatialThreshold) are residuals in more " - "than the computed effective maxNumEpochs, the artifact candidate is deemed persistant rather " - "than transient and not masked.", - dtype=int, - default=2, - ) - maxFractionEpochsLow = pexConfig.RangeField( - doc="Fraction of local number of epochs (N) to use as effective maxNumEpochs for low N. " - "Effective maxNumEpochs = " - "min(maxFractionEpochsLow * N, maxNumEpochs + maxFractionEpochsHigh * N)", - dtype=float, - default=0.4, - min=0.0, - max=1.0, - ) - maxFractionEpochsHigh = pexConfig.RangeField( - doc="Fraction of local number of epochs (N) to use as effective maxNumEpochs for high N. " - "Effective maxNumEpochs = " - "min(maxFractionEpochsLow * N, maxNumEpochs + maxFractionEpochsHigh * N)", - dtype=float, - default=0.03, - min=0.0, - max=1.0, - ) - spatialThreshold = pexConfig.RangeField( - doc="Unitless fraction of pixels defining how much of the outlier region has to meet the " - "temporal criteria. If 0, clip all. If 1, clip none.", - dtype=float, - default=0.5, - min=0.0, - max=1.0, - inclusiveMin=True, - inclusiveMax=True, - ) - doScaleWarpVariance = pexConfig.Field( - doc="Rescale Warp variance plane using empirical noise?", - dtype=bool, - default=True, - ) - scaleWarpVariance = pexConfig.ConfigurableField( - target=ScaleVarianceTask, - doc="Rescale variance on warps", - ) - doPreserveContainedBySource = pexConfig.Field( - doc="Rescue artifacts from clipping that completely lie within a footprint detected" - "on the PsfMatched Template Coadd. Replicates a behavior of SafeClip.", - dtype=bool, - default=True, - ) - doPrefilterArtifacts = pexConfig.Field( - doc="Ignore artifact candidates that are mostly covered by the bad pixel mask, " - "because they will be excluded anyway. This prevents them from contributing " - "to the outlier epoch count image and potentially being labeled as persistant." - "'Mostly' is defined by the config 'prefilterArtifactsRatio'.", - dtype=bool, - default=True, - ) - prefilterArtifactsMaskPlanes = pexConfig.ListField( - doc="Prefilter artifact candidates that are mostly covered by these bad mask planes.", - dtype=str, - default=("NO_DATA", "BAD", "SAT", "SUSPECT"), - ) - prefilterArtifactsRatio = pexConfig.Field( - doc="Prefilter artifact candidates with less than this fraction overlapping good pixels", - dtype=float, - default=0.05, - ) - doFilterMorphological = pexConfig.Field( - doc="Filter artifact candidates based on morphological criteria, i.g. those that appear to " - "be streaks.", - dtype=bool, + if not config.doSelectVisits: + self.inputs.remove("selectedVisits") + + if not config.doNImage: + self.outputs.remove("nImage") + + if not self.config.doInputMap: + self.outputs.remove("inputMap") + + +class AssembleCoaddConfig(BaseAssembleCoaddConfig): + doNImage = pexConfig.Field[bool]( + doc="Create image of number of contributing exposures for each pixel", default=False, ) - growStreakFp = pexConfig.Field( - doc="Grow streak footprints by this number multiplied by the PSF width", dtype=float, default=5 + subregionSize = pexConfig.ListField[int]( + doc="Width, height of stack subregion size; " + "make small enough that a full stack of images will fit into memory " + " at once. Relevant only if `doOnlineForMean` is False.", + length=2, + default=(2000, 2000), ) - def setDefaults(self): - AssembleCoaddConfig.setDefaults(self) - self.statistic = "MEAN" - self.doUsePsfMatchedPolygons = True - - # Real EDGE removed by psfMatched NO_DATA border half the width of the - # matching kernel. CompareWarp applies psfMatched EDGE pixels to - # directWarps before assembling. - if "EDGE" in self.badMaskPlanes: - self.badMaskPlanes.remove("EDGE") - self.removeMaskPlanes.append("EDGE") - self.assembleStaticSkyModel.badMaskPlanes = [ - "NO_DATA", - ] - self.assembleStaticSkyModel.warpType = "psfMatched" - self.assembleStaticSkyModel.connections.warpType = "psfMatched" - self.assembleStaticSkyModel.statistic = "MEANCLIP" - self.assembleStaticSkyModel.sigmaClip = 2.5 - self.assembleStaticSkyModel.clipIter = 3 - self.assembleStaticSkyModel.calcErrorFromInputVariance = False - self.assembleStaticSkyModel.doWrite = False - self.detect.doTempLocalBackground = False - self.detect.reEstimateBackground = False - self.detect.returnOriginalFootprints = False - self.detect.thresholdPolarity = "both" - self.detect.thresholdValue = 5 - self.detect.minPixels = 4 - self.detect.isotropicGrow = True - self.detect.thresholdType = "pixel_stdev" - self.detect.nSigmaToGrow = 0.4 - # The default nSigmaToGrow for SourceDetectionTask is already 2.4, - # Explicitly restating because ratio with detect.nSigmaToGrow matters - self.detectTemplate.nSigmaToGrow = 2.4 - self.detectTemplate.doTempLocalBackground = False - self.detectTemplate.reEstimateBackground = False - self.detectTemplate.returnOriginalFootprints = False - def validate(self): - super().validate() - if self.assembleStaticSkyModel.doNImage: - raise ValueError( - "No dataset type exists for a PSF-Matched Template N Image." - "Please set assembleStaticSkyModel.doNImage=False" - ) - - if self.assembleStaticSkyModel.doWrite and (self.warpType == self.assembleStaticSkyModel.warpType): - raise ValueError( - "warpType (%s) == assembleStaticSkyModel.warpType (%s) and will compete for " - "the same dataset name. Please set assembleStaticSkyModel.doWrite to False " - "or warpType to 'direct'. assembleStaticSkyModel.warpType should ways be " - "'PsfMatched'" % (self.warpType, self.assembleStaticSkyModel.warpType) - ) - - -class CompareWarpAssembleCoaddTask(AssembleCoaddTask): - """Assemble a compareWarp coadded image from a set of warps - by masking artifacts detected by comparing PSF-matched warps. - - In ``AssembleCoaddTask``, we compute the coadd as an clipped mean (i.e., - we clip outliers). The problem with doing this is that when computing the - coadd PSF at a given location, individual visit PSFs from visits with - outlier pixels contribute to the coadd PSF and cannot be treated correctly. - In this task, we correct for this behavior by creating a new badMaskPlane - 'CLIPPED' which marks pixels in the individual warps suspected to contain - an artifact. We populate this plane on the input warps by comparing - PSF-matched warps with a PSF-matched median coadd which serves as a - model of the static sky. Any group of pixels that deviates from the - PSF-matched template coadd by more than config.detect.threshold sigma, - is an artifact candidate. The candidates are then filtered to remove - variable sources and sources that are difficult to subtract such as - bright stars. This filter is configured using the config parameters - ``temporalThreshold`` and ``spatialThreshold``. The temporalThreshold is - the maximum fraction of epochs that the deviation can appear in and still - be considered an artifact. The spatialThreshold is the maximum fraction of - pixels in the footprint of the deviation that appear in other epochs - (where other epochs is defined by the temporalThreshold). If the deviant - region meets this criteria of having a significant percentage of pixels - that deviate in only a few epochs, these pixels have the 'CLIPPED' bit - set in the mask. These regions will not contribute to the final coadd. - Furthermore, any routine to determine the coadd PSF can now be cognizant - of clipped regions. Note that the algorithm implemented by this task is - preliminary and works correctly for HSC data. Parameter modifications and - or considerable redesigning of the algorithm is likley required for other - surveys. - - ``CompareWarpAssembleCoaddTask`` sub-classes - ``AssembleCoaddTask`` and instantiates ``AssembleCoaddTask`` - as a subtask to generate the TemplateCoadd (the model of the static sky). +class AssembleCoaddTask(BaseAssembleCoaddTask): + ConfigClass = AssembleCoaddConfig + _DefaultName = "assembleCoadd" - Notes - ----- - Debugging: - This task supports the following debug variables: - - ``saveCountIm`` - If True then save the Epoch Count Image as a fits file in the `figPath` - - ``figPath`` - Path to save the debug fits images and figures - """ + @timeMethod + def run( + self, + skyInfo, + tempExpRefList, + imageScalerList, + weightList, + altMaskList=None, + mask=None, + supplementaryData=None, + ): + """Assemble a coadd from input warps. - ConfigClass = CompareWarpAssembleCoaddConfig - _DefaultName = "compareWarpAssembleCoadd" + Assemble the coadd using the provided list of coaddTempExps. Since + the full coadd covers a patch (a large area), the assembly is + performed over small areas on the image at a time in order to + conserve memory usage. Iterate over subregions within the outer + bbox of the patch using `assembleSubregion` to stack the corresponding + subregions from the coaddTempExps with the statistic specified. + Set the edge bits the coadd mask based on the weight map. - def __init__(self, *args, **kwargs): - AssembleCoaddTask.__init__(self, *args, **kwargs) - self.makeSubtask("assembleStaticSkyModel") - detectionSchema = afwTable.SourceTable.makeMinimalSchema() - self.makeSubtask("detect", schema=detectionSchema) - if self.config.doPreserveContainedBySource: - self.makeSubtask("detectTemplate", schema=afwTable.SourceTable.makeMinimalSchema()) - if self.config.doScaleWarpVariance: - self.makeSubtask("scaleWarpVariance") - if self.config.doFilterMorphological: - self.makeSubtask("maskStreaks") - - @utils.inheritDoc(AssembleCoaddTask) - def _makeSupplementaryData(self, butlerQC, inputRefs, outputRefs): - """Generate a templateCoadd to use as a naive model of static sky to - subtract from PSF-Matched warps. + Parameters + ---------- + skyInfo : `~lsst.pipe.base.Struct` + Struct with geometric information about the patch. + tempExpRefList : `list` + List of data references to Warps (previously called CoaddTempExps). + imageScalerList : `list` + List of image scalers. + weightList : `list` + List of weights. + altMaskList : `list`, optional + List of alternate masks to use rather than those stored with + tempExp. + mask : `int`, optional + Bit mask value to exclude from coaddition. + supplementaryData : `~lsst.pipe.base.Struct`, optional + Struct with additional data products needed to assemble coadd. + Only used by subclasses that implement ``_makeSupplementaryData`` + and override `run`. Returns ------- result : `~lsst.pipe.base.Struct` Results as a struct with attributes: - ``templateCoadd`` - Coadded exposure (`lsst.afw.image.Exposure`). + ``coaddExposure`` + Coadded exposure (`~lsst.afw.image.Exposure`). ``nImage`` - Keeps track of exposure count for each pixel - (`lsst.afw.image.ImageU`). + Exposure count image (`~lsst.afw.image.Image`), if requested. + ``inputMap`` + Bit-wise map of inputs, if requested. + ``warpRefList`` + Input list of refs to the warps + (`~lsst.daf.butler.DeferredDatasetHandle`) (unmodified). + ``imageScalerList`` + Input list of image scalers (`list`) (unmodified). + ``weightList`` + Input list of weights (`list`) (unmodified). Raises ------ - RuntimeError - Raised if ``templateCoadd`` is `None`. + lsst.pipe.base.NoWorkFound + Raised if no data references are provided. """ - # Ensure that psfMatchedWarps are used as input warps for template - # generation. - staticSkyModelInputRefs = copy.deepcopy(inputRefs) - staticSkyModelInputRefs.inputWarps = inputRefs.psfMatchedWarps - - # Because subtasks don't have connections we have to make one. - # The main task's `templateCoadd` is the subtask's `coaddExposure` - staticSkyModelOutputRefs = copy.deepcopy(outputRefs) - if self.config.assembleStaticSkyModel.doWrite: - staticSkyModelOutputRefs.coaddExposure = staticSkyModelOutputRefs.templateCoadd - # Remove template coadd from both subtask's and main tasks outputs, - # because it is handled by the subtask as `coaddExposure` - del outputRefs.templateCoadd - del staticSkyModelOutputRefs.templateCoadd - - # A PSF-Matched nImage does not exist as a dataset type - if "nImage" in staticSkyModelOutputRefs.keys(): - del staticSkyModelOutputRefs.nImage - - templateCoadd = self.assembleStaticSkyModel.runQuantum( - butlerQC, staticSkyModelInputRefs, staticSkyModelOutputRefs - ) - if templateCoadd is None: - raise RuntimeError(self._noTemplateMessage(self.assembleStaticSkyModel.warpType)) - - return pipeBase.Struct( - templateCoadd=templateCoadd.coaddExposure, - nImage=templateCoadd.nImage, - warpRefList=templateCoadd.warpRefList, - imageScalerList=templateCoadd.imageScalerList, - weightList=templateCoadd.weightList, - ) - - def _noTemplateMessage(self, warpType): - warpName = warpType[0].upper() + warpType[1:] - message = """No %(warpName)s warps were found to build the template coadd which is - required to run CompareWarpAssembleCoaddTask. To continue assembling this type of coadd, - first either rerun makeCoaddTempExp with config.make%(warpName)s=True or - coaddDriver with config.makeCoadTempExp.make%(warpName)s=True, before assembleCoadd. + tempExpName = self.getTempExpDatasetName(self.warpType) + self.log.info("Assembling %s %s", len(tempExpRefList), tempExpName) + if not tempExpRefList: + raise pipeBase.NoWorkFound("No exposures provided for co-addition.") - Alternatively, to use another algorithm with existing warps, retarget the CoaddDriverConfig to - another algorithm like: + stats = self.prepareStats(mask=mask) - from lsst.pipe.tasks.assembleCoadd import SafeClipAssembleCoaddTask - config.assemble.retarget(SafeClipAssembleCoaddTask) - """ % { - "warpName": warpName - } - return message + if altMaskList is None: + altMaskList = [None] * len(tempExpRefList) - @utils.inheritDoc(AssembleCoaddTask) - @timeMethod - def run(self, skyInfo, tempExpRefList, imageScalerList, weightList, supplementaryData): - """Notes - ----- - Assemble the coadd. - - Find artifacts and apply them to the warps' masks creating a list of - alternative masks with a new "CLIPPED" plane and updated "NO_DATA" - plane. Then pass these alternative masks to the base class's ``run`` - method. - """ - # Check and match the order of the supplementaryData - # (PSF-matched) inputs to the order of the direct inputs, - # so that the artifact mask is applied to the right warp - dataIds = [ref.dataId for ref in tempExpRefList] - psfMatchedDataIds = [ref.dataId for ref in supplementaryData.warpRefList] - - if dataIds != psfMatchedDataIds: - self.log.info("Reordering and or/padding PSF-matched visit input list") - supplementaryData.warpRefList = reorderAndPadList( - supplementaryData.warpRefList, psfMatchedDataIds, dataIds - ) - supplementaryData.imageScalerList = reorderAndPadList( - supplementaryData.imageScalerList, psfMatchedDataIds, dataIds + coaddExposure = afwImage.ExposureF(skyInfo.bbox, skyInfo.wcs) + # coaddExposure.setPhotoCalib(self.scaleZeroPoint.getPhotoCalib()) + coaddExposure.getInfo().setCoaddInputs(self.inputRecorder.makeCoaddInputs()) + self.assembleMetadata(coaddExposure, tempExpRefList, weightList) + coaddMaskedImage = coaddExposure.getMaskedImage() + subregionSizeArr = self.config.subregionSize + subregionSize = geom.Extent2I(subregionSizeArr[0], subregionSizeArr[1]) + # if nImage is requested, create a zero one which can be passed to + # assembleSubregion. + if self.config.doNImage: + nImage = afwImage.ImageU(skyInfo.bbox) + else: + nImage = None + # If inputMap is requested, create the initial version that can be + # masked in assembleSubregion. + if self.config.doInputMap: + self.inputMapper.build_ccd_input_map( + skyInfo.bbox, skyInfo.wcs, coaddExposure.getInfo().getCoaddInputs().ccds ) - # Use PSF-Matched Warps (and corresponding scalers) and coadd to find - # artifacts. - spanSetMaskList = self.findArtifacts( - supplementaryData.templateCoadd, supplementaryData.warpRefList, supplementaryData.imageScalerList - ) + if self.config.doOnlineForMean and self.config.statistic == "MEAN": + try: + self.assembleOnlineMeanCoadd( + coaddExposure, + tempExpRefList, + imageScalerList, + weightList, + altMaskList, + stats.ctrl, + nImage=nImage, + ) + except Exception as e: + self.log.exception("Cannot compute online coadd %s", e) + raise + else: + for subBBox in subBBoxIter(skyInfo.bbox, subregionSize): + try: + self.assembleSubregion( + coaddExposure, + subBBox, + tempExpRefList, + imageScalerList, + weightList, + altMaskList, + stats.flags, + stats.ctrl, + nImage=nImage, + ) + except Exception as e: + self.log.exception("Cannot compute coadd %s: %s", subBBox, e) + raise - badMaskPlanes = self.config.badMaskPlanes[:] - badMaskPlanes.append("CLIPPED") - badPixelMask = afwImage.Mask.getPlaneBitMask(badMaskPlanes) + # If inputMap is requested, we must finalize the map after the + # accumulation. + if self.config.doInputMap: + self.inputMapper.finalize_ccd_input_map_mask() + inputMap = self.inputMapper.ccd_input_map + else: + inputMap = None - result = AssembleCoaddTask.run( - self, skyInfo, tempExpRefList, imageScalerList, weightList, spanSetMaskList, mask=badPixelMask + self.setInexactPsf(coaddMaskedImage.getMask()) + # Despite the name, the following doesn't really deal with "EDGE" + # pixels: it identifies pixels that didn't receive any unmasked inputs + # (as occurs around the edge of the field). + coaddUtils.setCoaddEdgeBits(coaddMaskedImage.getMask(), coaddMaskedImage.getVariance()) + return pipeBase.Struct( + coaddExposure=coaddExposure, + nImage=nImage, + warpRefList=tempExpRefList, + imageScalerList=imageScalerList, + weightList=weightList, + inputMap=inputMap, ) - # Propagate PSF-matched EDGE pixels to coadd SENSOR_EDGE and - # INEXACT_PSF. Psf-Matching moves the real edge inwards. - self.applyAltEdgeMask(result.coaddExposure.maskedImage.mask, spanSetMaskList) - return result - - def applyAltEdgeMask(self, mask, altMaskList): - """Propagate alt EDGE mask to SENSOR_EDGE AND INEXACT_PSF planes. + def assembleOnlineMeanCoadd( + self, coaddExposure, tempExpRefList, imageScalerList, weightList, altMaskList, statsCtrl, nImage=None + ): + """Assemble the coadd using the "online" method. - Parameters - ---------- - mask : `lsst.afw.image.Mask` - Original mask. - altMaskList : `list` of `dict` - List of Dicts containing ``spanSet`` lists. - Each element contains the new mask plane name (e.g. "CLIPPED - and/or "NO_DATA") as the key, and list of ``SpanSets`` to apply to - the mask. - """ - maskValue = mask.getPlaneBitMask(["SENSOR_EDGE", "INEXACT_PSF"]) - for visitMask in altMaskList: - if "EDGE" in visitMask: - for spanSet in visitMask["EDGE"]: - spanSet.clippedTo(mask.getBBox()).setMask(mask, maskValue) - - def findArtifacts(self, templateCoadd, tempExpRefList, imageScalerList): - """Find artifacts. - - Loop through warps twice. The first loop builds a map with the count - of how many epochs each pixel deviates from the templateCoadd by more - than ``config.chiThreshold`` sigma. The second loop takes each - difference image and filters the artifacts detected in each using - count map to filter out variable sources and sources that are - difficult to subtract cleanly. + This method takes a running sum of images and weights to save memory. + It only works for MEAN statistics. Parameters ---------- - templateCoadd : `lsst.afw.image.Exposure` - Exposure to serve as model of static sky. + coaddExposure : `lsst.afw.image.Exposure` + The target exposure for the coadd. tempExpRefList : `list` - List of data references to warps. + List of data reference to tempExp. imageScalerList : `list` List of image scalers. - - Returns - ------- - altMasks : `list` of `dict` - List of dicts containing information about CLIPPED - (i.e., artifacts), NO_DATA, and EDGE pixels. + weightList : `list` + List of weights. + altMaskList : `list` + List of alternate masks to use rather than those stored with + tempExp, or None. Each element is dict with keys = mask plane + name to which to add the spans. + statsCtrl : `lsst.afw.math.StatisticsControl` + Statistics control object for coadd. + nImage : `lsst.afw.image.ImageU`, optional + Keeps track of exposure count for each pixel. """ - self.log.debug("Generating Count Image, and mask lists.") - coaddBBox = templateCoadd.getBBox() - slateIm = afwImage.ImageU(coaddBBox) - epochCountImage = afwImage.ImageU(coaddBBox) - nImage = afwImage.ImageU(coaddBBox) - spanSetArtifactList = [] - spanSetNoDataMaskList = [] - spanSetEdgeList = [] - spanSetBadMorphoList = [] - badPixelMask = self.getBadPixelMask() - - # mask of the warp diffs should = that of only the warp - templateCoadd.mask.clearAllMaskPlanes() - - if self.config.doPreserveContainedBySource: - templateFootprints = self.detectTemplate.detectFootprints(templateCoadd) - else: - templateFootprints = None - - for warpRef, imageScaler in zip(tempExpRefList, imageScalerList): - warpDiffExp = self._readAndComputeWarpDiff(warpRef, imageScaler, templateCoadd) - if warpDiffExp is not None: - # This nImage only approximates the final nImage because it - # uses the PSF-matched mask. - nImage.array += ( - numpy.isfinite(warpDiffExp.image.array) * ((warpDiffExp.mask.array & badPixelMask) == 0) - ).astype(numpy.uint16) - fpSet = self.detect.detectFootprints(warpDiffExp, doSmooth=False, clearMask=True) - fpSet.positive.merge(fpSet.negative) - footprints = fpSet.positive - slateIm.set(0) - spanSetList = [footprint.spans for footprint in footprints.getFootprints()] - - # Remove artifacts due to defects before they contribute to - # the epochCountImage. - if self.config.doPrefilterArtifacts: - spanSetList = self.prefilterArtifacts(spanSetList, warpDiffExp) - - # Clear mask before adding prefiltered spanSets - self.detect.clearMask(warpDiffExp.mask) - for spans in spanSetList: - spans.setImage(slateIm, 1, doClip=True) - spans.setMask(warpDiffExp.mask, warpDiffExp.mask.getPlaneBitMask("DETECTED")) - epochCountImage += slateIm - - if self.config.doFilterMorphological: - maskName = self.config.streakMaskName - _ = self.maskStreaks.run(warpDiffExp) - streakMask = warpDiffExp.mask - spanSetStreak = afwGeom.SpanSet.fromMask( - streakMask, streakMask.getPlaneBitMask(maskName) - ).split() - # Pad the streaks to account for low-surface brightness - # wings. - psf = warpDiffExp.getPsf() - for s, sset in enumerate(spanSetStreak): - psfShape = psf.computeShape(sset.computeCentroid()) - dilation = self.config.growStreakFp * psfShape.getDeterminantRadius() - sset_dilated = sset.dilated(int(dilation)) - spanSetStreak[s] = sset_dilated - - # PSF-Matched warps have less available area (~the matching - # kernel) because the calexps undergo a second convolution. - # Pixels with data in the direct warp but not in the - # PSF-matched warp will not have their artifacts detected. - # NaNs from the PSF-matched warp therefore must be masked in - # the direct warp. - nans = numpy.where(numpy.isnan(warpDiffExp.maskedImage.image.array), 1, 0) - nansMask = afwImage.makeMaskFromArray(nans.astype(afwImage.MaskPixel)) - nansMask.setXY0(warpDiffExp.getXY0()) - edgeMask = warpDiffExp.mask - spanSetEdgeMask = afwGeom.SpanSet.fromMask(edgeMask, edgeMask.getPlaneBitMask("EDGE")).split() - else: - # If the directWarp has <1% coverage, the psfMatchedWarp can - # have 0% and not exist. In this case, mask the whole epoch. - nansMask = afwImage.MaskX(coaddBBox, 1) - spanSetList = [] - spanSetEdgeMask = [] - spanSetStreak = [] - - spanSetNoDataMask = afwGeom.SpanSet.fromMask(nansMask).split() - - spanSetNoDataMaskList.append(spanSetNoDataMask) - spanSetArtifactList.append(spanSetList) - spanSetEdgeList.append(spanSetEdgeMask) - if self.config.doFilterMorphological: - spanSetBadMorphoList.append(spanSetStreak) - - if lsstDebug.Info(__name__).saveCountIm: - path = self._dataRef2DebugPath("epochCountIm", tempExpRefList[0], coaddLevel=True) - epochCountImage.writeFits(path) - - for i, spanSetList in enumerate(spanSetArtifactList): - if spanSetList: - filteredSpanSetList = self.filterArtifacts( - spanSetList, epochCountImage, nImage, templateFootprints - ) - spanSetArtifactList[i] = filteredSpanSetList - if self.config.doFilterMorphological: - spanSetArtifactList[i] += spanSetBadMorphoList[i] + self.log.debug("Computing online coadd.") - altMasks = [] - for artifacts, noData, edge in zip(spanSetArtifactList, spanSetNoDataMaskList, spanSetEdgeList): - altMasks.append({"CLIPPED": artifacts, "NO_DATA": noData, "EDGE": edge}) - return altMasks + coaddExposure.mask.addMaskPlane("REJECTED") + coaddExposure.mask.addMaskPlane("CLIPPED") + coaddExposure.mask.addMaskPlane("SENSOR_EDGE") + maskMap = self.setRejectedMaskMapping(statsCtrl) + thresholdDict = AccumulatorMeanStack.stats_ctrl_to_threshold_dict(statsCtrl) - def prefilterArtifacts(self, spanSetList, exp): - """Remove artifact candidates covered by bad mask plane. + bbox = coaddExposure.maskedImage.getBBox() - Any future editing of the candidate list that does not depend on - temporal information should go in this method. + stacker = AccumulatorMeanStack( + coaddExposure.image.array.shape, + statsCtrl.getAndMask(), + mask_threshold_dict=thresholdDict, + mask_map=maskMap, + no_good_pixels_mask=statsCtrl.getNoGoodPixelsMask(), + calc_error_from_input_variance=self.config.calcErrorFromInputVariance, + compute_n_image=(nImage is not None), + ) - Parameters - ---------- - spanSetList : `list` [`lsst.afw.geom.SpanSet`] - List of SpanSets representing artifact candidates. - exp : `lsst.afw.image.Exposure` - Exposure containing mask planes used to prefilter. + for tempExpRef, imageScaler, altMask, weight in zip( + tempExpRefList, imageScalerList, altMaskList, weightList + ): + exposure = tempExpRef.get() + maskedImage = exposure.getMaskedImage() + mask = maskedImage.getMask() + if altMask is not None: + self.applyAltMaskPlanes(mask, altMask) + imageScaler.scaleMaskedImage(maskedImage) + if self.config.removeMaskPlanes: + self.removeMaskPlanes(maskedImage) - Returns - ------- - returnSpanSetList : `list` [`lsst.afw.geom.SpanSet`] - List of SpanSets with artifacts. - """ - badPixelMask = exp.mask.getPlaneBitMask(self.config.prefilterArtifactsMaskPlanes) - goodArr = (exp.mask.array & badPixelMask) == 0 - returnSpanSetList = [] - bbox = exp.getBBox() - x0, y0 = exp.getXY0() - for i, span in enumerate(spanSetList): - y, x = span.clippedTo(bbox).indices() - yIndexLocal = numpy.array(y) - y0 - xIndexLocal = numpy.array(x) - x0 - goodRatio = numpy.count_nonzero(goodArr[yIndexLocal, xIndexLocal]) / span.getArea() - if goodRatio > self.config.prefilterArtifactsRatio: - returnSpanSetList.append(span) - return returnSpanSetList - - def filterArtifacts(self, spanSetList, epochCountImage, nImage, footprintsToExclude=None): - """Filter artifact candidates. + stacker.add_masked_image(maskedImage, weight=weight) - Parameters - ---------- - spanSetList : `list` [`lsst.afw.geom.SpanSet`] - List of SpanSets representing artifact candidates. - epochCountImage : `lsst.afw.image.Image` - Image of accumulated number of warpDiff detections. - nImage : `lsst.afw.image.ImageU` - Image of the accumulated number of total epochs contributing. + if self.config.doInputMap: + visit = exposure.getInfo().getCoaddInputs().visits[0].getId() + self.inputMapper.mask_warp_bbox(bbox, visit, mask, statsCtrl.getAndMask()) - Returns - ------- - maskSpanSetList : `list` [`lsst.afw.geom.SpanSet`] - List of SpanSets with artifacts. - """ - maskSpanSetList = [] - x0, y0 = epochCountImage.getXY0() - for i, span in enumerate(spanSetList): - y, x = span.indices() - yIdxLocal = [y1 - y0 for y1 in y] - xIdxLocal = [x1 - x0 for x1 in x] - outlierN = epochCountImage.array[yIdxLocal, xIdxLocal] - totalN = nImage.array[yIdxLocal, xIdxLocal] - - # effectiveMaxNumEpochs is broken line (fraction of N) with - # characteristic config.maxNumEpochs. - effMaxNumEpochsHighN = self.config.maxNumEpochs + self.config.maxFractionEpochsHigh * numpy.mean( - totalN - ) - effMaxNumEpochsLowN = self.config.maxFractionEpochsLow * numpy.mean(totalN) - effectiveMaxNumEpochs = int(min(effMaxNumEpochsLowN, effMaxNumEpochsHighN)) - nPixelsBelowThreshold = numpy.count_nonzero((outlierN > 0) & (outlierN <= effectiveMaxNumEpochs)) - percentBelowThreshold = nPixelsBelowThreshold / len(outlierN) - if percentBelowThreshold > self.config.spatialThreshold: - maskSpanSetList.append(span) - - if self.config.doPreserveContainedBySource and footprintsToExclude is not None: - # If a candidate is contained by a footprint on the template coadd, - # do not clip. - filteredMaskSpanSetList = [] - for span in maskSpanSetList: - doKeep = True - for footprint in footprintsToExclude.positive.getFootprints(): - if footprint.spans.contains(span): - doKeep = False - break - if doKeep: - filteredMaskSpanSetList.append(span) - maskSpanSetList = filteredMaskSpanSetList - - return maskSpanSetList - - def _readAndComputeWarpDiff(self, warpRef, imageScaler, templateCoadd): - """Fetch a warp from the butler and return a warpDiff. + stacker.fill_stacked_masked_image(coaddExposure.maskedImage) + + if nImage is not None: + nImage.array[:, :] = stacker.n_image + + def setInexactPsf(self, mask): + """Set INEXACT_PSF mask plane. + + If any of the input images isn't represented in the coadd (due to + clipped pixels or chip gaps), the `CoaddPsf` will be inexact. Flag + these pixels. Parameters ---------- - warpRef : `lsst.daf.butler.DeferredDatasetHandle` - Handle for the warp. - imageScaler : `lsst.pipe.tasks.scaleZeroPoint.ImageScaler` - An image scaler object. - templateCoadd : `lsst.afw.image.Exposure` - Exposure to be substracted from the scaled warp. - - Returns - ------- - warp : `lsst.afw.image.Exposure` - Exposure of the image difference between the warp and template. + mask : `lsst.afw.image.Mask` + Coadded exposure's mask, modified in-place. """ - # If the PSF-Matched warp did not exist for this direct warp - # None is holding its place to maintain order in Gen 3 - if warpRef is None: - return None - - warp = warpRef.get() - # direct image scaler OK for PSF-matched Warp - imageScaler.scaleMaskedImage(warp.getMaskedImage()) - mi = warp.getMaskedImage() - if self.config.doScaleWarpVariance: - try: - self.scaleWarpVariance.run(mi) - except Exception as exc: - self.log.warning("Unable to rescale variance of warp (%s); leaving it as-is", exc) - mi -= templateCoadd.getMaskedImage() - return warp + mask.addMaskPlane("INEXACT_PSF") + inexactPsf = mask.getPlaneBitMask("INEXACT_PSF") + sensorEdge = mask.getPlaneBitMask("SENSOR_EDGE") # chip edges (so PSF is discontinuous) + clipped = mask.getPlaneBitMask("CLIPPED") # pixels clipped from coadd + rejected = mask.getPlaneBitMask("REJECTED") # pixels rejected from coadd due to masks + array = mask.getArray() + selected = array & (sensorEdge | clipped | rejected) > 0 + array[selected] |= inexactPsf + + +def countMaskFromFootprint(mask, footprint, bitmask, ignoreMask): + """Function to count the number of pixels with a specific mask in a + footprint. + + Find the intersection of mask & footprint. Count all pixels in the mask + that are in the intersection that have bitmask set but do not have + ignoreMask set. Return the count. + + Parameters + ---------- + mask : `lsst.afw.image.Mask` + Mask to define intersection region by. + footprint : `lsst.afw.detection.Footprint` + Footprint to define the intersection region by. + bitmask : `Unknown` + Specific mask that we wish to count the number of occurances of. + ignoreMask : `Unknown` + Pixels to not consider. + + Returns + ------- + result : `int` + Number of pixels in footprint with specified mask. + """ + bbox = footprint.getBBox() + bbox.clip(mask.getBBox(afwImage.PARENT)) + fp = afwImage.Mask(bbox) + subMask = mask.Factory(mask, bbox, afwImage.PARENT) + footprint.spans.setMask(fp, bitmask) + return numpy.logical_and( + (subMask.getArray() & fp.getArray()) > 0, (subMask.getArray() & ignoreMask) == 0 + ).sum() diff --git a/python/lsst/drp/tasks/compare_warp.py b/python/lsst/drp/tasks/compare_warp.py new file mode 100644 index 00000000..f8fd6e5e --- /dev/null +++ b/python/lsst/drp/tasks/compare_warp.py @@ -0,0 +1,699 @@ +# This file is part of drp_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +__all__ = ( + "CompareWarpAssembleCoaddConfig", + "CompareWarpAssembleCoaddTask", +) + +import copy +import logging + +import lsst.afw.geom as afwGeom +import lsst.afw.image as afwImage +import lsst.afw.table as afwTable +import lsst.pex.config as pexConfig +import lsst.pipe.base as pipeBase +import lsst.utils as utils +import lsstDebug +import numpy +from lsst.meas.algorithms import ScaleVarianceTask, SourceDetectionTask +from lsst.pipe.tasks.coaddBase import reorderAndPadList +from lsst.pipe.tasks.maskStreaks import MaskStreaksTask +from lsst.utils.timer import timeMethod + +from .assemble_coadd import AssembleCoaddConnections, AssembleCoaddTask, BaseAssembleCoaddConfig + +log = logging.getLogger(__name__) + + +class CompareWarpAssembleCoaddConnections(AssembleCoaddConnections): + psfMatchedWarps = pipeBase.connectionTypes.Input( + doc=( + "PSF-Matched Warps are required by CompareWarp regardless of the coadd type requested. " + "Only PSF-Matched Warps make sense for image subtraction. " + "Therefore, they must be an additional declared input." + ), + name="{inputCoaddName}Coadd_psfMatchedWarp", + storageClass="ExposureF", + dimensions=("tract", "patch", "skymap", "visit"), + deferLoad=True, + multiple=True, + ) + templateCoadd = pipeBase.connectionTypes.Output( + doc=( + "Model of the static sky, used to find temporal artifacts. Typically a PSF-Matched, " + "sigma-clipped coadd. Written if and only if assembleStaticSkyModel.doWrite=True" + ), + name="{outputCoaddName}CoaddPsfMatched", + storageClass="ExposureF", + dimensions=("tract", "patch", "skymap", "band"), + ) + + def __init__(self, *, config=None): + super().__init__(config=config) + if config and not config.assembleStaticSkyModel.doWrite: + self.outputs.remove("templateCoadd") + config.validate() + + +class CompareWarpAssembleCoaddConfig( + BaseAssembleCoaddConfig, pipelineConnections=CompareWarpAssembleCoaddConnections +): + assembleStaticSkyModel = pexConfig.ConfigurableField( + target=AssembleCoaddTask, + doc="Task to assemble an artifact-free, PSF-matched Coadd to serve as " + "a naive/first-iteration model of the static sky.", + ) + assembleCoadd = pexConfig.ConfigurableField( + target=AssembleCoaddTask, + doc="Task to assemble a coadd from a set of warps", + ) + detect = pexConfig.ConfigurableField( + target=SourceDetectionTask, + doc="Detect outlier sources on difference between each psfMatched warp and static sky model", + ) + detectTemplate = pexConfig.ConfigurableField( + target=SourceDetectionTask, + doc="Detect sources on static sky model. Only used if doPreserveContainedBySource is True", + ) + maskStreaks = pexConfig.ConfigurableField( + target=MaskStreaksTask, + doc="Detect streaks on difference between each psfMatched warp and static sky model. Only used if " + "doFilterMorphological is True. Adds a mask plane to an exposure, with the mask plane name set by" + "streakMaskName", + ) + streakMaskName = pexConfig.Field( + dtype=str, + default="STREAK", + doc="Name of mask bit used for streaks", + ) + maxNumEpochs = pexConfig.Field( + doc="Charactistic maximum local number of epochs/visits in which an artifact candidate can appear " + "and still be masked. The effective maxNumEpochs is a broken linear function of local " + "number of epochs (N): min(maxFractionEpochsLow*N, maxNumEpochs + maxFractionEpochsHigh*N). " + "For each footprint detected on the image difference between the psfMatched warp and static sky " + "model, if a significant fraction of pixels (defined by spatialThreshold) are residuals in more " + "than the computed effective maxNumEpochs, the artifact candidate is deemed persistant rather " + "than transient and not masked.", + dtype=int, + default=2, + ) + maxFractionEpochsLow = pexConfig.RangeField( + doc="Fraction of local number of epochs (N) to use as effective maxNumEpochs for low N. " + "Effective maxNumEpochs = " + "min(maxFractionEpochsLow * N, maxNumEpochs + maxFractionEpochsHigh * N)", + dtype=float, + default=0.4, + min=0.0, + max=1.0, + ) + maxFractionEpochsHigh = pexConfig.RangeField( + doc="Fraction of local number of epochs (N) to use as effective maxNumEpochs for high N. " + "Effective maxNumEpochs = " + "min(maxFractionEpochsLow * N, maxNumEpochs + maxFractionEpochsHigh * N)", + dtype=float, + default=0.03, + min=0.0, + max=1.0, + ) + spatialThreshold = pexConfig.RangeField( + doc="Unitless fraction of pixels defining how much of the outlier region has to meet the " + "temporal criteria. If 0, clip all. If 1, clip none.", + dtype=float, + default=0.5, + min=0.0, + max=1.0, + inclusiveMin=True, + inclusiveMax=True, + ) + doScaleWarpVariance = pexConfig.Field( + doc="Rescale Warp variance plane using empirical noise?", + dtype=bool, + default=True, + ) + scaleWarpVariance = pexConfig.ConfigurableField( + target=ScaleVarianceTask, + doc="Rescale variance on warps", + ) + doPreserveContainedBySource = pexConfig.Field( + doc="Rescue artifacts from clipping that completely lie within a footprint detected" + "on the PsfMatched Template Coadd. Replicates a behavior of SafeClip.", + dtype=bool, + default=True, + ) + doPrefilterArtifacts = pexConfig.Field( + doc="Ignore artifact candidates that are mostly covered by the bad pixel mask, " + "because they will be excluded anyway. This prevents them from contributing " + "to the outlier epoch count image and potentially being labeled as persistant." + "'Mostly' is defined by the config 'prefilterArtifactsRatio'.", + dtype=bool, + default=True, + ) + prefilterArtifactsMaskPlanes = pexConfig.ListField( + doc="Prefilter artifact candidates that are mostly covered by these bad mask planes.", + dtype=str, + default=("NO_DATA", "BAD", "SAT", "SUSPECT"), + ) + prefilterArtifactsRatio = pexConfig.Field( + doc="Prefilter artifact candidates with less than this fraction overlapping good pixels", + dtype=float, + default=0.05, + ) + doFilterMorphological = pexConfig.Field( + doc="Filter artifact candidates based on morphological criteria, i.g. those that appear to " + "be streaks.", + dtype=bool, + default=False, + ) + growStreakFp = pexConfig.Field( + doc="Grow streak footprints by this number multiplied by the PSF width", + dtype=float, + default=5, + ) + + def setDefaults(self): + super().setDefaults() + self.statistic = "MEAN" + self.doUsePsfMatchedPolygons = True + + # Real EDGE removed by psfMatched NO_DATA border half the width of the + # matching kernel. CompareWarp applies psfMatched EDGE pixels to + # directWarps before assembling. + if "EDGE" in self.badMaskPlanes: + self.badMaskPlanes.remove("EDGE") + self.removeMaskPlanes.append("EDGE") + self.assembleStaticSkyModel.badMaskPlanes = [ + "NO_DATA", + ] + self.assembleStaticSkyModel.warpType = "psfMatched" + self.assembleStaticSkyModel.connections.warpType = "psfMatched" + self.assembleStaticSkyModel.statistic = "MEANCLIP" + self.assembleStaticSkyModel.sigmaClip = 2.5 + self.assembleStaticSkyModel.clipIter = 3 + self.assembleStaticSkyModel.calcErrorFromInputVariance = False + self.assembleStaticSkyModel.doWrite = False + self.detect.doTempLocalBackground = False + self.detect.reEstimateBackground = False + self.detect.returnOriginalFootprints = False + self.detect.thresholdPolarity = "both" + self.detect.thresholdValue = 5 + self.detect.minPixels = 4 + self.detect.isotropicGrow = True + self.detect.thresholdType = "pixel_stdev" + self.detect.nSigmaToGrow = 0.4 + # The default nSigmaToGrow for SourceDetectionTask is already 2.4, + # Explicitly restating because ratio with detect.nSigmaToGrow matters + self.detectTemplate.nSigmaToGrow = 2.4 + self.detectTemplate.doTempLocalBackground = False + self.detectTemplate.reEstimateBackground = False + self.detectTemplate.returnOriginalFootprints = False + + def validate(self): + super().validate() + if self.assembleStaticSkyModel.doNImage: + raise ValueError( + "No dataset type exists for a PSF-Matched Template N Image." + "Please set assembleStaticSkyModel.doNImage=False" + ) + + if self.assembleStaticSkyModel.doWrite and (self.warpType == self.assembleStaticSkyModel.warpType): + raise ValueError( + "warpType (%s) == assembleStaticSkyModel.warpType (%s) and will compete for " + "the same dataset name. Please set assembleStaticSkyModel.doWrite to False " + "or warpType to 'direct'. assembleStaticSkyModel.warpType should ways be " + "'PsfMatched'" % (self.warpType, self.assembleStaticSkyModel.warpType) + ) + + +class CompareWarpAssembleCoaddTask(AssembleCoaddTask): + """Assemble a compareWarp coadded image from a set of warps + by masking artifacts detected by comparing PSF-matched warps. + + In ``AssembleCoaddTask``, we compute the coadd as an clipped mean (i.e., + we clip outliers). The problem with doing this is that when computing the + coadd PSF at a given location, individual visit PSFs from visits with + outlier pixels contribute to the coadd PSF and cannot be treated correctly. + In this task, we correct for this behavior by creating a new badMaskPlane + 'CLIPPED' which marks pixels in the individual warps suspected to contain + an artifact. We populate this plane on the input warps by comparing + PSF-matched warps with a PSF-matched median coadd which serves as a + model of the static sky. Any group of pixels that deviates from the + PSF-matched template coadd by more than config.detect.threshold sigma, + is an artifact candidate. The candidates are then filtered to remove + variable sources and sources that are difficult to subtract such as + bright stars. This filter is configured using the config parameters + ``temporalThreshold`` and ``spatialThreshold``. The temporalThreshold is + the maximum fraction of epochs that the deviation can appear in and still + be considered an artifact. The spatialThreshold is the maximum fraction of + pixels in the footprint of the deviation that appear in other epochs + (where other epochs is defined by the temporalThreshold). If the deviant + region meets this criteria of having a significant percentage of pixels + that deviate in only a few epochs, these pixels have the 'CLIPPED' bit + set in the mask. These regions will not contribute to the final coadd. + Furthermore, any routine to determine the coadd PSF can now be cognizant + of clipped regions. Note that the algorithm implemented by this task is + preliminary and works correctly for HSC data. Parameter modifications and + or considerable redesigning of the algorithm is likley required for other + surveys. + + ``CompareWarpAssembleCoaddTask`` sub-classes + ``AssembleCoaddTask`` and instantiates ``AssembleCoaddTask`` + as a subtask to generate the TemplateCoadd (the model of the static sky). + + Notes + ----- + Debugging: + This task supports the following debug variables: + - ``saveCountIm`` + If True then save the Epoch Count Image as a fits file in the `figPath` + - ``figPath`` + Path to save the debug fits images and figures + """ + + ConfigClass = CompareWarpAssembleCoaddConfig + _DefaultName = "compareWarpAssembleCoadd" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.makeSubtask("assembleStaticSkyModel") + self.makeSubtask("assembleCoadd") + detectionSchema = afwTable.SourceTable.makeMinimalSchema() + self.makeSubtask("detect", schema=detectionSchema) + if self.config.doPreserveContainedBySource: + self.makeSubtask("detectTemplate", schema=afwTable.SourceTable.makeMinimalSchema()) + if self.config.doScaleWarpVariance: + self.makeSubtask("scaleWarpVariance") + if self.config.doFilterMorphological: + self.makeSubtask("maskStreaks") + + @utils.inheritDoc(AssembleCoaddTask) + def _makeSupplementaryData(self, butlerQC, inputRefs, outputRefs): + """Generate a templateCoadd to use as a naive model of static sky to + subtract from PSF-Matched warps. + + Returns + ------- + result : `~lsst.pipe.base.Struct` + Results as a struct with attributes: + + ``templateCoadd`` + Coadded exposure (`lsst.afw.image.Exposure`). + ``nImage`` + Keeps track of exposure count for each pixel + (`lsst.afw.image.ImageU`). + + Raises + ------ + RuntimeError + Raised if ``templateCoadd`` is `None`. + """ + # Ensure that psfMatchedWarps are used as input warps for template + # generation. + staticSkyModelInputRefs = copy.deepcopy(inputRefs) + staticSkyModelInputRefs.inputWarps = inputRefs.psfMatchedWarps + + # Because subtasks don't have connections we have to make one. + # The main task's `templateCoadd` is the subtask's `coaddExposure` + staticSkyModelOutputRefs = copy.deepcopy(outputRefs) + if self.config.assembleStaticSkyModel.doWrite: + staticSkyModelOutputRefs.coaddExposure = staticSkyModelOutputRefs.templateCoadd + # Remove template coadd from both subtask's and main tasks outputs, + # because it is handled by the subtask as `coaddExposure` + del outputRefs.templateCoadd + del staticSkyModelOutputRefs.templateCoadd + + # A PSF-Matched nImage does not exist as a dataset type + if "nImage" in staticSkyModelOutputRefs.keys(): + del staticSkyModelOutputRefs.nImage + + templateCoadd = self.assembleStaticSkyModel.runQuantum( + butlerQC, staticSkyModelInputRefs, staticSkyModelOutputRefs + ) + if templateCoadd is None: + raise RuntimeError(self._noTemplateMessage(self.assembleStaticSkyModel.warpType)) + + return pipeBase.Struct( + templateCoadd=templateCoadd.coaddExposure, + nImage=templateCoadd.nImage, + warpRefList=templateCoadd.warpRefList, + imageScalerList=templateCoadd.imageScalerList, + weightList=templateCoadd.weightList, + ) + + def _noTemplateMessage(self, warpType): + warpName = warpType[0].upper() + warpType[1:] + message = """No %(warpName)s warps were found to build the template coadd which is + required to run CompareWarpAssembleCoaddTask. To continue assembling this type of coadd, + first either rerun makeCoaddTempExp with config.make%(warpName)s=True or + coaddDriver with config.makeCoadTempExp.make%(warpName)s=True, before assembleCoadd. + + Alternatively, to use another algorithm with existing warps, retarget the CoaddDriverConfig to + another algorithm like: + + from lsst.pipe.tasks.assembleCoadd import SafeClipAssembleCoaddTask + config.assemble.retarget(SafeClipAssembleCoaddTask) + """ % { + "warpName": warpName + } + return message + + @utils.inheritDoc(AssembleCoaddTask) + @timeMethod + def run(self, skyInfo, tempExpRefList, imageScalerList, weightList, supplementaryData): + """Notes + ----- + Assemble the coadd. + + Find artifacts and apply them to the warps' masks creating a list of + alternative masks with a new "CLIPPED" plane and updated "NO_DATA" + plane. Then pass these alternative masks to the base class's ``run`` + method. + """ + # Check and match the order of the supplementaryData + # (PSF-matched) inputs to the order of the direct inputs, + # so that the artifact mask is applied to the right warp + dataIds = [ref.dataId for ref in tempExpRefList] + psfMatchedDataIds = [ref.dataId for ref in supplementaryData.warpRefList] + + if dataIds != psfMatchedDataIds: + self.log.info("Reordering and or/padding PSF-matched visit input list") + supplementaryData.warpRefList = reorderAndPadList( + supplementaryData.warpRefList, psfMatchedDataIds, dataIds + ) + supplementaryData.imageScalerList = reorderAndPadList( + supplementaryData.imageScalerList, psfMatchedDataIds, dataIds + ) + + # Use PSF-Matched Warps (and corresponding scalers) and coadd to find + # artifacts. + spanSetMaskList = self.findArtifacts( + supplementaryData.templateCoadd, + supplementaryData.warpRefList, + supplementaryData.imageScalerList, + ) + + badMaskPlanes = self.config.badMaskPlanes[:] + badMaskPlanes.append("CLIPPED") + badPixelMask = afwImage.Mask.getPlaneBitMask(badMaskPlanes) + result = self.assembleCoadd.run( + skyInfo, + tempExpRefList, + imageScalerList, + weightList, + spanSetMaskList, + mask=badPixelMask, + ) + + # Propagate PSF-matched EDGE pixels to coadd SENSOR_EDGE and + # INEXACT_PSF. Psf-Matching moves the real edge inwards. + self.applyAltEdgeMask(result.coaddExposure.maskedImage.mask, spanSetMaskList) + return result + + def applyAltEdgeMask(self, mask, altMaskList): + """Propagate alt EDGE mask to SENSOR_EDGE AND INEXACT_PSF planes. + + Parameters + ---------- + mask : `lsst.afw.image.Mask` + Original mask. + altMaskList : `list` of `dict` + List of Dicts containing ``spanSet`` lists. + Each element contains the new mask plane name (e.g. "CLIPPED + and/or "NO_DATA") as the key, and list of ``SpanSets`` to apply to + the mask. + """ + maskValue = mask.getPlaneBitMask(["SENSOR_EDGE", "INEXACT_PSF"]) + for visitMask in altMaskList: + if "EDGE" in visitMask: + for spanSet in visitMask["EDGE"]: + spanSet.clippedTo(mask.getBBox()).setMask(mask, maskValue) + + def findArtifacts(self, templateCoadd, tempExpRefList, imageScalerList): + """Find artifacts. + + Loop through warps twice. The first loop builds a map with the count + of how many epochs each pixel deviates from the templateCoadd by more + than ``config.chiThreshold`` sigma. The second loop takes each + difference image and filters the artifacts detected in each using + count map to filter out variable sources and sources that are + difficult to subtract cleanly. + + Parameters + ---------- + templateCoadd : `lsst.afw.image.Exposure` + Exposure to serve as model of static sky. + tempExpRefList : `list` + List of data references to warps. + imageScalerList : `list` + List of image scalers. + + Returns + ------- + altMasks : `list` of `dict` + List of dicts containing information about CLIPPED + (i.e., artifacts), NO_DATA, and EDGE pixels. + """ + self.log.debug("Generating Count Image, and mask lists.") + coaddBBox = templateCoadd.getBBox() + slateIm = afwImage.ImageU(coaddBBox) + epochCountImage = afwImage.ImageU(coaddBBox) + nImage = afwImage.ImageU(coaddBBox) + spanSetArtifactList = [] + spanSetNoDataMaskList = [] + spanSetEdgeList = [] + spanSetBadMorphoList = [] + badPixelMask = self.getBadPixelMask() + + # mask of the warp diffs should = that of only the warp + templateCoadd.mask.clearAllMaskPlanes() + + if self.config.doPreserveContainedBySource: + templateFootprints = self.detectTemplate.detectFootprints(templateCoadd) + else: + templateFootprints = None + + for warpRef, imageScaler in zip(tempExpRefList, imageScalerList): + warpDiffExp = self._readAndComputeWarpDiff(warpRef, imageScaler, templateCoadd) + if warpDiffExp is not None: + # This nImage only approximates the final nImage because it + # uses the PSF-matched mask. + nImage.array += ( + numpy.isfinite(warpDiffExp.image.array) * ((warpDiffExp.mask.array & badPixelMask) == 0) + ).astype(numpy.uint16) + fpSet = self.detect.detectFootprints(warpDiffExp, doSmooth=False, clearMask=True) + fpSet.positive.merge(fpSet.negative) + footprints = fpSet.positive + slateIm.set(0) + spanSetList = [footprint.spans for footprint in footprints.getFootprints()] + + # Remove artifacts due to defects before they contribute to + # the epochCountImage. + if self.config.doPrefilterArtifacts: + spanSetList = self.prefilterArtifacts(spanSetList, warpDiffExp) + + # Clear mask before adding prefiltered spanSets + self.detect.clearMask(warpDiffExp.mask) + for spans in spanSetList: + spans.setImage(slateIm, 1, doClip=True) + spans.setMask(warpDiffExp.mask, warpDiffExp.mask.getPlaneBitMask("DETECTED")) + epochCountImage += slateIm + + if self.config.doFilterMorphological: + maskName = self.config.streakMaskName + _ = self.maskStreaks.run(warpDiffExp) + streakMask = warpDiffExp.mask + spanSetStreak = afwGeom.SpanSet.fromMask( + streakMask, streakMask.getPlaneBitMask(maskName) + ).split() + # Pad the streaks to account for low-surface brightness + # wings. + psf = warpDiffExp.getPsf() + for s, sset in enumerate(spanSetStreak): + psfShape = psf.computeShape(sset.computeCentroid()) + dilation = self.config.growStreakFp * psfShape.getDeterminantRadius() + sset_dilated = sset.dilated(int(dilation)) + spanSetStreak[s] = sset_dilated + + # PSF-Matched warps have less available area (~the matching + # kernel) because the calexps undergo a second convolution. + # Pixels with data in the direct warp but not in the + # PSF-matched warp will not have their artifacts detected. + # NaNs from the PSF-matched warp therefore must be masked in + # the direct warp. + nans = numpy.where(numpy.isnan(warpDiffExp.maskedImage.image.array), 1, 0) + nansMask = afwImage.makeMaskFromArray(nans.astype(afwImage.MaskPixel)) + nansMask.setXY0(warpDiffExp.getXY0()) + edgeMask = warpDiffExp.mask + spanSetEdgeMask = afwGeom.SpanSet.fromMask(edgeMask, edgeMask.getPlaneBitMask("EDGE")).split() + else: + # If the directWarp has <1% coverage, the psfMatchedWarp can + # have 0% and not exist. In this case, mask the whole epoch. + nansMask = afwImage.MaskX(coaddBBox, 1) + spanSetList = [] + spanSetEdgeMask = [] + spanSetStreak = [] + + spanSetNoDataMask = afwGeom.SpanSet.fromMask(nansMask).split() + + spanSetNoDataMaskList.append(spanSetNoDataMask) + spanSetArtifactList.append(spanSetList) + spanSetEdgeList.append(spanSetEdgeMask) + if self.config.doFilterMorphological: + spanSetBadMorphoList.append(spanSetStreak) + + if lsstDebug.Info(__name__).saveCountIm: + path = self._dataRef2DebugPath("epochCountIm", tempExpRefList[0], coaddLevel=True) + epochCountImage.writeFits(path) + + for i, spanSetList in enumerate(spanSetArtifactList): + if spanSetList: + filteredSpanSetList = self.filterArtifacts( + spanSetList, epochCountImage, nImage, templateFootprints + ) + spanSetArtifactList[i] = filteredSpanSetList + if self.config.doFilterMorphological: + spanSetArtifactList[i] += spanSetBadMorphoList[i] + + altMasks = [] + for artifacts, noData, edge in zip(spanSetArtifactList, spanSetNoDataMaskList, spanSetEdgeList): + altMasks.append({"CLIPPED": artifacts, "NO_DATA": noData, "EDGE": edge}) + return altMasks + + def prefilterArtifacts(self, spanSetList, exp): + """Remove artifact candidates covered by bad mask plane. + + Any future editing of the candidate list that does not depend on + temporal information should go in this method. + + Parameters + ---------- + spanSetList : `list` [`lsst.afw.geom.SpanSet`] + List of SpanSets representing artifact candidates. + exp : `lsst.afw.image.Exposure` + Exposure containing mask planes used to prefilter. + + Returns + ------- + returnSpanSetList : `list` [`lsst.afw.geom.SpanSet`] + List of SpanSets with artifacts. + """ + badPixelMask = exp.mask.getPlaneBitMask(self.config.prefilterArtifactsMaskPlanes) + goodArr = (exp.mask.array & badPixelMask) == 0 + returnSpanSetList = [] + bbox = exp.getBBox() + x0, y0 = exp.getXY0() + for i, span in enumerate(spanSetList): + y, x = span.clippedTo(bbox).indices() + yIndexLocal = numpy.array(y) - y0 + xIndexLocal = numpy.array(x) - x0 + goodRatio = numpy.count_nonzero(goodArr[yIndexLocal, xIndexLocal]) / span.getArea() + if goodRatio > self.config.prefilterArtifactsRatio: + returnSpanSetList.append(span) + return returnSpanSetList + + def filterArtifacts(self, spanSetList, epochCountImage, nImage, footprintsToExclude=None): + """Filter artifact candidates. + + Parameters + ---------- + spanSetList : `list` [`lsst.afw.geom.SpanSet`] + List of SpanSets representing artifact candidates. + epochCountImage : `lsst.afw.image.Image` + Image of accumulated number of warpDiff detections. + nImage : `lsst.afw.image.ImageU` + Image of the accumulated number of total epochs contributing. + + Returns + ------- + maskSpanSetList : `list` [`lsst.afw.geom.SpanSet`] + List of SpanSets with artifacts. + """ + maskSpanSetList = [] + x0, y0 = epochCountImage.getXY0() + for i, span in enumerate(spanSetList): + y, x = span.indices() + yIdxLocal = [y1 - y0 for y1 in y] + xIdxLocal = [x1 - x0 for x1 in x] + outlierN = epochCountImage.array[yIdxLocal, xIdxLocal] + totalN = nImage.array[yIdxLocal, xIdxLocal] + + # effectiveMaxNumEpochs is broken line (fraction of N) with + # characteristic config.maxNumEpochs. + effMaxNumEpochsHighN = self.config.maxNumEpochs + self.config.maxFractionEpochsHigh * numpy.mean( + totalN + ) + effMaxNumEpochsLowN = self.config.maxFractionEpochsLow * numpy.mean(totalN) + effectiveMaxNumEpochs = int(min(effMaxNumEpochsLowN, effMaxNumEpochsHighN)) + nPixelsBelowThreshold = numpy.count_nonzero((outlierN > 0) & (outlierN <= effectiveMaxNumEpochs)) + percentBelowThreshold = nPixelsBelowThreshold / len(outlierN) + if percentBelowThreshold > self.config.spatialThreshold: + maskSpanSetList.append(span) + + if self.config.doPreserveContainedBySource and footprintsToExclude is not None: + # If a candidate is contained by a footprint on the template coadd, + # do not clip. + filteredMaskSpanSetList = [] + for span in maskSpanSetList: + doKeep = True + for footprint in footprintsToExclude.positive.getFootprints(): + if footprint.spans.contains(span): + doKeep = False + break + if doKeep: + filteredMaskSpanSetList.append(span) + maskSpanSetList = filteredMaskSpanSetList + + return maskSpanSetList + + def _readAndComputeWarpDiff(self, warpRef, imageScaler, templateCoadd): + """Fetch a warp from the butler and return a warpDiff. + + Parameters + ---------- + warpRef : `lsst.daf.butler.DeferredDatasetHandle` + Handle for the warp. + imageScaler : `lsst.pipe.tasks.scaleZeroPoint.ImageScaler` + An image scaler object. + templateCoadd : `lsst.afw.image.Exposure` + Exposure to be substracted from the scaled warp. + + Returns + ------- + warp : `lsst.afw.image.Exposure` + Exposure of the image difference between the warp and template. + """ + # If the PSF-Matched warp did not exist for this direct warp + # None is holding its place to maintain order in Gen 3 + if warpRef is None: + return None + + warp = warpRef.get() + # direct image scaler OK for PSF-matched Warp + imageScaler.scaleMaskedImage(warp.getMaskedImage()) + mi = warp.getMaskedImage() + if self.config.doScaleWarpVariance: + try: + self.scaleWarpVariance.run(mi) + except Exception as exc: + self.log.warning("Unable to rescale variance of warp (%s); leaving it as-is", exc) + mi -= templateCoadd.getMaskedImage() + return warp diff --git a/python/lsst/drp/tasks/dcr_assemble_coadd.py b/python/lsst/drp/tasks/dcr_assemble_coadd.py index 567d9efd..51973b50 100644 --- a/python/lsst/drp/tasks/dcr_assemble_coadd.py +++ b/python/lsst/drp/tasks/dcr_assemble_coadd.py @@ -21,8 +21,6 @@ __all__ = ["DcrAssembleCoaddConnections", "DcrAssembleCoaddTask", "DcrAssembleCoaddConfig"] -from math import ceil - import lsst.afw.image as afwImage import lsst.afw.table as afwTable import lsst.coadd.utils as coaddUtils @@ -37,14 +35,11 @@ from lsst.pipe.tasks.coaddBase import makeSkyInfo, subBBoxIter from lsst.pipe.tasks.measurePsf import MeasurePsfTask from lsst.utils.timer import timeMethod +from math import ceil from scipy import ndimage -from .assemble_coadd import ( - AssembleCoaddConnections, - AssembleCoaddTask, - CompareWarpAssembleCoaddConfig, - CompareWarpAssembleCoaddTask, -) +from .assemble_coadd import AssembleCoaddConnections, AssembleCoaddTask +from .compare_warp import CompareWarpAssembleCoaddConfig, CompareWarpAssembleCoaddTask class DcrAssembleCoaddConnections( @@ -235,7 +230,7 @@ class DcrAssembleCoaddConfig(CompareWarpAssembleCoaddConfig, pipelineConnections ) def setDefaults(self): - CompareWarpAssembleCoaddConfig.setDefaults(self) + super().setDefaults() self.assembleStaticSkyModel.retarget(CompareWarpAssembleCoaddTask) self.doNImage = True self.assembleStaticSkyModel.warpType = self.warpType diff --git a/tests/test_assemble_coadd.py b/tests/test_assemble_coadd.py index ee0ccba0..6ae5e52f 100644 --- a/tests/test_assemble_coadd.py +++ b/tests/test_assemble_coadd.py @@ -27,12 +27,8 @@ import lsst.utils.tests import numpy as np from assemble_coadd_test_utils import MockCoaddTestData, makeMockSkyInfo -from lsst.drp.tasks.assemble_coadd import ( - AssembleCoaddConfig, - AssembleCoaddTask, - CompareWarpAssembleCoaddConfig, - CompareWarpAssembleCoaddTask, -) +from lsst.drp.tasks.assemble_coadd import AssembleCoaddConfig, AssembleCoaddTask +from lsst.drp.tasks.compare_warp import CompareWarpAssembleCoaddConfig, CompareWarpAssembleCoaddTask from lsst.drp.tasks.dcr_assemble_coadd import DcrAssembleCoaddConfig, DcrAssembleCoaddTask __all__ = [ @@ -105,6 +101,8 @@ def setDefaults(self): super().setDefaults() self.assembleStaticSkyModel.retarget(MockAssembleCoaddTask) self.assembleStaticSkyModel.doWrite = False + self.assembleCoadd.retarget(MockAssembleCoaddTask) + self.assembleCoadd.doWrite = False self.doWrite = False @@ -176,7 +174,7 @@ def __init__(self, *args, **kwargs): class MockInputMapAssembleCoaddConfig(MockCompareWarpAssembleCoaddConfig): def setDefaults(self): super().setDefaults() - self.doInputMap = True + self.assembleCoadd.doInputMap = True class MockInputMapAssembleCoaddTask(MockCompareWarpAssembleCoaddTask):