Skip to content

Commit

Permalink
Backwards compatible updates for Band/Filter
Browse files Browse the repository at this point in the history
  • Loading branch information
rhiannonlynne committed Jan 19, 2025
1 parent 7615c3f commit 39732ec
Show file tree
Hide file tree
Showing 17 changed files with 259 additions and 146 deletions.
221 changes: 87 additions & 134 deletions rubin_scheduler/scheduler/basis_functions/basis_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
"M5DiffBasisFunction",
"M5DiffAtHpixBasisFunction",
"StrictBandBasisFunction",
"StrictFilterBasisFunction",
"BandChangeBasisFunction",
"FilterChangeBasisFunction",
"SlewtimeBasisFunction",
"CadenceEnhanceBasisFunction",
"CadenceEnhanceTrapezoidBasisFunction",
Expand All @@ -21,15 +23,14 @@
"NObsPerYearBasisFunction",
"CadenceInSeasonBasisFunction",
"NearSunHighAirmassBasisFunction",
"NObsHighAmBasisFunction",
"GoodSeeingBasisFunction",
"EclipticBasisFunction",
"VisitGap",
"NGoodSeeingBasisFunction",
"AvoidDirectWind",
"BalanceVisits",
"RewardNObsSequence",
"BandDistBasisFunction",
"FilterDistBasisFunction",
"RewardRisingBasisFunction",
"send_unused_deprecation_warning",
)
Expand Down Expand Up @@ -61,7 +62,7 @@ class BaseBasisFunction:
"""Class that takes features and computes a reward function when
called."""

def __init__(self, nside=DEFAULT_NSIDE, bandname=None, **kwargs):
def __init__(self, nside=DEFAULT_NSIDE, bandname=None, filtername=None, **kwargs):
# Set if basis function needs to be recalculated if there is a new
# observation
self.update_on_newobs = True
Expand All @@ -86,6 +87,13 @@ def __init__(self, nside=DEFAULT_NSIDE, bandname=None, **kwargs):
else:
self.nside = nside

if filtername is not None:
warnings.warn(
"Use of `filtername` will be deprecated in favor of `bandname` at v4", FutureWarning
)
bandname = filtername
# Save filtername as a backup in case someone tries to access it
self.filtername = filtername
self.bandname = bandname

def add_observations_array(self, observations_array, observations_hpid):
Expand Down Expand Up @@ -291,6 +299,14 @@ def _calc_value(self, conditions, indx=None):
return result


class FilterDistBasisFunction(BandDistBasisFunction):
"""Deprecated version of BandDistBasisFunction"""

def __init__(self, filtername="r"):
warnings.warn("FilterDistBasisFunction deprecated for BandDistBasisFunction", FutureWarning)
super().__init__(bandname=filtername)


class NObsPerYearBasisFunction(BaseBasisFunction):
"""Reward areas that have not been observed N-times in the last year
Expand Down Expand Up @@ -325,7 +341,11 @@ def __init__(
season_start_hour=-4.0,
season_end_hour=2.0,
night_max=365,
filtername=None,
):
if filtername is not None:
warnings.warn("filtername deprecated in favor of bandname", FutureWarning)
bandname = filtername
super(NObsPerYearBasisFunction, self).__init__(nside=nside, bandname=bandname)
self.footprint = footprint
self.n_obs = n_obs
Expand Down Expand Up @@ -406,7 +426,11 @@ def __init__(
n_obs_desired=3,
mjd_start=None,
footprint=None,
filtername=None,
):
if filtername is not None:
warnings.warn("filtername deprecated in favor of bandname", FutureWarning)
bandname = filtername
super().__init__(nside=nside, bandname=bandname)
self.seeing_fwhm_max = seeing_fwhm_max
self.m5_penalty_max = m5_penalty_max
Expand Down Expand Up @@ -464,76 +488,6 @@ def az_rel_point(azs, point_az):
return az_rel_moon


class NObsHighAmBasisFunction(BaseBasisFunction):
"""Reward only reward/count observations at high airmass."""

def __init__(
self,
nside=DEFAULT_NSIDE,
bandname="r",
footprint=None,
n_obs=3,
season=300.0,
am_limits=[1.5, 2.2],
out_of_bounds_val=np.nan,
):
send_unused_deprecation_warning("NObsHighAmBasisFunction")
return
super(NObsHighAmBasisFunction, self).__init__(nside=nside, bandname=bandname)
if footprint is None:
footprints, labels = get_current_footprint(self.nside)
footprint = footprints[self.bandname]
self.footprint = footprint
self.out_footprint = np.where((footprint == 0) | np.isnan(footprint))
self.am_limits = am_limits
self.season = season
self.survey_features["last_n_mjds"] = features.Last_n_obs_times(
nside=nside, bandname=bandname, n_obs=n_obs
)

self.result = np.zeros(hp.nside2npix(self.nside), dtype=float) + out_of_bounds_val
self.out_of_bounds_val = out_of_bounds_val

def add_observation(self, observation, indx=None):
# Only count the observations if they are at the airmass limits
if (observation["airmass"] > np.min(self.am_limits)) & (
observation["airmass"] < np.max(self.am_limits)
):
for feature in self.survey_features:
self.survey_features[feature].add_observation(observation, indx=indx)
if self.update_on_newobs:
self.recalc = True

def check_feasibility(self, conditions):
result = True
reward = self._calc_value(conditions)
# If there are no non-NaN values, we're not feasible now
if True not in np.isfinite(reward):
result = False

return result

def _calc_value(self, conditions, indx=None):
result = self.result.copy()
behind_pix = np.where(
(
IntRounded(conditions.mjd - self.survey_features["last_n_mjds"].feature[0])
> IntRounded(self.season)
)
& (IntRounded(conditions.airmass) > IntRounded(np.min(self.am_limits)))
& (IntRounded(conditions.airmass) < IntRounded(np.max(self.am_limits)))
)
result[behind_pix] = 1
result[self.out_footprint] = self.out_of_bounds_val

# Update the last time we had an mjd
self.mjd_last = conditions.mjd + 0
self.recalc = False
self.value = result

return result


class EclipticBasisFunction(BaseBasisFunction):
"""Mark the area around the ecliptic"""

Expand Down Expand Up @@ -566,7 +520,12 @@ class CadenceInSeasonBasisFunction(BaseBasisFunction):
How long to wait before activating the basis function (days).
"""

