Skip to content

Commit

Permalink
Merge pull request #15 from simonsobs/ref_file
Browse files Browse the repository at this point in the history
feat: switch to logger
  • Loading branch information
skhrg authored Nov 19, 2024
2 parents b672238 + 8561ecd commit cccb6bc
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 41 deletions.
23 changes: 23 additions & 0 deletions lat_alignment/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import argparse
import logging
import os
from functools import partial
from importlib.resources import files
Expand Down Expand Up @@ -64,18 +65,29 @@ def main():
# load config
parser = argparse.ArgumentParser()
parser.add_argument("config", help="path to config file")
parser.add_argument(
"--log_level", "-l", default="INFO", help="the log level to use"
)
args = parser.parse_args()
logging.basicConfig()
logger = logging.getLogger("lat_alignment")
logger.setLevel(args.log_level.upper())
with open(args.config) as file:
cfg = yaml.safe_load(file)

mode = cfg.get("mode", "panel")
cfgdir = os.path.dirname(os.path.abspath(args.config))
meas_file = os.path.abspath(os.path.join(cfgdir, cfg["measurement"]))
title_str = cfg["title"]
logger.info("Begining alignment %s in %s mode", title_str, mode)
logger.debug("Using measurement file: %s", meas_file)

dat_dir = os.path.abspath(os.path.join(cfgdir, cfg.get("data_dir", "/")))
if "data_dir" in cfg:
logger.info("Using data files from %s", dat_dir)
ref_path = os.path.join(dat_dir, "reference.yaml")
else:
logger.info("Using packaged data files")
ref_path = str(files("lat_alignment.data").joinpath("reference.yaml"))
with open(ref_path) as file:
reference = yaml.safe_load(file)
Expand All @@ -88,6 +100,7 @@ def main():
mnum = 2
else:
raise ValueError(f"Invalid mirror: {mirror}")
logger.info("Aligning panels for the %s mirror", mirror)

if "data_dir" in cfg:
corner_path = os.path.join(dat_dir, f"{mirror}_corners.yaml")
Expand Down Expand Up @@ -117,10 +130,12 @@ def main():
cfg.get("compensate", 0),
cfg.get("adjuster_radius", 100),
)
logger.info("Found measurements for %d panels", len(panels))
fig = mir.plot_panels(panels, title_str, vmax=cfg.get("vmax", None))
fig.savefig(os.path.join(cfgdir, f"{title_str.replace(' ', '_')}.png"))

# calc and save adjustments
logger.info("Caluculating adjustments")
_adjust = partial(adjust_panel, mnum=mnum, cfg=cfg)
adjustments = np.vstack(pqdm(panels, _adjust, n_jobs=8))
order = np.lexsort((adjustments[2], adjustments[1], adjustments[0]))
Expand All @@ -136,6 +151,7 @@ def main():
raise ValueError(f"Invalid element specified for 'align_to': {align_to}")
if align_to in ["receiver", "bearing"]:
raise NotImplementedError(f"Alignment with {align_to} not yet implemented")
logger.info("Aligning all optical elements to the %s", align_to)

# Load data and compute the transformation to align with the model
# We want to put all the transformations into opt_global
Expand Down Expand Up @@ -194,8 +210,14 @@ def main():
raise ValueError(
f"Specified 'align_to' element ({align_to}) not found in measurment. Can't align!"
)
logger.info(
"Found %d optical elements in measurement: %s",
len(elements),
str(list(elements.keys())),
)

# Now combine with the align_to alignment
logger.info("Composing transforms to align with %s fixed", align_to)
transforms = {}
align_to_inv = mt.invert_transform(*elements[align_to])
for element, full_transform in elements.items():
Expand All @@ -213,3 +235,4 @@ def main():

else:
raise ValueError(f"Invalid mode: {mode}")
logger.info("Outputs can be found in %s", cfgdir)
8 changes: 7 additions & 1 deletion lat_alignment/io.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from collections import defaultdict

import matplotlib.pyplot as plt
Expand All @@ -8,6 +9,8 @@

from .transforms import align_photo, coord_transform

logger = logging.getLogger("lat_alignment")


