From ec7c214328cc91a63b78c545fe807b66e40dde94 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Mon, 13 Nov 2023 14:50:18 +0000 Subject: [PATCH 1/4] changes class_implementation to init_class in gpxq_mode --- src/brevitas/graph/gpfq.py | 24 +++++++++++++++++++----- src/brevitas/graph/gptq.py | 24 +++++++++++++++++++----- src/brevitas/graph/gpxq.py | 2 +- 3 files changed, 39 insertions(+), 11 deletions(-) diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index 01dc11a82..d0c25b5ff 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -59,8 +59,7 @@ def __init__( self.orig_forward = self.model.forward self.model.forward = self.catch_stopfwd - self.class_implementation = GPFQ - GPFQ.p = p + self.p = p def catch_stopfwd(self, *args, **kwargs): # Collect quant input @@ -95,14 +94,29 @@ def catch_stopfwd(self, *args, **kwargs): gpxq_class.disable_pre_forward_hook = False return out + def init_class(self, layer, name, act_order, parallel_layers, create_weight_orig): + return GPFQ( + layer=layer, + name=name, + act_order=act_order, + parallel_layers=parallel_layers, + create_weight_orig=create_weight_orig, + p=self.p) + class GPFQ(GPxQ): """ Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main """ - p = 0.25 - def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig=True) -> None: + def __init__( + self, + layer, + name, + act_order, + parallel_layers=1, + create_weight_orig=True, + p=0.25) -> None: if act_order: raise ValueError("Act_order is not supported in GPFQ") @@ -111,7 +125,7 @@ def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig self.float_input = None self.quantized_input = None self.index_computed = False - self.p = GPFQ.p + self.p = p def update_batch(self, module, input, current_layer): if self.disable_pre_forward_hook: diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index b224e0a37..a925bb308 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -67,8 +67,6 @@ def __init__( self.model.forward = self.catch_stopfwd # How many subblock to use during GPTQ for each layer self.num_blocks = num_blocks - self.class_implementation = GPTQ - GPTQ.num_blocks = num_blocks def catch_stopfwd(self, *args, **kwargs): try: @@ -85,6 +83,15 @@ def catch_stopfwd(self, *args, **kwargs): gpxq_class.disable_pre_forward_hook = False return out + def init_class(self, layer, name, act_order, parallel_layers, create_weight_orig): + return GPTQ( + layer=layer, + name=name, + act_order=act_order, + parallel_layers=parallel_layers, + create_weight_orig=create_weight_orig, + num_blocks=self.num_blocks) + class GPTQ(GPxQ): """ @@ -104,15 +111,22 @@ class GPTQ(GPxQ): See the License for the specific language governing permissions and limitations under the License. """ - num_blocks = 100 - def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig=True) -> None: + def __init__( + self, + layer, + name, + act_order, + parallel_layers=1, + create_weight_orig=True, + num_blocks=100) -> None: super().__init__(layer, name, act_order, parallel_layers, create_weight_orig) dev = self.layer.weight.device # Define how many columns to update in each mini-block - self.blocksize = math.ceil(self.columns / GPTQ.num_blocks) + self.blocksize = math.ceil(self.columns / num_blocks) + self.num_blocks = num_blocks # Initialize Hessian matrix and counter. We need it in float32 to compute the inverse self.H = torch.zeros((self.groups, self.columns, self.columns), diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index b13c46683..453a79bb6 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -98,7 +98,7 @@ def __enter__(self): # Attach hooks for GPTQ if self._is_module_supported(module): - gpxq = self.class_implementation( + gpxq = self.init_class( module, name, act_order=self.act_order, From 59721b4208a280ebf643bdde0139f9588f27131e Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Mon, 13 Nov 2023 15:47:39 +0000 Subject: [PATCH 2/4] changes names to _module_optimizer --- src/brevitas/graph/gpfq.py | 3 ++- src/brevitas/graph/gptq.py | 4 ++-- src/brevitas/graph/gpxq.py | 7 ++++--- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index d0c25b5ff..53f9ea80b 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -94,7 +94,8 @@ def catch_stopfwd(self, *args, **kwargs): gpxq_class.disable_pre_forward_hook = False return out - def init_class(self, layer, name, act_order, parallel_layers, create_weight_orig): + def initialize_module_optimizer( + self, layer, name, act_order, parallel_layers, create_weight_orig): return GPFQ( layer=layer, name=name, diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index a925bb308..8078748d5 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -83,7 +83,8 @@ def catch_stopfwd(self, *args, **kwargs): gpxq_class.disable_pre_forward_hook = False return out - def init_class(self, layer, name, act_order, parallel_layers, create_weight_orig): + def initialize_module_optimizer( + self, layer, name, act_order, parallel_layers, create_weight_orig): return GPTQ( layer=layer, name=name, @@ -126,7 +127,6 @@ def __init__( # Define how many columns to update in each mini-block self.blocksize = math.ceil(self.columns / num_blocks) - self.num_blocks = num_blocks # Initialize Hessian matrix and counter. We need it in float32 to compute the inverse self.H = torch.zeros((self.groups, self.columns, self.columns), diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index 453a79bb6..35fdbe984 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -98,15 +98,16 @@ def __enter__(self): # Attach hooks for GPTQ if self._is_module_supported(module): - gpxq = self.init_class( + gpxq_module_optimizer = self.initialize_module_optimizer( module, name, act_order=self.act_order, parallel_layers=parallel_layers, create_weight_orig=self.create_weight_orig) - hook_fn = partial(gpxq.update_batch, current_layer=self.current_layer) + hook_fn = partial( + gpxq_module_optimizer.update_batch, current_layer=self.current_layer) self.hook_dict[name] = module.register_forward_pre_hook(hook_fn) - self.gpxq_layers[name] = gpxq + self.gpxq_layers[name] = gpxq_module_optimizer if not self.use_quant_activations: self.disable_quant_inference.disable_act_quantization( self.model, is_training=self.model.training) From 57bcc6740921272bcaea6c02909d8eba7a98e39e Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Mon, 13 Nov 2023 16:15:59 +0000 Subject: [PATCH 3/4] changes parallel_layers to len(parallel_layers) as only the length is used --- src/brevitas/graph/gpfq.py | 2 +- src/brevitas/graph/gptq.py | 2 +- src/brevitas/graph/gpxq.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index 53f9ea80b..a04f579bc 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -203,7 +203,7 @@ def update_batch(self, module, input, current_layer): # we executed. Once we executed as many as the number of parallel_layers, we raise # StopFwdException current_layer.forward_count += 1 - if current_layer.forward_count == len(self.parallel_layers): + if current_layer.forward_count == self.parallel_layers: current_layer.forward_count = 0 raise StopFwdException diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 8078748d5..653d60ccd 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -184,7 +184,7 @@ def update_batch(self, module, input, current_layer): # we executed. Once we executed as many as the number of parallel_layers, we raise # StopFwdException current_layer.forward_count += 1 - if current_layer.forward_count == len(self.parallel_layers): + if current_layer.forward_count == self.parallel_layers: current_layer.forward_count = 0 raise StopFwdException diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index 35fdbe984..7d7bb030a 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -102,7 +102,7 @@ def __enter__(self): module, name, act_order=self.act_order, - parallel_layers=parallel_layers, + parallel_layers=len(parallel_layers), create_weight_orig=self.create_weight_orig) hook_fn = partial( gpxq_module_optimizer.update_batch, current_layer=self.current_layer) From b04c68096047e4fc4949cced8136cbc54c365b6a Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Mon, 13 Nov 2023 16:56:33 +0000 Subject: [PATCH 4/4] renaming parallel_layers to len_parallel_layers --- src/brevitas/graph/gpfq.py | 10 +++++----- src/brevitas/graph/gptq.py | 10 +++++----- src/brevitas/graph/gpxq.py | 7 ++++--- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index a04f579bc..cf23d18b0 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -95,12 +95,12 @@ def catch_stopfwd(self, *args, **kwargs): return out def initialize_module_optimizer( - self, layer, name, act_order, parallel_layers, create_weight_orig): + self, layer, name, act_order, len_parallel_layers, create_weight_orig): return GPFQ( layer=layer, name=name, act_order=act_order, - parallel_layers=parallel_layers, + len_parallel_layers=len_parallel_layers, create_weight_orig=create_weight_orig, p=self.p) @@ -115,14 +115,14 @@ def __init__( layer, name, act_order, - parallel_layers=1, + len_parallel_layers=1, create_weight_orig=True, p=0.25) -> None: if act_order: raise ValueError("Act_order is not supported in GPFQ") - super().__init__(layer, name, act_order, parallel_layers, create_weight_orig) + super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig) self.float_input = None self.quantized_input = None self.index_computed = False @@ -203,7 +203,7 @@ def update_batch(self, module, input, current_layer): # we executed. Once we executed as many as the number of parallel_layers, we raise # StopFwdException current_layer.forward_count += 1 - if current_layer.forward_count == self.parallel_layers: + if current_layer.forward_count == self.len_parallel_layers: current_layer.forward_count = 0 raise StopFwdException diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 653d60ccd..b10943f1b 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -84,12 +84,12 @@ def catch_stopfwd(self, *args, **kwargs): return out def initialize_module_optimizer( - self, layer, name, act_order, parallel_layers, create_weight_orig): + self, layer, name, act_order, len_parallel_layers, create_weight_orig): return GPTQ( layer=layer, name=name, act_order=act_order, - parallel_layers=parallel_layers, + len_parallel_layers=len_parallel_layers, create_weight_orig=create_weight_orig, num_blocks=self.num_blocks) @@ -118,10 +118,10 @@ def __init__( layer, name, act_order, - parallel_layers=1, + len_parallel_layers=1, create_weight_orig=True, num_blocks=100) -> None: - super().__init__(layer, name, act_order, parallel_layers, create_weight_orig) + super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig) dev = self.layer.weight.device @@ -184,7 +184,7 @@ def update_batch(self, module, input, current_layer): # we executed. Once we executed as many as the number of parallel_layers, we raise # StopFwdException current_layer.forward_count += 1 - if current_layer.forward_count == self.parallel_layers: + if current_layer.forward_count == self.len_parallel_layers: current_layer.forward_count = 0 raise StopFwdException diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index 7d7bb030a..1279950a8 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -102,7 +102,7 @@ def __enter__(self): module, name, act_order=self.act_order, - parallel_layers=len(parallel_layers), + len_parallel_layers=len(parallel_layers), create_weight_orig=self.create_weight_orig) hook_fn = partial( gpxq_module_optimizer.update_batch, current_layer=self.current_layer) @@ -138,7 +138,8 @@ def catch_stopfwd(self, *args, **kwargs): class GPxQ(ABC): - def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig=True) -> None: + def __init__( + self, layer, name, act_order, len_parallel_layers=1, create_weight_orig=True) -> None: self.layer = layer self.name = name self.act_order = act_order @@ -160,7 +161,7 @@ def __init__(self, layer, name, act_order, parallel_layers=1, create_weight_orig self.rows = weight.shape[0] # Number of columns is equal to the input channels (IC) self.columns = weight.shape[1] - self.parallel_layers = parallel_layers + self.len_parallel_layers = len_parallel_layers self.disable_pre_forward_hook = False # Some layers require knowledge from quant inputs to compute quant weights