From e04763b757c53fc3686a22cc6097c15960c9bc3b Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 20 Jan 2025 18:59:20 +0000 Subject: [PATCH] Minor change to comment --- src/brevitas/graph/equalize.py | 20 +++++++++++++++----- src/brevitas_examples/llm/main.py | 22 +++------------------- 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 097441279..1519d7ba5 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1531,6 +1531,7 @@ def __init__( orphan_sink: bool = False, sdpa_regions: bool = False, rotate_matmul: bool = False, + fuse_rotations: bool = True, full_rotation_method: str = 'had', return_rewriters: bool = False) -> None: super(GraphRotationEqualization, self).__init__() @@ -1548,6 +1549,17 @@ def __init__( self.full_rotation_method = full_rotation_method self.return_rewriters = return_rewriters self.sdpa_regions = sdpa_regions + if not fuse_rotations: + # NOTE: When fuse_rotations=False, parametrized rotations are applied. This changes the attribute __class__ + # of the parametrized module, e.g. to"". + # Therefore, algorithms that do type checking might need to use type_before_parametrizations(module), + # instead of only type(module) (see layerwise_layer_handler). Algorithms that rely on in-place modifications + # of the weights should not operate on parametrized modules. In this situation, parametrizations + # need to be removed beforehand by invoking fuse_parametrized_rotations + warnings.warn( + "Using parametrized results might break type-checking, which could lead to unexpected behaviour." + ) + self.fuse_rotations = fuse_rotations def rotate_matmuls(self, graph_module): matmul_nodes = list(graph_module.graph.nodes) @@ -1608,10 +1620,8 @@ def find_sink(node): m.pre_process_k = functional_rotate_input return regions - def apply( - self, - graph_model: GraphModule, - fuse_rotations: bool = True) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: + def apply(self, + graph_model: GraphModule) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: rewriters = [] regions = _extract_regions( graph_model, @@ -1640,7 +1650,7 @@ def apply( self.rotate_matmuls(graph_model) if len(regions) > 0: rewriters = _apply_rotate( - graph_model, regions, self.full_rotation_method, fuse_rotations=fuse_rotations) + graph_model, regions, self.full_rotation_method, fuse_rotations=self.fuse_rotations) if self.return_rewriters: return graph_model, rewriters else: diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index b4235e283..d7f49e1a3 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -104,25 +104,9 @@ def fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations: bool = orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode, return_rewriters=True, - sdpa_regions=args.rotation_sdpa_regions) - if not fuse_rotations: - # NOTE: When fuse_rotations=False, parametrized rotations are applied, i.e. the weights of - # selected modules stop being attributes but, instead, properties, and their value is - # computed by passing the original value of the tensor through the forward passes of the - # parametrization modules. Parametrizations are registered using - # torch.nn.utils.parametrize.register_parametrization, which modifies the __class__ - # attribute of the parametrized module, e.g. "" - # changes to "". Therefore, - # algorithms that do type checking might need to use type_before_parametrizations(module), - # instead of only type(module) (see layerwise_layer_handler). Moreover, if, for instance, - # the "weight" attribute is parametrized, it will be removed from the attributes - # of the class. Consequently, quantization algorithms that rely on in-place modifications - # of the weights should not operate on parametrized modules. In this situation, parametrizations - # need to be removed beforehand by invoking fuse_parametrized_rotations - warn( - "Using parametrized results might break type-checking, which could lead to unexpected behaviour." - ) - new_model, rewriters = eq.apply(new_model, fuse_rotations=fuse_rotations) + sdpa_regions=args.rotation_sdpa_regions, + fuse_rotations=fuse_rotations) + new_model, rewriters = eq.apply(new_model) rewriters = fix_rewriter(rewriters, model, 'weight') for r in rewriters: # The weights between model and new_model are tied, so this check prevents