From 7d27050eede1138c7d5ca1cae51a5bbd00925e5c Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Fri, 10 Jan 2025 17:35:35 +0000 Subject: [PATCH] Layerwise fix --- src/brevitas/graph/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index c3c4b82e9..4954e1c10 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -251,7 +251,7 @@ def apply(self, model: GraphModule) -> GraphModule: weight = RotationWeightParametrization( self.rot_mat, self.rot_func, self.axis, self.K)(weight) # Modify the weights in-place - getattr(old_module, self.tensor_name).data = weight + setattr(old_module, self.tensor_name, torch.nn.Parameter(weight)) if hasattr(old_module, 'offload_params'): old_module.offload_params(old_module)