From 5ef2416710ab25b8deae283170058a4df19aaabf Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Tue, 28 Nov 2023 17:36:09 +0000 Subject: [PATCH] Feat (Channel-Splitting): adds graph channel splitting --- src/brevitas/graph/quantize.py | 16 +- .../ptq_algorithms/channel_splitting.py | 289 ++++++++++++++++++ .../benchmark/ptq_benchmark_torchvision.py | 22 +- .../imagenet_classification/ptq/ptq_common.py | 8 +- .../ptq/ptq_evaluate.py | 48 ++- .../brevitas/graph/test_channel_splitting.py | 48 +++ 6 files changed, 419 insertions(+), 12 deletions(-) create mode 100644 src/brevitas/ptq_algorithms/channel_splitting.py create mode 100644 tests/brevitas/graph/test_channel_splitting.py diff --git a/src/brevitas/graph/quantize.py b/src/brevitas/graph/quantize.py index 63143c4e5..9a8ba5eac 100644 --- a/src/brevitas/graph/quantize.py +++ b/src/brevitas/graph/quantize.py @@ -26,6 +26,7 @@ from brevitas.graph.standardize import TorchFunctionalToModule from brevitas.nn import quant_layer import brevitas.nn as qnn +from brevitas.ptq_algorithms.channel_splitting import RegionwiseChannelSplitting from brevitas.quant import Int8ActPerTensorFloat from brevitas.quant import Int8ActPerTensorFloatMinMaxInit from brevitas.quant import Int8WeightPerTensorFloat @@ -263,7 +264,13 @@ def preprocess_for_quantize( equalize_merge_bias=True, merge_bn=True, equalize_bias_shrinkage: str = 'vaiq', - equalize_scale_computation: str = 'maxabs'): + equalize_scale_computation: str = 'maxabs', + channel_splitting=False, + channel_splitting_ratio=0.02, + channel_splitting_grid_aware=False, + channel_splitting_split_input=True, + channel_splitting_criterion: str = 'maxabs', + channel_splitting_weight_bit_width=8): training_state = model.training model.eval() @@ -285,6 +292,13 @@ def preprocess_for_quantize( merge_bias=equalize_merge_bias, bias_shrinkage=equalize_bias_shrinkage, scale_computation_type=equalize_scale_computation).apply(model) + if channel_splitting: + model = RegionwiseChannelSplitting( + split_ratio=channel_splitting_ratio, + grid_aware=channel_splitting_grid_aware, + split_criterion=channel_splitting_criterion, + split_input=channel_splitting_split_input, + weight_bit_width=channel_splitting_weight_bit_width).apply(model) model.train(training_state) return model diff --git a/src/brevitas/ptq_algorithms/channel_splitting.py b/src/brevitas/ptq_algorithms/channel_splitting.py new file mode 100644 index 000000000..1cec6ffc0 --- /dev/null +++ b/src/brevitas/ptq_algorithms/channel_splitting.py @@ -0,0 +1,289 @@ +import math +from typing import Dict, List, Set, Tuple, Union +import warnings + +import torch +import torch.nn as nn + +from brevitas.fx import GraphModule +from brevitas.graph.base import GraphTransform +from brevitas.graph.equalize import _extract_regions + + +def _calculate_scale(weights, bit_width, clip_threshold=1.): + max_abs = weights.abs().max() + clip_max_abs = max_abs * clip_threshold + n = 2 ** (bit_width - 1) - 1 + return n / clip_max_abs + + +def _channels_maxabs(module, splits_per_layer, split_input): + # works for Conv2d and Linear + dim = 1 - int(split_input) + if isinstance(module, nn.Conv2d): + if not split_input: + # gets the max value for each output channel + max_per_channel = module.weight.data.abs().flatten(1).max(1).values + # check if max_per_channel has the same length as output channels + assert len(max_per_channel) == module.weight.shape[0] + else: + # getting max value for each input channel + max_per_channel = module.weight.data.abs().max(0).values.flatten(1).max(1).values + # check if same length as input channels + assert len(max_per_channel) == module.weight.shape[1] + elif isinstance(module, nn.Linear): + max_per_channel = module.weight.data.abs().max(dim=dim).values.flatten() + channels = torch.argsort(max_per_channel, descending=True) + return channels[:splits_per_layer] + + +def _channels_to_split( + sources: Dict[str, nn.Module], + sinks: Dict[str, nn.Module], + split_criterion: str, + split_ratio: float, + split_input: bool) -> Dict[nn.Module, List[torch.Tensor]]: + modules = sinks if split_input else sources + # the modules are all of the same shape so we can just take the first one + num_channels = next(iter(modules.values())).weight.shape[int(split_input)] + splits_per_layer = int(math.ceil(split_ratio * num_channels)) + + if splits_per_layer == 0: + warnings.warn(f'No splits for {modules}, increasing split_ratio could help.') + + module_to_channels = {} + if split_criterion == 'maxabs': + for name, module in modules.items(): + module_to_channels[name] = _channels_maxabs(module, splits_per_layer, split_input) + + # return tensor with the indices to split + channels_to_split = torch.cat(list(module_to_channels.values())) + return torch.unique(channels_to_split) + + +def _split_single_channel(channel, grid_aware: bool, split_factor: float, scale: float = 1.): + if grid_aware: + split_channel = channel * split_factor * scale + slice1 = (split_channel - 0.25) / scale + slice2 = (split_channel + 0.25) / scale + return slice1, slice2 + else: + return channel * split_factor, channel * split_factor + + +def _split_channels( + module, + channels_to_split, + grid_aware=True, + split_input=False, + split_factor=0.5, + bit_width=8) -> None: + """ + Splits the channels `channels_to_split` of the `weights`. + `split_input` specifies whether to split Input or Output channels. + Can also be used to duplicate a channel, just set split_factor to 1. + Returns: None + """ + weight = torch.clone(module.weight.data) + bias = torch.clone(module.bias.data) if module.bias is not None else None + + # init scale + scale = 1. + + if grid_aware: + # do a preliminary split of the channels to get the scale for the split channels + for id in channels_to_split: + if isinstance(module, torch.nn.Conv2d): + # there are four dimensions: [OC, IC, k, k] + if split_input: + channel = weight[:, id:id + 1, :, :] + slice1, slice2 = _split_single_channel(channel=channel, grid_aware=False, split_factor=split_factor) + weight = torch.cat( + (weight[:, :id, :, :], slice1, slice2, weight[:, id + 1:, :, :]), dim=1) + else: + channel = weight[id:id + 1, :, :, :] + slice1, slice2 = _split_single_channel(channel=channel, grid_aware=False, split_factor=split_factor) + weight = torch.cat( + (weight[:id, :, :, :], slice1, slice2, weight[id + 1:, :, :, :]), dim=0) + + elif isinstance(module, torch.nn.Linear): + # there are two dimensions: [OC, IC] + if split_input: + channel = weight[:, id:id + 1] + slice1, slice2 = _split_single_channel(channel=channel, grid_aware=False, split_factor=split_factor) + weight = torch.cat((weight[:, :id], slice1, slice2, weight[:, id + 1:]), dim=1) + else: + channel = weight[id:id + 1, :] + slice1, slice2 = _split_single_channel(channel=channel, grid_aware=False, split_factor=split_factor) + weight = torch.cat((weight[:id, :], slice1, slice2, weight[id + 1:, :]), dim=0) + # now calculate the scale of the split weights + scale = _calculate_scale(weight, bit_width) + + # reset the weight variable + weight = torch.clone(module.weight.data) + + for id in channels_to_split: + if isinstance(module, torch.nn.Conv2d): + # there are four dimensions: [OC, IC, k, k] + if split_input: + channel = weight[:, id:id + 1, :, :] + slice1, slice2 = _split_single_channel(channel=channel, grid_aware=grid_aware, split_factor=split_factor, scale=scale) + weight = torch.cat((weight[:, :id, :, :], slice1, weight[:, id + 1:, :, :], slice2), + dim=1) + module.in_channels += 1 + else: + channel = weight[id:id + 1, :, :, :] + slice1, slice2 = _split_single_channel(channel=channel, grid_aware=grid_aware, split_factor=split_factor, scale=scale) + weight = torch.cat((weight[:id, :, :, :], slice1, weight[id + 1:, :, :, :], slice2), + dim=0) + module.out_channels += 1 + + elif isinstance(module, torch.nn.Linear): + # there are two dimensions: [OC, IC] + if split_input: + channel = weight[:, id:id + 1] + slice1, slice2 = _split_single_channel(channel=channel, grid_aware=grid_aware, split_factor=split_factor, scale=scale) + weight = torch.cat((weight[:, :id], slice1, weight[:, id + 1:], slice2), dim=1) + module.in_features += 1 + else: + channel = weight[id:id + 1, :] + slice1, slice2 = _split_single_channel(channel=channel, grid_aware=grid_aware, split_factor=split_factor, scale=scale) + weight = torch.cat((weight[:id, :], slice1, weight[id + 1:, :], slice2), dim=0) + module.out_features += 1 + + if bias is not None and not split_input: + channel = bias[id:id + 1] * split_factor + bias = torch.cat((bias[:id], channel, bias[id + 1:], channel)) + + module.weight.data = weight + if bias is not None: + module.bias.data = bias + + +def _split_channels_region( + sources: Dict[str, nn.Module], + sinks: Dict[str, nn.Module], + channels_to_split: torch.tensor, + split_input: bool, + grid_aware: bool = False, + weight_bit_width: int = 8): + # splitting output channels + # concat all channels that are split so we can duplicate those in the input channels later + if not split_input: + for name, module in sources.items(): + _split_channels( + module, channels_to_split, grid_aware=grid_aware, bit_width=weight_bit_width) + for name, module in sinks.items(): + # then duplicate the input_channels for all modules in the sink + _split_channels( + module, channels_to_split, grid_aware=False, split_factor=1, split_input=True) + else: + # input channels are split in half, output channels duplicated + for name, module in sinks.items(): + _split_channels( + module, + channels_to_split, + grid_aware=grid_aware, + split_input=True, + bit_width=weight_bit_width) + for name, module in sources.items(): + # duplicate out_channels for all modules in the source + _split_channels(module, channels_to_split, grid_aware=False, split_factor=1) + + +def _is_supported(srcs: List[nn.Module], sinks: List[nn.Module]) -> bool: + # check if OCs of sources are all equal + srcs_ocs = set(module.weight.shape[0] for module in srcs) + if len(srcs_ocs) > 1: + return False + + # check if ICs of sinks are all equal + sinks_ics = set(module.weight.shape[1] for module in sinks) + if len(sinks_ics) > 1: + return False + + return srcs_ocs == sinks_ics + + +def _split( + model: GraphModule, + regions: Set[Tuple[str]], + split_ratio: float, + split_criterion: str, + grid_aware: bool, + split_input: bool, + weight_bit_width: int) -> GraphModule: + name_to_module: Dict[str, nn.Module] = {} + name_set = set() + for region in regions: + for name in region.srcs_names: + name_set.add(name) + for name in region.sinks_names: + name_set.add(name) + + for name, module in model.named_modules(): + if name in name_set: + name_to_module[name] = module + + for i, region in enumerate(regions): + + # check if region is suitable for channel splitting + sources = {n: name_to_module[n] for n in region.srcs_names} + sinks = {n: name_to_module[n] for n in region.sinks_names} + + if _is_supported(sources.values(), sinks.values()): + # get channels to split + channels_to_split = _channels_to_split( + sources=sources, + sinks=sinks, + split_criterion=split_criterion, + split_ratio=split_ratio, + split_input=split_input) + # splitting/duplicating channels + _split_channels_region( + sources=sources, + sinks=sinks, + channels_to_split=channels_to_split, + grid_aware=grid_aware, + split_input=split_input, + weight_bit_width=weight_bit_width) + + return model + + +class RegionwiseChannelSplitting(GraphTransform): + + def __init__( + self, + split_ratio=0.02, + split_criterion='maxabs', + grid_aware=False, + split_input=True, + weight_bit_width=8): + super(RegionwiseChannelSplitting, self).__init__() + + self.grid_aware = grid_aware + self.split_ratio = split_ratio + self.split_criterion = split_criterion + self.split_input = split_input + self.weight_bit_width = weight_bit_width + + def apply( + self, + model, + return_regions: bool = False + ) -> Union[Tuple[GraphModule, Set[Tuple[str]]], GraphModule]: + regions = _extract_regions(model) + if len(regions) > 0: + self.graph_model = _split( + model=model, + regions=regions, + split_ratio=self.split_ratio, + split_criterion=self.split_criterion, + grid_aware=self.grid_aware, + split_input=self.split_input, + weight_bit_width=self.weight_bit_width) + if return_regions: + return self.graph_model, regions + else: + return self.graph_model diff --git a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py index f5fe652cb..277de3a61 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py +++ b/src/brevitas_examples/imagenet_classification/ptq/benchmark/ptq_benchmark_torchvision.py @@ -105,7 +105,11 @@ def unique(sequence): 'accumulator_bit_width': [16], # Accumulator bit width, only in combination with GPFA2Q 'act_quant_percentile': [99.999], # Activation Quantization Percentile 'uint_sym_act_for_unsigned_values': [True], # Whether to use unsigned act quant when possible -} + 'channel_splitting': [False], # Use channel splitting algorithm in preprocessing step + 'split_ratio': [0.01], # Split ratio for channel splitting + 'split_input': [True], # Split input or output channels for channel splitting + 'grid_aware': [False], # Use grid aware channel splitting + 'merge_bn': [True],} parser = argparse.ArgumentParser(description='PyTorch ImageNet PTQ Validation') parser.add_argument('idx', type=int) @@ -210,7 +214,13 @@ def ptq_torchvision_models(args): model = preprocess_for_quantize( model, equalize_iters=config_namespace.graph_eq_iterations, - equalize_merge_bias=config_namespace.graph_eq_merge_bias) + equalize_merge_bias=config_namespace.graph_eq_merge_bias, + merge_bn=config_namespace.merge_bn, + channel_splitting=config_namespace.channel_splitting, + channel_splitting_ratio=config_namespace.split_ratio, + channel_splitting_grid_aware=config_namespace.grid_aware, + channel_splitting_split_input=config_namespace.split_input, + channel_splitting_weight_bit_width=config_namespace.weight_bit_width) else: raise RuntimeError(f"{config_namespace.target_backend} backend not supported.") @@ -334,6 +344,9 @@ def validate_config(config_namespace): config_namespace.gpfa2q) if multiple_gpxqs > 1: is_valid = False + elif multiple_gpxqs == 0: + # no gpxq algorithm, set act order to None + config_namespace.gpxq_act_order = None if config_namespace.act_equalization == 'layerwise' and config_namespace.target_backend == 'fx': is_valid = False @@ -364,6 +377,11 @@ def validate_config(config_namespace): is_valid = False if config_namespace.act_exponent_bit_width + config_namespace.act_mantissa_bit_width != config_namespace.act_bit_width - 1: is_valid = False + # if channel splitting is false, no need for split ratio, grid aware, or iteratively split + if not config_namespace.channel_splitting: + config_namespace.split_ratio = None + config_namespace.grid_aware = None + config_namespace.split_input = None config_namespace.is_valid = is_valid return config_namespace diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 0b626e485..541f0085f 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -197,8 +197,10 @@ def layerwise_bit_width_fn_weight(module): act_bit_width_dict = {} if quant_format == 'int' and backend == 'layerwise': weight_bit_width_dict['weight_bit_width'] = layerwise_bit_width_fn_weight - act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act - + if act_bit_width is not None: + act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act + else: + act_bit_width_dict['act_bit_width'] = None elif quant_format == 'int' and backend != 'layerwise': weight_bit_width_dict['weight_bit_width'] = weight_bit_width act_bit_width_dict['act_bit_width'] = act_bit_width @@ -291,7 +293,7 @@ def kwargs_prefix(prefix, weight_kwargs): act_bit_width_dict['mantissa_bit_width'] = act_mantissa_bit_width # Retrieve base input, weight, and bias quantizers - bias_quant = BIAS_BIT_WIDTH_MAP[bias_bit_width] + bias_quant = BIAS_BIT_WIDTH_MAP[bias_bit_width] if act_bit_width is not None else None weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_type][weight_param_method][ weight_quant_granularity][weight_quant_type] weight_quant = weight_quant.let(**weight_bit_width_dict) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index c821dad33..daadcce67 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -218,6 +218,11 @@ type=int, help='Accumulator Bit Width for GPFA2Q (default: None)') parser.add_argument('--onnx-opset-version', default=None, type=int, help='ONNX opset version') +parser.add_argument( + '--split-ratio', + default=0.02, + type=float, + help='Split Ratio for Channel Splitting (default: 0.02)') add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)') add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)') add_bool_arg(parser, 'gpfa2q', default=False, help='GPFA2Q (default: disabled)') @@ -225,6 +230,23 @@ parser, 'gpxq-act-order', default=False, help='GPxQ Act order heuristic (default: disabled)') add_bool_arg(parser, 'learned-round', default=False, help='Learned round (default: disabled)') add_bool_arg(parser, 'calibrate-bn', default=False, help='Calibrate BN (default: disabled)') +add_bool_arg( + parser, + 'channel-splitting', + default=False, + help='Apply Channel Splitting before Quantization (default: disabled)') +add_bool_arg( + parser, 'grid-aware', default=False, help='Grid-aware channel splitting (default: disabled)') +add_bool_arg( + parser, + 'split-input', + default=False, + help='Input Channels Splitting for channel splitting (default: disabled)') +add_bool_arg( + parser, + 'merge-bn', + default=True, + help='Merge BN layers before quantizing the model (default: enabled)') def main(): @@ -239,6 +261,9 @@ def main(): else: act_quant_calib_config = args.act_quant_calibration_type + if args.act_bit_width == 0: + args.act_bit_width = None + config = ( f"{args.model_name}_" f"{args.target_backend}_" @@ -264,7 +289,8 @@ def main(): f"{'mb_' if args.graph_eq_merge_bias else ''}" f"{act_quant_calib_config}_" f"{args.weight_quant_calibration_type}_" - f"{'bnc' if args.calibrate_bn else ''}") + f"{'bnc_' if args.calibrate_bn else ''}" + f"{'channel_splitting' if args.channel_splitting else ''}") print( f"Model: {args.model_name} - " @@ -288,7 +314,12 @@ def main(): f"Merge bias in graph equalization: {args.graph_eq_merge_bias} - " f"Activation quant calibration type: {act_quant_calib_config} - " f"Weight quant calibration type: {args.weight_quant_calibration_type} - " - f"Calibrate BN: {args.calibrate_bn}") + f"Calibrate BN: {args.calibrate_bn} - " + f"Channel Splitting: {args.channel_splitting} - " + f"Split Ratio: {args.split_ratio} - " + f"Grid Aware: {args.grid_aware} - " + f"Split Input: {args.split_input} - " + f"Merge BN: {args.merge_bn}") # Get model-specific configurations about input shapes and normalization model_config = get_model_config(args.model_name) @@ -332,7 +363,12 @@ def main(): model, equalize_iters=args.graph_eq_iterations, equalize_merge_bias=args.graph_eq_merge_bias, - merge_bn=not args.calibrate_bn) + merge_bn=not args.calibrate_bn, + channel_splitting=args.channel_splitting, + channel_splitting_grid_aware=args.grid_aware, + channel_splitting_split_input=args.split_input, + channel_splitting_ratio=args.split_ratio, + channel_splitting_weight_bit_width=args.weight_bit_width) else: raise RuntimeError(f"{args.target_backend} backend not supported.") @@ -400,9 +436,9 @@ def main(): iters=args.learned_round_iters, optimizer_lr=args.learned_round_lr) - if args.calibrate_bn: - print("Calibrate BN:") - calibrate_bn(calib_loader, quant_model) + # if args.calibrate_bn: + # print("Calibrate BN:") + # calibrate_bn(calib_loader, quant_model) if args.bias_corr: print("Applying bias correction:") diff --git a/tests/brevitas/graph/test_channel_splitting.py b/tests/brevitas/graph/test_channel_splitting.py new file mode 100644 index 000000000..9e3bd1d59 --- /dev/null +++ b/tests/brevitas/graph/test_channel_splitting.py @@ -0,0 +1,48 @@ +import torch +from torchvision import models + +from brevitas.fx import symbolic_trace +from brevitas.graph.fixed_point import MergeBatchNorm +from brevitas.ptq_algorithms.channel_splitting import * + +from .equalization_fixtures import * + + +@pytest.mark.parametrize('split_ratio', [0.05, 0.1, 0.2]) +@pytest.mark.parametrize('split_input', [False, True]) +def test_resnet18(split_ratio, split_input): + model = models.resnet18(pretrained=True) + + torch.manual_seed(SEED) + inp = torch.randn(IN_SIZE_CONV) + + model.eval() + expected_out = model(inp) + model = symbolic_trace(model) + + # merge BN before applying channel splitting + model = MergeBatchNorm().apply(model) + + model = RegionwiseChannelSplitting( + split_ratio=split_ratio, split_input=split_input).apply(model) + out = model(inp) + assert torch.allclose(expected_out, out, atol=ATOL) + + +@pytest.mark.parametrize('split_ratio', [0.05, 0.1]) +@pytest.mark.parametrize('split_input', [False, True]) +def test_alexnet(split_ratio, split_input): + model = models.alexnet(pretrained=True) + + torch.manual_seed(SEED) + inp = torch.randn(IN_SIZE_CONV) + + model.eval() + expected_out = model(inp) + model = symbolic_trace(model) + + # set split_ratio to 0.2 to def have some splits + model = RegionwiseChannelSplitting( + split_ratio=split_ratio, split_input=split_input).apply(model) + out = model(inp) + assert torch.allclose(expected_out, out, atol=ATOL)