Skip to content

Commit

Permalink
Feat (Channel-Splitting): cleans up and adds multiple srcs/sinks support
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Dec 4, 2023
1 parent 959f6df commit 38ac31e
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 55 deletions.
10 changes: 9 additions & 1 deletion src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from brevitas.graph.standardize import TorchFunctionalToModule
from brevitas.nn import quant_layer
import brevitas.nn as qnn
from brevitas.ptq_algorithms.channel_splitting import ChannelSplitting
from brevitas.quant import Int8ActPerTensorFloat
from brevitas.quant import Int8ActPerTensorFloatMinMaxInit
from brevitas.quant import Int8WeightPerTensorFloat
Expand Down Expand Up @@ -263,7 +264,10 @@ def preprocess_for_quantize(
equalize_merge_bias=True,
merge_bn=True,
equalize_bias_shrinkage: str = 'vaiq',
equalize_scale_computation: str = 'maxabs'):
equalize_scale_computation: str = 'maxabs',
channel_splitting=False,
channel_splitting_ratio=0.02,
channel_splitting_criterion: str = 'maxabs'):

training_state = model.training
model.eval()
Expand All @@ -285,6 +289,10 @@ def preprocess_for_quantize(
merge_bias=equalize_merge_bias,
bias_shrinkage=equalize_bias_shrinkage,
scale_computation_type=equalize_scale_computation).apply(model)
if channel_splitting:
model = ChannelSplitting(
split_ratio=channel_splitting_ratio,
split_criterion=channel_splitting_criterion).apply(model)
model.train(training_state)
return model

Expand Down
121 changes: 74 additions & 47 deletions src/brevitas/ptq_algorithms/channel_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from brevitas.graph.base import GraphTransform
from brevitas.graph.equalize import _extract_regions

_batch_norm = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)


def _channels_maxabs(module, splits_per_layer, split_input):
# works for Conv2d and Linear
Expand All @@ -27,8 +29,12 @@ def _channels_maxabs(module, splits_per_layer, split_input):


def _channels_to_split(
modules: List[nn.Module], split_criterion: str, split_ratio: float,
sources: List[nn.Module],
sinks: List[nn.Module],
split_criterion: str,
split_ratio: float,
split_input: bool) -> Dict[nn.Module, List[torch.Tensor]]:
modules = sinks if split_input else sources
# the modules are all of the same shape so we can just take the first one
num_channels = modules[0].weight.shape[int(split_input)]
total_splits = int(math.ceil(split_ratio * num_channels))
Expand All @@ -50,75 +56,101 @@ def _channels_to_split(


def _split_channels(
layer, channels_to_split, grid_aware=True, split_input=False, split_factor=0.5) -> None:
module, channels_to_split, grid_aware=True, split_input=False, split_factor=0.5) -> None:
"""
Splits the channels `channels_to_split` of the `weights`.
`split_input` specifies whether to split Input or Output channels.
Can also be used to duplicate a channel, just set split_factor to 1.
Returns: None
"""
# change it to .data attribute
weight = layer.weight.data
bias = layer.bias.data
weight = module.weight.data
bias = module.bias.data if module.bias is not None else None
if isinstance(module, _batch_norm):
running_mean = module.running_mean.data
running_var = module.running_var.data

for id in channels_to_split:
if isinstance(layer, torch.nn.Conv2d):
if isinstance(module, torch.nn.Conv2d):
# there are four dimensions: [OC, IC, k, k]
if split_input:
channel = weight[:, id:id + 1, :, :] * split_factor
weight = torch.cat(
(weight[:, :id, :, :], channel, channel, weight[:, id + 1:, :, :]), dim=1)
layer.in_channels += 1
module.in_channels += 1
else:
# split output
channel = weight[id:id + 1, :, :, :] * split_factor
# duplicate channel
weight = torch.cat(
(weight[:id, :, :, :], channel, channel, weight[id + 1:, :, :, :]), dim=0)
layer.out_channels += 1
module.out_channels += 1

elif isinstance(layer, torch.nn.Linear):
elif isinstance(module, torch.nn.Linear):
# there are two dimensions: [OC, IC]
if split_input:
# simply duplicate channel
channel = weight[:, id:id + 1] * split_factor
weight = torch.cat((weight[:, :id], channel, channel, weight[:, id + 1:]), dim=1)
layer.in_features += 1
module.in_features += 1
else:
# split output
channel = weight[id:id + 1, :] * split_factor
weight = torch.cat((weight[:id, :], channel, channel, weight[id + 1:, :]), dim=0)
layer.out_features += 1
module.out_features += 1

