Skip to content

Commit

Permalink
Merge pull request #13 from simonsobs/better-rule-policy
Browse files Browse the repository at this point in the history
more flexible policy and drift plant scan
  • Loading branch information
guanyilun authored Oct 19, 2023
2 parents ea79aa6 + 0c1fca6 commit 15ba570
Show file tree
Hide file tree
Showing 16 changed files with 722 additions and 115 deletions.
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
"toolz",
"pandas",
"chex",
"pyephem"
"pyephem",
"equinox",
"so3g"
],
)
14 changes: 14 additions & 0 deletions src/schedlib/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from . import utils
import yaml

def datetime_constructor(loader, node):
"""load a datetime"""
return utils.str2datetime(loader.construct_scalar(node))

def get_loader():
# add some useful tags to the loader
loader = yaml.SafeLoader
loader.add_constructor('!datetime', datetime_constructor)
# ignore unknown tags
loader.add_multi_constructor('!', lambda loader, tag_suffix, node: None)
return loader
55 changes: 45 additions & 10 deletions src/schedlib/core.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import List, Union, Callable, Optional, Any, TypeVar
from chex import dataclass
from abc import ABC, abstractmethod
import datetime as dt
import numpy as np
from toolz import compose_left
import jax.tree_util as tu
import equinox
from dataclasses import dataclass, replace as dc_replace

from . import utils

@dataclass(frozen=True)
class Block:
Expand Down Expand Up @@ -37,6 +39,8 @@ def trim_right_to(self, t: dt.datetime) -> List["Block"]:
return block_trim_right_to(self, t)
def isa(self, block_type: "BlockType") -> bool:
return block_isa(block_type)(self)
def replace(self, **kwargs) -> "Block":
return dc_replace(self, **kwargs)

BlockType = type(Block)
Blocks = List[Union[Block, None, "Blocks"]] # maybe None, maybe nested
Expand Down Expand Up @@ -224,6 +228,9 @@ def seq_filter_out(op: Callable[[Block], bool], blocks: BlocksTree) -> BlocksTre
def seq_map(op, *blocks: BlocksTree) -> List[Any]:
return tu.tree_map(op, *blocks, is_leaf=is_block)

def seq_map_with_path(op, *blocks: BlocksTree) -> List[Any]:
return tu.tree_map_with_path(op, *blocks, is_leaf=is_block)

def seq_map_when(op_when: Callable[[Block], bool], op: Callable[[Block], Any], blocks: BlocksTree) -> List[Any]:
return tu.tree_map(lambda b: op(b) if op_when(b) else b, blocks, is_leaf=is_block)

Expand All @@ -233,6 +240,42 @@ def seq_replace_block(blocks: BlocksTree, source: Block, target: Block) -> Block
def seq_trim(blocks: BlocksTree, t0: dt.datetime, t1: dt.datetime) -> BlocksTree:
return seq_map(lambda b: b.trim(t0, t1), blocks)

def seq_partition(op, blocks: BlocksTree) -> List[Any]:
"""partition a blockstree into two trees, one for blocks that satisfy the predicate,
which is specified through a function that takes a block as input and returns
a boolean, and the second return is for blocks that don't match the predicate. Unmatched
values will be left as None."""
filter_spec = tu.tree_map(op, blocks, is_leaf=is_block)
return equinox.partition(blocks, filter_spec)

def seq_partition_with_path(op, blocks: BlocksTree, **kwargs) -> List[Any]:
"""partition a blockstree into two trees, one for blocks that satisfy the predicate,
which is specified through a function that takes a block and path as input and returns
a boolean, and the second return is for blocks that don't match the predicate. Unmatched
values will be left as None."""
filter_spec = tu.tree_map_with_path(op, blocks, is_leaf=is_block)
return equinox.partition(blocks, filter_spec, **kwargs)

def seq_partition_with_query(query, blocks: BlocksTree):
def path2key(path):
"""convert a path (used in tree_util.tree_map_with_path) to a dot-separated key"""
keys = []
for p in path:
if isinstance(p, tu.SequenceKey):
keys.append(p.idx)
elif isinstance(p, tu.DictKey):
keys.append(p.key)
else:
raise ValueError(f"unknown path type {type(p)}")
return ".".join([str(k) for k in keys])
return seq_partition_with_path(lambda path, block: utils.match_query(path, query), blocks)

