diff --git a/setup.py b/setup.py index 6bb47ad0..45782fbc 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,8 @@ "toolz", "pandas", "chex", - "pyephem" + "pyephem", + "equinox", + "so3g" ], ) diff --git a/src/schedlib/config.py b/src/schedlib/config.py new file mode 100644 index 00000000..147f3ff9 --- /dev/null +++ b/src/schedlib/config.py @@ -0,0 +1,14 @@ +from . import utils +import yaml + +def datetime_constructor(loader, node): + """load a datetime""" + return utils.str2datetime(loader.construct_scalar(node)) + +def get_loader(): + # add some useful tags to the loader + loader = yaml.SafeLoader + loader.add_constructor('!datetime', datetime_constructor) + # ignore unknown tags + loader.add_multi_constructor('!', lambda loader, tag_suffix, node: None) + return loader \ No newline at end of file diff --git a/src/schedlib/core.py b/src/schedlib/core.py index bdb89a31..d5491a3c 100644 --- a/src/schedlib/core.py +++ b/src/schedlib/core.py @@ -1,10 +1,12 @@ from typing import List, Union, Callable, Optional, Any, TypeVar -from chex import dataclass from abc import ABC, abstractmethod import datetime as dt import numpy as np -from toolz import compose_left import jax.tree_util as tu +import equinox +from dataclasses import dataclass, replace as dc_replace + +from . import utils @dataclass(frozen=True) class Block: @@ -37,6 +39,8 @@ def trim_right_to(self, t: dt.datetime) -> List["Block"]: return block_trim_right_to(self, t) def isa(self, block_type: "BlockType") -> bool: return block_isa(block_type)(self) + def replace(self, **kwargs) -> "Block": + return dc_replace(self, **kwargs) BlockType = type(Block) Blocks = List[Union[Block, None, "Blocks"]] # maybe None, maybe nested @@ -224,6 +228,9 @@ def seq_filter_out(op: Callable[[Block], bool], blocks: BlocksTree) -> BlocksTre def seq_map(op, *blocks: BlocksTree) -> List[Any]: return tu.tree_map(op, *blocks, is_leaf=is_block) +def seq_map_with_path(op, *blocks: BlocksTree) -> List[Any]: + return tu.tree_map_with_path(op, *blocks, is_leaf=is_block) + def seq_map_when(op_when: Callable[[Block], bool], op: Callable[[Block], Any], blocks: BlocksTree) -> List[Any]: return tu.tree_map(lambda b: op(b) if op_when(b) else b, blocks, is_leaf=is_block) @@ -233,6 +240,42 @@ def seq_replace_block(blocks: BlocksTree, source: Block, target: Block) -> Block def seq_trim(blocks: BlocksTree, t0: dt.datetime, t1: dt.datetime) -> BlocksTree: return seq_map(lambda b: b.trim(t0, t1), blocks) +def seq_partition(op, blocks: BlocksTree) -> List[Any]: + """partition a blockstree into two trees, one for blocks that satisfy the predicate, + which is specified through a function that takes a block as input and returns + a boolean, and the second return is for blocks that don't match the predicate. Unmatched + values will be left as None.""" + filter_spec = tu.tree_map(op, blocks, is_leaf=is_block) + return equinox.partition(blocks, filter_spec) + +def seq_partition_with_path(op, blocks: BlocksTree, **kwargs) -> List[Any]: + """partition a blockstree into two trees, one for blocks that satisfy the predicate, + which is specified through a function that takes a block and path as input and returns + a boolean, and the second return is for blocks that don't match the predicate. Unmatched + values will be left as None.""" + filter_spec = tu.tree_map_with_path(op, blocks, is_leaf=is_block) + return equinox.partition(blocks, filter_spec, **kwargs) + +def seq_partition_with_query(query, blocks: BlocksTree): + def path2key(path): + """convert a path (used in tree_util.tree_map_with_path) to a dot-separated key""" + keys = [] + for p in path: + if isinstance(p, tu.SequenceKey): + keys.append(p.idx) + elif isinstance(p, tu.DictKey): + keys.append(p.key) + else: + raise ValueError(f"unknown path type {type(p)}") + return ".".join([str(k) for k in keys]) + return seq_partition_with_path(lambda path, block: utils.match_query(path, query), blocks) + +def seq_combine(*blocks: BlocksTree) -> BlocksTree: + """combine blocks from multiple trees into a single tree, where the blocks are + combined in a list. The trees must have the same structure.""" + seq_assert_same_structure(*blocks) + return equinox.combine(*blocks, is_leaf=is_block) + # ========================= # Other useful Block types # ========================= @@ -256,14 +299,6 @@ def __call__(self, blocks: Blocks) -> Blocks: Rule = Union[BlocksTransformation, Callable[[Blocks], Blocks]] RuleSet = List[Rule] -@dataclass(frozen=True) -class MultiRules(BlocksTransformation): - rules: RuleSet - def apply(self, blocks: Blocks) -> Blocks: - """apply rules to blocks in first-to-last order""" - return compose_left(*self.rules)(blocks) - - @dataclass(frozen=True) class Policy(BlocksTransformation, ABC): """apply: apply policy to a tree of blocks""" diff --git a/src/schedlib/instrument.py b/src/schedlib/instrument.py index 69303355..a84194d9 100644 --- a/src/schedlib/instrument.py +++ b/src/schedlib/instrument.py @@ -1,17 +1,20 @@ from __future__ import annotations -from chex import dataclass from jax import tree_util as tu +import pandas as pd from typing import List, TypeVar, Union, Dict import numpy as np from functools import reduce +from dataclasses import dataclass +from so3g.proj import quat -from . import core +from . import core, utils as u @dataclass(frozen=True) class ScanBlock(core.NamedBlock): - az: float # deg - alt: float # deg - throw: float # deg + az: float # deg + alt: float # deg + throw: float # deg + drift: float = 0 # deg / s @dataclass(frozen=True) class IVBlock(core.NamedBlock): pass @@ -35,14 +38,13 @@ def get_spec(specs: SpecsTree, query: List[str], merge=True) -> Union[Spec, Spec one of the queries. return all matches if merge=False""" is_leaf = lambda x: isinstance(x, dict) and 'bounds_x' in x match_p = lambda key: any([p in key for p in query]) - path2key = lambda path: ".".join([str(p.key) for p in path]) def reduce_fn(l, r): res = {} for k in ['bounds_x', 'bounds_y']: res[k] = [min(l[k][0], r[k][0]), max(l[k][1], r[k][1])] return res all_matches = tu.tree_leaves( - tu.tree_map_with_path(lambda path, x: x if match_p(path2key(path)) else None, specs, is_leaf=is_leaf), + tu.tree_map_with_path(lambda path, x: x if match_p(u.path2key(path)) else None, specs, is_leaf=is_leaf), is_leaf=is_leaf ) # None is not a leaf, so it will be filtered out if not merge: return all_matches @@ -51,11 +53,66 @@ def reduce_fn(l, r): def get_bounds_x_tilted(bounds_x: List[float], bounds_y: List[float], phi_tilt: Union[float, core.Arr[float]], shape: str): """get the effective bounds of the x-axis of the spec when covering a tilted patch""" - assert shape in ['ellipse'] # more to implement + assert shape in ['ellipse', 'rect'] # more to implement + a = (bounds_x[1] - bounds_x[0])/2 + b = (bounds_y[1] - bounds_y[0])/2 if shape == 'ellipse': - a = (bounds_x[1] - bounds_x[0])/2 - b = (bounds_y[1] - bounds_y[0])/2 - w_proj = np.sqrt(a**2 * np.cos(phi_tilt)**2 + b**2 * np.sin(phi_tilt)**2) - return np.array([-w_proj, w_proj]) + (bounds_x[0] + bounds_x[1])/2 + w_proj = a * np.sqrt(1 + b**2 / a**2 * np.tan(phi_tilt)**2) + elif shape == 'rect': + w_proj = b * np.tan(phi_tilt) + a else: raise NotImplementedError + return np.array([-w_proj, w_proj]) + (bounds_x[0] + bounds_x[1])/2 + +def make_circular_cover(xi0, eta0, R, count=50, degree=True): + """make a circular cover centered at xi0, eta0 with radius R""" + if degree: xi0, eta0, R = np.deg2rad([xi0, eta0, R]) + dphi = 2*np.pi/count + phi = np.arange(count) * dphi + L = 1.01*R / np.cos(dphi/2) + xi, eta = L * np.cos(phi), L * np.sin(phi) + xi, eta, _ = quat.decompose_xieta(quat.rotation_xieta(xi0, eta0) * quat.rotation_xieta(xi, eta)) + return { + 'center': (xi0, eta0), + 'cover': np.array([xi, eta]) + } + +def array_info_merge(arrays): + center = np.mean(np.array([a['center'] for a in arrays]), axis=0) + cover = np.concatenate([a['cover'] for a in arrays], axis=1) + return { + 'center': center, + 'cover': cover + } + +def array_info_from_query(geometries, query): + """make an array info with geometries that match the query""" + is_leaf = lambda x: isinstance(x, dict) and 'center' in x + matched = tu.tree_leaves(tu.tree_map_with_path( + lambda path, x: x if u.match_query(path, query) else None, + geometries, + is_leaf=is_leaf + ), is_leaf=is_leaf) + arrays = [make_circular_cover(*g['center'], g['radius']) for g in matched] + return array_info_merge(arrays) + +def parse_sequence_from_toast(ifile): + """ + Parameters + ---------- + ifile: input master schedule from toast + """ + columns = ["start_utc", "stop_utc", "rotation", "patch", "az_min", "az_max", "el", "pass", "sub"] + df = pd.read_csv(ifile, skiprows=3, delimiter="|", names=columns) + blocks = [] + for _, row in df.iterrows(): + block = ScanBlock( + name=row['patch'].strip(), + t0=u.str2datetime(row['start_utc']), + t1=u.str2datetime(row['stop_utc']), + alt=row['el'], + az=row['az_min'], + throw=np.abs(row['az_max'] - row['az_min']), + ) + blocks.append(block) + return blocks \ No newline at end of file diff --git a/src/schedlib/policies/__init__.py b/src/schedlib/policies/__init__.py new file mode 100644 index 00000000..07fdf0d9 --- /dev/null +++ b/src/schedlib/policies/__init__.py @@ -0,0 +1,2 @@ +from .basic import BasePolicy, BasicPolicy +from .flex import FlexPolicy \ No newline at end of file diff --git a/src/schedlib/policies.py b/src/schedlib/policies/basic.py similarity index 93% rename from src/schedlib/policies.py rename to src/schedlib/policies/basic.py index 66f8778d..6fdd607f 100644 --- a/src/schedlib/policies.py +++ b/src/schedlib/policies/basic.py @@ -1,10 +1,11 @@ #!/usr/bin/env python3 -from chex import dataclass import datetime as dt from abc import ABC, abstractmethod -from typing import List -from . import core, utils, commands as cmd, instrument as inst, rules as ru, source as src +from typing import List +from dataclasses import dataclass +from .. import core, utils, commands as cmd, instrument as inst, rules as ru, source as src, config as cfg + @dataclass(frozen=True) class BasePolicy(core.Policy, ABC): @@ -15,15 +16,6 @@ class BasePolicy(core.Policy, ABC): preserve the nested structure for the user to see, but we can also flatten the structure for the scheduler to consume.""" - rules: core.RuleSet - - def make_rule(self, rule_name: str, **kwargs) -> core.Rule: - # caller kwargs take precedence - if not kwargs: - assert rule_name in self.rules, f"Rule {rule_name} not found in rules config" - kwargs = self.rules[rule_name] - return ru.make_rule(rule_name, **kwargs) - @abstractmethod def transform(self, blocks: core.BlocksTree) -> core.BlocksTree: ... @@ -39,15 +31,24 @@ def apply(self, blocks: core.BlocksTree) -> core.Blocks: @abstractmethod def seq2cmd(self, seq: core.Blocks) -> cmd.Command: ... + @dataclass(frozen=True) class BasicPolicy(BasePolicy): - + rules: core.RuleSet master_schedule: str calibration_targets: List[str] soft_targets: List[str] + def make_rule(self, rule_name: str, **kwargs) -> core.Rule: + # caller kwargs take precedence + print(self.rules) + if not kwargs: + assert rule_name in self.rules, f"Rule {rule_name} not found in rules config" + kwargs = self.rules[rule_name] + return ru.make_rule(rule_name, **kwargs) + def init_seqs(self, t0: dt.datetime, t1: dt.datetime) -> core.BlocksTree: - master = utils.parse_sequence_from_toast(self.master_schedule) + master = inst.parse_sequence_from_toast(self.master_schedule) calibration = {k: src.source_gen_seq(k, t0, t1) for k in self.calibration_targets} soft = {k: src.source_gen_seq(k, t0, t1) for k in self.soft_targets} blocks = { @@ -120,4 +121,4 @@ def seq2cmd(self, seq: core.Blocks): """map a scan to a command""" commands = core.seq_flatten(core.seq_map(self.block2cmd, seq)) commands = [cmd.Preamble()] + commands - return cmd.CompositeCommand(commands) + return cmd.CompositeCommand(commands) \ No newline at end of file diff --git a/src/schedlib/policies/flex.py b/src/schedlib/policies/flex.py new file mode 100644 index 00000000..d3d7ef84 --- /dev/null +++ b/src/schedlib/policies/flex.py @@ -0,0 +1,128 @@ +import yaml +import os.path as op +from dataclasses import dataclass +import datetime as dt +from typing import List + +from . import basic +from .. import config as cfg, core, utils, source as src, rules as ru, commands as cmd, instrument as inst + + +@dataclass(frozen=True) +class FlexPolicy(basic.BasePolicy): + """a flexible policy. `config` is a string yaml config *content*""" + config_text: str + rules: List[dict] + post_rules: List[core.Rule] + merge_order: List[str] + geometries: List[dict] + + @classmethod + def make_rule(cls, rule_cfg, full_config={}): + rule_name = rule_cfg.pop('name') + constraint = rule_cfg.pop('constraint', None) + + # rules that require randomization + if rule_name in ['make-source-scan', 'rephase-first']: + today = dt.datetime.now() + rng_key = utils.PRNGKey((today.year, today.month, today.day, rule_cfg.pop('seed', 0))) + rule_cfg['rng_key'] = rng_key + rule = ru.make_rule(rule_name, **rule_cfg) + # treat special rule + elif rule_name == 'make-drift-scan': + rule_cfg['geometries'] = full_config['geometries'] + rule = ru.MakeCESourceScan.from_config(rule_cfg) + else: + rule = ru.make_rule(rule_name, **rule_cfg) + + # if a constraint is specified, make a constrained rule instead. + if constraint is not None: + rule = ru.ConstrainedRule(rule, constraint) + return rule + + @classmethod + def from_config(cls, config: str): + """populate policy object from a yaml config file""" + # load the text content of config into a string for later use + if op.isfile(config): + with open(config, "r") as f: + config_text = f.read() + else: + config_text = config + + # pre-load the config to populate some common fields in the policy + loader = cfg.get_loader() + config = yaml.load(config_text, Loader=loader) + + # load rules + rules = [] + for rule_cfg in config.pop('rules'): + rules.append(cls.make_rule(rule_cfg, full_config=config)) + + post_rules = [] + for rule_cfg in config.pop('post_rules', []): + post_rules.append(cls.make_rule(rule_cfg, full_config=config)) + + # remove fields that need special handling later on + config.pop('blocks') + + # now we can construct the policy + return cls(config_text=config_text, rules=rules, post_rules=post_rules, **config) + + def init_seqs(self, t0: dt.datetime, t1: dt.datetime) -> core.BlocksTree: + # prepare some specialized loaders: !source [source_name], !toast [toast schedule name] + def source_constructor(t0, t1, loader, node): + return src.source_gen_seq(loader.construct_scalar(node), t0, t1) + def toast_constructor(loader, node): + return inst.parse_sequence_from_toast(loader.construct_scalar(node)) + loader = cfg.get_loader() + loader.add_constructor('!source', lambda loader, node: source_constructor(t0, t1, loader, node)) + loader.add_constructor('!toast', lambda loader, node: toast_constructor(loader, node)) + # load blocks for processing + blocks = yaml.load(self.config_text, Loader=loader)["blocks"] + return core.seq_trim(blocks, t0, t1) + + def transform(self, blocks: core.BlocksTree) -> core.BlocksTree: + # apply each rule + for rule in self.rules: + blocks = rule(blocks) + return blocks + + def merge(self, blocks: core.BlocksTree) -> core.Blocks: + """merge blocks into a single sequence by the order specified + in self.merge_order, assuming an descending priority order as moving + down the merge_order list.""" + seq = None + for query in self.merge_order[::-1]: + match, _ = core.seq_partition_with_query(query, blocks) + if seq is None: + seq = match + continue + else: + # match takes precedence + seq = core.seq_merge(seq, match, flatten=True) + + # apply transformation if needed + for rule in self.post_rules: + seq = rule(seq) + + return core.seq_sort(seq) + + def block2cmd(self, block: core.Block): + if isinstance(block, inst.ScanBlock): + return cmd.CompositeCommand([ + f"# {block.name}", + cmd.Goto(block.az, block.alt), + cmd.BiasDets(), + cmd.Wait(block.t0), + cmd.BiasStep(), + cmd.Scan(block.name, block.t1, block.throw), + cmd.BiasStep(), + "", + ]) + + def seq2cmd(self, seq: core.Blocks): + """map a scan to a command""" + commands = core.seq_flatten(core.seq_map(self.block2cmd, seq)) + commands = [cmd.Preamble()] + commands + return cmd.CompositeCommand(commands) \ No newline at end of file diff --git a/src/schedlib/rules.py b/src/schedlib/rules.py index 167eb143..906f04d8 100644 --- a/src/schedlib/rules.py +++ b/src/schedlib/rules.py @@ -1,9 +1,8 @@ from typing import Tuple, Dict, List, Optional import numpy as np -from chex import dataclass -from functools import partial from abc import ABC, abstractmethod import datetime as dt +from dataclasses import dataclass from . import core, source as src, instrument as inst, utils @@ -25,6 +24,17 @@ def apply_block(self, block) -> core.Blocks: ... def applicable(self, block) -> bool: return True +@dataclass(frozen=True) +class ConstrainedRule(GreenRule): + """ConstrainedRule applies a rule to a subset of blocks. Here + constraint is a fnmatch pattern that matches to the `key` of a + block.""" + rule: core.Rule + constraint: str + def apply(self, blocks: core.BlocksTree) -> core.BlocksTree: + matched, unmatched = core.seq_partition_with_query(self.constraint, blocks) + return core.seq_combine(self.rule(matched), unmatched) + @dataclass(frozen=True) class AltRange(MappableRule): """Restrict the altitude range of source blocks. @@ -91,9 +101,9 @@ def apply(self, blocks: core.BlocksTree) -> core.BlocksTree: # identify the first block as the first in the sorted list src = core.seq_sort(core.seq_flatten(blocks))[0] # randomize the phase of it but not too much - allowance = min(self.max_fraction * src.duration, - max(src.duration - self.min_block_size, 0)) - tgt = src.replace(t0=src.t0 + utils.uniform(self.rng_key, 0, allowance)) + allowance = min(self.max_fraction * src.duration.total_seconds(), + max(src.duration.total_seconds() - self.min_block_size, 0)) + tgt = src.replace(t0=src.t0 + dt.timedelta(seconds=utils.uniform(self.rng_key, 0, allowance))) return core.seq_replace_block(blocks, src, tgt) @dataclass(frozen=True) @@ -126,8 +136,9 @@ def apply_block(self, block: core.Block): # we should stop scanning when the source is at this alt alt_stop = alt + sign*alt_height # total passage time - obs_length = utils.interp_extra(alt_stop, alt, t) - t - assert np.all(obs_length >= 0), "passage time must be positive, something is wrong" + obs_length = utils.interp_extra(alt_stop, alt, t, fill_value=np.nan) - t + ok = np.logical_not(np.isnan(obs_length)) + assert np.all(obs_length[ok] >= 0), "passage time must be positive, something is wrong" # this is where our boresight pointing should be to observe the passage. # this places our wafer set at the center of the source path, so the source @@ -152,6 +163,7 @@ def apply_block(self, block: core.Block): shape=self.spec_shape, phi_tilt=phi_tilt, ) for spec in self.specs]) + x_lo, x_hi = np.min(bounds_x[:, 0]), np.max(bounds_x[:, 1]) # add back the projection effect to get the actual az bounds stretch = 1 / np.cos(np.deg2rad(alt_center)) @@ -160,7 +172,7 @@ def apply_block(self, block: core.Block): az_bore = az_center - az_offset # get validity ranges - ok = utils.within_bound(alt_stop, [alt.min(), alt.max()]) + ok *= utils.within_bound(alt_stop, [alt.min(), alt.max()]) if self.bounds_alt is not None: ok *= utils.within_bound(alt_bore, self.bounds_alt) if self.bounds_az_throw is not None: @@ -231,20 +243,56 @@ class MakeSourceScan(MappableRule): """convert observing window to actual scan blocks and allow for rephasing of the block. Applicable to only ObservingWindow blocks. """ - preferred_length: float # seconds rng_key: utils.PRNGKey + preferred_length: Optional[float] = None # seconds + fixed_alt: Optional[float] = None def apply_block(self, block: core.Block) -> core.Block: duration = block.duration.total_seconds() - preferred_len = min(self.preferred_length, duration) - allowance = duration - preferred_len - offset = utils.uniform(self.rng_key, 0, allowance) - t0 = block.t0 + dt.timedelta(seconds=offset) - return block.get_scan_starting_at(t0) + # make sure preferred length and fixed_alt are not both set + assert not (self.preferred_length is not None and self.fixed_alt is not None) + if self.preferred_length is not None: + preferred_len = min(self.preferred_length, duration) + allowance = duration - preferred_len + offset = utils.uniform(self.rng_key, 0, allowance) + t0 = block.t0 + dt.timedelta(seconds=offset) + scan = block.get_scan_at_t0(t0) + elif self.fixed_alt is not None: + scan = block.get_scan_at_alt(self.fixed_alt) + else: + scan = block + return scan def applicable(self, block: core.Block) -> bool: return isinstance(block, src.ObservingWindow) +@dataclass(frozen=True) +class MakeCESourceScan(MappableRule): + """Transform SourceBlock into fixed-elevation ScanBlocks that support + az drift mode. + + Parameters + ---------- + array_info : dict. array information, contains 'center' and 'radius' keys + el_bore : float. elevation of the boresight in degrees + drift : bool. whether to enable drift mode + + """ + array_info: dict + el_bore: float # deg + drift: bool = True + def apply_block(self, block: core.Block) -> core.Block: + return src.make_source_ces(block, array_info=self.array_info, el_bore=self.el_bore, enable_drift=self.drift) + def applicable(self, block: core.Block) -> bool: + return isinstance(block, src.SourceBlock) + @classmethod + def from_config(cls, config): + query = config.pop('array_query', "*") + geometries = config.pop('geometries', {}) + utils.pprint(geometries) + array_info = inst.array_info_from_query(geometries, query) + return cls(array_info=array_info, **config) + # global registry of rules RULES = { 'alt-range': AltRange, @@ -255,9 +303,11 @@ def applicable(self, block: core.Block) -> bool: 'sun-avoidance': SunAvoidance, 'make-source-plan': MakeSourcePlan, 'make-source-scan': MakeSourceScan, + 'make-drift-scan': MakeCESourceScan } def get_rule(name: str) -> core.Rule: return RULES[name] def make_rule(name: str, **kwargs) -> core.Rule: + assert name in RULES, f"unknown rule {name}" return get_rule(name)(**kwargs) diff --git a/src/schedlib/source.py b/src/schedlib/source.py index ee44d6ca..1d623d82 100644 --- a/src/schedlib/source.py +++ b/src/schedlib/source.py @@ -1,19 +1,20 @@ # -*- coding: utf-8 -*- from __future__ import annotations -from chex import dataclass +from dataclasses import dataclass import ephem from ephem import to_timezone import datetime as dt from typing import Union, Callable, NamedTuple, List, Tuple, Optional from scipy.interpolate import interp1d import numpy as np +from scipy import interpolate +from so3g.proj import quat from . import core, utils, instrument as inst UTC = dt.timezone.utc - class Location(NamedTuple): """Location given in degrees and meters""" lat: float @@ -29,10 +30,21 @@ def at(self, date: dt.datetime) -> ephem.Observer: obs.date = ephem.date(date) return obs +def _debabyl(deg, arcmin, arcsec): + return deg + arcmin/60 + arcsec/3600 + +SITES = { + 'act': Location(lat=-22.9585, lon=-67.7876, elev=5188), + 'lat': Location(lat=-_debabyl(22,57,39.47), lon=-_debabyl(67,47,15.68), elev=5188), + 'satp1': Location(lat=-_debabyl(22,57,36.38), lon=-_debabyl(67,47,18.11), elev=5188), + 'satp2': Location(lat=-_debabyl(22,57,36.35), lon=-_debabyl(67,47,17.28), elev=5188), + 'satp3': Location(lat=-_debabyl(22,57,35.97), lon=-_debabyl(67,47,16.53), elev=5188), +} DEFAULT_SITE = Location(lat=-22.958, lon=-67.786, elev=5200) -def get_site() -> Location: - return DEFAULT_SITE +def get_site(site='lat') -> Location: + """use lat as default following so3g convention""" + return SITES[site] # source needs to be callable to avoid side effects SOURCES = { @@ -66,7 +78,9 @@ def _source_get_az_alt(source: str, times: List[dt.datetime]): source.compute(observer) az.append(np.rad2deg(source.az)) alt.append(np.rad2deg(source.alt)) - return np.array(az), np.array(alt) + az = np.unwrap(np.array(az), period=360) + alt = np.array(alt) + return az, alt def _source_az_alt_interpolators(source: str, t0: dt.datetime, t1: dt.datetime, time_step: dt.timedelta): times = [t0 + i * time_step for i in range(int((t1 - t0) / time_step))] @@ -200,7 +214,7 @@ class ObservingWindow(SourceBlock): az_bore: core.Arr[float] alt_bore: core.Arr[float] az_throw: core.Arr[float] - def get_scan_starting_at(self, t0: dt.datetime) -> inst.ScanBlock: + def get_scan_at_t0(self, t0: dt.datetime) -> inst.ScanBlock: """get a possible scan starting at t0""" t_req = int(t0.timestamp()) # if we start at t0, we can observe for at most obs_length @@ -218,3 +232,96 @@ def get_scan_starting_at(self, t0: dt.datetime) -> inst.ScanBlock: alt=float(alt), throw=float(az_throw), ) + def get_scan_at_alt(self, alt: float) -> inst.ScanBlock: + """get a possible scan at a given altitude""" + t0 = utils.interp_bounded(alt, self.alt_bore, self.t_start) + return self.get_scan_at_t0(t0) + +def make_source_ces(block, array_info, el_bore=50, drift_params=None, enable_drift=False, verbose=False): + assert 'center' in array_info and 'cover' in array_info + # move to the frame in which the center of the wafer is at the origin + q_center = quat.rotation_xieta(*array_info['center']) + q_cover = quat.rotation_xieta(*array_info['cover']) + xi_cover_array, eta_cover_array, _ = quat.decompose_xieta(~q_center * q_cover) + # find out the elevation of the array if boresight is at el_bore + _, dalt, _ = quat.decompose_lonlat(quat.rotation_lonlat(0, 0) * q_center) + el_array = el_bore + dalt / utils.deg + # get trajectory of the source + t, az_src, el_src = block.get_az_alt() # degs + if drift_params is not None: + assert 't' in drift_params and 'az_speed' in drift_params + v_az = drift_params['az_speed'] + az_src -= (t - drift_params['t']) * drift_params['az_speed'] + else: + v_az = 0 + # az of the source when el_src = el_array + if el_array > max(el_src): + print("Warning: source is too low") + return None + if el_array < min(el_src): + print("Warning: source is too high") + return None + az_array = interpolate.interp1d(el_src, az_src)(el_array) + # center array on the source and put it at the origin + q_src_ground = quat.rotation_lonlat(-az_src * utils.deg, el_src * utils.deg) + q_target_ground = quat.rotation_lonlat(-az_array * utils.deg, el_array * utils.deg) + q_src_array = ~q_target_ground * q_src_ground # where target is at the origin + xi_src_array, eta_src_array, _ = quat.decompose_xieta(q_src_array) + # make sure a scan of the entire array is possible + if max(eta_cover_array) < max(eta_src_array): + print("Warning: source is too low") + return None + if min(eta_cover_array) > min(eta_src_array): + print("Warning: source is too high") + return None + # work out the tilt of the wafer at the origin + phi_tilt_fun = interpolate.interp1d(eta_src_array[:-1], + np.arctan2(np.diff(xi_src_array), + np.diff(eta_src_array)), + fill_value='extrapolate') + # find out the boundaries of the wafer by taking a projection along the tilt axis + x_cross = - eta_cover_array * np.tan(phi_tilt_fun(eta_cover_array)) + xi_cover_array + q_A_array = quat.rotation_xieta(np.min(x_cross), 0) + q_B_array = quat.rotation_xieta(np.max(x_cross), 0) + az_A_array, _, _ = quat.decompose_lonlat(quat.rotation_lonlat(0, el_array * utils.deg) * q_A_array) + az_A_array *= -1 + az_B_array, _, _ = quat.decompose_lonlat(quat.rotation_lonlat(0, el_array * utils.deg) * q_B_array) + az_B_array *= -1 + # now we have all ingradients to make a source scan + daz, dalt, _ = quat.decompose_lonlat(quat.rotation_lonlat(0, el_bore * utils.deg) * quat.rotation_xieta(*array_info['center'])) + daz *= -1 + # az boresight should move between these two points + az_A_bore = az_array * utils.deg + az_A_array - daz # rad + az_B_bore = az_array * utils.deg + az_B_array - daz # rad + # get scan az and throw + az_start = min(az_A_bore, az_B_bore) # rad + throw = abs(az_B_bore - az_A_bore) # rad + q_bore_start = quat.rotation_lonlat(-az_start, el_bore * utils.deg) + az_cover_start, el_cover_start, _ = quat.decompose_lonlat(q_bore_start * q_cover) + az_cover_start *= -1 + # get the elevation ranges + if block.mode == 'rising': + el_src_start = np.min(el_cover_start) / utils.deg + el_src_stop = np.max(el_cover_start) / utils.deg + elif block.mode == 'setting': + el_src_start = np.max(el_cover_start) / utils.deg + el_src_stop = np.min(el_cover_start) / utils.deg + else: + raise ValueError(f'unsupported scan mode encountered: {block.mode}') + # get the time ranges + t_start = interpolate.interp1d(el_src, t)(el_src_start) + t_stop = interpolate.interp1d(el_src, t)(el_src_stop) + t0 = utils.ct2dt(float(t_start)) + t1 = utils.ct2dt(float(t_stop)) + if enable_drift: + az_speed_ref = np.median(np.diff(az_src) / np.diff(t)) + drift_params = {'t': t_start, 'az_speed': az_speed_ref} + return make_source_ces(block, array_info, el_bore=el_bore, drift_params=drift_params, enable_drift=False, verbose=verbose) + else: + if verbose: + print("t0 = ", t_start) + print("t1 = ", t_stop) + print("az = ", az_start / utils.deg) + print("throw = ", throw / utils.deg) + print("drift = ", v_az) + return inst.ScanBlock(name=block.name, az=az_start / utils.deg, alt=el_bore, throw=throw / utils.deg, t0=t0, t1=t1, drift=v_az) \ No newline at end of file diff --git a/src/schedlib/utils.py b/src/schedlib/utils.py index ce2cc2fb..94cc9a51 100644 --- a/src/schedlib/utils.py +++ b/src/schedlib/utils.py @@ -4,12 +4,10 @@ import numpy as np from functools import reduce from contextlib import contextmanager -from typing import Any, List from scipy import interpolate from collections.abc import Iterable - -from . import core, utils as u, instrument as inst - +from jax.tree_util import SequenceKey, DictKey +import fnmatch minute = 60 # second hour = 60 * minute @@ -29,7 +27,7 @@ def datetime2str(dtime): return dtime.strftime('%Y-%m-%dT%H:%M:%S.%f%z') def ct2dt(ctime): - if isinstance(ctime, Iterable): + if isinstance(ctime, list): return [datetime.utcfromtimestamp(t).astimezone(timezone.utc) for t in ctime] else: try: @@ -76,37 +74,17 @@ def ranges_complement(ranges, imax): """return the complement ranges""" return mask2ranges(~ranges2mask(ranges, imax)) -def parse_sequence_from_toast(ifile: str) -> core.Blocks: - """ - Parameters - ---------- - ifile: input master schedule from toast - """ - columns = ["start_utc", "stop_utc", "rotation", "patch", "az_min", "az_max", "el", "pass", "sub"] - df = pd.read_csv(ifile, skiprows=3, delimiter="|", names=columns) - blocks = [] - for _, row in df.iterrows(): - block = inst.ScanBlock( - name=row['patch'].strip(), - t0=u.str2datetime(row['start_utc']), - t1=u.str2datetime(row['stop_utc']), - alt=row['el'], - az=row['az_min'], - throw=np.abs(row['az_max'] - row['az_min']), - ) - blocks.append(block) - return blocks # convenience wrapper for interpolation: numpy-like scipy interpolate -def interp_extra(x_new, x, y): +def interp_extra(x_new, x, y, fill_value='extrapolate'): """interpolate with extrapolation""" - return interpolate.interp1d(x, y, fill_value='extrapolate', bounds_error=False, kind='cubic', assume_sorted=False)(x_new) + return interpolate.interp1d(x, y, fill_value=fill_value, bounds_error=False, kind='cubic', assume_sorted=False)(x_new) def interp_bounded(x_new, x, y): """interpolate with bounded extrapolation""" return interpolate.interp1d(x, y, fill_value=(y[0], y[-1]), bounds_error=False, kind='cubic', assume_sorted=False)(x_new) -def within_bound(x: core.Arr[Any], bounds: List[float]) -> core.Arr[bool]: +def within_bound(x, bounds): """return a boolean mask indicating whether x is within the bound""" return (x >= bounds[0]) * (x <= bounds[1]) @@ -155,7 +133,41 @@ def uniform(key: PRNGKey, low=0.0, high=1.0, size=None): def daily_static_key(t: datetime): return PRNGKey((t.year, t.month, t.day)) -def pprint(seq: core.BlocksTree): +def pprint(seq): """pretty print""" from equinox import tree_pprint tree_pprint(seq) + +# ==================== +# path related +# ==================== + +def path2key(path): + """convert a path (used in tree_util.tree_map_with_path) to a dot-separated key""" + keys = [] + for p in path: + if isinstance(p, SequenceKey): + keys.append(p.idx) + elif isinstance(p, DictKey): + keys.append(p.key) + else: + raise ValueError(f"unknown path type {type(p)}") + return ".".join([str(k) for k in keys]) + +def match_query(path, query): + """in order for a query to match with a path, it can + satisfy the following: + 1. the query is a substring of the path + 2. the query is a glob pattern that matches the path + 3. if the query is a comma-separated list of multiple queries, + any of them meeting comdition 1 and 2 will return True + """ + key = path2key(path) + # first match the constraint to key + queires = query.split(",") + for q in queires: + if q in key: + return True + if fnmatch.fnmatch(key, q): + return True + return False \ No newline at end of file diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 00000000..471278c0 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,15 @@ +import yaml +import pytest +import datetime as dt +from schedlib import config as cfg + +@pytest.fixture +def config_str(): + return """ + date_ref: !datetime 2014-01-01 00:00:00 + """ + +def test_loader(config_str): + loader = cfg.get_loader() + config = yaml.load(config_str, Loader=loader) + assert config['date_ref'] == dt.datetime(2014, 1, 1, 0, 0, 0, tzinfo=dt.timezone.utc) \ No newline at end of file diff --git a/tests/test_core.py b/tests/test_core.py index cbc3124d..b8d3687a 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -559,3 +559,103 @@ def test_seq_merge(): core.Block(t0=dt.datetime(2023, 1, 5), t1=dt.datetime(2023, 1, 6)) ] assert core.seq_merge(seq5, seq6, flatten=True) == expected_result3 + +def test_seq_partition(): + blocks = [ + core.Block(t0=dt.datetime(2023, 1, 1), t1=dt.datetime(2023, 1, 2)), + [core.Block(t0=dt.datetime(2023, 1, 3), t1=dt.datetime(2023, 1, 4))], + core.Block(t0=dt.datetime(2023, 1, 5), t1=dt.datetime(2023, 1, 6)), + ] + + t0 = dt.datetime(2023, 1, 2) + t1 = dt.datetime(2023, 1, 5) + + def in_range(block): + """Return True if the block is within the range [t0, t1], False otherwise.""" + return (block.t0 > t0 and block.t1 < t1) + + matched, unmatched = core.seq_partition(in_range, blocks) + + # Now validate individual blocks + assert matched[0] == None + assert matched[1][0] == core.Block(t0=dt.datetime(2023, 1, 3, 0, 0), t1=dt.datetime(2023, 1, 4, 0, 0)) + assert matched[2] == None + + assert unmatched[0] == core.Block(t0=dt.datetime(2023, 1, 1, 0, 0), t1=dt.datetime(2023, 1, 2, 0, 0)) + assert unmatched[1][0] == None + assert unmatched[2] == core.Block(t0=dt.datetime(2023, 1, 5, 0, 0), t1=dt.datetime(2023, 1, 6, 0, 0)) + + assert core.seq_combine(matched, unmatched) == blocks + +def test_seq_partition_with_path(): + from schedlib import utils + + blocks = { + "A": core.Block(t0=dt.datetime(2023, 1, 1), t1=dt.datetime(2023, 1, 2)), + "B": [core.Block(t0=dt.datetime(2023, 1, 3), t1=dt.datetime(2023, 1, 4)), + core.Block(t0=dt.datetime(2023, 1, 4), t1=dt.datetime(2023, 1, 5))], + "C": core.Block(t0=dt.datetime(2023, 1, 5), t1=dt.datetime(2023, 1, 6)), + } + + t0 = dt.datetime(2023, 1, 2) + t1 = dt.datetime(2023, 1, 5) + + def in_range(path, block): + """Return True if the block is within the range [t0, t1], False otherwise.""" + if utils.path2key(path) == 'B.1': + return False + return (block.t0 > t0 and block.t1 < t1) + + matched, unmatched = core.seq_partition_with_path(in_range, blocks) + + # Now validate individual blocks + assert matched['A'] == None + assert matched['B'][0] == core.Block(t0=dt.datetime(2023, 1, 3, 0, 0), t1=dt.datetime(2023, 1, 4, 0, 0)) + assert matched['B'][1] == None + assert matched['C'] == None + + assert unmatched['A'] == core.Block(t0=dt.datetime(2023, 1, 1, 0, 0), t1=dt.datetime(2023, 1, 2, 0, 0)) + assert unmatched['B'][0] == None + assert unmatched['B'][1] == core.Block(t0=dt.datetime(2023, 1, 4, 0, 0), t1=dt.datetime(2023, 1, 5, 0, 0)) + assert unmatched['C'] == core.Block(t0=dt.datetime(2023, 1, 5, 0, 0), t1=dt.datetime(2023, 1, 6, 0, 0)) + + assert core.seq_combine(matched, unmatched) == blocks + +def test_seq_partition_with_query(): + blocks = { + "A": core.Block(t0=dt.datetime(2023, 1, 1), t1=dt.datetime(2023, 1, 2)), + "B": [core.Block(t0=dt.datetime(2023, 1, 3), t1=dt.datetime(2023, 1, 4)), + core.Block(t0=dt.datetime(2023, 1, 4), t1=dt.datetime(2023, 1, 5))], + "C": core.Block(t0=dt.datetime(2023, 1, 5), t1=dt.datetime(2023, 1, 6)), + } + + matched, unmatched = core.seq_partition_with_query("B.*", blocks) + + # Now validate individual blocks + assert matched['A'] == None + assert matched['B'][0] == core.Block(t0=dt.datetime(2023, 1, 3, 0, 0), t1=dt.datetime(2023, 1, 4, 0, 0)) + assert matched['B'][1] == core.Block(t0=dt.datetime(2023, 1, 4), t1=dt.datetime(2023, 1, 5)) + assert matched['C'] == None + + assert unmatched['A'] == core.Block(t0=dt.datetime(2023, 1, 1, 0, 0), t1=dt.datetime(2023, 1, 2, 0, 0)) + assert unmatched['B'][0] == None + assert unmatched['B'][1] == None + assert unmatched['C'] == core.Block(t0=dt.datetime(2023, 1, 5, 0, 0), t1=dt.datetime(2023, 1, 6, 0, 0)) + + assert core.seq_combine(matched, unmatched) == blocks + + # case 2: + matched, unmatched = core.seq_partition_with_query("B.1", blocks) + # Now validate individual blocks + # assert matched['A'] == core.Block(t0=dt.datetime(2023, 1, 1, 0, 0), t1=dt.datetime(2023, 1, 2, 0, 0)) + assert matched['A'] == None + assert matched['B'][0] == None + assert matched['B'][1] == core.Block(t0=dt.datetime(2023, 1, 4), t1=dt.datetime(2023, 1, 5)) + assert matched['C'] == None + + assert unmatched['A'] == core.Block(t0=dt.datetime(2023, 1, 1, 0, 0), t1=dt.datetime(2023, 1, 2, 0, 0)) + assert unmatched['B'][0] == core.Block(t0=dt.datetime(2023, 1, 3, 0, 0), t1=dt.datetime(2023, 1, 4, 0, 0)) + assert unmatched['B'][1] == None + assert unmatched['C'] == core.Block(t0=dt.datetime(2023, 1, 5, 0, 0), t1=dt.datetime(2023, 1, 6, 0, 0)) + + assert core.seq_combine(matched, unmatched) == blocks diff --git a/tests/test_instrument.py b/tests/test_instrument.py index 6ff9368b..1461533e 100644 --- a/tests/test_instrument.py +++ b/tests/test_instrument.py @@ -1,4 +1,6 @@ -from schedlib.instrument import get_spec +import numpy as np +import os.path as op +import schedlib.instrument as inst def test_get_spec(): specs = { @@ -13,25 +15,69 @@ def test_get_spec(): }, } } - spec = get_spec(specs, ["platform1"]) + spec = inst.get_spec(specs, ["platform1"]) assert spec == { 'bounds_x': [-2.0, 1.0], 'bounds_y': [-2.0, 1.0], } - spec = get_spec(specs, ["wafer1"]) + spec = inst.get_spec(specs, ["wafer1"]) assert spec == { 'bounds_x': [-1.0, 1.0], 'bounds_y': [-1.0, 1.0], } - spec = get_spec(specs, ["platform1.wafer2"]) + spec = inst.get_spec(specs, ["platform1.wafer2"]) assert spec == { 'bounds_x': [-2.0, 1.0], 'bounds_y': [-2.0, 1.0], } - spec = get_spec(specs, ["wafer"]) + spec = inst.get_spec(specs, ["wafer"]) assert spec == { 'bounds_x': [-2.0, 1.0], 'bounds_y': [-2.0, 1.0], } - spec = get_spec(specs, ["wafer3"]) + spec = inst.get_spec(specs, ["wafer3"]) assert spec == {} + +def test_parse_sequence_from_toast(): + ifile = op.join(op.abspath(op.dirname(__file__)), "data/schedule_test.txt") + seq = inst.parse_sequence_from_toast(ifile) + print(seq) + assert len(seq) == 17 + +def test_array_info(): + geometries = { + 'w11': { + 'center': [0, 0], + 'radius': 1.0, + }, + 'w22': { + 'center': [1, 1], + 'radius': 1.0, + } + } + array_info1 = inst.make_circular_cover(*geometries['w11']['center'], geometries['w11']['radius']) + array_info2 = inst.make_circular_cover(*geometries['w22']['center'], geometries['w22']['radius']) + assert array_info1['cover'].shape == (2, 50) + query = "w11,w22" + array_info = inst.array_info_from_query(geometries, query) + assert array_info['cover'].shape == (2, 100) + assert np.allclose(array_info['center'], np.mean([array_info1['center'], array_info2['center']], axis=0)) + assert np.allclose(array_info['cover'], np.concatenate([array_info1['cover'], array_info2['cover']], axis=1)) + + query = "w1*" + array_info = inst.array_info_from_query(geometries, query) + assert array_info['cover'].shape == (2, 50) + assert np.allclose(array_info['center'], array_info1['center']) + assert np.allclose(array_info['cover'], array_info1['cover']) + + query = "*2" + array_info = inst.array_info_from_query(geometries, query) + assert array_info['cover'].shape == (2, 50) + assert np.allclose(array_info['center'], array_info2['center']) + assert np.allclose(array_info['cover'], array_info2['cover']) + + query = "*2,*1" + array_info = inst.array_info_from_query(geometries, query) + assert array_info['cover'].shape == (2, 100) + assert np.allclose(array_info['center'], np.mean([array_info1['center'], array_info2['center']], axis=0)) + assert np.allclose(array_info['cover'], np.concatenate([array_info1['cover'], array_info2['cover']], axis=1)) \ No newline at end of file diff --git a/tests/test_policies.py b/tests/test_policies.py index 416d5117..5aa093e2 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -46,3 +46,47 @@ def test_basic_policy(): dt.datetime(2023, 1, 1, 0, 0, 0, tzinfo=dt.timezone.utc), dt.datetime(2023, 1, 10, 1, 0, 0, tzinfo=dt.timezone.utc)) policy.apply(seqs) + +def test_flex_policy(): + config = """ + blocks: + master: !toast data/schedule_test.txt + calibration: + saturn: !source saturn + moon: !source moon + rules: + - name: sun-avoidance + min_angle_az: 6 + min_angle_alt: 6 + time_step: 30 + n_buffer: 10 + - name: day-mod + day: 0 + day_mod: 1 + day_ref: !datetime "2014-01-01 00:00:00" + - name: make-drift-scan + constraint: calibration + array_query: full + el_bore: 50 + drift: true + post_rules: + - name: min-duration + min_duration: 600 + merge_order: + - moon + - saturn + - master + geometries: + full: + center: + - 0 + - 0 + radius: 7 + """ + policy = policies.FlexPolicy.from_config(config) + seqs = policy.init_seqs( + dt.datetime(2023, 1, 1, 0, 0, 0, tzinfo=dt.timezone.utc), + dt.datetime(2023, 1, 10, 1, 0, 0, tzinfo=dt.timezone.utc)) + policy.apply(seqs) + seqs = policy.transform(seqs) + seqs = policy.merge(seqs) diff --git a/tests/test_source.py b/tests/test_source.py index f75f7198..789b3bed 100644 --- a/tests/test_source.py +++ b/tests/test_source.py @@ -32,7 +32,7 @@ def test_source_get_az_alt(): source = 'sun' times = [dt.datetime(2023, 1, 1, tzinfo=dt.timezone.utc)] expected_az = [240.15972382] - expected_alt = [-9.02030811] + expected_alt = [-9.01748468] az, alt = src._source_get_az_alt(source, times) assert np.allclose(az, expected_az) assert np.allclose(alt, expected_alt) @@ -49,7 +49,7 @@ def test_source_get_az_alt(): dt.datetime(2023, 1, 2, 0, 0, 0, tzinfo=dt.timezone.utc), dt.datetime(2023, 1, 3, 0, 0, 0, tzinfo=dt.timezone.utc) ] - expected_az = [302.09675348, 301.27609259, 300.48474687] + expected_az = [302.1017805 , 301.28101033, 300.48955533] expected_alt = [52.57749508, 51.85943046, 51.13723016] az, alt = src._source_get_az_alt(source, times) assert np.allclose(az, expected_az) @@ -62,26 +62,26 @@ def test_source_get_blocks(): blocks = src.source_get_blocks(source, t0, t1) assert blocks == [ src.SourceBlock( - t0=dt.datetime(2022, 12, 31, 9, 48, 9, 902594, tzinfo=dt.timezone.utc), - t1=dt.datetime(2022, 12, 31, 16, 34, 11, 308132, tzinfo=dt.timezone.utc), + t0=dt.datetime(2022, 12, 31, 9, 48, 9, 937409, tzinfo=dt.timezone.utc), + t1=dt.datetime(2022, 12, 31, 16, 34, 11, 713598, tzinfo=dt.timezone.utc), name='sun', mode='rising' ), src.SourceBlock( - t0=dt.datetime(2022, 12, 31, 16, 34, 11, 308132, tzinfo=dt.timezone.utc), - t1=dt.datetime(2022, 12, 31, 23, 20, 7, 342350, tzinfo=dt.timezone.utc), + t0=dt.datetime(2022, 12, 31, 16, 34, 11, 713598, tzinfo=dt.timezone.utc), + t1=dt.datetime(2022, 12, 31, 23, 20, 8, 117595, tzinfo=dt.timezone.utc), name='sun', mode='setting' ), src.SourceBlock( - t0=dt.datetime(2023, 1, 1, 9, 48, 48, 46183, tzinfo=dt.timezone.utc), - t1=dt.datetime(2023, 1, 1, 16, 34, 39, 671616, tzinfo=dt.timezone.utc), + t0=dt.datetime(2023, 1, 1, 9, 48, 48, 82430, tzinfo=dt.timezone.utc), + t1=dt.datetime(2023, 1, 1, 16, 34, 40, 77082, tzinfo=dt.timezone.utc), name='sun', mode='rising' ), src.SourceBlock( - t0=dt.datetime(2023, 1, 1, 16, 34, 39, 671616, tzinfo=dt.timezone.utc), - t1=dt.datetime(2023, 1, 1, 23, 20, 25, 379071, tzinfo=dt.timezone.utc), + t0=dt.datetime(2023, 1, 1, 16, 34, 40, 77082, tzinfo=dt.timezone.utc), + t1=dt.datetime(2023, 1, 1, 23, 20, 26, 152794, tzinfo=dt.timezone.utc), name='sun', mode='setting' ) @@ -95,7 +95,7 @@ def test_precomputed_source(): assert len(source.blocks) == 4 t = int(dt.datetime(2023, 1, 1, 5, 0, 0, tzinfo=dt.timezone.utc).timestamp()) - assert np.allclose(source.interp_az(t), [295.34634092]) + assert np.allclose(source.interp_az(t), [-64.65207448]) assert np.allclose(source.interp_alt(t), [15.48141435]) assert 'uranus' in src.PRECOMPUTED_SOURCES @@ -122,13 +122,13 @@ def test_source_gen_seq(): assert blocks == [ src.SourceBlock( t0=dt.datetime(2023, 1, 1, 0, 0, tzinfo=dt.timezone.utc), - t1=dt.datetime(2023, 1, 1, 0, 40, 38, 505632, tzinfo=dt.timezone.utc), + t1=dt.datetime(2023, 1, 1, 0, 40, 38, 909838, tzinfo=dt.timezone.utc), name='uranus', mode='rising' ), src.SourceBlock( - t0=dt.datetime(2023, 1, 1, 0, 40, 38, 505632, tzinfo=dt.timezone.utc), - t1=dt.datetime(2023, 1, 1, 6, 14, 17, 34180, tzinfo=dt.timezone.utc), + t0=dt.datetime(2023, 1, 1, 0, 40, 38, 909838, tzinfo=dt.timezone.utc), + t1=dt.datetime(2023, 1, 1, 6, 14, 17, 199762, tzinfo=dt.timezone.utc), name='uranus', mode='setting' ) @@ -143,5 +143,5 @@ def test_source_block_get_az_alt(): ) times, az, alt = src.source_block_get_az_alt(srcblk) assert len(times) == 67 - assert np.allclose(az[:5], [83.97807382, 395.73085271, 349.995829, 362.00775543, 358.54644389]) - assert np.allclose(alt[:5], [51.01782363, 51.01763853, 51.01706878, 51.01611489, 51.01477258]) + assert np.allclose(az[:5], [0.00580815, -0.18562059, -0.37705974, -0.56848984, -0.75989556]) + assert np.allclose(alt[:5], [51.01486071, 51.01468336, 51.01411836, 51.01316534, 51.01183098]) diff --git a/tests/test_utils.py b/tests/test_utils.py index 884a5150..cd2aa947 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -53,12 +53,6 @@ def test_mask2ranges_digest(): for i_left, i_right in mask2ranges(mask): assert np.all(mask[i_left:i_right]) -def test_parse_sequence_from_toast(): - ifile = op.join(op.abspath(op.dirname(__file__)), "data/schedule_test.txt") - seq = parse_sequence_from_toast(ifile) - print(seq) - assert len(seq) == 17 - def test_ranges_pad(): mask = np.array([False, False, True, True, False, False, False, True, True, True, False, False]) ranges = mask2ranges(mask)