Skip to content

Commit

Permalink
[A] Add registry feature
Browse files Browse the repository at this point in the history
  • Loading branch information
kunkunlin1221 committed Dec 23, 2024
1 parent 090cb26 commit 6818fa2
Show file tree
Hide file tree
Showing 41 changed files with 611 additions and 383 deletions.
1 change: 1 addition & 0 deletions chameleon/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .base import *
from .metrics import *
from .modules import *
from .registry import *
from .tools import *

__version__ = '0.1.0'
10 changes: 5 additions & 5 deletions chameleon/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .blocks import build_block, list_blocks
from .components import build_component, list_components
from .layers import build_layer, list_layers
from .optim import (build_lr_scheduler, build_optimizer, list_lr_schedulers,
list_optimizers)
from .blocks import *
from .components import *
from .layers import *
from .ops import *
from .optim import *
from .power_module import PowerModule
from .utils import (has_children, initialize_weights_, replace_module,
replace_module_attr_value)
24 changes: 11 additions & 13 deletions chameleon/base/blocks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import fnmatch

from .conv_block import Conv2dBlock, SeparableConv2dBlock

# from .mamba_block import build_mamba_block
# from .vit_block import build_vit_block


def build_block(name, **kwargs):
cls = globals().get(name, None)
if cls is None:
raise ValueError(f'Block named {name} is not support.')
return cls(**kwargs)
# def build_block(name, **kwargs):
# cls = globals().get(name, None)
# if cls is None:
# raise ValueError(f'Block named {name} is not support.')
# return cls(**kwargs)


def list_blocks(filter=''):
block_list = [k for k in globals().keys() if 'Block' in k]
if len(filter):
return fnmatch.filter(block_list, filter) # include these blocks
else:
return block_list
# def list_blocks(filter=''):
# block_list = [k for k in globals().keys() if 'Block' in k]
# if len(filter):
# return fnmatch.filter(block_list, filter) # include these blocks
# else:
# return block_list
43 changes: 22 additions & 21 deletions chameleon/base/blocks/conv_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import torch
import torch.nn as nn

from ..components import build_component
from ...registry import BLOCKS, COMPONENTS
from ..power_module import PowerModule


@BLOCKS.register_module()
class SeparableConv2dBlock(PowerModule):

