Skip to content

Commit

Permalink
Attempt revert
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 13, 2025
1 parent 361e664 commit 0471543
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,12 +1332,11 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method=
module.allocate_params(module)
axis = _get_output_axis(module)
weight = module.weight.data

if axis == 0:
rotated_weight = rot_func(weight.t(), rot_mat, K).t()
_update_weights(module, rotated_weight, 'weight')
weight = rot_func(weight.t(), rot_mat, K).t()
elif axis == 1:
rotated_weight = rot_func(weight, rot_mat, K)
_update_weights(module, rotated_weight, 'weight')
weight = rot_func(weight, rot_mat, K)
else:
raise RuntimeError("Not supported yet")

Expand All @@ -1354,7 +1353,7 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method=
module.allocate_params(module)
axis = _get_input_axis(module)
weight = module.weight.data
original_dtype = next(module.parameters()).dtype

if axis == 1:
rotated_weight = rot_func(weight, rot_mat, K)
_update_weights(module, rotated_weight, 'weight')
Expand Down

0 comments on commit 0471543

Please sign in to comment.