From a3dad710de729f072eb473e5ff1275f51f033ef2 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Thu, 7 Dec 2023 15:04:25 +0000 Subject: [PATCH] Feat (GPFA2Q): unify quant_input initialization --- src/brevitas/graph/gpfq.py | 8 ++++++-- src/brevitas/graph/gpxq.py | 8 ++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index 0cb086d01..0c2b85b8b 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -297,6 +297,10 @@ def __init__( self.accumulator_bit_width = accumulator_bit_width def single_layer_update(self): + # raise error in case no quant-input is here + if self.quant_input is None: + raise ValueError( + 'Expected quant input to calculate Upper Bound on L1 norm, but received None') weight = self.layer.weight.data dev = weight.device dtype = weight.dtype @@ -311,8 +315,8 @@ def single_layer_update(self): self.quantized_input = self.quantized_input.to(dev) # get upper bound - input_bit_width = self.layer.quant_input_bit_width() - input_is_signed = self.layer.is_quant_input_signed + input_bit_width = self.quant_input.bit_width + input_is_signed = self.quant_input.signed T = get_upper_bound_on_l1_norm(self.accumulator_bit_width, input_bit_width, input_is_signed) s = self.layer.quant_weight_scale() s = s.view(self.groups, -1) # [Groups, OC/Groups] diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index 1279950a8..f3d246d75 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -182,6 +182,14 @@ def process_input(self, inp): signed=inp.signed, training=inp.training) inp = inp.value + elif self.layer.is_input_quant_enabled: + self.quant_input = QuantTensor( + value=None, + scale=self.layer.quant_input_scale(), + zero_point=self.layer.quant_input_zero_point(), + bit_width=self.layer.quant_input_bit_width(), + signed=self.layer.is_quant_input_signed, + training=self.layer.training) # If input is unbatched, add batch_size = 1 if len(inp.shape) == 1: