diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 6b1945dcc..80026d814 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -121,7 +121,7 @@ def update_batch(self, module, input, current_layer): current_layer.forward_count = 0 raise StopFwdException - def single_layer_update(self, percdamp=.01): + def single_layer_update(self, percdamp=.01, c=1e3): assert not self.layer.weight_quant.requires_quant_input, "Error: GPTQ does not support weight quantizers that require quantized inputs." if hasattr(self.layer, 'allocate_params'): self.layer.allocate_params(self.layer) @@ -174,7 +174,8 @@ def single_layer_update(self, percdamp=.01): self.H[i, diag, diag] += damp self.H[i, :, :] = torch.linalg.cholesky(self.H[i, :, :]) self.H[i, :, :] = torch.cholesky_inverse(self.H[i, :, :]) - self.H[i, :, :] = torch.linalg.cholesky(self.H[i, :, :], upper=True) + # stabilizing the Cholesky decomposition with a fairly large constant, c + self.H[i, :, :] = torch.linalg.cholesky(self.H[i, :, :] * (c ** 2), upper=True) / c h_inv = self.H except LinAlgError as e: warnings.warn(