Skip to content

Commit

Permalink
Remove legacy code
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Feb 5, 2025
1 parent 6610a0c commit 7ba38a5
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 253 deletions.
223 changes: 0 additions & 223 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,229 +733,6 @@ def _equalize(
return model


# TODO: Remove
def _cross_layer_equalization_legacy(
region: Region,
merge_bias: bool,
scale_computation_type: str,
bias_shrinkage: Optional[Union[float, str]] = None,
list_of_act_val: Optional[torch.Tensor] = None,
list_of_insert_mul_node_fn: Optional[List[Callable]] = None,
alpha: float = 0.5,
co_optimize_act_weights: bool = False) -> torch.Tensor:
"""
Given two adjacent tensors', the weights are scaled such that
the ranges of the first tensors' output channel are equal to the
ranges of the second tensors' input channel
"""

# If equalization criteria are not met, we return a scalar one to indicate that no equalization
# has been performed
def _no_equalize():
return torch.tensor(1., dtype=dtype)

# If a module has `allocate_params` attribute, we must load the weights following that method

for name in (region.srcs_names + region.sinks_names):
module = region.get_module_from_name(name)
if hasattr(module, 'allocate_params'):
module.allocate_params(module)

act_sink_axes = {}
act_sources_axes = {}
single_module = region.get_module_from_name(next(iter(region.sinks_names)))
dtype = next(single_module.parameters()).dtype

# If region is not valid, don't equalize. If we are inserting a standalone mul, we don't need this check
if not region.is_valid and list_of_insert_mul_node_fn is None:
return _no_equalize()

src_axes = {}
for name, indexes in region.srcs.items():
module = region.get_module_from_name(name)
# If module is not supported, do not perform graph equalization
axis = _get_output_axis(module)
act_sources_axes[name] = _get_act_axis(module)

if isinstance(module, nn.MultiheadAttention):
module = module.out_proj
src_axes[name] = (module, axis)

sink_axes = {}
for name, indexes in region.sinks.items():
module = region.get_module_from_name(name)
axis = _get_input_axis(module)
act_sink_axes[name] = _get_act_axis(module)
# For MultiheadAttention, we support only self-attetion
if isinstance(module, nn.MultiheadAttention) and module.in_proj_weight is not None:
# For sinks, we only need to modify the weight but not the bias
module = WeightBiasWrapper(module.in_proj_weight)
elif isinstance(module, nn.MultiheadAttention) and module.in_proj_weight is None:
return _no_equalize()
sink_axes[name] = (module, axis)
# If act_val is enabled, use source or sink weights to determine the activation channel
# For example, if the source is BatchNorm, we need to use the information coming from the sinks
if list_of_act_val is not None:
list_of_sink_axes = [x for x in list(act_sink_axes.values()) if x is not None]
list_of_source_axes = [x for x in list(act_sources_axes.values()) if x is not None]
if len(list_of_sink_axes) > 0:
act_axis = list_of_sink_axes[0]
elif len(list_of_source_axes) > 0:
act_axis = list_of_source_axes[0]
else:
return _no_equalize()
# If there is a mismatch in the activation channel (e.g. a transpose/flatten op in between),
# do not perform equalization
if any([act_axis != axis for axis in list_of_source_axes + list_of_sink_axes]):
return _no_equalize()

# Check if any of the axis is None, which means that the module is not supported.
# In that case, do not perform graph equalization
axes_to_check = [axis for _, axis in list(src_axes.values()) + list(sink_axes.values())]
if None in axes_to_check:
return _no_equalize()

scale_fn = _select_scale_computation_fn(scale_computation_type)
sink_weights = {
name: transpose(m.weight.cpu().to(torch.float32), axis)
for name, (m, axis) in sink_axes.items()}
srcs_range = -1 * torch.ones(region.max_shape_srcs, device='cpu', dtype=torch.float32)
sinks_range = -1 * torch.ones(region.max_shape_sinks, device='cpu', dtype=torch.float32)
for k, v in sink_weights.items():
# Sinks can be partially equalized, thus we need to select
# only the channels we are interested in
indexes = region.sinks[k]
# Compute the range of the channels we need to equalize
weight_range = scale_fn(v.reshape(v.size(0), -1))[indexes.start:indexes.end]
# Compute the numbers of channels we are equalizing
channel_range = indexes.end - indexes.start
# Use the offset and the range to update the correct range in the sinks
sinks_range[indexes.offset:indexes.offset + channel_range] = torch.max(
sinks_range[indexes.offset:indexes.offset + channel_range], weight_range)

# Determine the srcs_range based on where we are performing activation equalization or
# weight equalization
if merge_bias:
src_weights = {
name: _combine_weights_bias(transpose(m.weight, axis), bias_shrinkage,
m.bias).cpu().to(torch.float32)
for name, (m, axis) in src_axes.items()}
else:
src_weights = {
name: transpose(m.weight.cpu().to(torch.float32), axis)
for name, (m, axis) in src_axes.items()}
for k, v in src_weights.items():
# Srcs are always fully equalized, thus we simply need to apply the offset to position them
# correctly with respect to the other srcs matrices.
indexes = region.srcs[k]
channel_start = indexes.offset + indexes.start
channel_end = indexes.offset + indexes.end
weight_range = scale_fn(v.reshape(v.size(0), -1))
srcs_range[channel_start:channel_end] = torch.max(
srcs_range[channel_start:channel_end], weight_range)
if list_of_act_val is not None:
list_of_act_val_shapes = [act_val.shape for act_val in list_of_act_val]
if len(list_of_act_val_shapes) > 0:
shape_0 = list_of_act_val_shapes[0]
if any(shape_0 != shape for shape in list_of_act_val_shapes):
return _no_equalize()
list_of_act_val = list_of_act_val = [
transpose(act_val, act_axis) for act_val in list_of_act_val]
srcs_range_act = scale_fn(
torch.cat([
act_val.reshape(act_val.size(0), -1).cpu().to(torch.float32)
for act_val in list_of_act_val],
1))

if list_of_act_val is not None:
if co_optimize_act_weights and len(src_axes) > 0:
srcs_range = .5 * srcs_range + .5 * srcs_range_act
else:
srcs_range = srcs_range_act

# If there is a mismatch between srcs and sinks values, exit
if srcs_range.shape != sinks_range.shape:
warnings.warn(
"Detected source and sink with non compatible shapes, equalization is skipped")
return _no_equalize()

# Instead of clipping very low values, which would cause their reciprocal to be very large
# thus hindering quantization, we set both sources and sinks to one,
# which is the no-op equivalent for equalization.
channelwise_no_equalize = (sinks_range <= EPSILON) | (srcs_range <= EPSILON)
sinks_range = torch.where(
channelwise_no_equalize, torch.tensor(1., dtype=torch.float32, device='cpu'), sinks_range)
srcs_range = torch.where(
channelwise_no_equalize, torch.tensor(1., dtype=torch.float32, device='cpu'), srcs_range)

srcs_range = torch.pow(srcs_range, alpha)
sinks_range = torch.pow(sinks_range, 1 - alpha)
scaling_factors = srcs_range / sinks_range
inverse_scaling_factors = torch.reciprocal(scaling_factors)

if list_of_act_val is not None and list_of_insert_mul_node_fn is not None:
device = list_of_act_val[0].device
for act_val_shape, insert_mul_node_fn in zip(list_of_act_val_shapes, list_of_insert_mul_node_fn):
insert_mul_node_fn(
inverse_scaling_factors.to(device=device, dtype=dtype), act_val_shape, act_axis)
if len(src_axes) > 0:
for name, (module, axis) in src_axes.items():
module_device = module.weight.device
indexes = region.srcs[name]
channel_start = indexes.offset + indexes.start
channel_end = indexes.offset + indexes.end
partial_inverse_scale = inverse_scaling_factors[channel_start:channel_end].to(
device=module_device, dtype=dtype)
if hasattr(module, 'bias') and module.bias is not None:
_update_weights(
module, module.bias * partial_inverse_scale.view_as(module.bias), attr='bias')
src_broadcast_size = [1] * module.weight.ndim
src_broadcast_size[axis] = module.weight.size(axis)

_update_weights(
module,
module.weight * torch.reshape(partial_inverse_scale, src_broadcast_size),
attr='weight')
for name, (module, axis) in sink_axes.items():
module_device = module.weight.device
sink_broadcast_size = [1] * module.weight.ndim
sink_broadcast_size[axis] = module.weight.size(axis)
indexes = region.sinks[name]
channel_range = indexes.end - indexes.start
partial_scaling = torch.ones(module.weight.size(axis), device='cpu', dtype=dtype)
# We replace the scaling factors of the channels we need to equalize, leaving the other to
# one (i.e., no equalization)
partial_scaling[indexes.start:indexes.end] = scaling_factors[indexes.offset:indexes.offset +
channel_range]
partial_scaling = partial_scaling.to(device=module_device, dtype=dtype)
_update_weights(
module,
module.weight * torch.reshape(partial_scaling, sink_broadcast_size),
attr='weight')

# If a module has `offload_params` attribute, we must offload the weights following that method
for name in (region.srcs_names + region.sinks_names):
module = region.get_module_from_name(name)
if hasattr(module, 'offload_params'):
module.offload_params(module)

return scaling_factors


# Required for being hashable
@dataclass(eq=True, frozen=True)
class WeightBiasWrapper:
weight: torch.Tensor = None
bias: torch.Tensor = None


def _update_weights(original_module, new_value, attr='weight'):
if isinstance(original_module, WeightBiasWrapper):
setattr(getattr(original_module, attr), 'data', new_value)
else:
setattr(original_module, attr, nn.Parameter(new_value))


def _is_supported_module(
graph_model: GraphModule, node: Node, supported_layers: Set = _supported_layers) -> bool:
if node.op == 'call_module':
Expand Down
32 changes: 2 additions & 30 deletions tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,45 +38,17 @@
IN_SIZE_CONV_SMALL = (1, 3, 32, 32)


#TODO: Remove legacy stuff
def equalize_test(model, regions, merge_bias, bias_shrinkage, scale_computation_type):
def equalize_test(regions, merge_bias, bias_shrinkage, scale_computation_type):
scale_factors_regions = []
import copy

from brevitas.graph.equalize import _cross_layer_equalization_legacy
model_copy = copy.deepcopy(model)
for i in range(3):
for region in regions:
scale_factors_region, _ = _cross_layer_equalization(
model,
region,
merge_bias=merge_bias,
bias_shrinkage=bias_shrinkage,
scale_computation_type=scale_computation_type,
fuse_scaling=True)
if i == 0:
scale_factors_regions.append(scale_factors_region)
scale_factors_regions_legacy = []
regions_copy = copy.deepcopy(regions)
for r in regions_copy:
for name in r.name_to_module.keys():
from brevitas.utils.python_utils import recurse_getattr
r.name_to_module[name] = recurse_getattr(model_copy, name)
for i in range(3):
for region in regions_copy:
scale_factors_region_legacy = _cross_layer_equalization_legacy(
region,
merge_bias=merge_bias,
bias_shrinkage=bias_shrinkage,
scale_computation_type=scale_computation_type)
if i == 0:
scale_factors_regions_legacy.append(scale_factors_region_legacy)
# Add asserts
for sfr, sfrl in zip(scale_factors_regions, scale_factors_regions_legacy):
assert torch.allclose(sfr, sfrl, atol=0.0, rtol=0.0)
for name, param_legacy in model_copy.named_parameters():
param = recurse_getattr(model, name)
assert torch.allclose(param, param_legacy, atol=0.0, rtol=0.0)
scale_factors_regions.append(scale_factors_region)
return scale_factors_regions


Expand Down

0 comments on commit 7ba38a5

Please sign in to comment.