Skip to content

Commit

Permalink
[F] add is_model_builder into Registry to fix the fail of function build
Browse files Browse the repository at this point in the history
  • Loading branch information
kunkunlin1221 committed Dec 24, 2024
1 parent 8f63a54 commit 57f400d
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 42 deletions.
59 changes: 30 additions & 29 deletions chameleon/modules/backbones/gpunet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import List, Optional

import torch
Expand Down Expand Up @@ -125,45 +126,45 @@ def _replace_padding(model):
out_indices=out_indices,
)

@classmethod
def build_gpunet_0(cls, **kwargs):
return cls.build_gpunet(name='gpunet_0', **kwargs)
# @classmethod
# def build_gpunet_0(cls, **kwargs):
# return cls.build_gpunet(name='gpunet_0', **kwargs)

@classmethod
def build_gpunet_1(cls, **kwargs):
return cls.build_gpunet(name='gpunet_1', **kwargs)
# @classmethod
# def build_gpunet_1(cls, **kwargs):
# return cls.build_gpunet(name='gpunet_1', **kwargs)

@classmethod
def build_gpunet_2(cls, **kwargs):
return cls.build_gpunet(name='gpunet_2', **kwargs)
# @classmethod
# def build_gpunet_2(cls, **kwargs):
# return cls.build_gpunet(name='gpunet_2', **kwargs)

@classmethod
def build_gpunet_p0(cls, **kwargs):
return cls.build_gpunet(name='gpunet_p0', **kwargs)
# @classmethod
# def build_gpunet_p0(cls, **kwargs):
# return cls.build_gpunet(name='gpunet_p0', **kwargs)

@classmethod
def build_gpunet_p1(cls, **kwargs):
return cls.build_gpunet(name='gpunet_p1', **kwargs)
# @classmethod
# def build_gpunet_p1(cls, **kwargs):
# return cls.build_gpunet(name='gpunet_p1', **kwargs)

@classmethod
def build_gpunet_d1(cls, **kwargs):
return cls.build_gpunet(name='gpunet_d1', **kwargs)
# @classmethod
# def build_gpunet_d1(cls, **kwargs):
# return cls.build_gpunet(name='gpunet_d1', **kwargs)

@classmethod
def build_gpunet_d2(cls, **kwargs):
return cls.build_gpunet(name='gpunet_d2', **kwargs)
# @classmethod
# def build_gpunet_d2(cls, **kwargs):
# return cls.build_gpunet(name='gpunet_d2', **kwargs)


GPUNETs = {
'GPUNet_0': GPUNet.build_gpunet_0,
'GPUNet_1': GPUNet.build_gpunet_1,
'GPUNet_2': GPUNet.build_gpunet_2,
'GPUNet_p0': GPUNet.build_gpunet_p0,
'GPUNet_p1': GPUNet.build_gpunet_p1,
'GPUNet_d1': GPUNet.build_gpunet_d1,
'GPUNet_d2': GPUNet.build_gpunet_d2,
'GPUNet_0': partial(GPUNet.build_gpunet, name='gpunet_0'),
'GPUNet_1': partial(GPUNet.build_gpunet, name='gpunet_1'),
'GPUNet_2': partial(GPUNet.build_gpunet, name='gpunet_2'),
'GPUNet_p0': partial(GPUNet.build_gpunet, name='gpunet_p0'),
'GPUNet_p1': partial(GPUNet.build_gpunet, name='gpunet_p1'),
'GPUNet_d1': partial(GPUNet.build_gpunet, name='gpunet_d1'),
'GPUNet_d2': partial(GPUNet.build_gpunet, name='gpunet_d2'),
}


for k, v in GPUNETs.items():
BACKBONES.register_module(name=k, module=v)
BACKBONES.register_module(name=k, module=v, is_model_builder=True)
18 changes: 9 additions & 9 deletions chameleon/modules/backbones/timm.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from functools import partial

import timm
import torch.nn as nn

from ...registry import BACKBONES

models = timm.list_models()

class Timm:
@staticmethod
def build_model(*args, **kwargs) -> nn.Module:
return timm.create_model(*args, **kwargs)


timm_models = timm.list_models()
for name in timm_models:
BACKBONES.register_module(f'timm_{name}', module=Timm.build_model)
for name in models:
BACKBONES.register_module(
f'timm_{name}',
module=partial(timm.create_model, model_name=name),
is_model_builder=True,
)
25 changes: 21 additions & 4 deletions chameleon/registry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def build_from_cfg(cfg: dict, registry: "Registry") -> Any:
kwargs = cfg.copy()
name = kwargs.pop('name')
obj_cls = registry.get(name)
if inspect.isclass(obj_cls) or inspect.ismethod(obj_cls):
is_model_builder = registry.is_model_builder(name)

if inspect.isclass(obj_cls) or is_model_builder:
obj = obj_cls(**kwargs)
else:
obj = obj_cls
Expand All @@ -33,6 +35,7 @@ class Registry:
def __init__(self, name: str):
self._name = name
self._module_dict: Dict[str, Type] = dict()
self._type_dict: Dict[str, Type] = dict()

def __len__(self):
return len(self._module_dict)
Expand Down Expand Up @@ -73,14 +76,26 @@ def get(self, key: str) -> Optional[Type]:

return obj_cls

def is_model_builder(self, key: str) -> bool:
if not isinstance(key, str):
raise TypeError(f'key must be a str, but got {type(key)}')

is_model_builder = self._type_dict.get(key, None)

if is_model_builder is None:
raise KeyError(f'{key} is not in the {self.name} registry')

return is_model_builder

def build(self, cfg: dict) -> Any:
return build_from_cfg(cfg, registry=self)

def _register_module(
self,
module: Type,
module_name: Optional[Union[str, List[str]]] = None,
force: bool = False
force: bool = False,
is_model_builder: bool = False,
) -> None:
if not callable(module):
raise TypeError(f'module must be a callable, but got {type(module)}')
Expand All @@ -94,12 +109,14 @@ def _register_module(
existed_module = self.module_dict[name]
raise KeyError(f'{name} is already registered in {self.name} at {existed_module.__module__}')
self._module_dict[name] = module
self._type_dict[name] = is_model_builder

def register_module(
self,
name: str = None,
force: bool = False,
module: Optional[Type] = None,
is_model_builder: bool = False,
) -> Union[type, Callable]:

if not (name is None or isinstance(name, str)):
Expand All @@ -110,12 +127,12 @@ def register_module(

# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(module=module, module_name=name, force=force)
self._register_module(module=module, module_name=name, force=force, is_model_builder=is_model_builder)
return module

# use it as a decorator: @x.register_module()
def _register(module):
self._register_module(module=module, module_name=name, force=force)
self._register_module(module=module, module_name=name, force=force, is_model_builder=is_model_builder)
return module

return _register
Expand Down

0 comments on commit 57f400d

Please sign in to comment.