elif isinstance(module, _batch_norm):
# bach norm is 1d
channel = weight[id:id + 1] * split_factor
weight = torch.cat((weight[:id], channel, channel, weight[id + 1:]))
# also split running_mean and running_var
mean = running_mean[id:id + 1] * split_factor
running_mean = torch.cat((running_mean[:id], mean, mean, running_mean[id + 1:]))

var = running_var[id:id + 1] * split_factor
running_var = torch.cat((running_var[:id], var, var, running_var[id + 1:]))

if bias is not None and not split_input:
# also split bias
channel = layer.bias.data[id:id + 1] * split_factor
channel = bias[id:id + 1] * split_factor
bias = torch.cat((bias[:id], channel, channel, bias[id + 1:]))
layer.bias.data = bias

# setting the weights as the new data
layer.weight.data = weight
module.weight.data = weight
if bias is not None:
module.bias.data = bias
if isinstance(module, _batch_norm):
module.running_mean.data = running_mean
module.running_var.data = running_var


def _split_channels_region(
module_to_split: Dict[nn.Module, torch.tensor],
modules_to_duplicate: [nn.Module],
sources: List[nn.Module],
sinks: List[nn.Module],
modules_to_split: Dict[nn.Module, torch.tensor],
split_input: bool,
grid_aware: bool = False):
# we are getting a dict[Module, channels to split]
# splitting output channels
# concat all channels that are split so we can duplicate those in the input channels later
if not split_input:
input_channels = torch.cat(list(module_to_split.values()))
for module, channels in module_to_split.items():
for module, channels in modules_to_split.items():
_split_channels(module, channels, grid_aware=grid_aware)
for module in modules_to_duplicate:
# get all the channels that we have to duplicate
channels_to_duplicate = torch.cat(list(modules_to_split.values()))
for module in sinks:
# then duplicate the input_channels for all modules in the sink
_split_channels(
module, input_channels, grid_aware=False, split_factor=1, split_input=True)
module, channels_to_duplicate, grid_aware=False, split_factor=1, split_input=True)
else:
# what if we split input channels of the sinks, which channels of the OC srcs have to duplicated?
pass
for module, channels in modules_to_split.items():
_split_channels(module, channels, grid_aware=grid_aware)
# TODO duplicating the channels in the output channels of the sources could be tricky
channels_to_duplicate = torch.cat(list(modules_to_split.values()))
for module in sources:
# then duplicate the input_channels for all modules in the sink
_split_channels(
module, channels_to_duplicate, grid_aware=False, split_factor=1, split_input=False)


def _is_supported(srcs: List[nn.Module], sinks: List[nn.Module]) -> bool:
Expand Down Expand Up @@ -154,56 +186,51 @@ def _split(
if name in name_set:
name_to_module[name] = module

for region in regions:
for i, region in enumerate(regions):

# check if region is suitable for channel splitting
srcs = [name_to_module[n] for n in region.srcs]
sources = [name_to_module[n] for n in region.srcs]
sinks = [name_to_module[n] for n in region.sinks]

if _is_supported(srcs, sinks):

if _is_supported(sources, sinks):
# problem: if region[0] has bn modules as sources, we split them but the input from prev layer is not the correct shape anymore! So we need to skip the first region in case that happens
# get channels to split
if split_input:
# we will have
mod_to_channels = _channels_to_split(
sinks, split_criterion, split_ratio, split_input)
else:
mod_to_channels = _channels_to_split(srcs, split_criterion, split_ratio, False)
_split_channels_region(
module_to_split=mod_to_channels,
modules_to_duplicate=sinks,
split_input=split_input)

# now splits those channels that we just selected!
modules_to_split = _channels_to_split(
sources=sources,
sinks=sinks,
split_criterion=split_criterion,
split_ratio=split_ratio,
split_input=split_input)
# splitting/duplicating channels
_split_channels_region(
sources=sources,
sinks=sinks,
modules_to_split=modules_to_split,
split_input=split_input)

return model


class ChannelSplitting(GraphTransform):

def __init__(
self,
model,
split_ratio=0.02,
split_criterion='maxabs',
grid_aware=False,
split_input=False):
self, split_ratio=0.02, split_criterion='maxabs', grid_aware=False, split_input=False):
super(ChannelSplitting, self).__init__()
self.graph_model = model
self.grid_aware = grid_aware

