Skip to content

Commit

Permalink
Revert "Add option to use custom weights in inference. (#15)"
Browse files Browse the repository at this point in the history
This reverts commit 6cac7b5.
  • Loading branch information
veichta authored Nov 28, 2024
1 parent 6cac7b5 commit 63578b8
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 31 deletions.
26 changes: 10 additions & 16 deletions geocalib/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,18 @@ def __init__(self, weights: str = "pinhole"):
"""Initialize the model with optional config overrides.
Args:
weights (str): Weights to load. Can be "pinhole", "distorted" or path to a checkpoint.
Note that in case of custom weights, the architecture must match the original model.
If this is not the case, use the extractor from the 'siclib' package
(from siclib.models.extractor import GeoCalib).
weights (str): trained variant, "pinhole" (default) or "distorted".
"""
super().__init__()
if weights in {"pinhole", "distorted"}:
url = f"https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-{weights}.tar"

# load checkpoint
model_dir = f"{torch.hub.get_dir()}/geocalib"
state_dict = torch.hub.load_state_dict_from_url(
url, model_dir, map_location="cpu", file_name=f"{weights}.tar"
)
elif Path(weights).exists():
state_dict = torch.load(weights, map_location="cpu")
else:
raise ValueError(f"Invalid weights: {weights}")
if weights not in {"pinhole", "distorted"}:
raise ValueError(f"Unknown weights: {weights}")
url = f"https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-{weights}.tar"

# load checkpoint
model_dir = f"{torch.hub.get_dir()}/geocalib"
state_dict = torch.hub.load_state_dict_from_url(
url, model_dir, map_location="cpu", file_name=f"{weights}.tar"
)

self.model = Model()
self.model.flexible_load(state_dict["model"])
Expand Down
25 changes: 10 additions & 15 deletions siclib/models/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torch.nn.functional import interpolate

from siclib.geometry.base_camera import BaseCamera
from siclib.models import get_model
from siclib.models.networks.geocalib import GeoCalib as Model
from siclib.utils.image import ImagePreprocessor, load_image

Expand All @@ -20,24 +19,20 @@ def __init__(self, weights: str = "pinhole"):
"""Initialize the model with optional config overrides.
Args:
weights (str): Weights to load. Can be "pinhole", "distorted" or path to a checkpoint.
weights (str, optional): Weights to load. Defaults to "pinhole".
"""
super().__init__()
if weights in {"pinhole", "distorted"}:
url = f"https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-{weights}.tar"

# load checkpoint
model_dir = f"{torch.hub.get_dir()}/geocalib"
state_dict = torch.hub.load_state_dict_from_url(
url, model_dir, map_location="cpu", file_name=f"{weights}.tar"
)
self.model = Model({})
elif Path(weights).exists():
state_dict = torch.load(weights, map_location="cpu")
self.model = get_model(state_dict["conf"]["name"])(state_dict["conf"])
else:
if weights not in {"pinhole", "distorted"}:
raise ValueError(f"Unknown weights: {weights}")
url = f"https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-{weights}.tar"

# load checkpoint
model_dir = f"{torch.hub.get_dir()}/geocalib"
state_dict = torch.hub.load_state_dict_from_url(
url, model_dir, map_location="cpu", file_name=f"{weights}.tar"
)

self.model = Model({})
self.model.flexible_load(state_dict["model"])
self.model.eval()

Expand Down

0 comments on commit 63578b8

Please sign in to comment.