Skip to content

Commit

Permalink
no upcast
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 12, 2025
1 parent 51a8c38 commit 361e664
Showing 1 changed file with 4 additions and 11 deletions.
15 changes: 4 additions & 11 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,15 +1332,11 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method=
module.allocate_params(module)
axis = _get_output_axis(module)
weight = module.weight.data
original_dtype = next(module.parameters()).dtype
if axis == 0:
rotated_weight = rot_func(
weight.t().to(torch.float32), rot_mat.to(torch.float32),
K).t().to(original_dtype)
rotated_weight = rot_func(weight.t(), rot_mat, K).t()
_update_weights(module, rotated_weight, 'weight')
elif axis == 1:
rotated_weight = rot_func(weight.to(torch.float32), rot_mat.to(torch.float32),
K).to(original_dtype)
rotated_weight = rot_func(weight, rot_mat, K)
_update_weights(module, rotated_weight, 'weight')
else:
raise RuntimeError("Not supported yet")
Expand All @@ -1360,13 +1356,10 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method=
weight = module.weight.data
original_dtype = next(module.parameters()).dtype
if axis == 1:
rotated_weight = rot_func(weight.to(torch.float32), rot_mat.to(torch.float32),
K).to(original_dtype)
rotated_weight = rot_func(weight, rot_mat, K)
_update_weights(module, rotated_weight, 'weight')
elif axis == 0:
rotated_weight = rot_func(
weight.t().to(torch.float32), rot_mat.to(torch.float32),
K).t().to(original_dtype)
rotated_weight = rot_func(weight.t(), rot_mat, K).t()
_update_weights(module, rotated_weight, 'weight')
else:
raise RuntimeError("Not supported yet")
Expand Down

0 comments on commit 361e664

Please sign in to comment.