Skip to content

Commit

Permalink
Feat (Channel-Splitting): adds graph channel splitting
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Jan 12, 2024
1 parent 8f1f5ee commit 5ef2416
Show file tree
Hide file tree
Showing 6 changed files with 419 additions and 12 deletions.
16 changes: 15 additions & 1 deletion src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from brevitas.graph.standardize import TorchFunctionalToModule
from brevitas.nn import quant_layer
import brevitas.nn as qnn
from brevitas.ptq_algorithms.channel_splitting import RegionwiseChannelSplitting
from brevitas.quant import Int8ActPerTensorFloat
from brevitas.quant import Int8ActPerTensorFloatMinMaxInit
from brevitas.quant import Int8WeightPerTensorFloat
Expand Down Expand Up @@ -263,7 +264,13 @@ def preprocess_for_quantize(
equalize_merge_bias=True,
merge_bn=True,
equalize_bias_shrinkage: str = 'vaiq',
equalize_scale_computation: str = 'maxabs'):
equalize_scale_computation: str = 'maxabs',
channel_splitting=False,
channel_splitting_ratio=0.02,
channel_splitting_grid_aware=False,
channel_splitting_split_input=True,
channel_splitting_criterion: str = 'maxabs',
channel_splitting_weight_bit_width=8):

training_state = model.training
model.eval()
Expand All @@ -285,6 +292,13 @@ def preprocess_for_quantize(
merge_bias=equalize_merge_bias,
bias_shrinkage=equalize_bias_shrinkage,
scale_computation_type=equalize_scale_computation).apply(model)
if channel_splitting:
model = RegionwiseChannelSplitting(
split_ratio=channel_splitting_ratio,
grid_aware=channel_splitting_grid_aware,
split_criterion=channel_splitting_criterion,
split_input=channel_splitting_split_input,
weight_bit_width=channel_splitting_weight_bit_width).apply(model)
model.train(training_state)
return model

Expand Down
289 changes: 289 additions & 0 deletions src/brevitas/ptq_algorithms/channel_splitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
import math
from typing import Dict, List, Set, Tuple, Union
import warnings

import torch
import torch.nn as nn

from brevitas.fx import GraphModule
from brevitas.graph.base import GraphTransform
from brevitas.graph.equalize import _extract_regions


def _calculate_scale(weights, bit_width, clip_threshold=1.):
max_abs = weights.abs().max()
clip_max_abs = max_abs * clip_threshold
n = 2 ** (bit_width - 1) - 1
return n / clip_max_abs


def _channels_maxabs(module, splits_per_layer, split_input):
# works for Conv2d and Linear
dim = 1 - int(split_input)
if isinstance(module, nn.Conv2d):
if not split_input:
# gets the max value for each output channel
max_per_channel = module.weight.data.abs().flatten(1).max(1).values
# check if max_per_channel has the same length as output channels
assert len(max_per_channel) == module.weight.shape[0]
else:
# getting max value for each input channel
max_per_channel = module.weight.data.abs().max(0).values.flatten(1).max(1).values
# check if same length as input channels
assert len(max_per_channel) == module.weight.shape[1]
elif isinstance(module, nn.Linear):
max_per_channel = module.weight.data.abs().max(dim=dim).values.flatten()
channels = torch.argsort(max_per_channel, descending=True)
return channels[:splits_per_layer]


def _channels_to_split(
sources: Dict[str, nn.Module],
sinks: Dict[str, nn.Module],
split_criterion: str,
split_ratio: float,
split_input: bool) -> Dict[nn.Module, List[torch.Tensor]]:
modules = sinks if split_input else sources
# the modules are all of the same shape so we can just take the first one
num_channels = next(iter(modules.values())).weight.shape[int(split_input)]
splits_per_layer = int(math.ceil(split_ratio * num_channels))

if splits_per_layer == 0:
warnings.warn(f'No splits for {modules}, increasing split_ratio could help.')

