Skip to content

Commit

Permalink
Feat (Channel-Splitting): multiple sources/sinks work with MergeBatch…
Browse files Browse the repository at this point in the history
…Norm
  • Loading branch information
fabianandresgrob committed Dec 5, 2023
1 parent 38ac31e commit b7072f0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
12 changes: 7 additions & 5 deletions src/brevitas/ptq_algorithms/channel_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def _split_channels(
var = running_var[id:id + 1] * split_factor
running_var = torch.cat((running_var[:id], var, var, running_var[id + 1:]))

module.num_features += 1

if bias is not None and not split_input:
channel = bias[id:id + 1] * split_factor
bias = torch.cat((bias[:id], channel, channel, bias[id + 1:]))
Expand All @@ -133,14 +135,14 @@ def _split_channels_region(
# splitting output channels
# concat all channels that are split so we can duplicate those in the input channels later
if not split_input:
for module, channels in modules_to_split.items():
channels = torch.cat(list(modules_to_split.values()))
for module in modules_to_split.keys():
_split_channels(module, channels, grid_aware=grid_aware)
# get all the channels that we have to duplicate
channels_to_duplicate = torch.cat(list(modules_to_split.values()))
channels = torch.cat(list(modules_to_split.values()))
for module in sinks:
# then duplicate the input_channels for all modules in the sink
_split_channels(
module, channels_to_duplicate, grid_aware=False, split_factor=1, split_input=True)
_split_channels(module, channels, grid_aware=False, split_factor=1, split_input=True)
else:
# what if we split input channels of the sinks, which channels of the OC srcs have to duplicated?
for module, channels in modules_to_split.items():
Expand All @@ -159,7 +161,7 @@ def _is_supported(srcs: List[nn.Module], sinks: List[nn.Module]) -> bool:
if len(srcs_ocs) > 1:
return False

# check if ICs of sinks are all equal
# check if ICs of sinks are all equal, what if sinks does not have IC?
sinks_ics = set(module.weight.shape[1] for module in sinks)
if len(sinks_ics) > 1:
return False
Expand Down
5 changes: 4 additions & 1 deletion tests/brevitas/graph/test_channel_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torchvision import models

from brevitas.fx import symbolic_trace
from brevitas.graph.equalize import _extract_regions
from brevitas.graph.fixed_point import MergeBatchNorm
from brevitas.ptq_algorithms.channel_splitting import *

from .equalization_fixtures import *
Expand All @@ -18,6 +18,9 @@ def test_resnet18():
expected_out = model(inp)
model = symbolic_trace(model)

# merge BN before applying channel splitting
model = MergeBatchNorm().apply(model)

model = ChannelSplitting(split_ratio=0.1).apply(model)
out = model(inp)
assert torch.allclose(expected_out, out, atol=ATOL)
Expand Down

0 comments on commit b7072f0

Please sign in to comment.