def seq_combine(*blocks: BlocksTree) -> BlocksTree:
"""combine blocks from multiple trees into a single tree, where the blocks are
combined in a list. The trees must have the same structure."""
seq_assert_same_structure(*blocks)
return equinox.combine(*blocks, is_leaf=is_block)

# =========================
# Other useful Block types
# =========================
Expand All @@ -256,14 +299,6 @@ def __call__(self, blocks: Blocks) -> Blocks:
Rule = Union[BlocksTransformation, Callable[[Blocks], Blocks]]
RuleSet = List[Rule]

@dataclass(frozen=True)
class MultiRules(BlocksTransformation):
rules: RuleSet
def apply(self, blocks: Blocks) -> Blocks:
"""apply rules to blocks in first-to-last order"""
return compose_left(*self.rules)(blocks)


@dataclass(frozen=True)
class Policy(BlocksTransformation, ABC):
"""apply: apply policy to a tree of blocks"""
Expand Down
81 changes: 69 additions & 12 deletions src/schedlib/instrument.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from __future__ import annotations
from chex import dataclass
from jax import tree_util as tu
import pandas as pd
from typing import List, TypeVar, Union, Dict
import numpy as np
from functools import reduce
from dataclasses import dataclass
from so3g.proj import quat

from . import core
from . import core, utils as u

@dataclass(frozen=True)
class ScanBlock(core.NamedBlock):
az: float # deg
alt: float # deg
throw: float # deg
az: float # deg
alt: float # deg
throw: float # deg
drift: float = 0 # deg / s

