Skip to content

Commit

Permalink
Add option to use custom weights in inference. (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
veichta authored Nov 28, 2024
1 parent 57d3648 commit 6cac7b5
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 20 deletions.
26 changes: 16 additions & 10 deletions geocalib/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,24 @@ def __init__(self, weights: str = "pinhole"):
"""Initialize the model with optional config overrides.
Args:
weights (str): trained variant, "pinhole" (default) or "distorted".
weights (str): Weights to load. Can be "pinhole", "distorted" or path to a checkpoint.
Note that in case of custom weights, the architecture must match the original model.
If this is not the case, use the extractor from the 'siclib' package
(from siclib.models.extractor import GeoCalib).
"""
super().__init__()
if weights not in {"pinhole", "distorted"}:
raise ValueError(f"Unknown weights: {weights}")
url = f"https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-{weights}.tar"

# load checkpoint
model_dir = f"{torch.hub.get_dir()}/geocalib"
state_dict = torch.hub.load_state_dict_from_url(
url, model_dir, map_location="cpu", file_name=f"{weights}.tar"
)
if weights in {"pinhole", "distorted"}:
url = f"https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-{weights}.tar"

# load checkpoint
model_dir = f"{torch.hub.get_dir()}/geocalib"
state_dict = torch.hub.load_state_dict_from_url(
url, model_dir, map_location="cpu", file_name=f"{weights}.tar"
)
elif Path(weights).exists():
state_dict = torch.load(weights, map_location="cpu")
else:
raise ValueError(f"Invalid weights: {weights}")

self.model = Model()
self.model.flexible_load(state_dict["model"])
Expand Down
25 changes: 15 additions & 10 deletions siclib/models/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.nn.functional import interpolate

from siclib.geometry.base_camera import BaseCamera
from siclib.models import get_model
from siclib.models.networks.geocalib import GeoCalib as Model
from siclib.utils.image import ImagePreprocessor, load_image

Expand All @@ -19,20 +20,24 @@ def __init__(self, weights: str = "pinhole"):
"""Initialize the model with optional config overrides.
Args:
weights (str, optional): Weights to load. Defaults to "pinhole".
weights (str): Weights to load. Can be "pinhole", "distorted" or path to a checkpoint.
"""
super().__init__()
if weights not in {"pinhole", "distorted"}:
if weights in {"pinhole", "distorted"}:
url = f"https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-{weights}.tar"

# load checkpoint
model_dir = f"{torch.hub.get_dir()}/geocalib"
state_dict = torch.hub.load_state_dict_from_url(
url, model_dir, map_location="cpu", file_name=f"{weights}.tar"
)
self.model = Model({})
elif Path(weights).exists():
state_dict = torch.load(weights, map_location="cpu")
self.model = get_model(state_dict["conf"]["name"])(state_dict["conf"])
else:
raise ValueError(f"Unknown weights: {weights}")
url = f"https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-{weights}.tar"

# load checkpoint
model_dir = f"{torch.hub.get_dir()}/geocalib"
state_dict = torch.hub.load_state_dict_from_url(
url, model_dir, map_location="cpu", file_name=f"{weights}.tar"
)

self.model = Model({})
self.model.flexible_load(state_dict["model"])
self.model.eval()

Expand Down

0 comments on commit 6cac7b5

Please sign in to comment.