From 36727ae72def8a119d3f53d969801261f83d32f1 Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Thu, 16 Nov 2023 15:52:24 +0000 Subject: [PATCH] Feat (GPFA2Q): unify upper bound method, fix act_order and cleanup --- src/brevitas/core/scaling/pre_scaling.py | 23 +----- src/brevitas/graph/gpfq.py | 71 ++++++++++--------- src/brevitas/nn/utils.py | 14 ++++ .../imagenet_classification/ptq/ptq_common.py | 2 +- .../ptq/ptq_evaluate.py | 4 +- 5 files changed, 56 insertions(+), 58 deletions(-) diff --git a/src/brevitas/core/scaling/pre_scaling.py b/src/brevitas/core/scaling/pre_scaling.py index 632242507..6f9a1905a 100644 --- a/src/brevitas/core/scaling/pre_scaling.py +++ b/src/brevitas/core/scaling/pre_scaling.py @@ -14,6 +14,7 @@ from brevitas.core.stats import SCALAR_SHAPE from brevitas.core.stats.stats_wrapper import _Stats from brevitas.function import abs_binary_sign_grad +from brevitas.nn.utils import get_upper_bound_on_l1_norm __all__ = ["ParameterPreScalingWeightNorm", "AccumulatorAwareParameterPreScaling"] @@ -170,25 +171,6 @@ def __init__( ) self.accumulator_bit_width = accumulator_bit_width_impl - @brevitas.jit.script_method - def get_upper_bound_on_l1_norm(self, input_bit_width: Tensor, input_is_signed: bool) -> Tensor: - """Calculate the upper bound on the l1-norm of the weights using the derivations from - `Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow Avoidance` - by I.Colbert, A.Pappalardo, and J.Petri-Koenig.""" - assert input_bit_width is not None, "A2Q relies on input bit-width." - assert input_is_signed is not None, "A2Q relies on input sign." - input_is_signed = float(input_is_signed) # 1. if signed else 0. - # This is the minimum of the two maximum magnitudes that P could take, which are -2^{P-1} - # and 2^{P-1}-1. Note that evaluating to -2^{P-1} would mean there is a possibility of overflow - # on the positive side of this range. - max_accumulator_bit_width = self.accumulator_bit_width() # P - max_accumulator_mag = pow(2., max_accumulator_bit_width - 1.) - 1. # 2^{P-1}-1 - # This is the maximum possible magnitude that the input data could take. When the data is signed, - # this is 2^{N-1}. When the data is unsigned, this is 2^N - 1. We use a slightly looser bound here - # to simplify our derivations on the export validation. - max_input_mag_inverse = pow(2., input_is_signed - input_bit_width) - return max_accumulator_mag * max_input_mag_inverse - @brevitas.jit.script_method def forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tensor: """Takes weights as input and returns the pre-clipping scaling factor""" @@ -196,7 +178,8 @@ def forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: boo d_w = self.stats(weights) # denominator for weight normalization s = self.scaling_impl(weights) # s g = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) # g - T = self.get_upper_bound_on_l1_norm(input_bit_width, input_is_signed) # T / s + T = get_upper_bound_on_l1_norm( + self.accumulator_bit_width(), input_bit_width, input_is_signed) # T / s g = torch.clamp_max(g / s, T) value = d_w / g # calculating final pre-clipping scaling factor # re-apply clamp_min_ste from restrict_scaling_impl to the specified pre_scaling_min_val diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index d38413625..8d0273856 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -15,6 +15,7 @@ from brevitas.graph.gpxq import StopFwdException from brevitas.graph.gpxq import SUPPORTED_CONV_OP import brevitas.nn as qnn +from brevitas.nn.utils import get_upper_bound_on_l1_norm class gpfq_mode(gpxq_mode): @@ -48,7 +49,7 @@ def __init__( p: float = 1.0, return_forward_output: bool = False, act_order: bool = False, - accumulator_bit_width=None) -> None: + accumulator_bit_width: int = None) -> None: if not inplace: model = deepcopy(model) super().__init__( @@ -109,7 +110,7 @@ def initialize_module_optimizer( create_weight_orig=create_weight_orig, p=self.p) else: - return GPA2Q( + return GPFA2Q( layer=layer, name=name, act_order=act_order, @@ -272,44 +273,26 @@ def single_layer_update(self): del self.quantized_input -class L1NormMixin(ABC): - - def __init__(self, accumulator_bit_width) -> None: - self.accumulator_bit_width = accumulator_bit_width - - def get_upper_bound_on_l1_norm(self, input_bit_width: Tensor, input_is_signed: bool) -> Tensor: - """Calculate the upper bound on the l1-norm of the weights using the derivations from - `Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow Avoidance` - by I.Colbert, A.Pappalardo, and J.Petri-Koenig.""" - assert input_bit_width is not None, "A2Q relies on input bit-width." - assert input_is_signed is not None, "A2Q relies on input sign." - input_is_signed = float(input_is_signed) # 1. if signed else 0. - max_accumulator_bit_width = self.accumulator_bit_width # P - max_accumulator_mag = pow(2., max_accumulator_bit_width - 1.) - 1. # 2^{P-1}-1 - max_input_mag_inverse = pow(2., input_is_signed - input_bit_width) - return max_accumulator_mag * max_input_mag_inverse - - -class GPA2Q(GPFQ, L1NormMixin): +class GPFA2Q(GPFQ): def __init__( self, layer, name, act_order, - parallel_layers=1, + len_parallel_layers=1, create_weight_orig=True, accumulator_bit_width=None, - p=0.25) -> None: + p=1.0) -> None: GPFQ.__init__( self, layer=layer, name=name, act_order=act_order, - parallel_layers=parallel_layers, + len_parallel_layers=len_parallel_layers, create_weight_orig=create_weight_orig, p=p) - L1NormMixin.__init__(self, accumulator_bit_width) + self.accumulator_bit_width = accumulator_bit_width def single_layer_update(self): weight = self.layer.weight.data @@ -328,24 +311,41 @@ def single_layer_update(self): # get upper bound input_bit_width = self.layer.quant_input_bit_width() input_is_signed = self.layer.is_quant_input_signed - T = self.get_upper_bound_on_l1_norm(input_bit_width, input_is_signed) + T = get_upper_bound_on_l1_norm(self.accumulator_bit_width, input_bit_width, input_is_signed) s = self.layer.quant_weight_scale() - permutation_list = [torch.tensor(range(weight.shape[-1]))] l1_norm = torch.zeros(weight.shape[:-1], device=dev) + + # We don't need full Hessian, we just need the diagonal + self.H_diag = self.quantized_input.transpose(2, 1).square().sum( + 2) # summing over Batch dimension + permutation_list = [] + for group_index in range(self.groups): + if self.act_order: + # Re-order Hessian_diagonal so that weights associated to + # higher magnitude activations are quantized first + perm = torch.argsort(self.H_diag[group_index, :], descending=True) + else: + # No permutation, permutation tensor is a ordered index + perm = torch.tensor(range(weight.shape[-1]), device=dev) + permutation_list.append(perm) + for t in range(weight.shape[-1]): for group_index in range(self.groups): U[group_index] += torch.matmul( - weight[group_index, :, t].unsqueeze(1), - self.float_input[group_index, :, - t].unsqueeze(0)) #[OC/Groups, 1] * [1, INSHAPE[1]] - norm = torch.linalg.norm(self.quantized_input[group_index, :, t], 2) ** 2 + weight[group_index, :, permutation_list[group_index][t]].unsqueeze(1), + self.float_input[group_index, :, permutation_list[group_index][t]].unsqueeze( + 0)) #[OC/Groups, 1] * [1, INSHAPE[1]] + norm = torch.linalg.norm( + self.quantized_input[group_index, :, permutation_list[group_index][t]], 2) ** 2 if norm > 0: - q_arg = U[group_index].matmul(self.quantized_input[group_index, :, t]) / norm + q_arg = U[group_index].matmul( + self.quantized_input[group_index, :, + permutation_list[group_index][t]]) / norm else: q_arg = torch.zeros_like(U[group_index, :, 0]) - weight[group_index, :, t] = q_arg + weight[group_index, :, permutation_list[group_index][t]] = q_arg q = self.get_quant_weights(t, 0, permutation_list) for group_index in range(self.groups): @@ -353,13 +353,14 @@ def single_layer_update(self): candidate_l1_mask = candidate_l1 > T * s if torch.any(candidate_l1_mask): # set all values to 0 that are exceeding T * s - weight[group_index, :, t][candidate_l1_mask] = 0 + weight[group_index, :, permutation_list[group_index][t]][candidate_l1_mask] = 0 q[group_index][candidate_l1_mask] = 0 else: l1_norm[group_index] = candidate_l1 U[group_index] -= torch.matmul( q[group_index].unsqueeze(1), - self.quantized_input[group_index, :, t].unsqueeze(0)) + self.quantized_input[group_index, :, + permutation_list[group_index][t]].unsqueeze(0)) del self.float_input del self.quantized_input diff --git a/src/brevitas/nn/utils.py b/src/brevitas/nn/utils.py index 3e7b423ee..634f2d751 100644 --- a/src/brevitas/nn/utils.py +++ b/src/brevitas/nn/utils.py @@ -127,3 +127,17 @@ def calculate_min_accumulator_bit_width( min_bit_width = alpha + phi(alpha) + 1. min_bit_width = ceil_ste(min_bit_width) return min_bit_width # returns the minimum accumulator that can be used without risk of overflow + + +def get_upper_bound_on_l1_norm( + accumulator_bit_width: int, input_bit_width: Tensor, input_is_signed: bool) -> Tensor: + """Calculate the upper bound on the l1-norm of the weights using the derivations from + `Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow Avoidance` + by I.Colbert, A.Pappalardo, and J.Petri-Koenig.""" + assert input_bit_width is not None, "A2Q relies on input bit-width." + assert input_is_signed is not None, "A2Q relies on input sign." + input_is_signed = float(input_is_signed) # 1. if signed else 0. + max_accumulator_bit_width = accumulator_bit_width # P + max_accumulator_mag = pow(2., max_accumulator_bit_width - 1.) - 1. # 2^{P-1}-1 + max_input_mag_inverse = pow(2., input_is_signed - input_bit_width) + return max_accumulator_mag * max_input_mag_inverse diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 056362265..8c8837f2e 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -471,7 +471,7 @@ def apply_gptq(calib_loader, model, act_order=False): gptq.update() -def apply_gpfq(calib_loader, model, act_order, p=0.25, accumulator_bit_width=None): +def apply_gpfq(calib_loader, model, act_order, p=1.0, accumulator_bit_width=None): model.eval() dtype = next(model.parameters()).dtype device = next(model.parameters()).device diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 7cb6e96c3..5aa47568c 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -166,8 +166,7 @@ 'weight-narrow-range', default=True, help='Narrow range for weight quantization (default: enabled)') -parser.add_argument( - '--gpfq-p', default=1.0, type=float, help='P parameter for GPFQ (default: 0.25)') +parser.add_argument('--gpfq-p', default=1.0, type=float, help='P parameter for GPFQ (default: 1.0)') parser.add_argument( '--quant-format', default='int', @@ -272,6 +271,7 @@ def main(): f"GPFQ P: {args.gpfq_p} - " f"GPTQ Act Order: {args.gptq_act_order} - " f"GPFQ Act Order: {args.gpfq_act_order} - " + f"GPFQ Accumulator Bit Width: {args.accumulator_bit_width} - " f"Learned Round: {args.learned_round} - " f"Weight narrow range: {args.weight_narrow_range} - " f"Bias bit width: {args.bias_bit_width} - "