diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index cf23d18b0..2d312a549 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -43,7 +43,7 @@ def __init__( inplace: bool = True, create_weight_orig: bool = True, use_quant_activations: bool = True, - p: int = 0.25, + p: float = 1.0, return_forward_output: bool = False, act_order: bool = False) -> None: if not inplace: @@ -117,12 +117,10 @@ def __init__( act_order, len_parallel_layers=1, create_weight_orig=True, - p=0.25) -> None: - - if act_order: - raise ValueError("Act_order is not supported in GPFQ") + p=1.0) -> None: super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig) + self.float_input = None self.quantized_input = None self.index_computed = False @@ -220,25 +218,41 @@ def single_layer_update(self): weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype) self.float_input = self.float_input.to(dev) self.quantized_input = self.quantized_input.to(dev) - permutation_list = [torch.tensor(range(weight.shape[-1]))] + # We don't need full Hessian, we just need the diagonal + self.H_diag = self.quantized_input.transpose(2, 1).square().sum( + 2) # summing over Batch dimension + permutation_list = [] + for group_index in range(self.groups): + if self.act_order: + # Re-order Hessian_diagonal so that weights associated to + # higher magnitude activations are quantized first + perm = torch.argsort(self.H_diag[group_index, :], descending=True) + else: + # No permutation, permutation tensor is a ordered index + perm = torch.tensor(range(weight.shape[-1]), device=dev) + permutation_list.append(perm) for t in range(weight.shape[-1]): for group_index in range(self.groups): U[group_index] += torch.matmul( - weight[group_index, :, t].unsqueeze(1), - self.float_input[group_index, :, - t].unsqueeze(0)) #[OC/Groups, 1] * [1, INSHAPE[1]] - norm = torch.linalg.norm(self.quantized_input[group_index, :, t], 2) ** 2 + weight[group_index, :, permutation_list[group_index][t]].unsqueeze(1), + self.float_input[group_index, :, permutation_list[group_index][t]].unsqueeze( + 0)) #[OC/Groups, 1] * [1, INSHAPE[1]] + norm = torch.linalg.norm( + self.quantized_input[group_index, :, permutation_list[group_index][t]], 2) ** 2 if norm > 0: - q_arg = U[group_index].matmul(self.quantized_input[group_index, :, t]) / norm + q_arg = U[group_index].matmul( + self.quantized_input[group_index, :, + permutation_list[group_index][t]]) / norm else: q_arg = torch.zeros_like(U[group_index, :, 0]) - weight[group_index, :, t] = q_arg + weight[group_index, :, permutation_list[group_index][t]] = q_arg q = self.get_quant_weights(t, 0, permutation_list) for group_index in range(self.groups): U[group_index] -= torch.matmul( q[group_index].unsqueeze(1), - self.quantized_input[group_index, :, t].unsqueeze(0)) + self.quantized_input[group_index, :, + permutation_list[group_index][t]].unsqueeze(0)) del self.float_input del self.quantized_input diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index cbb56e7f3..b3e7b00d6 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -471,12 +471,12 @@ def apply_gptq(calib_loader, model, act_order=False): gptq.update() -def apply_gpfq(calib_loader, model, p=0.25): +def apply_gpfq(calib_loader, model, act_order, p=0.25): model.eval() dtype = next(model.parameters()).dtype device = next(model.parameters()).device with torch.no_grad(): - with gpfq_mode(model, p=p, use_quant_activations=True) as gpfq: + with gpfq_mode(model, p=p, use_quant_activations=True, act_order=act_order) as gpfq: gpfq_model = gpfq.model for i in tqdm(range(gpfq.num_layers)): for i, (images, target) in enumerate(calib_loader): diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index dd49e8531..6f4243741 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -167,7 +167,7 @@ default=True, help='Narrow range for weight quantization (default: enabled)') parser.add_argument( - '--gpfq-p', default=0.25, type=float, help='P parameter for GPFQ (default: 0.25)') + '--gpfq-p', default=1.0, type=float, help='P parameter for GPFQ (default: 0.25)') parser.add_argument( '--quant-format', default='int', @@ -207,10 +207,12 @@ default=3, type=int, help='Exponent bit width used with float quantization for activations (default: 3)') -add_bool_arg(parser, 'gptq', default=True, help='GPTQ (default: enabled)') +add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)') add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)') add_bool_arg( parser, 'gptq-act-order', default=False, help='GPTQ Act order heuristic (default: disabled)') +add_bool_arg( + parser, 'gpfq-act-order', default=False, help='GPFQ Act order heuristic (default: disabled)') add_bool_arg(parser, 'learned-round', default=False, help='Learned round (default: disabled)') add_bool_arg(parser, 'calibrate-bn', default=False, help='Calibrate BN (default: disabled)') @@ -241,6 +243,7 @@ def main(): f"{'gptq_' if args.gptq else ''}" f"{'gpfq_' if args.gpfq else ''}" f"{'gptq_act_order_' if args.gptq_act_order else ''}" + f"{'gpfq_act_order_' if args.gpfq_act_order else ''}" f"{'learned_round_' if args.learned_round else ''}" f"{'weight_narrow_range_' if args.weight_narrow_range else ''}" f"{args.bias_bit_width}bias_" @@ -263,6 +266,7 @@ def main(): f"GPFQ: {args.gpfq} - " f"GPFQ P: {args.gpfq_p} - " f"GPTQ Act Order: {args.gptq_act_order} - " + f"GPFQ Act Order: {args.gpfq_act_order} - " f"Learned Round: {args.learned_round} - " f"Weight narrow range: {args.weight_narrow_range} - " f"Bias bit width: {args.bias_bit_width} - " @@ -359,7 +363,7 @@ def main(): if args.gpfq: print("Performing GPFQ:") - apply_gpfq(calib_loader, quant_model, p=args.gpfq_p) + apply_gpfq(calib_loader, quant_model, p=args.gpfq_p, act_order=args.gpfq_act_order) if args.gptq: print("Performing GPTQ:")