diff --git a/chameleon/modules/backbones/gpunet.py b/chameleon/modules/backbones/gpunet.py index bccda37..70e97c2 100644 --- a/chameleon/modules/backbones/gpunet.py +++ b/chameleon/modules/backbones/gpunet.py @@ -1,3 +1,4 @@ +from functools import partial from typing import List, Optional import torch @@ -125,45 +126,45 @@ def _replace_padding(model): out_indices=out_indices, ) - @classmethod - def build_gpunet_0(cls, **kwargs): - return cls.build_gpunet(name='gpunet_0', **kwargs) + # @classmethod + # def build_gpunet_0(cls, **kwargs): + # return cls.build_gpunet(name='gpunet_0', **kwargs) - @classmethod - def build_gpunet_1(cls, **kwargs): - return cls.build_gpunet(name='gpunet_1', **kwargs) + # @classmethod + # def build_gpunet_1(cls, **kwargs): + # return cls.build_gpunet(name='gpunet_1', **kwargs) - @classmethod - def build_gpunet_2(cls, **kwargs): - return cls.build_gpunet(name='gpunet_2', **kwargs) + # @classmethod + # def build_gpunet_2(cls, **kwargs): + # return cls.build_gpunet(name='gpunet_2', **kwargs) - @classmethod - def build_gpunet_p0(cls, **kwargs): - return cls.build_gpunet(name='gpunet_p0', **kwargs) + # @classmethod + # def build_gpunet_p0(cls, **kwargs): + # return cls.build_gpunet(name='gpunet_p0', **kwargs) - @classmethod - def build_gpunet_p1(cls, **kwargs): - return cls.build_gpunet(name='gpunet_p1', **kwargs) + # @classmethod + # def build_gpunet_p1(cls, **kwargs): + # return cls.build_gpunet(name='gpunet_p1', **kwargs) - @classmethod - def build_gpunet_d1(cls, **kwargs): - return cls.build_gpunet(name='gpunet_d1', **kwargs) + # @classmethod + # def build_gpunet_d1(cls, **kwargs): + # return cls.build_gpunet(name='gpunet_d1', **kwargs) - @classmethod - def build_gpunet_d2(cls, **kwargs): - return cls.build_gpunet(name='gpunet_d2', **kwargs) + # @classmethod + # def build_gpunet_d2(cls, **kwargs): + # return cls.build_gpunet(name='gpunet_d2', **kwargs) GPUNETs = { - 'GPUNet_0': GPUNet.build_gpunet_0, - 'GPUNet_1': GPUNet.build_gpunet_1, - 'GPUNet_2': GPUNet.build_gpunet_2, - 'GPUNet_p0': GPUNet.build_gpunet_p0, - 'GPUNet_p1': GPUNet.build_gpunet_p1, - 'GPUNet_d1': GPUNet.build_gpunet_d1, - 'GPUNet_d2': GPUNet.build_gpunet_d2, + 'GPUNet_0': partial(GPUNet.build_gpunet, name='gpunet_0'), + 'GPUNet_1': partial(GPUNet.build_gpunet, name='gpunet_1'), + 'GPUNet_2': partial(GPUNet.build_gpunet, name='gpunet_2'), + 'GPUNet_p0': partial(GPUNet.build_gpunet, name='gpunet_p0'), + 'GPUNet_p1': partial(GPUNet.build_gpunet, name='gpunet_p1'), + 'GPUNet_d1': partial(GPUNet.build_gpunet, name='gpunet_d1'), + 'GPUNet_d2': partial(GPUNet.build_gpunet, name='gpunet_d2'), } for k, v in GPUNETs.items(): - BACKBONES.register_module(name=k, module=v) + BACKBONES.register_module(name=k, module=v, is_model_builder=True) diff --git a/chameleon/modules/backbones/timm.py b/chameleon/modules/backbones/timm.py index c17ff87..280e3ad 100644 --- a/chameleon/modules/backbones/timm.py +++ b/chameleon/modules/backbones/timm.py @@ -1,15 +1,15 @@ +from functools import partial + import timm import torch.nn as nn from ...registry import BACKBONES +models = timm.list_models() -class Timm: - @staticmethod - def build_model(*args, **kwargs) -> nn.Module: - return timm.create_model(*args, **kwargs) - - -timm_models = timm.list_models() -for name in timm_models: - BACKBONES.register_module(f'timm_{name}', module=Timm.build_model) +for name in models: + BACKBONES.register_module( + f'timm_{name}', + module=partial(timm.create_model, model_name=name), + is_model_builder=True, + ) diff --git a/chameleon/registry/registry.py b/chameleon/registry/registry.py index e092107..37a459f 100644 --- a/chameleon/registry/registry.py +++ b/chameleon/registry/registry.py @@ -21,7 +21,9 @@ def build_from_cfg(cfg: dict, registry: "Registry") -> Any: kwargs = cfg.copy() name = kwargs.pop('name') obj_cls = registry.get(name) - if inspect.isclass(obj_cls) or inspect.ismethod(obj_cls): + is_model_builder = registry.is_model_builder(name) + + if inspect.isclass(obj_cls) or is_model_builder: obj = obj_cls(**kwargs) else: obj = obj_cls @@ -33,6 +35,7 @@ class Registry: def __init__(self, name: str): self._name = name self._module_dict: Dict[str, Type] = dict() + self._type_dict: Dict[str, Type] = dict() def __len__(self): return len(self._module_dict) @@ -73,6 +76,17 @@ def get(self, key: str) -> Optional[Type]: return obj_cls + def is_model_builder(self, key: str) -> bool: + if not isinstance(key, str): + raise TypeError(f'key must be a str, but got {type(key)}') + + is_model_builder = self._type_dict.get(key, None) + + if is_model_builder is None: + raise KeyError(f'{key} is not in the {self.name} registry') + + return is_model_builder + def build(self, cfg: dict) -> Any: return build_from_cfg(cfg, registry=self) @@ -80,7 +94,8 @@ def _register_module( self, module: Type, module_name: Optional[Union[str, List[str]]] = None, - force: bool = False + force: bool = False, + is_model_builder: bool = False, ) -> None: if not callable(module): raise TypeError(f'module must be a callable, but got {type(module)}') @@ -94,12 +109,14 @@ def _register_module( existed_module = self.module_dict[name] raise KeyError(f'{name} is already registered in {self.name} at {existed_module.__module__}') self._module_dict[name] = module + self._type_dict[name] = is_model_builder def register_module( self, name: str = None, force: bool = False, module: Optional[Type] = None, + is_model_builder: bool = False, ) -> Union[type, Callable]: if not (name is None or isinstance(name, str)): @@ -110,12 +127,12 @@ def register_module( # use it as a normal method: x.register_module(module=SomeClass) if module is not None: - self._register_module(module=module, module_name=name, force=force) + self._register_module(module=module, module_name=name, force=force, is_model_builder=is_model_builder) return module # use it as a decorator: @x.register_module() def _register(module): - self._register_module(module=module, module_name=name, force=force) + self._register_module(module=module, module_name=name, force=force, is_model_builder=is_model_builder) return module return _register