Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[F] add is_model_builder into Registry to fix the fail of function build #5

Merged
merged 2 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 9 additions & 36 deletions chameleon/modules/backbones/gpunet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import List, Optional

import torch
Expand Down Expand Up @@ -125,45 +126,17 @@ def _replace_padding(model):
out_indices=out_indices,
)

@classmethod
def build_gpunet_0(cls, **kwargs):
return cls.build_gpunet(name='gpunet_0', **kwargs)

@classmethod
def build_gpunet_1(cls, **kwargs):
return cls.build_gpunet(name='gpunet_1', **kwargs)

@classmethod
def build_gpunet_2(cls, **kwargs):
return cls.build_gpunet(name='gpunet_2', **kwargs)

@classmethod
def build_gpunet_p0(cls, **kwargs):
return cls.build_gpunet(name='gpunet_p0', **kwargs)

@classmethod
def build_gpunet_p1(cls, **kwargs):
return cls.build_gpunet(name='gpunet_p1', **kwargs)

@classmethod
def build_gpunet_d1(cls, **kwargs):
return cls.build_gpunet(name='gpunet_d1', **kwargs)

@classmethod
def build_gpunet_d2(cls, **kwargs):
return cls.build_gpunet(name='gpunet_d2', **kwargs)


GPUNETs = {
'GPUNet_0': GPUNet.build_gpunet_0,
'GPUNet_1': GPUNet.build_gpunet_1,
'GPUNet_2': GPUNet.build_gpunet_2,
'GPUNet_p0': GPUNet.build_gpunet_p0,
'GPUNet_p1': GPUNet.build_gpunet_p1,
'GPUNet_d1': GPUNet.build_gpunet_d1,
'GPUNet_d2': GPUNet.build_gpunet_d2,
'GPUNet_0': partial(GPUNet.build_gpunet, name='gpunet_0'),
'GPUNet_1': partial(GPUNet.build_gpunet, name='gpunet_1'),
'GPUNet_2': partial(GPUNet.build_gpunet, name='gpunet_2'),
'GPUNet_p0': partial(GPUNet.build_gpunet, name='gpunet_p0'),
'GPUNet_p1': partial(GPUNet.build_gpunet, name='gpunet_p1'),
'GPUNet_d1': partial(GPUNet.build_gpunet, name='gpunet_d1'),
'GPUNet_d2': partial(GPUNet.build_gpunet, name='gpunet_d2'),
}


for k, v in GPUNETs.items():
BACKBONES.register_module(name=k, module=v)
BACKBONES.register_module(name=k, module=v, is_model_builder=True)
18 changes: 9 additions & 9 deletions chameleon/modules/backbones/timm.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from functools import partial

import timm
import torch.nn as nn

from ...registry import BACKBONES

models = timm.list_models()

class Timm:
@staticmethod
def build_model(*args, **kwargs) -> nn.Module:
return timm.create_model(*args, **kwargs)


timm_models = timm.list_models()
for name in timm_models:
BACKBONES.register_module(f'timm_{name}', module=Timm.build_model)
for name in models:
BACKBONES.register_module(
f'timm_{name}',
module=partial(timm.create_model, model_name=name),
is_model_builder=True,
)
25 changes: 21 additions & 4 deletions chameleon/registry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def build_from_cfg(cfg: dict, registry: "Registry") -> Any:
kwargs = cfg.copy()
name = kwargs.pop('name')
obj_cls = registry.get(name)
if inspect.isclass(obj_cls) or inspect.ismethod(obj_cls):
is_model_builder = registry.is_model_builder(name)

if inspect.isclass(obj_cls) or is_model_builder:
obj = obj_cls(**kwargs)
else:
obj = obj_cls
Expand All @@ -33,6 +35,7 @@ class Registry:
def __init__(self, name: str):
self._name = name
self._module_dict: Dict[str, Type] = dict()
self._type_dict: Dict[str, Type] = dict()

def __len__(self):
return len(self._module_dict)
Expand Down Expand Up @@ -73,14 +76,26 @@ def get(self, key: str) -> Optional[Type]:

return obj_cls

def is_model_builder(self, key: str) -> bool:
if not isinstance(key, str):
raise TypeError(f'key must be a str, but got {type(key)}')

is_model_builder = self._type_dict.get(key, None)

if is_model_builder is None:
raise KeyError(f'{key} is not in the {self.name} registry')

return is_model_builder

def build(self, cfg: dict) -> Any:
return build_from_cfg(cfg, registry=self)

def _register_module(
self,
module: Type,
module_name: Optional[Union[str, List[str]]] = None,
force: bool = False
force: bool = False,
is_model_builder: bool = False,
) -> None:
if not callable(module):
raise TypeError(f'module must be a callable, but got {type(module)}')
Expand All @@ -94,12 +109,14 @@ def _register_module(
existed_module = self.module_dict[name]
raise KeyError(f'{name} is already registered in {self.name} at {existed_module.__module__}')
self._module_dict[name] = module
self._type_dict[name] = is_model_builder

def register_module(
self,
name: str = None,
force: bool = False,
module: Optional[Type] = None,
is_model_builder: bool = False,
) -> Union[type, Callable]:

if not (name is None or isinstance(name, str)):
Expand All @@ -110,12 +127,12 @@ def register_module(

# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(module=module, module_name=name, force=force)
self._register_module(module=module, module_name=name, force=force, is_model_builder=is_model_builder)
return module

# use it as a decorator: @x.register_module()
def _register(module):
self._register_module(module=module, module_name=name, force=force)
self._register_module(module=module, module_name=name, force=force, is_model_builder=is_model_builder)
return module

return _register
Expand Down
Loading