@dataclass(frozen=True)
class IVBlock(core.NamedBlock): pass
Expand All @@ -35,14 +38,13 @@ def get_spec(specs: SpecsTree, query: List[str], merge=True) -> Union[Spec, Spec
one of the queries. return all matches if merge=False"""
is_leaf = lambda x: isinstance(x, dict) and 'bounds_x' in x
match_p = lambda key: any([p in key for p in query])
path2key = lambda path: ".".join([str(p.key) for p in path])
def reduce_fn(l, r):
res = {}
for k in ['bounds_x', 'bounds_y']:
res[k] = [min(l[k][0], r[k][0]), max(l[k][1], r[k][1])]
return res
all_matches = tu.tree_leaves(
tu.tree_map_with_path(lambda path, x: x if match_p(path2key(path)) else None, specs, is_leaf=is_leaf),
tu.tree_map_with_path(lambda path, x: x if match_p(u.path2key(path)) else None, specs, is_leaf=is_leaf),
is_leaf=is_leaf
) # None is not a leaf, so it will be filtered out
if not merge: return all_matches
Expand All @@ -51,11 +53,66 @@ def reduce_fn(l, r):

def get_bounds_x_tilted(bounds_x: List[float], bounds_y: List[float], phi_tilt: Union[float, core.Arr[float]], shape: str):
"""get the effective bounds of the x-axis of the spec when covering a tilted patch"""
assert shape in ['ellipse'] # more to implement
assert shape in ['ellipse', 'rect'] # more to implement
a = (bounds_x[1] - bounds_x[0])/2
b = (bounds_y[1] - bounds_y[0])/2
if shape == 'ellipse':
a = (bounds_x[1] - bounds_x[0])/2
b = (bounds_y[1] - bounds_y[0])/2
w_proj = np.sqrt(a**2 * np.cos(phi_tilt)**2 + b**2 * np.sin(phi_tilt)**2)
return np.array([-w_proj, w_proj]) + (bounds_x[0] + bounds_x[1])/2
w_proj = a * np.sqrt(1 + b**2 / a**2 * np.tan(phi_tilt)**2)
elif shape == 'rect':
w_proj = b * np.tan(phi_tilt) + a
else:
raise NotImplementedError
return np.array([-w_proj, w_proj]) + (bounds_x[0] + bounds_x[1])/2

def make_circular_cover(xi0, eta0, R, count=50, degree=True):
"""make a circular cover centered at xi0, eta0 with radius R"""
if degree: xi0, eta0, R = np.deg2rad([xi0, eta0, R])
dphi = 2*np.pi/count
phi = np.arange(count) * dphi
L = 1.01*R / np.cos(dphi/2)
xi, eta = L * np.cos(phi), L * np.sin(phi)
xi, eta, _ = quat.decompose_xieta(quat.rotation_xieta(xi0, eta0) * quat.rotation_xieta(xi, eta))
return {
'center': (xi0, eta0),
'cover': np.array([xi, eta])
}

def array_info_merge(arrays):
center = np.mean(np.array([a['center'] for a in arrays]), axis=0)
cover = np.concatenate([a['cover'] for a in arrays], axis=1)
return {
'center': center,
'cover': cover
}

def array_info_from_query(geometries, query):
"""make an array info with geometries that match the query"""
is_leaf = lambda x: isinstance(x, dict) and 'center' in x
matched = tu.tree_leaves(tu.tree_map_with_path(
lambda path, x: x if u.match_query(path, query) else None,
geometries,
is_leaf=is_leaf
), is_leaf=is_leaf)
arrays = [make_circular_cover(*g['center'], g['radius']) for g in matched]
return array_info_merge(arrays)

def parse_sequence_from_toast(ifile):
"""
Parameters
----------
ifile: input master schedule from toast
"""
columns = ["start_utc", "stop_utc", "rotation", "patch", "az_min", "az_max", "el", "pass", "sub"]
df = pd.read_csv(ifile, skiprows=3, delimiter="|", names=columns)
blocks = []
for _, row in df.iterrows():
block = ScanBlock(
name=row['patch'].strip(),
t0=u.str2datetime(row['start_utc']),
t1=u.str2datetime(row['stop_utc']),
alt=row['el'],
az=row['az_min'],
throw=np.abs(row['az_max'] - row['az_min']),
)
blocks.append(block)
return blocks
2 changes: 2 additions & 0 deletions src/schedlib/policies/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .basic import BasePolicy, BasicPolicy
from .flex import FlexPolicy
31 changes: 16 additions & 15 deletions src/schedlib/policies.py → src/schedlib/policies/basic.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#!/usr/bin/env python3

from chex import dataclass
import datetime as dt
from abc import ABC, abstractmethod
from typing import List
from . import core, utils, commands as cmd, instrument as inst, rules as ru, source as src
from typing import List
from dataclasses import dataclass
from .. import core, utils, commands as cmd, instrument as inst, rules as ru, source as src, config as cfg


@dataclass(frozen=True)
class BasePolicy(core.Policy, ABC):
Expand All @@ -15,15 +16,6 @@ class BasePolicy(core.Policy, ABC):
preserve the nested structure for the user to see, but we can
also flatten the structure for the scheduler to consume."""

rules: core.RuleSet

def make_rule(self, rule_name: str, **kwargs) -> core.Rule:
# caller kwargs take precedence
if not kwargs:
assert rule_name in self.rules, f"Rule {rule_name} not found in rules config"
kwargs = self.rules[rule_name]
return ru.make_rule(rule_name, **kwargs)

@abstractmethod
def transform(self, blocks: core.BlocksTree) -> core.BlocksTree: ...

Expand All @@ -39,15 +31,24 @@ def apply(self, blocks: core.BlocksTree) -> core.Blocks:
@abstractmethod
def seq2cmd(self, seq: core.Blocks) -> cmd.Command: ...


@dataclass(frozen=True)
class BasicPolicy(BasePolicy):

rules: core.RuleSet
master_schedule: str
calibration_targets: List[str]
soft_targets: List[str]

def make_rule(self, rule_name: str, **kwargs) -> core.Rule:
# caller kwargs take precedence
print(self.rules)
if not kwargs:
assert rule_name in self.rules, f"Rule {rule_name} not found in rules config"
kwargs = self.rules[rule_name]
return ru.make_rule(rule_name, **kwargs)

def init_seqs(self, t0: dt.datetime, t1: dt.datetime) -> core.BlocksTree:
master = utils.parse_sequence_from_toast(self.master_schedule)
master = inst.parse_sequence_from_toast(self.master_schedule)
calibration = {k: src.source_gen_seq(k, t0, t1) for k in self.calibration_targets}
soft = {k: src.source_gen_seq(k, t0, t1) for k in self.soft_targets}
blocks = {
Expand Down Expand Up @@ -120,4 +121,4 @@ def seq2cmd(self, seq: core.Blocks):
"""map a scan to a command"""
commands = core.seq_flatten(core.seq_map(self.block2cmd, seq))
commands = [cmd.Preamble()] + commands
return cmd.CompositeCommand(commands)
return cmd.CompositeCommand(commands)
Loading

0 comments on commit 15ba570

Please sign in to comment.