module_to_channels = {}
if split_criterion == 'maxabs':
for name, module in modules.items():
module_to_channels[name] = _channels_maxabs(module, splits_per_layer, split_input)

# return tensor with the indices to split
channels_to_split = torch.cat(list(module_to_channels.values()))
return torch.unique(channels_to_split)


def _split_single_channel(channel, grid_aware: bool, split_factor: float, scale: float = 1.):
if grid_aware:
split_channel = channel * split_factor * scale
slice1 = (split_channel - 0.25) / scale
slice2 = (split_channel + 0.25) / scale
return slice1, slice2
else:
return channel * split_factor, channel * split_factor


def _split_channels(
module,
channels_to_split,
grid_aware=True,
split_input=False,
split_factor=0.5,
bit_width=8) -> None:
"""
Splits the channels `channels_to_split` of the `weights`.
`split_input` specifies whether to split Input or Output channels.
Can also be used to duplicate a channel, just set split_factor to 1.
Returns: None
"""
weight = torch.clone(module.weight.data)
bias = torch.clone(module.bias.data) if module.bias is not None else None

# init scale
scale = 1.

if grid_aware:
# do a preliminary split of the channels to get the scale for the split channels
for id in channels_to_split:
if isinstance(module, torch.nn.Conv2d):
# there are four dimensions: [OC, IC, k, k]
if split_input:
channel = weight[:, id:id + 1, :, :]
slice1, slice2 = _split_single_channel(channel=channel, grid_aware=False, split_factor=split_factor)
weight = torch.cat(
(weight[:, :id, :, :], slice1, slice2, weight[:, id + 1:, :, :]), dim=1)
else:
channel = weight[id:id + 1, :, :, :]
slice1, slice2 = _split_single_channel(channel=channel, grid_aware=False, split_factor=split_factor)
weight = torch.cat(
(weight[:id, :, :, :], slice1, slice2, weight[id + 1:, :, :, :]), dim=0)

elif isinstance(module, torch.nn.Linear):
# there are two dimensions: [OC, IC]
if split_input:
channel = weight[:, id:id + 1]
slice1, slice2 = _split_single_channel(channel=channel, grid_aware=False, split_factor=split_factor)
weight = torch.cat((weight[:, :id], slice1, slice2, weight[:, id + 1:]), dim=1)
else:
channel = weight[id:id + 1, :]
slice1, slice2 = _split_single_channel(channel=channel, grid_aware=False, split_factor=split_factor)
weight = torch.cat((weight[:id, :], slice1, slice2, weight[id + 1:, :]), dim=0)
# now calculate the scale of the split weights
scale = _calculate_scale(weight, bit_width)

# reset the weight variable
weight = torch.clone(module.weight.data)

for id in channels_to_split:
if isinstance(module, torch.nn.Conv2d):
# there are four dimensions: [OC, IC, k, k]
if split_input:
channel = weight[:, id:id + 1, :, :]
slice1, slice2 = _split_single_channel(channel=channel, grid_aware=grid_aware, split_factor=split_factor, scale=scale)
weight = torch.cat((weight[:, :id, :, :], slice1, weight[:, id + 1:, :, :], slice2),
dim=1)
module.in_channels += 1
else:
channel = weight[id:id + 1, :, :, :]
slice1, slice2 = _split_single_channel(channel=channel, grid_aware=grid_aware, split_factor=split_factor, scale=scale)
weight = torch.cat((weight[:id, :, :, :], slice1, weight[id + 1:, :, :, :], slice2),
dim=0)
module.out_channels += 1

elif isinstance(module, torch.nn.Linear):
# there are two dimensions: [OC, IC]
if split_input:
channel = weight[:, id:id + 1]
slice1, slice2 = _split_single_channel(channel=channel, grid_aware=grid_aware, split_factor=split_factor, scale=scale)
weight = torch.cat((weight[:, :id], slice1, weight[:, id + 1:], slice2), dim=1)
module.in_features += 1
else:
channel = weight[id:id + 1, :]
slice1, slice2 = _split_single_channel(channel=channel, grid_aware=grid_aware, split_factor=split_factor, scale=scale)
weight = torch.cat((weight[:id, :], slice1, weight[id + 1:, :], slice2), dim=0)
module.out_features += 1