def __init__(self, drive_map, bandname="griz", season_span=2.5, cadence=2.5, nside=DEFAULT_NSIDE):
def __init__(
self, drive_map, bandname="griz", season_span=2.5, cadence=2.5, nside=DEFAULT_NSIDE, filtername=None
):
if filtername is not None:
warnings.warn("filtername deprecated in favor of bandname", FutureWarning)
bandname = filtername
super(CadenceInSeasonBasisFunction, self).__init__(nside=nside, bandname=bandname)
self.drive_map = drive_map
self.season_span = season_span / 12.0 * np.pi # To radians
Expand Down Expand Up @@ -630,7 +589,11 @@ def __init__(
n_per_season=3,
mjd_start=None,
season_frac_start=0.5,
filtername=None,
):
if filtername is not None:
warnings.warn("filtername deprecated in favor of bandname", FutureWarning)
bandname = filtername
send_unused_deprecation_warning("SeasonCoverageBasisFunction")
super().__init__(nside=nside, bandname=bandname)

Expand Down Expand Up @@ -693,7 +656,10 @@ class AvoidFastRevisitsBasisFunction(BaseBasisFunction):
Will be masked if set to np.nan (default).
"""

def __init__(self, bandname="r", nside=DEFAULT_NSIDE, gap_min=25.0, penalty_val=np.nan):
def __init__(self, bandname="r", nside=DEFAULT_NSIDE, gap_min=25.0, penalty_val=np.nan, filtername=None):
if filtername is not None:
warnings.warn("filtername deprecated in favor of bandname", FutureWarning)
bandname = filtername
super().__init__(nside=nside, bandname=bandname)

self.bandname = bandname
Expand Down Expand Up @@ -767,7 +733,12 @@ class VisitRepeatBasisFunction(BaseBasisFunction):
The number of pairs of observations to attempt to gather
"""

def __init__(self, gap_min=25.0, gap_max=45.0, bandname="r", nside=DEFAULT_NSIDE, npairs=1):
def __init__(
self, gap_min=25.0, gap_max=45.0, bandname="r", nside=DEFAULT_NSIDE, npairs=1, filtername=None
):
if filtername is not None:
warnings.warn("filtername deprecated in favor of bandname", FutureWarning)
bandname = filtername
super(VisitRepeatBasisFunction, self).__init__(nside=nside, bandname=bandname)

