diff --git a/chameleon/modules/backbones/__init__.py b/chameleon/modules/backbones/__init__.py index 4c62191..ff94830 100644 --- a/chameleon/modules/backbones/__init__.py +++ b/chameleon/modules/backbones/__init__.py @@ -1,14 +1,4 @@ -from functools import partial - -import timm - -from ...registry import BACKBONES from .gpunet import GPUNet - -timm_models = timm.list_models() -for name in timm_models: - create_func = partial(timm.create_model, model_name=name) - BACKBONES.register_module(f'timm_{name}', module=create_func) - +from .timm import * __all__ = ['GPUNet'] diff --git a/chameleon/modules/backbones/timm.py b/chameleon/modules/backbones/timm.py new file mode 100644 index 0000000..c17ff87 --- /dev/null +++ b/chameleon/modules/backbones/timm.py @@ -0,0 +1,15 @@ +import timm +import torch.nn as nn + +from ...registry import BACKBONES + + +class Timm: + @staticmethod + def build_model(*args, **kwargs) -> nn.Module: + return timm.create_model(*args, **kwargs) + + +timm_models = timm.list_models() +for name in timm_models: + BACKBONES.register_module(f'timm_{name}', module=Timm.build_model) diff --git a/tests/registry/test_root.py b/tests/registry/test_root.py index 28f952b..e15a3a0 100644 --- a/tests/registry/test_root.py +++ b/tests/registry/test_root.py @@ -1,4 +1,5 @@ import torch +import torch.nn as nn from torchmetrics.metric import Metric from chameleon import ASPP, FPN, AdamW, AWingLoss, Conv2dBlock, GPUNet diff --git a/tests/tools/test_calflops.py b/tests/tools/test_calflops.py index d5c92e2..17e1635 100644 --- a/tests/tools/test_calflops.py +++ b/tests/tools/test_calflops.py @@ -5,7 +5,8 @@ def test_calcualte_flops(): - model = BACKBONES.build({'name': 'timm_resnet50'}) + timm_create = BACKBONES.build({'name': 'timm_resnet50'}) + model = timm_create(model_name='resnet50') flops, macs, params = calculate_flops(model, (1, 3, 224, 224)) assert flops == '8.21 GFLOPS' assert macs == '4.09 GMACs'