Skip to content

Commit

Permalink
Feat (GFPQ): add A2Q accumulator bit width bound
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Nov 20, 2023
1 parent bd46f89 commit 8964e07
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 11 deletions.
123 changes: 115 additions & 8 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC
from copy import deepcopy
from typing import List, Optional

import numpy as np
import torch
from torch import Tensor
import unfoldNd

from brevitas.graph.gpxq import GPxQ
Expand Down Expand Up @@ -45,7 +47,8 @@ def __init__(
use_quant_activations: bool = True,
p: float = 1.0,
return_forward_output: bool = False,
act_order: bool = False) -> None:
act_order: bool = False,
accumulator_bit_width=None) -> None:
if not inplace:
model = deepcopy(model)
super().__init__(
Expand All @@ -60,6 +63,7 @@ def __init__(
self.orig_forward = self.model.forward
self.model.forward = self.catch_stopfwd
self.p = p
self.accumulator_bit_width = accumulator_bit_width

def catch_stopfwd(self, *args, **kwargs):
# Collect quant input
Expand Down Expand Up @@ -96,13 +100,23 @@ def catch_stopfwd(self, *args, **kwargs):

def initialize_module_optimizer(
self, layer, name, act_order, len_parallel_layers, create_weight_orig):
return GPFQ(
layer=layer,
name=name,
act_order=act_order,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=self.p)
if not self.accumulator_bit_width:
return GPFQ(
layer=layer,
name=name,
act_order=act_order,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=self.p)
else:
return GPA2Q(
layer=layer,
name=name,
act_order=act_order,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=self.p,
accumulator_bit_width=self.accumulator_bit_width)


class GPFQ(GPxQ):
Expand Down Expand Up @@ -256,3 +270,96 @@ def single_layer_update(self):

del self.float_input
del self.quantized_input


class L1NormMixin(ABC):

def __init__(self, accumulator_bit_width) -> None:
self.accumulator_bit_width = accumulator_bit_width

def get_upper_bound_on_l1_norm(self, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
"""Calculate the upper bound on the l1-norm of the weights using the derivations from
`Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow Avoidance`
by I.Colbert, A.Pappalardo, and J.Petri-Koenig."""
assert input_bit_width is not None, "A2Q relies on input bit-width."
assert input_is_signed is not None, "A2Q relies on input sign."
input_is_signed = float(input_is_signed) # 1. if signed else 0.
max_accumulator_bit_width = self.accumulator_bit_width # P
max_accumulator_mag = pow(2., max_accumulator_bit_width - 1.) - 1. # 2^{P-1}-1
max_input_mag_inverse = pow(2., input_is_signed - input_bit_width)
return max_accumulator_mag * max_input_mag_inverse


class GPA2Q(GPFQ, L1NormMixin):

def __init__(
self,
layer,
name,
act_order,
parallel_layers=1,
create_weight_orig=True,
accumulator_bit_width=None,
p=0.25) -> None:
GPFQ.__init__(
self,
layer=layer,
name=name,
act_order=act_order,
parallel_layers=parallel_layers,
create_weight_orig=create_weight_orig,
p=p)
L1NormMixin.__init__(self, accumulator_bit_width)

def single_layer_update(self):
weight = self.layer.weight.data
dev = weight.device
dtype = weight.dtype
if isinstance(self.layer, SUPPORTED_CONV_OP):
if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
weight = weight.transpose(1, 0) # This performs a view
weight = weight.flatten(1)
weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC]
U = torch.zeros(
weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype)
self.float_input = self.float_input.to(dev)
self.quantized_input = self.quantized_input.to(dev)

# get upper bound
input_bit_width = self.layer.quant_input_bit_width()
input_is_signed = self.layer.is_quant_input_signed
T = self.get_upper_bound_on_l1_norm(input_bit_width, input_is_signed)
s = self.layer.quant_weight_scale()

permutation_list = [torch.tensor(range(weight.shape[-1]))]
l1_norm = torch.zeros(weight.shape[:-1], device=dev)
for t in range(weight.shape[-1]):
for group_index in range(self.groups):
U[group_index] += torch.matmul(
weight[group_index, :, t].unsqueeze(1),
self.float_input[group_index, :,
t].unsqueeze(0)) #[OC/Groups, 1] * [1, INSHAPE[1]]
norm = torch.linalg.norm(self.quantized_input[group_index, :, t], 2) ** 2
if norm > 0:
q_arg = U[group_index].matmul(self.quantized_input[group_index, :, t]) / norm
else:
q_arg = torch.zeros_like(U[group_index, :, 0])

weight[group_index, :, t] = q_arg
q = self.get_quant_weights(t, 0, permutation_list)

for group_index in range(self.groups):
candidate_l1 = l1_norm[group_index] + torch.abs(q[group_index])
candidate_l1_mask = candidate_l1 > T * s
if torch.any(candidate_l1_mask):
# set all values to 0 that are exceeding T * s
weight[group_index, :, t][candidate_l1_mask] = 0
q[group_index][candidate_l1_mask] = 0
else:
l1_norm[group_index] = candidate_l1
U[group_index] -= torch.matmul(
q[group_index].unsqueeze(1),
self.quantized_input[group_index, :, t].unsqueeze(0))

del self.float_input
del self.quantized_input
1 change: 1 addition & 0 deletions src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
device=dev,
dtype=torch.float32)
self.nsamples = 0
self.num_blocks = num_blocks

def update_batch(self, module, input, current_layer):
if self.disable_pre_forward_hook:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -471,12 +471,16 @@ def apply_gptq(calib_loader, model, act_order=False):
gptq.update()


def apply_gpfq(calib_loader, model, act_order, p=0.25):
def apply_gpfq(calib_loader, model, act_order, p=0.25, accumulator_bit_width=None):
model.eval()
dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
with torch.no_grad():
with gpfq_mode(model, p=p, use_quant_activations=True, act_order=act_order) as gpfq:
with gpfq_mode(model,
p=p,
use_quant_activations=True,
act_order=act_order,
accumulator_bit_width=accumulator_bit_width) as gpfq:
gpfq_model = gpfq.model
for i in tqdm(range(gpfq.num_layers)):
for i, (images, target) in enumerate(calib_loader):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,11 @@
default=3,
type=int,
help='Exponent bit width used with float quantization for activations (default: 3)')
parser.add_argument(
'--accumulator-bit-width',
default=None,
type=int,
help='Accumulator Bit Width for GPFQ in combination with A2Q (default: None)')
add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)')
add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)')
add_bool_arg(
Expand Down Expand Up @@ -363,7 +368,12 @@ def main():

if args.gpfq:
print("Performing GPFQ:")
apply_gpfq(calib_loader, quant_model, p=args.gpfq_p, act_order=args.gpfq_act_order)
apply_gpfq(
calib_loader,
quant_model,
p=args.gpfq_p,
act_order=args.gpfq_act_order,
accumulator_bit_width=args.accumulator_bit_width)

if args.gptq:
print("Performing GPTQ:")
Expand Down

0 comments on commit 8964e07

Please sign in to comment.