Skip to content

Commit

Permalink
Tests with unfused parametrizations
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Feb 5, 2025
1 parent 7ba38a5 commit deab1b8
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 11 deletions.
6 changes: 3 additions & 3 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1526,14 +1526,14 @@ def _untie_parameters_with_parametrizations(model: torch.nn.Module):
return model


def fuse_parametrized_rotations(model: nn.Module) -> nn.Module:
def fuse_parametrizations(model: nn.Module) -> nn.Module:
# First of all, parameters that have parametrizations need to be untied
model = _untie_parameters_with_parametrizations(model)
# Then, parametrizations can be safely removed
for module in model.modules():
if parametrize.is_parametrized(module):
# Names of the tensors that can potentially be parametrized
tensor_names = ["weight", "bias"]
tensor_names = ["weight", "in_proj_weight", "bias"]
# Remove parametrizations from each tensor
for tensor_name in tensor_names:
if parametrize.is_parametrized(module) and tensor_name in module.parametrizations:
Expand Down Expand Up @@ -1634,7 +1634,7 @@ def __init__(
# Therefore, algorithms that do type checking might need to use type_before_parametrizations(module),
# instead of only type(module) (see layerwise_layer_handler). Algorithms that rely on in-place modifications
# of the weights should not operate on parametrized modules. In this situation, parametrizations
# need to be removed beforehand by invoking fuse_parametrized_rotations
# need to be removed beforehand by invoking fuse_parametrizations
warnings.warn(
"Using parametrized results might break type-checking, which could lead to unexpected behaviour."
)
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager
from brevitas.graph import load_quant_model_mode
from brevitas.graph.base import ModuleInstanceTransformTensor
from brevitas.graph.equalize import fuse_parametrized_rotations
from brevitas.graph.equalize import fuse_parametrizations
from brevitas.graph.equalize import GraphRotationEqualization
from brevitas.graph.equalize import LayerwiseActivationRotation
from brevitas.graph.quantize import functional_quantization_mode
Expand Down Expand Up @@ -501,7 +501,7 @@ def quantize_llm(args, extra_args=None):
# Offload model before fusing the rotations
model = offload_model(model)
# Fuse rotations with weights
model = fuse_parametrized_rotations(model)
model = fuse_parametrizations(model)

if args.act_calibration and not args.load_checkpoint:
print("Apply act calibration...")
Expand Down
7 changes: 5 additions & 2 deletions tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,18 @@
IN_SIZE_CONV_SMALL = (1, 3, 32, 32)


def equalize_test(regions, merge_bias, bias_shrinkage, scale_computation_type):
def equalize_test(
model, regions, merge_bias, bias_shrinkage, scale_computation_type, fuse_scaling=True):
scale_factors_regions = []
for i in range(3):
for region in regions:
scale_factors_region, _ = _cross_layer_equalization(
model,
region,
merge_bias=merge_bias,
bias_shrinkage=bias_shrinkage,
scale_computation_type=scale_computation_type)
scale_computation_type=scale_computation_type,
fuse_scaling=fuse_scaling)
if i == 0:
scale_factors_regions.append(scale_factors_region)
return scale_factors_regions
Expand Down
57 changes: 54 additions & 3 deletions tests/brevitas/graph/test_equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from brevitas.graph.equalize import _supported_layers
from brevitas.graph.equalize import activation_equalization_mode
from brevitas.graph.equalize import EqualizationIndexes
from brevitas.graph.equalize import fuse_parametrized_rotations
from brevitas.graph.equalize import fuse_parametrizations
from brevitas.graph.equalize import GraphRotationEqualization
from brevitas.graph.equalize import MergeLnAffine
from brevitas.graph.equalize import random_orthogonal_matrix
Expand Down Expand Up @@ -145,6 +145,57 @@ def test_equalization_torchvision_models(model_coverage: tuple, merge_bias: bool
assert all([shape != () for shape in shape_scale_regions])


@requires_pt_ge('2.3.1')
@pytest_cases.parametrize("merge_bias", [True, False])
def test_equalization_torchvision_models_unfused(model_coverage: tuple, merge_bias: bool):
model, _ = model_coverage

torch.manual_seed(SEED)
model.eval()
model = symbolic_trace(model)
model = TorchFunctionalToModule().apply(model)

supported_sinks = list(_supported_layers)
supported_sinks = tuple([
x for x in _supported_layers if x not in (torch.nn.LayerNorm, *_batch_norm)])
regions = _extract_regions(model, state_impl_kwargs={'supported_sinks': supported_sinks})
# Instantiate model with unfused scales
model_unfused = copy.deepcopy(model)
# Copy regions and ensure that they point to the modules in the copied model
regions_unfused = copy.deepcopy(regions)
for r in regions_unfused:
for m_name in r.name_to_module:
r.name_to_module[m_name] = recurse_getattr(model_unfused, m_name)
# Equalize original model
scale_factor_regions = equalize_test(
model,
regions,
merge_bias=merge_bias,
bias_shrinkage='vaiq',
scale_computation_type='maxabs')
# Equalized copied model
scale_factor_regions_unfused = equalize_test(
model_unfused,
regions_unfused,
merge_bias=merge_bias,
bias_shrinkage='vaiq',
scale_computation_type='maxabs',
fuse_scaling=False)
# Ensure that scale factors match
for scale_factor, scale_factor_unfused in zip(scale_factor_regions, scale_factor_regions_unfused):
assert torch.allclose(scale_factor, scale_factor_unfused, atol=0.0, rtol=0.0)
# Ensure that parameters match
for name, param in model.named_parameters():
param_unfused = recurse_getattr(model_unfused, name)
assert torch.allclose(param, param_unfused, atol=0.0, rtol=0.0)
# Fuse parametrizations and make sure that weights keep matching
model_unfused = fuse_parametrizations(model_unfused)
assert all([not parametrize.is_parametrized(m) for m in model_unfused.modules()])
for name, param in model.named_parameters():
param_unfused = recurse_getattr(model_unfused, name)
assert torch.allclose(param, param_unfused, atol=0.0, rtol=0.0)


@pytest_cases.parametrize("merge_bias", [True, False])
def test_models(toy_model, merge_bias, request):
test_id = request.node.callspec.id
Expand Down Expand Up @@ -528,7 +579,7 @@ def test_apply_rotate(rotation_model, mask, full_rotation_method, device, fuse_r
isinstance(module, RotatedModule) for module in rotated_model.modules()])
# Optionally fuse the rotations
if fuse_rotations:
rotated_model_unfused = fuse_parametrized_rotations(rotated_model_unfused)
rotated_model_unfused = fuse_parametrizations(rotated_model_unfused)
# Verify that no parametrizations remain after fusing
for module in rotated_model_unfused.modules():
assert not parametrize.is_parametrized(module)
Expand Down Expand Up @@ -582,7 +633,7 @@ def test_fuse_parametrized_modules(kwargs):
with torch.no_grad():
output = qmodel(sample_input)
# Fuse parametrizations
qmodel = fuse_parametrized_rotations(qmodel)
qmodel = fuse_parametrizations(qmodel)
# Verify that scales were not lost
module = recurse_getattr(model, key)
assert module.weight_quant.tensor_quant.scaling_impl.init_done
Expand Down
2 changes: 1 addition & 1 deletion tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ def test_small_models_rotation_optimization_layer_count(
# with non-optimized quantized perplexities
caplog.set_level(logging.INFO)
args, extra_args, _, _, exp_layer_types_count = rotation_optimization_args_layer_count_and_ppl
with patch('brevitas_examples.llm.main.fuse_parametrized_rotations', lambda model: model):
with patch('brevitas_examples.llm.main.fuse_parametrizations', lambda model: model):
_, _, model = validate_args_and_run_main(args, extra_args)
assert_layer_types_count(model, exp_layer_types_count)

Expand Down

0 comments on commit deab1b8

Please sign in to comment.