self.gap_min = IntRounded(gap_min / 60.0 / 24.0)
Expand Down Expand Up @@ -824,7 +795,10 @@ class M5DiffBasisFunction(BaseBasisFunction):
Default None uses `set_default_nside()`.
"""

def __init__(self, bandname="r", fiducial_FWHMEff=0.7, nside=DEFAULT_NSIDE):
def __init__(self, bandname="r", fiducial_FWHMEff=0.7, nside=DEFAULT_NSIDE, filtername=None):
if filtername is not None:
warnings.warn("filtername deprecated in favor of bandname", FutureWarning)
bandname = filtername
super().__init__(nside=nside, bandname=bandname)
# The dark sky surface brightness values
self.dark_map = None
Expand Down Expand Up @@ -908,6 +882,16 @@ def _calc_value(self, conditions, **kwargs):
return result


class StrictFilterBasisFunction(StrictBandBasisFunction):
"""Deprecated in favor of StrictBandBasisFunction"""

def __init__(self, time_lag=10.0, filtername="r", twi_change=-18.0, note_free="DD"):
warnings.warn(
"StrictFilterBasisFunction deprecated in favor of StrictBandBasisFunction", FutureWarning
)
super().__init__(time_lag=time_lag, bandname=filtername, twi_change=twi_change, note_free=note_free)


class BandChangeBasisFunction(BaseBasisFunction):
"""Reward staying in the current band."""

Expand All @@ -922,6 +906,16 @@ def _calc_value(self, conditions, **kwargs):
return result


class FilterChangeBasisFunction(BandChangeBasisFunction):
"""Deprecated in favor of BandChangeBasisFunction"""

def __init__(self, filtername="r"):
warnings.warn(
"FilterChangeBasisFunction deprecated in favor of BandChangeBasisFunction", FutureWarning
)
super().__init__(bandname=filtername)


class SlewtimeBasisFunction(BaseBasisFunction):
"""Reward slews that take little time
Expand All @@ -941,7 +935,10 @@ class SlewtimeBasisFunction(BaseBasisFunction):
Default None will use `set_default_nside()`.
"""

def __init__(self, max_time=135.0, bandname="r", nside=DEFAULT_NSIDE):
def __init__(self, max_time=135.0, bandname="r", nside=DEFAULT_NSIDE, filtername=None):
if filtername is not None:
warnings.warn("filtername deprecated in favor of bandname", FutureWarning)
bandname = filtername
super(SlewtimeBasisFunction, self).__init__(nside=nside, bandname=bandname)

self.maxtime = max_time
Expand Down Expand Up @@ -998,7 +995,11 @@ def __init__(
enhance_window=[2.1, 3.2],
enhance_val=1.0,
apply_area=None,
filtername=None,
):
if filtername is not None:
warnings.warn("filtername deprecated in favor of bandname", FutureWarning)
bandname = filtername
super(CadenceEnhanceBasisFunction, self).__init__(nside=nside, bandname=bandname)

self.supress_window = np.sort(supress_window)
Expand Down Expand Up @@ -1091,7 +1092,11 @@ def __init__(
enhance_amp=1.0,
apply_area=None,
season_limit=None,
filtername=None,
):
if filtername is not None:
warnings.warn("filtername deprecated in favor of bandname", FutureWarning)
bandname = filtername
super(CadenceEnhanceTrapezoidBasisFunction, self).__init__(nside=nside, bandname=bandname)

self.delay_width = delay_width
Expand Down Expand Up @@ -1280,61 +1285,6 @@ def _calc_value(self, conditions, indx=None):
return result


class GoodSeeingBasisFunction(BaseBasisFunction):
"""Drive observations in good seeing conditions"""

def __init__(
self,
nside=DEFAULT_NSIDE,
bandname="r",
footprint=None,
fwhm_eff_limit=0.8,
mag_diff=0.75,
):
send_unused_deprecation_warning("GoodSeeingBasisFunction")
return
super(GoodSeeingBasisFunction, self).__init__(nside=nside)

self.bandname = bandname
self.fwhm_eff_limit = IntRounded(fwhm_eff_limit)
if footprint is None:
footprints, labels = get_current_footprint(nside=self.nside)
fp = footprints[self.bandname]
else:
fp = footprint
self.out_of_bounds = np.where(fp == 0)[0]
self.result = fp * 0

self.mag_diff = IntRounded(mag_diff)
self.survey_features = {}
self.survey_features["coadd_depth_all"] = features.CoaddedDepth(
bandname=self.bandname, nside=self.nside
)
self.survey_features["coadd_depth_good"] = features.CoaddedDepth(
bandname=self.bandname, nside=self.nside, fwhm_eff_limit=fwhm_eff_limit
)

def _calc_value(self, conditions, **kwargs):
# Seeing is "bad"
if IntRounded(conditions.FWHMeff[self.bandname].min()) > self.fwhm_eff_limit:
return 0
result = self.result.copy()

diff = (
self.survey_features["coadd_depth_all"].feature - self.survey_features["coadd_depth_good"].feature
)
# Where are there things we want to observe?
good_pix = np.where(
(IntRounded(diff) > self.mag_diff)
& (IntRounded(conditions.FWHMeff[self.bandname]) <= self.fwhm_eff_limit)
)
# Hm, should this scale by the mag differences? Probably.
result[good_pix] = diff[good_pix]
result[self.out_of_bounds] = 0

return result


class VisitGap(BaseBasisFunction):
"""Basis function to create a visit gap based on the survey note field.
Expand All @@ -1358,7 +1308,10 @@ class VisitGap(BaseBasisFunction):
the last observation was at least gap in the past.
"""

def __init__(self, note, band_names=None, gap_min=25.0, penalty_val=np.nan):
def __init__(self, note, band_names=None, gap_min=25.0, penalty_val=np.nan, filter_names=None):
if filter_names is not None:
warnings.warn("filter_names deprecated in favor of band_names", FutureWarning)
band_names = filter_names
super().__init__()
self.penalty_val = penalty_val

Expand Down
Loading

0 comments on commit 39732ec

Please sign in to comment.