From b7072f0b2d2551349ab2e3e744b558a62913a060 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Tue, 5 Dec 2023 10:29:02 +0000 Subject: [PATCH] Feat (Channel-Splitting): multiple sources/sinks work with MergeBatchNorm --- src/brevitas/ptq_algorithms/channel_splitting.py | 12 +++++++----- tests/brevitas/graph/test_channel_splitting.py | 5 ++++- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/brevitas/ptq_algorithms/channel_splitting.py b/src/brevitas/ptq_algorithms/channel_splitting.py index 7a4333225..70c193393 100644 --- a/src/brevitas/ptq_algorithms/channel_splitting.py +++ b/src/brevitas/ptq_algorithms/channel_splitting.py @@ -110,6 +110,8 @@ def _split_channels( var = running_var[id:id + 1] * split_factor running_var = torch.cat((running_var[:id], var, var, running_var[id + 1:])) + module.num_features += 1 + if bias is not None and not split_input: channel = bias[id:id + 1] * split_factor bias = torch.cat((bias[:id], channel, channel, bias[id + 1:])) @@ -133,14 +135,14 @@ def _split_channels_region( # splitting output channels # concat all channels that are split so we can duplicate those in the input channels later if not split_input: - for module, channels in modules_to_split.items(): + channels = torch.cat(list(modules_to_split.values())) + for module in modules_to_split.keys(): _split_channels(module, channels, grid_aware=grid_aware) # get all the channels that we have to duplicate - channels_to_duplicate = torch.cat(list(modules_to_split.values())) + channels = torch.cat(list(modules_to_split.values())) for module in sinks: # then duplicate the input_channels for all modules in the sink - _split_channels( - module, channels_to_duplicate, grid_aware=False, split_factor=1, split_input=True) + _split_channels(module, channels, grid_aware=False, split_factor=1, split_input=True) else: # what if we split input channels of the sinks, which channels of the OC srcs have to duplicated? for module, channels in modules_to_split.items(): @@ -159,7 +161,7 @@ def _is_supported(srcs: List[nn.Module], sinks: List[nn.Module]) -> bool: if len(srcs_ocs) > 1: return False - # check if ICs of sinks are all equal + # check if ICs of sinks are all equal, what if sinks does not have IC? sinks_ics = set(module.weight.shape[1] for module in sinks) if len(sinks_ics) > 1: return False diff --git a/tests/brevitas/graph/test_channel_splitting.py b/tests/brevitas/graph/test_channel_splitting.py index e07e2c55b..106618c1a 100644 --- a/tests/brevitas/graph/test_channel_splitting.py +++ b/tests/brevitas/graph/test_channel_splitting.py @@ -2,7 +2,7 @@ from torchvision import models from brevitas.fx import symbolic_trace -from brevitas.graph.equalize import _extract_regions +from brevitas.graph.fixed_point import MergeBatchNorm from brevitas.ptq_algorithms.channel_splitting import * from .equalization_fixtures import * @@ -18,6 +18,9 @@ def test_resnet18(): expected_out = model(inp) model = symbolic_trace(model) + # merge BN before applying channel splitting + model = MergeBatchNorm().apply(model) + model = ChannelSplitting(split_ratio=0.1).apply(model) out = model(inp) assert torch.allclose(expected_out, out, atol=ATOL)