Skip to content

Commit

Permalink
fix for llama
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 18, 2024
1 parent c0a6bd2 commit ed61b9f
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,13 +1356,12 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method=
weight = module.weight.data

if axis == 1:
weight = rot_func(weight, rot_mat, K)
_update_weights(module, rot_func(weight, rot_mat, K), 'weight')
elif axis == 0:
weight = rot_func(weight.t(), rot_mat, K).t()
_update_weights(module, rot_func(weight.t(), rot_mat, K).t(), 'weight')
else:
raise RuntimeError("Not supported yet")

module.weight.data = weight
if hasattr(module, 'offload_params'):
module.offload_params(module)

Expand Down

0 comments on commit ed61b9f

Please sign in to comment.