From 2f24dff6f4acd71603ea58be0b3ad767c00af0b0 Mon Sep 17 00:00:00 2001 From: i-colbert Date: Wed, 4 Dec 2024 18:26:17 +0000 Subject: [PATCH] Fix (gptq): stabilize Cholesky decomposition --- src/brevitas/graph/gptq.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 6b1945dcc..80026d814 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -121,7 +121,7 @@ def update_batch(self, module, input, current_layer): current_layer.forward_count = 0 raise StopFwdException - def single_layer_update(self, percdamp=.01): + def single_layer_update(self, percdamp=.01, c=1e3): assert not self.layer.weight_quant.requires_quant_input, "Error: GPTQ does not support weight quantizers that require quantized inputs." if hasattr(self.layer, 'allocate_params'): self.layer.allocate_params(self.layer) @@ -174,7 +174,8 @@ def single_layer_update(self, percdamp=.01): self.H[i, diag, diag] += damp self.H[i, :, :] = torch.linalg.cholesky(self.H[i, :, :]) self.H[i, :, :] = torch.cholesky_inverse(self.H[i, :, :]) - self.H[i, :, :] = torch.linalg.cholesky(self.H[i, :, :], upper=True) + # stabilizing the Cholesky decomposition with a fairly large constant, c + self.H[i, :, :] = torch.linalg.cholesky(self.H[i, :, :] * (c ** 2), upper=True) / c h_inv = self.H except LinAlgError as e: warnings.warn(