Skip to content

Commit

Permalink
Layerwise fix
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Jan 10, 2025
1 parent 8b71a53 commit 7d27050
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def apply(self, model: GraphModule) -> GraphModule:
weight = RotationWeightParametrization(
self.rot_mat, self.rot_func, self.axis, self.K)(weight)
# Modify the weights in-place
getattr(old_module, self.tensor_name).data = weight
setattr(old_module, self.tensor_name, torch.nn.Parameter(weight))

if hasattr(old_module, 'offload_params'):
old_module.offload_params(old_module)
Expand Down

0 comments on commit 7d27050

Please sign in to comment.