Skip to content

Commit

Permalink
adding IntRounded as needed
Browse files Browse the repository at this point in the history
  • Loading branch information
yoachim committed Jan 12, 2024
1 parent dba3c3e commit 1816808
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 14 deletions.
15 changes: 8 additions & 7 deletions rubin_scheduler/scheduler/surveys/pointings_survey.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from rubin_scheduler.scheduler.detailers import ParallacticRotationDetailer
from rubin_scheduler.skybrightness_pre import dark_m5
from rubin_scheduler.utils import _angular_separation, _approx_ra_dec2_alt_az
from rubin_scheduler.scheduler.utils import IntRounded

from .base_survey import BaseSurvey

Expand Down Expand Up @@ -133,7 +134,7 @@ def _check_feasibility(self, conditions):
result = True

# if the sun is too high
if conditions.sun_alt > self.sun_alt_limit:
if IntRounded(conditions.sun_alt) > IntRounded(self.sun_alt_limit):
return False

reward = self.calc_reward_function(conditions)
Expand Down Expand Up @@ -199,14 +200,14 @@ def ha_limit(self, conditions):
"""Apply hour angle limits."""
result = self.zeros.copy()
# apply hour angle limits
result[np.where((self.ha > self.ha_max) & (self.ha < self.ha_min))] = np.nan
result[np.where((IntRounded(self.ha) > IntRounded(self.ha_max)) & (IntRounded(self.ha) < IntRounded(self.ha_min)))] = np.nan
return result

def alt_limit(self, conditions):
"""Apply altitude limits."""
result = self.zeros.copy()
result[np.where(self.alt > self.alt_max)] = np.nan
result[np.where(self.alt < self.alt_min)] = np.nan
result[np.where(IntRounded(self.alt) > IntRounded(self.alt_max))] = np.nan
result[np.where(IntRounded(self.alt) < IntRounded(self.alt_min))] = np.nan
return result

def moon_limit(self, conditions):
Expand All @@ -215,7 +216,7 @@ def moon_limit(self, conditions):
dists = _angular_separation(
self.observations["RA"], self.observations["dec"], conditions.moon_ra, conditions.moon_dec
)
result[np.where(dists < self.moon_dist_limit)] = np.nan
result[np.where(IntRounded(dists) < IntRounded(self.moon_dist_limit))] = np.nan
return result

def wind_limit(self, conditions):
Expand All @@ -225,15 +226,15 @@ def wind_limit(self, conditions):
return result
wind_pressure = conditions.wind_speed * np.cos(self.az - conditions.wind_direction)
result -= wind_pressure**2.0
mask = wind_pressure > self.wind_speed_maximum
mask = IntRounded(wind_pressure) > IntRounded(self.wind_speed_maximum)
result[mask] = np.nan

return result

def visit_gap(self, conditions):
"""Enforce a minimum visit gap."""
diff = conditions.mjd - self.last_observed
too_soon = np.where(diff < self.gap_min)[0]
too_soon = np.where(IntRounded(diff) < IntRounded(self.gap_min))[0]
result = self.zeros.copy()
# Using NaN makes it a hard limit
# could have a weight and just subtract from the reward
Expand Down
14 changes: 7 additions & 7 deletions rubin_scheduler/scheduler/surveys/scripted_surveys.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

from rubin_scheduler.scheduler.surveys import BaseSurvey
from rubin_scheduler.scheduler.utils import empty_observation, set_default_nside
from rubin_scheduler.scheduler.utils import empty_observation, set_default_nside, IntRounded
from rubin_scheduler.utils import _approx_ra_dec2_alt_az

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -179,10 +179,10 @@ def _check_alts_ha(self, observation, conditions):
HA[np.where(HA > 24)] -= 24
HA[np.where(HA < 0)] += 24
in_range = np.where(
(self.alt < observation["alt_max"])
& (self.alt > observation["alt_min"])
& ((HA > observation["HA_max"]) | (HA < observation["HA_min"]))
& (conditions.sun_alt < observation["sun_alt_max"])
(IntRounded(self.alt) < IntRounded(observation["alt_max"]))
& (IntRounded(self.alt) > IntRounded(observation["alt_min"]))
& ((IntRounded(HA) > IntRounded(observation["HA_max"])) | (IntRounded(HA) < IntRounded(observation["HA_min"])))
& (IntRounded(conditions.sun_alt) < IntRounded(observation["sun_alt_max"]))
)[0]
return in_range

Expand All @@ -192,8 +192,8 @@ def _check_list(self, conditions):
if self.obs_wanted is not None:
# Scheduled observations that are in the right time window and have not been executed
in_time_window = np.where(
(self.mjd_start < conditions.mjd)
& (self.obs_wanted["flush_by_mjd"] > conditions.mjd)
(IntRounded(self.mjd_start) < IntRounded(conditions.mjd))
& (IntRounded(self.obs_wanted["flush_by_mjd"]) > IntRounded(conditions.mjd))
& (~self.obs_wanted["observed"])
)[0]

Expand Down
2 changes: 2 additions & 0 deletions tests/scheduler/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def test_example(self):
observatory, scheduler, observations = run_sched(scheduler, mjd_start=mjd_start, survey_length=5)
u_notes = np.unique(observations["note"])

import pdb ; pdb.set_trace()

# Note that some of these may change and need to be updated if survey
# start date changes, e.g., different DDFs in season, or different lunar phase
# means different filters get picked for the blobs
Expand Down

0 comments on commit 1816808

Please sign in to comment.