Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[A] Add registry feature #4

Merged
merged 2 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions chameleon/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .base import *
from .metrics import *
from .modules import *
from .registry import *
from .tools import *

__version__ = '0.1.0'
10 changes: 5 additions & 5 deletions chameleon/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .blocks import build_block, list_blocks
from .components import build_component, list_components
from .layers import build_layer, list_layers
from .optim import (build_lr_scheduler, build_optimizer, list_lr_schedulers,
list_optimizers)
from .blocks import *
from .components import *
from .layers import *
from .ops import *
from .optim import *
from .power_module import PowerModule
from .utils import (has_children, initialize_weights_, replace_module,
replace_module_attr_value)
24 changes: 11 additions & 13 deletions chameleon/base/blocks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import fnmatch

from .conv_block import Conv2dBlock, SeparableConv2dBlock

# from .mamba_block import build_mamba_block
# from .vit_block import build_vit_block


def build_block(name, **kwargs):
cls = globals().get(name, None)
if cls is None:
raise ValueError(f'Block named {name} is not support.')
return cls(**kwargs)
# def build_block(name, **kwargs):
# cls = globals().get(name, None)
# if cls is None:
# raise ValueError(f'Block named {name} is not support.')
# return cls(**kwargs)


def list_blocks(filter=''):
block_list = [k for k in globals().keys() if 'Block' in k]
if len(filter):
return fnmatch.filter(block_list, filter) # include these blocks
else:
return block_list
# def list_blocks(filter=''):
# block_list = [k for k in globals().keys() if 'Block' in k]
# if len(filter):
# return fnmatch.filter(block_list, filter) # include these blocks
# else:
# return block_list
43 changes: 22 additions & 21 deletions chameleon/base/blocks/conv_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import torch
import torch.nn as nn

from ..components import build_component
from ...registry import BLOCKS, COMPONENTS
from ..power_module import PowerModule


@BLOCKS.register_module()
class SeparableConv2dBlock(PowerModule):

