Skip to content

Commit

Permalink
Fix (gptq): handling compatibility issues
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert committed Mar 7, 2024
1 parent 4a25f90 commit c06da79
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
6 changes: 5 additions & 1 deletion src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

from copy import deepcopy
import math
from typing import List, Optional, Set
from typing import List, Optional
import warnings

from packaging import version
import torch

try:
Expand All @@ -14,6 +15,7 @@
LinAlgError = RuntimeError
import unfoldNd

from brevitas import torch_version
from brevitas.graph.gpxq import GPxQ
from brevitas.graph.gpxq import gpxq_mode
from brevitas.graph.gpxq import StopFwdException
Expand Down Expand Up @@ -133,6 +135,8 @@ def __init__(
dtype=torch.float32)
self.nsamples = 0

assert torch_version >= version.parse('1.10'), "GPTQ requires torch 1.10 or higher"

def update_batch(self, module, input, current_layer):
if self.disable_pre_forward_hook:
return input
Expand Down
16 changes: 11 additions & 5 deletions tests/brevitas/graph/test_gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,6 @@ def test_toymodels(
if (name == 'gptq' and acc_bit_width < 32):
pytest.skip("GPTQ does not support accumulator-aware quantization.")

if (name == 'gptq' and torch_version <= version.parse('1.9.1')):
pytest.skip(f"GPTQ usage of linalg_cholesky() is not compatible with torch {torch_version}")

if name == 'gpfq':
filter_func = filter_func_dict[filter_func_str]
apply_gpxq = partial(
Expand All @@ -123,8 +120,17 @@ def test_toymodels(
dataset = TensorDataset(inp, inp)
calib_loader = DataLoader(dataset, batch_size=16, num_workers=0, pin_memory=True, shuffle=True)

if (name == 'gpfq') and (acc_bit_width < 32) and (not use_quant_activations or
filter_func_str == 'identity'):
if (name == 'gptq' and torch_version < version.parse('1.10')):
# GPTQ usage of linalg_cholesky() is not compatible with torch 1.9.1 and below
with pytest.raises(AssertionError):
apply_gpxq(
calib_loader=calib_loader,
model=model,
act_order=act_order,
use_quant_activations=use_quant_activations)

elif (name == 'gpfq') and (acc_bit_width < 32) and (not use_quant_activations or
filter_func_str == 'identity'):
# GPFA2Q requires that the quant activations are used. GPFA2Q.single_layer_update will
# raise a ValueError if GPFA2Q.quant_input is None (also see GPxQ.process_input). This will
# happen when `use_quant_activations=False` or when the input to a model is not quantized
Expand Down

0 comments on commit c06da79

Please sign in to comment.