self.grid_aware = grid_aware
self.split_ratio = split_ratio
self.split_criterion = split_criterion
self.split_input = split_input

def apply(
self,
model,
return_regions: bool = False
) -> Union[Tuple[GraphModule, Set[Tuple[str]]], GraphModule]:
regions = _extract_regions(self.graph_model)
regions = _extract_regions(model)
if len(regions) > 0:
self.graph_model = _split(
model=self.graph_model,
model=model,
regions=regions,
split_ratio=self.split_ratio,
split_criterion=self.split_criterion,
Expand Down
21 changes: 18 additions & 3 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,11 @@
default=3,
type=int,
help='Exponent bit width used with float quantization for activations (default: 3)')
parser.add_argument(
'--split-ratio',
default=0.02,
type=float,
help='Split Ratio for Channel Splitting (default: 0.02)')
add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)')
add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)')
add_bool_arg(
Expand All @@ -215,6 +220,11 @@
parser, 'gpfq-act-order', default=False, help='GPFQ Act order heuristic (default: disabled)')
add_bool_arg(parser, 'learned-round', default=False, help='Learned round (default: disabled)')
add_bool_arg(parser, 'calibrate-bn', default=False, help='Calibrate BN (default: disabled)')
add_bool_arg(
parser,
'channel-splitting',
default=False,
help='Apply Channel Splitting before Quantization (default: disabled)')


def main():
Expand Down Expand Up @@ -254,7 +264,8 @@ def main():
f"{'mb_' if args.graph_eq_merge_bias else ''}"
f"{act_quant_calib_config}_"
f"{args.weight_quant_calibration_type}_"
f"{'bnc' if args.calibrate_bn else ''}")
f"{'bnc_' if args.calibrate_bn else ''}"
f"{'channel_splitting' if args.channel_splitting else ''}")

print(
f"Model: {args.model_name} - "
Expand All @@ -277,7 +288,9 @@ def main():
f"Merge bias in graph equalization: {args.graph_eq_merge_bias} - "
f"Activation quant calibration type: {act_quant_calib_config} - "
f"Weight quant calibration type: {args.weight_quant_calibration_type} - "
f"Calibrate BN: {args.calibrate_bn}")
f"Calibrate BN: {args.calibrate_bn} - "
f"Channel Splitting: {args.channel_splitting} - "
f"Split Ratio: {args.split_ratio} - ")

# Get model-specific configurations about input shapes and normalization
model_config = get_model_config(args.model_name)
Expand Down Expand Up @@ -320,7 +333,9 @@ def main():
model,
equalize_iters=args.graph_eq_iterations,
equalize_merge_bias=args.graph_eq_merge_bias,
merge_bn=not args.calibrate_bn)
merge_bn=not args.calibrate_bn,
channel_splitting=args.channel_splitting,
channel_splitting_ratio=args.split_ratio)
else:
raise RuntimeError(f"{args.target_backend} backend not supported.")

Expand Down
8 changes: 4 additions & 4 deletions tests/brevitas/graph/test_channel_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from .equalization_fixtures import *


@pytest.mark.skip(reason="Focus on alexnet first")
def test_resnet18():
model = models.resnet18(pretrained=True)

Expand All @@ -19,9 +18,9 @@ def test_resnet18():
expected_out = model(inp)
model = symbolic_trace(model)

ChannelSplitting(model).apply()
model = ChannelSplitting(split_ratio=0.1).apply(model)
out = model(inp)
assert expected_out == out
assert torch.allclose(expected_out, out, atol=ATOL)


def test_alexnet():
Expand All @@ -34,6 +33,7 @@ def test_alexnet():
expected_out = model(inp)
model = symbolic_trace(model)

ChannelSplitting(model, split_ratio=0.2).apply()
# set split_ratio to 0.2 to def have some splits
model = ChannelSplitting(split_ratio=0.2).apply(model)
out = model(inp)
assert torch.allclose(expected_out, out, atol=ATOL)

0 comments on commit 38ac31e

Please sign in to comment.