if bias is not None and not split_input:
channel = bias[id:id + 1] * split_factor
bias = torch.cat((bias[:id], channel, bias[id + 1:], channel))

module.weight.data = weight
if bias is not None:
module.bias.data = bias


def _split_channels_region(
sources: Dict[str, nn.Module],
sinks: Dict[str, nn.Module],
channels_to_split: torch.tensor,
split_input: bool,
grid_aware: bool = False,
weight_bit_width: int = 8):
# splitting output channels
# concat all channels that are split so we can duplicate those in the input channels later
if not split_input:
for name, module in sources.items():
_split_channels(
module, channels_to_split, grid_aware=grid_aware, bit_width=weight_bit_width)
for name, module in sinks.items():
# then duplicate the input_channels for all modules in the sink
_split_channels(
module, channels_to_split, grid_aware=False, split_factor=1, split_input=True)
else:
# input channels are split in half, output channels duplicated
for name, module in sinks.items():
_split_channels(
module,
channels_to_split,
grid_aware=grid_aware,
split_input=True,
bit_width=weight_bit_width)
for name, module in sources.items():
# duplicate out_channels for all modules in the source
_split_channels(module, channels_to_split, grid_aware=False, split_factor=1)


def _is_supported(srcs: List[nn.Module], sinks: List[nn.Module]) -> bool:
# check if OCs of sources are all equal
srcs_ocs = set(module.weight.shape[0] for module in srcs)
if len(srcs_ocs) > 1:
return False

# check if ICs of sinks are all equal
sinks_ics = set(module.weight.shape[1] for module in sinks)
if len(sinks_ics) > 1:
return False

return srcs_ocs == sinks_ics


def _split(
model: GraphModule,
regions: Set[Tuple[str]],
split_ratio: float,
split_criterion: str,
grid_aware: bool,
split_input: bool,
weight_bit_width: int) -> GraphModule:
name_to_module: Dict[str, nn.Module] = {}
name_set = set()
for region in regions:
for name in region.srcs_names:
name_set.add(name)
for name in region.sinks_names:
name_set.add(name)

for name, module in model.named_modules():
if name in name_set:
name_to_module[name] = module

for i, region in enumerate(regions):

# check if region is suitable for channel splitting
sources = {n: name_to_module[n] for n in region.srcs_names}
sinks = {n: name_to_module[n] for n in region.sinks_names}

if _is_supported(sources.values(), sinks.values()):
# get channels to split
channels_to_split = _channels_to_split(
sources=sources,
sinks=sinks,
split_criterion=split_criterion,
split_ratio=split_ratio,
split_input=split_input)
# splitting/duplicating channels
_split_channels_region(
sources=sources,
sinks=sinks,
channels_to_split=channels_to_split,
grid_aware=grid_aware,
split_input=split_input,
weight_bit_width=weight_bit_width)

return model


class RegionwiseChannelSplitting(GraphTransform):

def __init__(
self,
split_ratio=0.02,
split_criterion='maxabs',
grid_aware=False,
split_input=True,
weight_bit_width=8):
super(RegionwiseChannelSplitting, self).__init__()

self.grid_aware = grid_aware
self.split_ratio = split_ratio
self.split_criterion = split_criterion
self.split_input = split_input
self.weight_bit_width = weight_bit_width