def __init__(
Expand All @@ -17,10 +18,10 @@ def __init__(
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 1,
bias: bool = False,
inner_norm: Optional[Union[dict, nn.Module]] = None,
inner_act: Optional[Union[dict, nn.Module]] = None,
norm: Optional[Union[dict, nn.Module]] = None,
act: Optional[Union[dict, nn.Module]] = None,
inner_norm: Optional[dict] = None,
inner_act: Optional[dict] = None,
norm: Optional[dict] = None,
act: Optional[dict] = None,
init_type: str = 'normal',
):
"""
Expand All @@ -41,13 +42,13 @@ def __init__(
Whether to include a bias term in the convolutional layer.
Noted: if normalization layer is not None, bias will always be set to False.
Defaults to False.
inner_norm (dict or nn.Module, optional):
inner_norm (dict, optional):
Configuration of normalization layer between dw and pw layer. Defaults to None.
inner_act (dict or nn.Module, optional):
inner_act (dict, optional):
Configuration of activation layer between dw and pw layer. Defaults to None.
norm (dict or nn.Module, optional):
norm (dict, optional):
Configuration of normalization layer after pw layer. Defaults to None.
act (dict or nn.Module, optional):
act (dict, optional):
Configuration of activation layer after pw layer. Defaults to None.
init_type (str, optional):
Initialization method for the model parameters. Defaults to 'normal'.
Expand Down Expand Up @@ -77,13 +78,13 @@ def __init__(
bias=bias,
)
if inner_norm is not None:
self.block['inner_norm'] = build_component(**inner_norm) if isinstance(inner_norm, dict) else inner_norm
self.block['inner_norm'] = COMPONENTS.build(inner_norm) if isinstance(inner_norm, dict) else inner_norm
if inner_act is not None:
self.block['inner_act'] = build_component(**inner_act) if isinstance(inner_act, dict) else inner_act
self.block['inner_act'] = COMPONENTS.build(inner_act) if isinstance(inner_act, dict) else inner_act
if norm is not None:
self.block['norm'] = build_component(**norm) if isinstance(norm, dict) else norm
self.block['norm'] = COMPONENTS.build(norm) if isinstance(norm, dict) else norm
if act is not None:
self.block['act'] = build_component(**act) if isinstance(act, dict) else act
self.block['act'] = COMPONENTS.build(act) if isinstance(act, dict) else act
self.initialize_weights_(init_type)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -92,8 +93,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


@BLOCKS.register_module()
class Conv2dBlock(PowerModule):

def __init__(
self,
in_channels: Union[float, int],
Expand All @@ -105,8 +106,8 @@ def __init__(
groups: int = 1,
bias: bool = False,
padding_mode: str = 'zeros',
norm: Union[dict, nn.Module] = None,
act: Union[dict, nn.Module] = None,
norm: Optional[dict] = None,
act: Optional[dict] = None,
init_type: str = 'normal',
):
"""
Expand Down Expand Up @@ -136,13 +137,13 @@ def __init__(
padding_mode (str, optional):
Options = {'zeros', 'reflect', 'replicate', 'circular'}.
Defaults to 'zeros'.
norm (Union[dict, nn.Module], optional):
norm (Optional[dict], optional):
normalization layer or a dictionary of arguments for building a
normalization layer. Default to None.
act (Union[dict, nn.Module], optional):
act (Optional[dict], optional):
Activation function or a dictionary of arguments for building an
activation function. Default to None.
pool (Union[dict, nn.Module], optional):
pool (Optional[dict], optional):
pooling layer or a dictionary of arguments for building a pooling
layer. Default to None.
init_type (str):
Expand Down Expand Up @@ -180,9 +181,9 @@ def __init__(
padding_mode=padding_mode,
)
if norm is not None:
self.block['norm'] = build_component(**norm) if isinstance(norm, dict) else norm
self.block['norm'] = COMPONENTS.build(norm) if isinstance(norm, dict) else norm
if act is not None:
self.block['act'] = build_component(**act) if isinstance(act, dict) else act
self.block['act'] = COMPONENTS.build(act) if isinstance(act, dict) else act

self.initialize_weights_(init_type)

Expand Down
24 changes: 0 additions & 24 deletions chameleon/base/components/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,5 @@
import fnmatch

from .activation import *
from .dropout import *
from .loss import *
from .norm import *
from .pooling import *


def build_component_cls(name):
cls = globals().get(name, None)
if cls is None:
raise ValueError(f'Component named {name} is not support.')
return cls


def build_component(name, **options):
cls = globals().get(name, None)
if cls is None:
raise ValueError(f'Component named {name} is not support.')
return cls(**options)


def list_components(filter=''):
component_list = [k for k in globals().keys() if 'Component' in k]
if len(filter):
return fnmatch.filter(component_list, filter) # include these components
else:
return component_list
8 changes: 6 additions & 2 deletions chameleon/base/components/activation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Union

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -12,6 +10,8 @@
Softshrink, Softsign, Tanh,
Tanhshrink, Threshold)

from ...registry import COMPONENTS

__all__ = [
'Swish', 'Hsigmoid', 'Hswish', 'StarReLU', 'SquaredReLU',
]
Expand Down Expand Up @@ -95,3 +95,7 @@ def forward(self, x):

# Ref: https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html
Swish = nn.SiLU


for k in __all__:
COMPONENTS.register_module(name=k, module=globals()[k])
7 changes: 6 additions & 1 deletion chameleon/base/components/dropout.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import torch.nn as nn
from torch.nn import AlphaDropout, Dropout, Dropout2d, Dropout3d

from ...registry import COMPONENTS

__all__ = [
'Dropout', 'Dropout2d', 'Dropout3d', 'AlphaDropout',
]


for k in __all__:
COMPONENTS.register_module(name=k, module=globals()[k])
7 changes: 6 additions & 1 deletion chameleon/base/components/loss.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import math
from typing import Union

import torch
import torch.nn as nn
from torch.nn.modules.loss import (BCELoss, BCEWithLogitsLoss,
CrossEntropyLoss, CTCLoss, KLDivLoss,
L1Loss, MSELoss, SmoothL1Loss)

from ...registry import COMPONENTS

__all__ = [
'AWingLoss', 'WeightedAWingLoss',
'BCELoss', 'BCEWithLogitsLoss', 'CrossEntropyLoss',
Expand Down Expand Up @@ -139,3 +140,7 @@ def dice_loss(self, input, target):
def forward(self, input, target):
dice_loss = self.dice_loss(input, target)
return torch.log(torch.cosh(dice_loss))


for k in __all__:
COMPONENTS.register_module(name=k, module=globals()[k])
6 changes: 6 additions & 0 deletions chameleon/base/components/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from torch.nn.modules.normalization import (CrossMapLRN2d, GroupNorm,
LayerNorm, LocalResponseNorm)

from ...registry import COMPONENTS

__all__ = [
'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'SyncBatchNorm', 'InstanceNorm1d',
'InstanceNorm2d', 'InstanceNorm3d', 'CrossMapLRN2d', 'GroupNorm', 'LayerNorm',
Expand Down Expand Up @@ -42,3 +44,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
self.weight, self.bias, self.eps)
x = x.permute(0, 3, 1, 2)
return x


for k in __all__:
COMPONENTS.register_module(name=k, module=globals()[k])
6 changes: 6 additions & 0 deletions chameleon/base/components/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
AvgPool1d, AvgPool2d, AvgPool3d,
MaxPool1d, MaxPool2d, MaxPool3d)

from ...registry import COMPONENTS

__all__ = [
'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d',
'MaxPool2d', 'MaxPool3d', 'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d',
Expand Down Expand Up @@ -44,3 +46,7 @@ def __init__(self):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply global max pooling on the input tensor."""
return self.pool(x)


for k in __all__:
COMPONENTS.register_module(name=k, module=globals()[k])
17 changes: 0 additions & 17 deletions chameleon/base/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,5 @@
import fnmatch

from .aspp import ASPP
from .grl import GradientReversalLayer
from .selayer import SELayer
from .vae import VAE
from .weighted_sum import WeightedSum


def build_layer(name, **options):
cls = globals().get(name, None)
if cls is None:
raise ValueError(f'Layer named {name} is not support.')
return cls(**options)


def list_layers(filter=''):
layer_list = [k for k in globals().keys() if 'Layer' in k]
if len(filter):
return fnmatch.filter(layer_list, filter) # include these layers
else:
return layer_list
Loading

0 comments on commit 6818fa2

Please sign in to comment.