diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 3478ce2a8..b304fd22f 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1384,7 +1384,7 @@ def _apply_rotate( # Only "weight" is rotated tensor_names_axis = [("weight", _get_input_axis(module))] # If rotations are fused or if the module is an orphan sink, transform is applied directly onto the tensor - rewriter_class = ModuleInstanceRegisterParametrization if insert_rotation_module or fuse_rotations else ModuleInstanceRegisterParametrization + rewriter_class = ModuleInstanceTransformTensor if insert_rotation_module or fuse_rotations else ModuleInstanceRegisterParametrization # Obtain rewriters for applying the rotations for tensor_name, axis in tensor_names_axis: rewriter = rewriter_class( diff --git a/tests/brevitas/graph/test_transforms.py b/tests/brevitas/graph/test_transforms.py index c56294d9a..2d5c7a78f 100644 --- a/tests/brevitas/graph/test_transforms.py +++ b/tests/brevitas/graph/test_transforms.py @@ -359,7 +359,8 @@ def forward(self, x): model_unfused.linear.weight.data = model_fused.linear.weight.data model_fused = ModuleInstanceTransformTensor( - model_fused.linear, "weight", rot_mat, rot_func, None, axis).apply(model_fused) + model_fused.linear, "weight", RotationWeightParametrization(rot_mat, rot_func, axis, + None)).apply(model_fused) model_unfused = ModuleInstanceRegisterParametrization( model_unfused.linear, "weight",