Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parametrized smoothquant #1168

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from brevitas.graph.utils import *
from brevitas.utils.python_utils import islambda
from brevitas.utils.rotation_utils import RotationWeightParametrization
from brevitas.utils.torch_utils import WeightBiasWrapper

__all__ = [
'Transform',
Expand Down Expand Up @@ -259,7 +260,10 @@ def apply(self, model: GraphModule) -> GraphModule:
tensor = getattr(module, self.tensor_name).data
tensor = self.transform_module(tensor)
# Modify the weights in-place
setattr(module, self.tensor_name, torch.nn.Parameter(tensor))
if isinstance(module, WeightBiasWrapper):
setattr(getattr(module, self.tensor_name), 'data', tensor)
else:
setattr(module, self.tensor_name, torch.nn.Parameter(tensor))

if hasattr(module, 'offload_params'):
module.offload_params(module)
Expand Down
155 changes: 103 additions & 52 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@
from brevitas.nn.quant_scale_bias import ScaleBias
from brevitas.utils.python_utils import recurse_getattr
from brevitas.utils.rotation_utils import RotationWeightParametrization
from brevitas.utils.rotation_utils import ScaleWeightParametrization
from brevitas.utils.torch_utils import KwargsForwardHook
from brevitas.utils.torch_utils import WeightBiasWrapper

# External optional dependency
try:
Expand Down Expand Up @@ -139,13 +141,6 @@ class EqualizationIndexes:
offset: int = 0


# Required for being hashable
@dataclass(eq=True, frozen=True)
class WeightBiasWrapper:
weight: torch.Tensor = None
bias: torch.Tensor = None


# Required for being hashable
@dataclass(eq=True, frozen=True)
class Region:
Expand Down Expand Up @@ -264,24 +259,30 @@ def _select_scale_computation_fn(
class activation_equalization_mode:

def __init__(
self,
model,
alpha,
add_mul_node=True,
layerwise=True,
enabled=True,
blacklist_layers=None,
co_optimize_act_weights=False) -> None:
self,
model,
alpha,
add_mul_node=True,
layerwise=True,
enabled=True,
blacklist_layers=None,
co_optimize_act_weights=False,
use_parametrized_scaling=False,
) -> None:
self.model = model
self.alpha = alpha
self.enabled = enabled
self.add_mul_node = add_mul_node
self.co_optimize_act_weights = co_optimize_act_weights
self.use_parametrized_scaling = use_parametrized_scaling
self.rewriters = None
if layerwise:
if not self.add_mul_node:
raise ValueError("Layerwise activation equalization requires add_mul_node")
self.graph_act_eq = LayerwiseActivationEqualization(
self.model, blacklist_layers=blacklist_layers)
self.model,
blacklist_layers=blacklist_layers,
use_parametrized_scaling=self.use_parametrized_scaling)
else:
if not isinstance(self.model, (TorchGraphModule, GraphModule)):
raise TypeError(
Expand All @@ -298,7 +299,18 @@ def __enter__(self):

def __exit__(self, type, value, traceback):
if self.enabled:
self.scale_factors = self.graph_act_eq.apply(self.alpha)
self.scale_factors, self.rewriters = self.graph_act_eq.apply(self.alpha)
# if len(self.rewriters) > 0:
# for r in self.rewriters:
# self.model = r.apply(self.model)
for module in self.model.modules():
import torch.nn.utils.parametrize as parametrize
tensor_names = ["weight", "bias"]
for tensor_name in tensor_names:
if parametrize.is_parametrized(
module) and tensor_name in module.parametrizations:
parametrize.remove_parametrizations(
module, tensor_name, leave_parametrized=True)
return True # To propagate exceptions


Expand Down Expand Up @@ -454,24 +466,28 @@ def transpose(tensor: torch.Tensor, axis: int):


def _cross_layer_equalization(
model: nn.Module,
region: Region,
merge_bias: bool,
scale_computation_type: str,
bias_shrinkage: Optional[Union[float, str]] = None,
list_of_act_val: Optional[torch.Tensor] = None,
list_of_insert_mul_node_fn: Optional[List[Callable]] = None,
alpha: float = 0.5,
co_optimize_act_weights: bool = False) -> torch.Tensor:
co_optimize_act_weights: bool = False,
fuse_scaling: bool = True) -> torch.Tensor:
"""
Given two adjacent tensors', the weights are scaled such that
the ranges of the first tensors' output channel are equal to the
ranges of the second tensors' input channel
"""
# Rewriters to be used when scaling is not fused
rewriters = []

# If equalization criteria are not met, we return a scalar one to indicate that no equalization
# has been performed
def _no_equalize():
return torch.tensor(1., dtype=dtype)
return torch.tensor(1., dtype=dtype), rewriters

# If a module has `allocate_params` attribute, we must load the weights following that method