def load_photo(
path: str,
Expand Down Expand Up @@ -51,6 +54,7 @@ def load_photo(
The first element is a rotation matrix and
the second is the shift.
"""
logger.info("Loading measurement data")
labels = np.genfromtxt(path, dtype=str, delimiter=",", usecols=(0,))
coords = np.genfromtxt(path, dtype=np.float32, delimiter=",", usecols=(1, 2, 3))
errs = np.genfromtxt(path, dtype=np.float32, delimiter=",", usecols=(4, 5, 6))
Expand All @@ -61,7 +65,7 @@ def load_photo(

if align:
labels, coords, msk, alignment = align_photo(
labels, coords, reference, **kwargs
labels, coords, reference, plot=plot, **kwargs
)
err = err[msk]
else:
Expand All @@ -73,6 +77,7 @@ def load_photo(

err_msk = err < err_thresh * np.median(err)
labels, coords, err = labels[err_msk], coords[err_msk], err[err_msk]
logger.info("\t%d points loaded", len(coords))

# Lets find and remove doubles
# Dumb brute force
Expand All @@ -90,6 +95,7 @@ def load_photo(
else:
to_kill += [i]
msk = ~np.isin(np.arange(len(coords), dtype=int), to_kill)
logger.info("\tFound and removed %d doubles", len(to_kill))
labels, coords = labels[msk], coords[msk]

if plot:
Expand Down
45 changes: 26 additions & 19 deletions lat_alignment/mirror.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Functions to describe the mirror surface.
"""

import logging
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass, field
Expand All @@ -24,6 +25,8 @@
from scipy.optimize import minimize
from scipy.spatial.transform import Rotation

logger = logging.getLogger("lat_alignment")

# fmt: off
a = {'primary' :
np.array([
Expand Down Expand Up @@ -373,7 +376,6 @@ def remove_cm(
thresh: float = 10,
cut_thresh: float = 50,
niters: int = 10,
verbose=False,
) -> tuple[
dict[str, NDArray[np.float32]], tuple[NDArray[np.float32], NDArray[np.float32]]
]:
Expand All @@ -396,8 +398,6 @@ def remove_cm(
considered an outlier.
niters : int, default: 10
How many iterations of common mode fitting to do.
verbose : bool, default: False
If True print the transformation for each iteration.
Returns
-------
Expand All @@ -408,6 +408,7 @@ def remove_cm(
The first element is an affine matrix and
the second is the shift.
"""
logger.info("Removing common mode for %s", mirror)

def _cm(x, panel):
panel.measurements[:] -= x[1:4]
Expand Down Expand Up @@ -445,37 +446,38 @@ def _opt(x, panel):
)
data = data.copy()
data_clean = data.copy()
logger.info("\tRemoved %d points not on mirror surface", np.sum(~msk))

x0 = np.hstack([np.ones(1), np.zeros(6)])
bounds = [(-0.95, 1.05)] + [(-100, 100)] * 3 + [(0, 2 * np.pi)] * 3

for i in range(niters):
if len(panel.measurements) < 3:
raise ValueError
print(f"iter {i} for common mode fit")
logger.debug("\titer %d for common mode fit", i)
cut = panel.res_norm > thresh * np.median(panel.res_norm)
if np.sum(cut) > 0:
# print(f"\tRemoving {np.sum(cut)} points from mirror")
panel.measurements = panel.measurements[~cut]
# labels = labels[~cut]
data = data[~cut]

if verbose:
print(f"\tRemoving a naive common mode shift of {panel.shift}")
logger.debug("\t\tRemoving a naive common mode shift of %s", str(panel.shift))
panel.measurements -= panel.shift
panel.measurements @= panel.rot.T

res = minimize(_opt, x0, (panel,), bounds=bounds)
if verbose:
print(
f"\tRemoving a fit common mode with scale {res.x[0]}, shift {res.x[1:4]}, and rotation {res.x[4:]}"
)
logger.debug(
"\t\tRemoving a fit common mode with scale %f, shift %s, and rotation %s",
res.x[0],
str(res.x[1:4]),
str(res.x[4:]),
)
_cm(res.x, panel)

if verbose:
print(
f"\tRemoving a secondary common mode shift of {panel.shift} and rotation of {decompose_rotation(panel.rot)}"
)
logger.debug(
"\t\tRemoving a secondary common mode shift of %s and rotation of %s",
str(panel.shift),
str(np.rad2deg(decompose_rotation(panel.rot))),
)
panel.measurements -= panel.shift
panel.measurements @= panel.rot.T

Expand All @@ -484,15 +486,20 @@ def _opt(x, panel):
)
scale, shear, rot = decompose_affine(aff)
rot = decompose_rotation(rot)
print(
f"Full common mode is:\n\tshift = {sft} mm\n\tscale = {scale}\n\tshear = {shear}\n\trot = {np.rad2deg(rot)} deg"
logger.info(
"\tFull common mode is:\n\t\t\tshift = %s mm\n\t\t\tscale = %s\n\t\t\tshear = %s\n\t\t\trot = %s deg",
str(sft),
str(scale),
str(shear),
str(np.rad2deg(rot)),
)

panel.measurements = apply_transform(data_clean, aff, sft)
cut = panel.res_norm > cut_thresh * np.median(panel.res_norm)
if np.sum(cut) > 0:
print(f"Removing {np.sum(cut)} points from mirror")
logger.info("\tRemoving %d bad points from mirror", np.sum(cut))
panel.measurements = panel.measurements[~cut]
logger.info("\tMirror has %d good points", len(panel.measurements))

return {l: d for l, d in zip(labels, panel.measurements)}, (aff, sft)

Expand Down
48 changes: 27 additions & 21 deletions lat_alignment/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@
- va_secondary
"""

import logging
from functools import cache, partial
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
from megham.transform import apply_transform, get_affine, get_rigid
from megham.utils import make_edm
from numpy.typing import NDArray

logger = logging.getLogger("lat_alignment")

opt_sm1 = np.array((0, 0, 3600), np.float32) # mm
opt_sm2 = np.array((0, -4800, 0), np.float32) # mm
opt_am1 = -np.arctan(0.5)
Expand Down Expand Up @@ -385,6 +388,7 @@ def align_photo(
coords: NDArray[np.float32],
reference: dict,
*,
plot: bool = True,
mirror: str = "primary",
max_dist: float = 100.0,
) -> tuple[
Expand Down Expand Up @@ -415,6 +419,9 @@ def align_photo(
of the point in the global coordinate system.
The second is a list of nearby coded targets that can be used
to identify the point.
plot : bool, default: True
If True show a diagnostic plot of how well the reference points
are aligned.
mirror : str, default: 'primary'
The mirror that these points belong to.
Should be either: 'primary' or 'secondary'.
Expand All @@ -437,6 +444,7 @@ def align_photo(
The first element is a rotation matrix and
the second is the shift.
"""
logger.info("\tAligning with reference points for %s", mirror)
if mirror not in ["primary", "secondary"]:
raise ValueError(f"Invalid mirror: {mirror}")
if len(reference) == 0:
Expand Down Expand Up @@ -464,43 +472,41 @@ def align_photo(
if np.sum(have) == 0:
continue
coded = coords[np.where(labels == codes[np.where(have)[0][0]])[0][0]]
print(codes[np.where(have)[0][0]])
# Find the closest point
dist = np.linalg.norm(coords[trg_idx] - coded, axis=-1)
if np.min(dist) > max_dist:
continue
print(np.min(dist))
ref += [rpoint]
pts += [coords[trg_idx][np.argmin(dist)]]
invars += [labels[trg_idx][np.argmin(dist)]]
if len(ref) < 4:
raise ValueError(f"Only {len(ref)} reference points found! Can't align!")
logger.debug(
"\t\tFound %d reference points in measurements with labels:\n\t\t\t%s",
len(pts),
str(invars),
)
pts = np.vstack(pts)
ref = np.vstack(ref)
pts = np.vstack((pts, np.mean(pts, 0)))
ref = np.vstack((ref, np.mean(ref, 0)))
ref = transform(ref)
print("Reference points in mirror coords:")
print(ref[:-1])
print(make_edm(ref) / make_edm(pts))
print(make_edm(ref) - make_edm(pts))
print(np.nanmedian(make_edm(ref) / make_edm(pts)))
pts *= np.nanmedian(make_edm(ref) / make_edm(pts))
print(make_edm(ref) / make_edm(pts))
print(make_edm(ref) - make_edm(pts))
print(np.nanmedian(make_edm(ref) / make_edm(pts)))
logger.debug("\t\tReference points in mirror coords:\n%s", str(ref[:-1]))
triu_idx = np.triu_indices(len(pts), 1)
scale_fac = np.nanmedian(make_edm(ref)[triu_idx] / make_edm(pts)[triu_idx])
logger.debug("\t\tScale factor of %f applied", scale_fac)
pts *= scale_fac

rot, sft = get_rigid(pts, ref, method="mean")
pts_t = apply_transform(pts, rot, sft)
import matplotlib.pyplot as plt

plt.scatter(pts_t[:, 0], pts_t[:, 1], color="b")
plt.scatter(ref[:, 0], ref[:, 1], color="r")
plt.show()
print(pts_t[:-1])
print(pts_t - ref)
print(
f"RMS of reference points after alignment: {np.sqrt(np.mean((pts_t - ref)**2))}"

if plot:
plt.scatter(pts_t[:, 0], pts_t[:, 1], color="b")
plt.scatter(ref[:, 0], ref[:, 1], color="r")
plt.show()
logger.info(
"\t\tRMS of reference points after alignment: %f",
np.sqrt(np.mean((pts_t - ref) ** 2)),
)
coords_transformed = apply_transform(coords, rot, sft)

Expand Down

0 comments on commit cccb6bc

Please sign in to comment.