def apply(
self,
model,
return_regions: bool = False
) -> Union[Tuple[GraphModule, Set[Tuple[str]]], GraphModule]:
regions = _extract_regions(model)
if len(regions) > 0:
self.graph_model = _split(
model=model,
regions=regions,
split_ratio=self.split_ratio,
split_criterion=self.split_criterion,
grid_aware=self.grid_aware,
split_input=self.split_input,
weight_bit_width=self.weight_bit_width)
if return_regions:
return self.graph_model, regions
else:
return self.graph_model
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ def unique(sequence):
'accumulator_bit_width': [16], # Accumulator bit width, only in combination with GPFA2Q
'act_quant_percentile': [99.999], # Activation Quantization Percentile
'uint_sym_act_for_unsigned_values': [True], # Whether to use unsigned act quant when possible
}
'channel_splitting': [False], # Use channel splitting algorithm in preprocessing step
'split_ratio': [0.01], # Split ratio for channel splitting
'split_input': [True], # Split input or output channels for channel splitting
'grid_aware': [False], # Use grid aware channel splitting
'merge_bn': [True],}

parser = argparse.ArgumentParser(description='PyTorch ImageNet PTQ Validation')
parser.add_argument('idx', type=int)
Expand Down Expand Up @@ -210,7 +214,13 @@ def ptq_torchvision_models(args):
model = preprocess_for_quantize(
model,
equalize_iters=config_namespace.graph_eq_iterations,
equalize_merge_bias=config_namespace.graph_eq_merge_bias)
equalize_merge_bias=config_namespace.graph_eq_merge_bias,
merge_bn=config_namespace.merge_bn,
channel_splitting=config_namespace.channel_splitting,
channel_splitting_ratio=config_namespace.split_ratio,
channel_splitting_grid_aware=config_namespace.grid_aware,
channel_splitting_split_input=config_namespace.split_input,
channel_splitting_weight_bit_width=config_namespace.weight_bit_width)
else:
raise RuntimeError(f"{config_namespace.target_backend} backend not supported.")

Expand Down Expand Up @@ -334,6 +344,9 @@ def validate_config(config_namespace):
config_namespace.gpfa2q)
if multiple_gpxqs > 1:
is_valid = False
elif multiple_gpxqs == 0:
# no gpxq algorithm, set act order to None
config_namespace.gpxq_act_order = None

if config_namespace.act_equalization == 'layerwise' and config_namespace.target_backend == 'fx':
is_valid = False
Expand Down Expand Up @@ -364,6 +377,11 @@ def validate_config(config_namespace):
is_valid = False
if config_namespace.act_exponent_bit_width + config_namespace.act_mantissa_bit_width != config_namespace.act_bit_width - 1:
is_valid = False
# if channel splitting is false, no need for split ratio, grid aware, or iteratively split
if not config_namespace.channel_splitting:
config_namespace.split_ratio = None
config_namespace.grid_aware = None
config_namespace.split_input = None

config_namespace.is_valid = is_valid
return config_namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,10 @@ def layerwise_bit_width_fn_weight(module):
act_bit_width_dict = {}
if quant_format == 'int' and backend == 'layerwise':
weight_bit_width_dict['weight_bit_width'] = layerwise_bit_width_fn_weight
act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act

if act_bit_width is not None:
act_bit_width_dict['act_bit_width'] = layerwise_bit_width_fn_act
else:
act_bit_width_dict['act_bit_width'] = None
elif quant_format == 'int' and backend != 'layerwise':
weight_bit_width_dict['weight_bit_width'] = weight_bit_width
act_bit_width_dict['act_bit_width'] = act_bit_width
Expand Down Expand Up @@ -291,7 +293,7 @@ def kwargs_prefix(prefix, weight_kwargs):
act_bit_width_dict['mantissa_bit_width'] = act_mantissa_bit_width

# Retrieve base input, weight, and bias quantizers
bias_quant = BIAS_BIT_WIDTH_MAP[bias_bit_width]
bias_quant = BIAS_BIT_WIDTH_MAP[bias_bit_width] if act_bit_width is not None else None
weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_type][weight_param_method][
weight_quant_granularity][weight_quant_type]
weight_quant = weight_quant.let(**weight_bit_width_dict)
Expand Down
Loading

0 comments on commit 5ef2416

Please sign in to comment.