Expand Down Expand Up @@ -508,7 +524,7 @@ def _no_equalize():
# For MultiheadAttention, we support only self-attetion
if isinstance(module, nn.MultiheadAttention) and module.in_proj_weight is not None:
# For sinks, we only need to modify the weight but not the bias
module = WeightBiasWrapper(module.in_proj_weight)
module = WeightBiasWrapper(module.in_proj_weight, orig_module=module)
elif isinstance(module, nn.MultiheadAttention) and module.in_proj_weight is None:
return _no_equalize()
sink_axes[name] = (module, axis)
Expand Down Expand Up @@ -609,56 +625,78 @@ def _no_equalize():

srcs_range = torch.pow(srcs_range, alpha)
sinks_range = torch.pow(sinks_range, 1 - alpha)
scaling_factors = srcs_range / sinks_range
inverse_scaling_factors = torch.reciprocal(scaling_factors)
scaling_factors = sinks_range / srcs_range

if list_of_act_val is not None and list_of_insert_mul_node_fn is not None:
device = list_of_act_val[0].device
for act_val_shape, insert_mul_node_fn in zip(list_of_act_val_shapes, list_of_insert_mul_node_fn):
insert_mul_node_fn(
inverse_scaling_factors.to(device=device, dtype=dtype), act_val_shape, act_axis)
scaling_factors.to(device=device, dtype=dtype), act_val_shape, act_axis)
if len(src_axes) > 0:
for name, (module, axis) in src_axes.items():
module_device = module.weight.device
indexes = region.srcs[name]
channel_start = indexes.offset + indexes.start
channel_end = indexes.offset + indexes.end
partial_inverse_scale = inverse_scaling_factors[channel_start:channel_end].to(
partial_scale = scaling_factors[channel_start:channel_end].to(
device=module_device, dtype=dtype)
# If rotations are fused or if the module is an orphan sink, transform is applied directly onto the tensor
rewriter_class = ModuleInstanceTransformTensor if fuse_scaling else ModuleInstanceRegisterParametrization
print(rewriter_class, fuse_scaling)
if hasattr(module, 'bias') and module.bias is not None:
_update_weights(
module, module.bias * partial_inverse_scale.view_as(module.bias), attr='bias')
rewriters.append(
rewriter_class(
module=module,
tensor_name="bias",
transform_module=ScaleWeightParametrization(
scaling_factor=partial_scale.view_as(module.bias), is_sink=False)))
src_broadcast_size = [1] * module.weight.ndim
src_broadcast_size[axis] = module.weight.size(axis)

_update_weights(
module,
module.weight * torch.reshape(partial_inverse_scale, src_broadcast_size),
attr='weight')
rewriters.append(
rewriter_class(
module=module,
tensor_name="weight",
transform_module=ScaleWeightParametrization(
scaling_factor=torch.reshape(partial_scale, src_broadcast_size),
is_sink=False)))
for name, (module, axis) in sink_axes.items():
module_device = module.weight.device
sink_broadcast_size = [1] * module.weight.ndim
sink_broadcast_size[axis] = module.weight.size(axis)
indexes = region.sinks[name]
channel_range = indexes.end - indexes.start
partial_scaling = torch.ones(module.weight.size(axis), device='cpu', dtype=dtype)
# If rotations are fused or if the module is an orphan sink, transform is applied directly onto the tensor
rewriter_class = ModuleInstanceTransformTensor if fuse_scaling else ModuleInstanceRegisterParametrization
if isinstance(module, WeightBiasWrapper):
module = module.orig_module
tensor_name = 'in_proj_weight'
else:
tensor_name = 'weight'
# We replace the scaling factors of the channels we need to equalize, leaving the other to
# one (i.e., no equalization)
partial_scaling[indexes.start:indexes.end] = scaling_factors[indexes.offset:indexes.offset +
channel_range]
partial_scaling = partial_scaling.to(device=module_device, dtype=dtype)
_update_weights(
module,
module.weight * torch.reshape(partial_scaling, sink_broadcast_size),
attr='weight')
rewriters.append(
rewriter_class(
module=module,
tensor_name=tensor_name,
transform_module=ScaleWeightParametrization(
scaling_factor=torch.reshape(partial_scaling, sink_broadcast_size),
is_sink=True)))
for r in rewriters:
r.apply(model)

# If a module has `offload_params` attribute, we must offload the weights following that method
for name in (region.srcs_names + region.sinks_names):
module = region.get_module_from_name(name)
if hasattr(module, 'offload_params'):
module.offload_params(module)

return scaling_factors
# print(scaling_factors, rewriters)
return scaling_factors, rewriters


def _update_weights(original_module, new_value, attr='weight'):
Expand All @@ -682,7 +720,8 @@ def _equalize(
for i in range(iterations):
scale_factor_max = None
for region in regions:
scale_factors_region = _cross_layer_equalization(
scale_factors_region, rewriters = _cross_layer_equalization(
model,
region,
merge_bias=merge_bias,
bias_shrinkage=bias_shrinkage,
Expand All @@ -692,6 +731,7 @@ def _equalize(
scale_factor_max = torch.max(scale_factor_max, scale_factor_region_max)
else:
scale_factor_max = scale_factor_region_max

if threshold is not None and scale_factor_max < threshold:
break
return model
Expand Down Expand Up @@ -1085,13 +1125,15 @@ def __init__(
self,
model,
scale_computation_type: str = 'maxabs',
blacklist_layers: Optional[List[str]] = None):
blacklist_layers: Optional[List[str]] = None,
use_parametrized_scaling: bool = False):
super(LayerwiseActivationEqualization, self).__init__(model, scale_computation_type)
self.float_act_map = {}
self.batch_dim_act_map = {}
self.hooks = []
self.add_mul_node = True
self.blacklist_layers = blacklist_layers
self.use_parametrized_scaling = use_parametrized_scaling

regions: List[Region] = []
self.find_module(model, regions)
Expand Down Expand Up @@ -1135,22 +1177,28 @@ def setup(self):

def apply(self, alpha):
scale_factors = []
rewriters = []
self.remove_hooks()
for region in self.regions:
module = region.get_module_from_name('sinks0')
if self.float_act_map[module] == None:
continue
insert_mul_fn = partial(
self.insert_mul_node, region=module, batch_dim=self.batch_dim_act_map[module])
scale_factors.append(
_cross_layer_equalization(
region,
False,
scale_computation_type=self.scale_computation_type,
list_of_act_val=[self.float_act_map[module]],
list_of_insert_mul_node_fn=[insert_mul_fn],
alpha=alpha))
return scale_factors
scale_factor, region_rewriter = _cross_layer_equalization(
self.model,
region,
False,
scale_computation_type=self.scale_computation_type,
list_of_act_val=[self.float_act_map[module]],
list_of_insert_mul_node_fn=[insert_mul_fn],
alpha=alpha,
fuse_scaling=not self.use_parametrized_scaling)
scale_factors.append(scale_factor)
# for r in region_rewriter:
# r.apply(self.model)
rewriters.extend(region_rewriter)
return scale_factors, rewriters

def insert_mul_node(self, scale, shape, axis, region, batch_dim=0):
mul_factor = self.create_mul_node(scale, shape, axis, batch_dim)
Expand Down Expand Up @@ -1230,6 +1278,7 @@ def setup(self):

def apply(self, alpha):
scale_factors = []
rewriters = []
self.remove_hooks()
for region in self.regions:
region_names = region.sinks_names if len(region.acts) == 0 else region.acts
Expand All @@ -1250,18 +1299,20 @@ def apply(self, alpha):
self.insert_mul_node,
act_node=act_node,
batch_dim=self.batch_dim_act_map[act_name]))

scale_factors.append(
_cross_layer_equalization(
scale_factor_region, rewriters_region = _cross_layer_equalization(
self.model,
region,
False,
scale_computation_type=self.scale_computation_type,
list_of_act_val=list_of_act_val,
list_of_insert_mul_node_fn=list_of_insert_mul_node_fn,
alpha=alpha,
co_optimize_act_weights=self.co_optimize_act_weights))

return scale_factors
co_optimize_act_weights=self.co_optimize_act_weights)
scale_factors.append(scale_factor_region)
rewriters.append(rewriters_region)
# for r in rewriters_region:
# r.apply(self.model)
return scale_factors, rewriters

def insert_mul_node(self, scale, shape, axis, act_node, batch_dim=0):
mul_factor = self.create_mul_node(scale, shape, axis, batch_dim)
Expand Down
19 changes: 19 additions & 0 deletions src/brevitas/utils/rotation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,22 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
raise RuntimeError("Not supported yet")

return tensor


class ScaleWeightParametrization(torch.nn.Module):
r"""Scales a tensor by a specified scaling factor

Args:
scaling_factor (Tensor): scaling factor by which to multiply
the tensor
"""

def __init__(self, scaling_factor: Tensor, is_sink: bool) -> None:
super().__init__()
self.scaling_factor = scaling_factor
self.is_sink = is_sink

def forward(self, tensor: torch.Tensor) -> torch.Tensor:
# Reciprocal is done on the fly as to preserve the tie between scale and its reciprocal
scale = torch.reciprocal(self.scaling_factor) if self.is_sink else self.scaling_factor
return tensor * scale
9 changes: 9 additions & 0 deletions src/brevitas/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import copy
from dataclasses import dataclass
from functools import wraps
from typing import List, Optional, Tuple

Expand All @@ -17,6 +18,14 @@ class StopFwdException(Exception):
pass


# Required for being hashable
@dataclass(eq=True, frozen=True)
class WeightBiasWrapper:
weight: torch.Tensor = None
bias: torch.Tensor = None
orig_module: torch.nn.Module = None


class TupleSequential(Sequential):

def output(self, mod, input):
Expand Down
Loading
Loading