Skip to content

Commit

Permalink
Feat (Channel-Splitting): add filter for groupwise conv and more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Jan 19, 2024
1 parent 6eb89b2 commit 029c2a0
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 40 deletions.
84 changes: 57 additions & 27 deletions src/brevitas/ptq_algorithms/channel_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,52 +7,34 @@

from brevitas.fx import GraphModule
from brevitas.graph.base import GraphTransform
from brevitas.graph.equalize import _batch_norm
from brevitas.graph.equalize import _channel_maxabs
from brevitas.graph.equalize import _extract_regions
from brevitas.graph.equalize import _get_input_axis
from brevitas.graph.equalize import _get_output_axis
from brevitas.graph.equalize import Region
from brevitas.graph.equalize import transpose

_conv = (
nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)


def old_maxabs(module, splits_per_layer, split_input):
# works for Conv2d and Linear
dim = 1 - int(split_input)
if isinstance(module, nn.Conv2d):
if not split_input:
# gets the max value for each output channel
max_per_channel = module.weight.data.abs().flatten(1).max(1).values
# check if max_per_channel has the same length as output channels
assert len(max_per_channel) == module.weight.shape[0]
else:
# getting max value for each input channel
max_per_channel = module.weight.data.abs().max(0).values.flatten(1).max(1).values
# check if same length as input channels
assert len(max_per_channel) == module.weight.shape[1]
elif isinstance(module, nn.Linear):
max_per_channel = module.weight.data.abs().max(dim=dim).values.flatten()
channels = torch.argsort(max_per_channel, descending=True)
return channels[:splits_per_layer]


def _channels_to_split(
sources: Dict[str, nn.Module],
sinks: Dict[str, nn.Module],
split_criterion: str,
split_ratio: float,
split_input: bool) -> Dict[nn.Module, List[torch.Tensor]]:
"""
This method computes the channels that will be split based on `split_criterion`.
"""
modules = sinks if split_input else sources
_get_axis = _get_input_axis if split_input else _get_output_axis
# the modules are all of the same shape so we can just take the first one
single_module = next(iter(modules.values()))
num_channels = single_module.weight.shape[_get_axis(single_module)]
splits_per_layer = int(math.ceil(split_ratio * num_channels))

if splits_per_layer == 0:
warnings.warn(f'No splits for {modules}, increasing split_ratio could help.')

