Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 13, 2025
1 parent a32a38b commit 7a320ff
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,12 +1334,13 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method=
weight = module.weight.data

if axis == 0:
weight = rot_func(weight.t(), rot_mat, K).t()
rotated_weight = rot_func(weight.t(), rot_mat, K).t()
_update_weights(module, rotated_weight, 'weight')
elif axis == 1:
weight = rot_func(weight, rot_mat, K)
rotated_weight = rot_func(weight, rot_mat, K)
_update_weights(module, rotated_weight, 'weight')
else:
raise RuntimeError("Not supported yet")
module.weight.data = weight

if getattr(module, 'bias', None) is not None:
bias = module.bias.data
Expand Down Expand Up @@ -1518,7 +1519,7 @@ def apply(self,
eq_layers = set()
orphan_regions = []
self.find_module(graph_model, orphan_regions)
if self.rotate_sdpa:
if self.sdpa_regions:
sdpa_regions = self.rotate_sdpa(graph_model)
regions.extend(sdpa_regions)
for r in regions:
Expand Down

0 comments on commit 7a320ff

Please sign in to comment.