Skip to content

Commit

Permalink
Minor change to comment
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Jan 20, 2025
1 parent 033e443 commit c6614a0
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 24 deletions.
20 changes: 15 additions & 5 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,6 +1531,7 @@ def __init__(
orphan_sink: bool = False,
sdpa_regions: bool = False,
rotate_matmul: bool = False,
fuse_rotations: bool = True,
full_rotation_method: str = 'had',
return_rewriters: bool = False) -> None:
super(GraphRotationEqualization, self).__init__()
Expand All @@ -1548,6 +1549,17 @@ def __init__(
self.full_rotation_method = full_rotation_method
self.return_rewriters = return_rewriters
self.sdpa_regions = sdpa_regions
if not fuse_rotations:
# NOTE: When fuse_rotations=False, parametrized rotations are applied. This changes the attribute __class__
# of the parametrized module, e.g. to"<class 'torch.nn.utils.parametrize.ParametrizedLinear'>".
# Therefore, algorithms that do type checking might need to use type_before_parametrizations(module),
# instead of only type(module) (see layerwise_layer_handler). Algorithms that rely on in-place modifications
# of the weights should not operate on parametrized modules. In this situation, parametrizations
# need to be removed beforehand by invoking fuse_parametrized_rotations
warnings.warn(
"Using parametrized results might break type-checking, which could lead to unexpected behaviour."
)
self.fuse_rotations = fuse_rotations

def rotate_matmuls(self, graph_module):
matmul_nodes = list(graph_module.graph.nodes)
Expand Down Expand Up @@ -1608,10 +1620,8 @@ def find_sink(node):
m.pre_process_k = functional_rotate_input
return regions

def apply(
self,
graph_model: GraphModule,
fuse_rotations: bool = True) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]:
def apply(self,
graph_model: GraphModule) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]:
rewriters = []
regions = _extract_regions(
graph_model,
Expand Down Expand Up @@ -1640,7 +1650,7 @@ def apply(
self.rotate_matmuls(graph_model)
if len(regions) > 0:
rewriters = _apply_rotate(
graph_model, regions, self.full_rotation_method, fuse_rotations=fuse_rotations)
graph_model, regions, self.full_rotation_method, fuse_rotations=self.fuse_rotations)
if self.return_rewriters:
return graph_model, rewriters
else:
Expand Down
22 changes: 3 additions & 19 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,25 +104,9 @@ def fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations: bool =
orphan_sink=args.rotation_orphan_sink,
full_rotation_method=args.rotation_mode,
return_rewriters=True,
sdpa_regions=args.rotation_sdpa_regions)
if not fuse_rotations:
# NOTE: When fuse_rotations=False, parametrized rotations are applied, i.e. the weights of
# selected modules stop being attributes but, instead, properties, and their value is
# computed by passing the original value of the tensor through the forward passes of the
# parametrization modules. Parametrizations are registered using
# torch.nn.utils.parametrize.register_parametrization, which modifies the __class__
# attribute of the parametrized module, e.g. "<class 'torch.nn.modules.linear.Linear'>"
# changes to "<class 'torch.nn.utils.parametrize.ParametrizedLinear'>". Therefore,
# algorithms that do type checking might need to use type_before_parametrizations(module),
# instead of only type(module) (see layerwise_layer_handler). Moreover, if, for instance,
# the "weight" attribute is parametrized, it will be removed from the attributes
# of the class. Consequently, quantization algorithms that rely on in-place modifications
# of the weights should not operate on parametrized modules. In this situation, parametrizations
# need to be removed beforehand by invoking fuse_parametrized_rotations
warn(
"Using parametrized results might break type-checking, which could lead to unexpected behaviour."
)
new_model, rewriters = eq.apply(new_model, fuse_rotations=fuse_rotations)
sdpa_regions=args.rotation_sdpa_regions,
fuse_rotations=fuse_rotations)
new_model, rewriters = eq.apply(new_model)
rewriters = fix_rewriter(rewriters, model, 'weight')
for r in rewriters:
# The weights between model and new_model are tied, so this check prevents
Expand Down

0 comments on commit c6614a0

Please sign in to comment.