module_to_channels = {}
if split_criterion == 'maxabs':
for name, module in modules.items():
Expand Down Expand Up @@ -153,7 +135,24 @@ def _split_channels_region(
_split_channels(module, channels_to_split, split_factor=1, split_input=False)


def _is_groupwise(module: nn.Module):
# only Conv layers can be groupwise
return isinstance(module, _conv) and module.groups > 1


def _is_batchnorm(module: nn.Module):
return isinstance(module, _batch_norm)


def _is_supported(srcs: List[nn.Module], sinks: List[nn.Module]) -> bool:
# groupwise convolutions are not supported so filter them out
if any(map(_is_groupwise, srcs + sinks)):
return False

# bn layers aren't allowed
if any(map(_is_batchnorm, sinks + srcs)):
return False

# check if OCs of sources are all equal
srcs_ocs = set(module.weight.shape[_get_output_axis(module)] for module in srcs)
if len(srcs_ocs) > 1:
Expand All @@ -179,7 +178,7 @@ def _split(
sources = {src: region.get_module_from_name(src) for src in region.srcs_names}
sinks = {sink: region.get_module_from_name(sink) for sink in region.sinks_names}

if _is_supported(sources.values(), sinks.values()):
if _is_supported(list(sources.values()), list(sinks.values())):
# get channels to split
channels_to_split = _channels_to_split(
sources=sources,
Expand All @@ -197,6 +196,36 @@ def _split(
return model


def _clean_regions(regions: List[Region]):
"""
This method checks whether the list of regions is compatible with channel splitting.
If a module is in the sinks/sources of multiple regions, these regions will be removed.
"""
# idea: map modules to their regions and check whether it appears in multiple regions
regions_to_del = set()
source_modules = dict()
sink_modules = dict()
for i, region in enumerate(regions):
# add srcs to source_modules
for src in region.srcs_names:
# if not yet in the dict, instantiate new list for keeping track
if src not in source_modules:
source_modules[src] = [i]
else:
# we know the module has been in sources before, so region needs to be deleted
source_modules[src].append(i)
regions_to_del.update({i, *source_modules[src]})
for sink in region.sinks_names:
if sink not in sink_modules:
sink_modules[sink] = [i]
else:
sink_modules[sink].append(i)
regions_to_del.update({i, *sink_modules[sink]})

regions = [regions[i] for i, _ in enumerate(regions) if i not in regions_to_del]
return regions


class RegionwiseChannelSplitting(GraphTransform):

def __init__(self, split_ratio=0.02, split_criterion='maxabs', split_input=True):
Expand All @@ -212,14 +241,15 @@ def apply(
return_regions: bool = False
) -> Union[Tuple[GraphModule, Set[Tuple[str]]], GraphModule]:
regions = _extract_regions(model)
regions = _clean_regions(regions)
if len(regions) > 0:
self.graph_model = _split(
model = _split(
model=model,
regions=regions,
split_ratio=self.split_ratio,
split_criterion=self.split_criterion,
split_input=self.split_input)
if return_regions:
return self.graph_model, regions
return model, regions
else:
return self.graph_model
return model
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def unique(sequence):
'channel_splitting': [False], # Use channel splitting algorithm in preprocessing step
'split_ratio': [0.01], # Split ratio for channel splitting
'split_input': [True], # Split input or output channels for channel splitting
'grid_aware': [False], # Use grid aware channel splitting
'merge_bn': [True],}

parser = argparse.ArgumentParser(description='PyTorch ImageNet PTQ Validation')
Expand Down Expand Up @@ -218,9 +217,7 @@ def ptq_torchvision_models(args):
merge_bn=config_namespace.merge_bn,
channel_splitting=config_namespace.channel_splitting,
channel_splitting_ratio=config_namespace.split_ratio,
channel_splitting_grid_aware=config_namespace.grid_aware,
channel_splitting_split_input=config_namespace.split_input,
channel_splitting_weight_bit_width=config_namespace.weight_bit_width)
channel_splitting_split_input=config_namespace.split_input)
else:
raise RuntimeError(f"{config_namespace.target_backend} backend not supported.")

Expand Down Expand Up @@ -377,10 +374,9 @@ def validate_config(config_namespace):
is_valid = False
if config_namespace.act_exponent_bit_width + config_namespace.act_mantissa_bit_width != config_namespace.act_bit_width - 1:
is_valid = False
# if channel splitting is false, no need for split ratio, grid aware, or iteratively split
# if channel splitting is false, no need for split ratio/input
if not config_namespace.channel_splitting:
config_namespace.split_ratio = None
config_namespace.grid_aware = None
config_namespace.split_input = None

config_namespace.is_valid = is_valid
Expand Down
83 changes: 76 additions & 7 deletions tests/brevitas/graph/test_channel_splitting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pytest_cases import fixture_union
import torch
from torchvision import models

Expand All @@ -8,7 +9,6 @@
from .equalization_fixtures import *


# @pytest.mark.skip()
@pytest.mark.parametrize('split_ratio', [0.05, 0.1, 0.2])
@pytest.mark.parametrize('split_input', [False, True])
def test_resnet18(split_ratio, split_input):
Expand All @@ -19,8 +19,8 @@ def test_resnet18(split_ratio, split_input):

model.eval()
expected_out = model(inp)
model = symbolic_trace(model)

model = symbolic_trace(model)
# merge BN before applying channel splitting
model = MergeBatchNorm().apply(model)

Expand All @@ -30,8 +30,7 @@ def test_resnet18(split_ratio, split_input):
assert torch.allclose(expected_out, out, atol=ATOL)


# @pytest.mark.skip()
@pytest.mark.parametrize('split_ratio', [0.05, 0.1])
@pytest.mark.parametrize('split_ratio', [0.05])
@pytest.mark.parametrize('split_input', [False, True])
def test_alexnet(split_ratio, split_input):
model = models.alexnet(pretrained=True)
Expand All @@ -49,9 +48,9 @@ def test_alexnet(split_ratio, split_input):
assert torch.allclose(expected_out, out, atol=ATOL)


# @pytest.mark.skip()
@pytest.mark.parametrize('split_ratio', [0.05, 0.1, 0.2])
@pytest.mark.parametrize('split_input', [False, True])
def test_transpose_layers(split_input):
def test_transpose_layers(split_ratio, split_input):
torch.manual_seed(SEED)

model = torch.nn.Sequential(
Expand All @@ -63,8 +62,78 @@ def test_transpose_layers(split_input):

model.eval()
expected_out = model(inp)

model = symbolic_trace(model)

model = RegionwiseChannelSplitting(
split_ratio=split_ratio, split_input=split_input).apply(model)
out = model(inp)
assert torch.allclose(expected_out, out, atol=ATOL)


@pytest.mark.parametrize('split_ratio', [0.05, 0.1, 0.2])
@pytest.mark.parametrize('split_input', [False, True])
def test_models(toy_model, split_ratio, split_input, request):
test_id = request.node.callspec.id
if 'mha' in test_id:
pytest.skip('MHA not supported.')
torch.manual_seed(SEED)

model_class = toy_model
model = model_class()
inp = torch.randn(IN_SIZE_CONV)

model.eval()
expected_out = model(inp)

model = symbolic_trace(model)
# merge BN before applying channel splitting
model = MergeBatchNorm().apply(model)

model = RegionwiseChannelSplitting(
split_ratio=split_ratio, split_input=split_input).apply(model)
out = model(inp)
assert torch.allclose(expected_out, out, atol=ATOL)


@pytest_cases.fixture
def convgroupconv_model():

class ConvGroupConvModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(3, 16, kernel_size=3)
self.conv_0 = nn.Conv2d(16, 32, kernel_size=1, groups=2)
self.conv_1 = nn.Conv2d(32, 64, kernel_size=1, groups=4)
self.relu = nn.ReLU()

def forward(self, x):
x = self.conv(x)
x = self.relu(x)
x = self.conv_0(x)
x = self.relu(x)
x = self.conv_1(x)
return x

return ConvGroupConvModel


@pytest.mark.parametrize('split_input', [False, True])
def test_groupwise_models(convgroupconv_model, split_input):
torch.manual_seed(SEED)

model_class = convgroupconv_model
model = model_class()
inp = torch.randn(IN_SIZE_CONV)

model.eval()
expected_out = model(inp)

model = symbolic_trace(model)
# merge BN before applying channel splitting
model = MergeBatchNorm().apply(model)

model = RegionwiseChannelSplitting(split_ratio=0.05, split_input=split_input).apply(model)
model = RegionwiseChannelSplitting(split_ratio=0.1, split_input=split_input).apply(model)
out = model(inp)
assert torch.allclose(expected_out, out, atol=ATOL)

0 comments on commit 029c2a0

Please sign in to comment.