def __init__(
Expand All @@ -17,10 +18,10 @@ def __init__(
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 1,
bias: bool = False,
inner_norm: Optional[Union[dict, nn.Module]] = None,
inner_act: Optional[Union[dict, nn.Module]] = None,
norm: Optional[Union[dict, nn.Module]] = None,
act: Optional[Union[dict, nn.Module]] = None,
inner_norm: Optional[dict] = None,
inner_act: Optional[dict] = None,
norm: Optional[dict] = None,
act: Optional[dict] = None,
init_type: str = 'normal',
):
"""
Expand All @@ -41,13 +42,13 @@ def __init__(
Whether to include a bias term in the convolutional layer.
Noted: if normalization layer is not None, bias will always be set to False.
Defaults to False.
inner_norm (dict or nn.Module, optional):
inner_norm (dict, optional):
Configuration of normalization layer between dw and pw layer. Defaults to None.
inner_act (dict or nn.Module, optional):
inner_act (dict, optional):
Configuration of activation layer between dw and pw layer. Defaults to None.
norm (dict or nn.Module, optional):
norm (dict, optional):
Configuration of normalization layer after pw layer. Defaults to None.
act (dict or nn.Module, optional):
act (dict, optional):
Configuration of activation layer after pw layer. Defaults to None.
init_type (str, optional):
Initialization method for the model parameters. Defaults to 'normal'.
Expand Down Expand Up @@ -77,13 +78,13 @@ def __init__(
bias=bias,
)
if inner_norm is not None:
self.block['inner_norm'] = build_component(**inner_norm) if isinstance(inner_norm, dict) else inner_norm
self.block['inner_norm'] = COMPONENTS.build(inner_norm) if isinstance(inner_norm, dict) else inner_norm
if inner_act is not None:
self.block['inner_act'] = build_component(**inner_act) if isinstance(inner_act, dict) else inner_act
self.block['inner_act'] = COMPONENTS.build(inner_act) if isinstance(inner_act, dict) else inner_act
if norm is not None:
self.block['norm'] = build_component(**norm) if isinstance(norm, dict) else norm
self.block['norm'] = COMPONENTS.build(norm) if isinstance(norm, dict) else norm
if act is not None:
self.block['act'] = build_component(**act) if isinstance(act, dict) else act
self.block['act'] = COMPONENTS.build(act) if isinstance(act, dict) else act
self.initialize_weights_(init_type)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -92,8 +93,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


@BLOCKS.register_module()
class Conv2dBlock(PowerModule):

def __init__(
self,
in_channels: Union[float, int],
Expand All @@ -105,8 +106,8 @@ def __init__(
groups: int = 1,
bias: bool = False,
padding_mode: str = 'zeros',
norm: Union[dict, nn.Module] = None,
act: Union[dict, nn.Module] = None,
norm: Optional[dict] = None,
act: Optional[dict] = None,
init_type: str = 'normal',
):
"""
Expand Down Expand Up @@ -136,13 +137,13 @@ def __init__(
padding_mode (str, optional):
Options = {'zeros', 'reflect', 'replicate', 'circular'}.
Defaults to 'zeros'.
norm (Union[dict, nn.Module], optional):
norm (Optional[dict], optional):
normalization layer or a dictionary of arguments for building a
normalization layer. Default to None.
act (Union[dict, nn.Module], optional):
act (Optional[dict], optional):
Activation function or a dictionary of arguments for building an
activation function. Default to None.
pool (Union[dict, nn.Module], optional):
pool (Optional[dict], optional):
pooling layer or a dictionary of arguments for building a pooling
layer. Default to None.
init_type (str):
Expand Down Expand Up @@ -180,9 +181,9 @@ def __init__(
padding_mode=padding_mode,
)
if norm is not None:
self.block['norm'] = build_component(**norm) if isinstance(norm, dict) else norm
self.block['norm'] = COMPONENTS.build(norm) if isinstance(norm, dict) else norm
if act is not None:
self.block['act'] = build_component(**act) if isinstance(act, dict) else act
self.block['act'] = COMPONENTS.build(act) if isinstance(act, dict) else act

self.initialize_weights_(init_type)

Expand Down
24 changes: 0 additions & 24 deletions chameleon/base/components/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,5 @@
import fnmatch

from .activation import *
from .dropout import *
from .loss import *
from .norm import *
from .pooling import *


def build_component_cls(name):
cls = globals().get(name, None)
if cls is None:
raise ValueError(f'Component named {name} is not support.')
return cls


def build_component(name, **options):
cls = globals().get(name, None)
if cls is None:
raise ValueError(f'Component named {name} is not support.')
return cls(**options)


def list_components(filter=''):
component_list = [k for k in globals().keys() if 'Component' in k]
if len(filter):
return fnmatch.filter(component_list, filter) # include these components
else:
return component_list
8 changes: 6 additions & 2 deletions chameleon/base/components/activation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Union

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -12,6 +10,8 @@
Softshrink, Softsign, Tanh,
Tanhshrink, Threshold)

from ...registry import COMPONENTS

__all__ = [
'Swish', 'Hsigmoid', 'Hswish', 'StarReLU', 'SquaredReLU',
]
Expand Down Expand Up @@ -95,3 +95,7 @@ def forward(self, x):

# Ref: https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html
Swish = nn.SiLU


for k in __all__:
COMPONENTS.register_module(name=k, module=globals()[k])
7 changes: 6 additions & 1 deletion chameleon/base/components/dropout.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import torch.nn as nn
from torch.nn import AlphaDropout, Dropout, Dropout2d, Dropout3d

from ...registry import COMPONENTS

__all__ = [
'Dropout', 'Dropout2d', 'Dropout3d', 'AlphaDropout',
]


for k in __all__:
COMPONENTS.register_module(name=k, module=globals()[k])
7 changes: 6 additions & 1 deletion chameleon/base/components/loss.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import math
from typing import Union

import torch
import torch.nn as nn
from torch.nn.modules.loss import (BCELoss, BCEWithLogitsLoss,
CrossEntropyLoss, CTCLoss, KLDivLoss,
L1Loss, MSELoss, SmoothL1Loss)

from ...registry import COMPONENTS

__all__ = [
'AWingLoss', 'WeightedAWingLoss',
'BCELoss', 'BCEWithLogitsLoss', 'CrossEntropyLoss',
Expand Down Expand Up @@ -139,3 +140,7 @@ def dice_loss(self, input, target):
def forward(self, input, target):
dice_loss = self.dice_loss(input, target)
return torch.log(torch.cosh(dice_loss))


for k in __all__:
COMPONENTS.register_module(name=k, module=globals()[k])
6 changes: 6 additions & 0 deletions chameleon/base/components/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from torch.nn.modules.normalization import (CrossMapLRN2d, GroupNorm,
LayerNorm, LocalResponseNorm)

from ...registry import COMPONENTS

__all__ = [
'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'SyncBatchNorm', 'InstanceNorm1d',
'InstanceNorm2d', 'InstanceNorm3d', 'CrossMapLRN2d', 'GroupNorm', 'LayerNorm',
Expand Down Expand Up @@ -42,3 +44,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
self.weight, self.bias, self.eps)
x = x.permute(0, 3, 1, 2)
return x


for k in __all__:
COMPONENTS.register_module(name=k, module=globals()[k])
6 changes: 6 additions & 0 deletions chameleon/base/components/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
AvgPool1d, AvgPool2d, AvgPool3d,
MaxPool1d, MaxPool2d, MaxPool3d)

from ...registry import COMPONENTS

__all__ = [
'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d',
'MaxPool2d', 'MaxPool3d', 'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d',
Expand Down Expand Up @@ -44,3 +46,7 @@ def __init__(self):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply global max pooling on the input tensor."""
return self.pool(x)


for k in __all__:
COMPONENTS.register_module(name=k, module=globals()[k])
17 changes: 0 additions & 17 deletions chameleon/base/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,5 @@
import fnmatch

from .aspp import ASPP
from .grl import GradientReversalLayer
from .selayer import SELayer
from .vae import VAE
from .weighted_sum import WeightedSum


def build_layer(name, **options):
cls = globals().get(name, None)
if cls is None:
raise ValueError(f'Layer named {name} is not support.')
return cls(**options)


def list_layers(filter=''):
layer_list = [k for k in globals().keys() if 'Layer' in k]
if len(filter):
return fnmatch.filter(layer_list, filter) # include these layers
else:
return layer_list
Loading