Skip to content

Commit

Permalink
Feat (GPFA2Q): unify upper bound method, fix act_order and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Nov 20, 2023
1 parent 593eb30 commit 36727ae
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 58 deletions.
23 changes: 3 additions & 20 deletions src/brevitas/core/scaling/pre_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from brevitas.core.stats import SCALAR_SHAPE
from brevitas.core.stats.stats_wrapper import _Stats
from brevitas.function import abs_binary_sign_grad
from brevitas.nn.utils import get_upper_bound_on_l1_norm

__all__ = ["ParameterPreScalingWeightNorm", "AccumulatorAwareParameterPreScaling"]

Expand Down Expand Up @@ -170,33 +171,15 @@ def __init__(
)
self.accumulator_bit_width = accumulator_bit_width_impl

@brevitas.jit.script_method
def get_upper_bound_on_l1_norm(self, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
"""Calculate the upper bound on the l1-norm of the weights using the derivations from
`Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow Avoidance`
by I.Colbert, A.Pappalardo, and J.Petri-Koenig."""
assert input_bit_width is not None, "A2Q relies on input bit-width."
assert input_is_signed is not None, "A2Q relies on input sign."
input_is_signed = float(input_is_signed) # 1. if signed else 0.
# This is the minimum of the two maximum magnitudes that P could take, which are -2^{P-1}
# and 2^{P-1}-1. Note that evaluating to -2^{P-1} would mean there is a possibility of overflow
# on the positive side of this range.
max_accumulator_bit_width = self.accumulator_bit_width() # P
max_accumulator_mag = pow(2., max_accumulator_bit_width - 1.) - 1. # 2^{P-1}-1
# This is the maximum possible magnitude that the input data could take. When the data is signed,
# this is 2^{N-1}. When the data is unsigned, this is 2^N - 1. We use a slightly looser bound here
# to simplify our derivations on the export validation.
max_input_mag_inverse = pow(2., input_is_signed - input_bit_width)
return max_accumulator_mag * max_input_mag_inverse

@brevitas.jit.script_method
def forward(self, weights: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
"""Takes weights as input and returns the pre-clipping scaling factor"""
weights = self.stats_input_view_shape_impl(weights)
d_w = self.stats(weights) # denominator for weight normalization
s = self.scaling_impl(weights) # s
g = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) # g
T = self.get_upper_bound_on_l1_norm(input_bit_width, input_is_signed) # T / s
T = get_upper_bound_on_l1_norm(
self.accumulator_bit_width(), input_bit_width, input_is_signed) # T / s
g = torch.clamp_max(g / s, T)
value = d_w / g # calculating final pre-clipping scaling factor
# re-apply clamp_min_ste from restrict_scaling_impl to the specified pre_scaling_min_val
Expand Down
71 changes: 36 additions & 35 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from brevitas.graph.gpxq import StopFwdException
from brevitas.graph.gpxq import SUPPORTED_CONV_OP
import brevitas.nn as qnn
from brevitas.nn.utils import get_upper_bound_on_l1_norm


class gpfq_mode(gpxq_mode):
Expand Down Expand Up @@ -48,7 +49,7 @@ def __init__(
p: float = 1.0,
return_forward_output: bool = False,
act_order: bool = False,
accumulator_bit_width=None) -> None:
accumulator_bit_width: int = None) -> None:
if not inplace:
model = deepcopy(model)
super().__init__(
Expand Down Expand Up @@ -109,7 +110,7 @@ def initialize_module_optimizer(
create_weight_orig=create_weight_orig,
p=self.p)
else:
return GPA2Q(
return GPFA2Q(
layer=layer,
name=name,
act_order=act_order,
Expand Down Expand Up @@ -272,44 +273,26 @@ def single_layer_update(self):
del self.quantized_input


class L1NormMixin(ABC):

def __init__(self, accumulator_bit_width) -> None:
self.accumulator_bit_width = accumulator_bit_width

def get_upper_bound_on_l1_norm(self, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
"""Calculate the upper bound on the l1-norm of the weights using the derivations from
`Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow Avoidance`
by I.Colbert, A.Pappalardo, and J.Petri-Koenig."""
assert input_bit_width is not None, "A2Q relies on input bit-width."
assert input_is_signed is not None, "A2Q relies on input sign."
input_is_signed = float(input_is_signed) # 1. if signed else 0.
max_accumulator_bit_width = self.accumulator_bit_width # P
max_accumulator_mag = pow(2., max_accumulator_bit_width - 1.) - 1. # 2^{P-1}-1
max_input_mag_inverse = pow(2., input_is_signed - input_bit_width)
return max_accumulator_mag * max_input_mag_inverse


class GPA2Q(GPFQ, L1NormMixin):
class GPFA2Q(GPFQ):

def __init__(
self,
layer,
name,
act_order,
parallel_layers=1,
len_parallel_layers=1,
create_weight_orig=True,
accumulator_bit_width=None,
p=0.25) -> None:
p=1.0) -> None:
GPFQ.__init__(
self,
layer=layer,
name=name,
act_order=act_order,
parallel_layers=parallel_layers,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=p)
L1NormMixin.__init__(self, accumulator_bit_width)
self.accumulator_bit_width = accumulator_bit_width

def single_layer_update(self):
weight = self.layer.weight.data
Expand All @@ -328,38 +311,56 @@ def single_layer_update(self):
# get upper bound
input_bit_width = self.layer.quant_input_bit_width()
input_is_signed = self.layer.is_quant_input_signed
T = self.get_upper_bound_on_l1_norm(input_bit_width, input_is_signed)
T = get_upper_bound_on_l1_norm(self.accumulator_bit_width, input_bit_width, input_is_signed)
s = self.layer.quant_weight_scale()

permutation_list = [torch.tensor(range(weight.shape[-1]))]
l1_norm = torch.zeros(weight.shape[:-1], device=dev)

# We don't need full Hessian, we just need the diagonal
self.H_diag = self.quantized_input.transpose(2, 1).square().sum(
2) # summing over Batch dimension
permutation_list = []
for group_index in range(self.groups):
if self.act_order:
# Re-order Hessian_diagonal so that weights associated to
# higher magnitude activations are quantized first
perm = torch.argsort(self.H_diag[group_index, :], descending=True)
else:
# No permutation, permutation tensor is a ordered index
perm = torch.tensor(range(weight.shape[-1]), device=dev)
permutation_list.append(perm)

for t in range(weight.shape[-1]):
for group_index in range(self.groups):
U[group_index] += torch.matmul(
weight[group_index, :, t].unsqueeze(1),
self.float_input[group_index, :,
t].unsqueeze(0)) #[OC/Groups, 1] * [1, INSHAPE[1]]
norm = torch.linalg.norm(self.quantized_input[group_index, :, t], 2) ** 2
weight[group_index, :, permutation_list[group_index][t]].unsqueeze(1),
self.float_input[group_index, :, permutation_list[group_index][t]].unsqueeze(
0)) #[OC/Groups, 1] * [1, INSHAPE[1]]
norm = torch.linalg.norm(
self.quantized_input[group_index, :, permutation_list[group_index][t]], 2) ** 2
if norm > 0:
q_arg = U[group_index].matmul(self.quantized_input[group_index, :, t]) / norm
q_arg = U[group_index].matmul(
self.quantized_input[group_index, :,
permutation_list[group_index][t]]) / norm
else:
q_arg = torch.zeros_like(U[group_index, :, 0])

weight[group_index, :, t] = q_arg
weight[group_index, :, permutation_list[group_index][t]] = q_arg
q = self.get_quant_weights(t, 0, permutation_list)

for group_index in range(self.groups):
candidate_l1 = l1_norm[group_index] + torch.abs(q[group_index])
candidate_l1_mask = candidate_l1 > T * s
if torch.any(candidate_l1_mask):
# set all values to 0 that are exceeding T * s
weight[group_index, :, t][candidate_l1_mask] = 0
weight[group_index, :, permutation_list[group_index][t]][candidate_l1_mask] = 0
q[group_index][candidate_l1_mask] = 0
else:
l1_norm[group_index] = candidate_l1
U[group_index] -= torch.matmul(
q[group_index].unsqueeze(1),
self.quantized_input[group_index, :, t].unsqueeze(0))
self.quantized_input[group_index, :,
permutation_list[group_index][t]].unsqueeze(0))

del self.float_input
del self.quantized_input
14 changes: 14 additions & 0 deletions src/brevitas/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,17 @@ def calculate_min_accumulator_bit_width(
min_bit_width = alpha + phi(alpha) + 1.
min_bit_width = ceil_ste(min_bit_width)
return min_bit_width # returns the minimum accumulator that can be used without risk of overflow


def get_upper_bound_on_l1_norm(
accumulator_bit_width: int, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
"""Calculate the upper bound on the l1-norm of the weights using the derivations from
`Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow Avoidance`
by I.Colbert, A.Pappalardo, and J.Petri-Koenig."""
assert input_bit_width is not None, "A2Q relies on input bit-width."
assert input_is_signed is not None, "A2Q relies on input sign."
input_is_signed = float(input_is_signed) # 1. if signed else 0.
max_accumulator_bit_width = accumulator_bit_width # P
max_accumulator_mag = pow(2., max_accumulator_bit_width - 1.) - 1. # 2^{P-1}-1
max_input_mag_inverse = pow(2., input_is_signed - input_bit_width)
return max_accumulator_mag * max_input_mag_inverse
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def apply_gptq(calib_loader, model, act_order=False):
gptq.update()


def apply_gpfq(calib_loader, model, act_order, p=0.25, accumulator_bit_width=None):
def apply_gpfq(calib_loader, model, act_order, p=1.0, accumulator_bit_width=None):
model.eval()
dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,7 @@
'weight-narrow-range',
default=True,
help='Narrow range for weight quantization (default: enabled)')
parser.add_argument(
'--gpfq-p', default=1.0, type=float, help='P parameter for GPFQ (default: 0.25)')
parser.add_argument('--gpfq-p', default=1.0, type=float, help='P parameter for GPFQ (default: 1.0)')
parser.add_argument(
'--quant-format',
default='int',
Expand Down Expand Up @@ -272,6 +271,7 @@ def main():
f"GPFQ P: {args.gpfq_p} - "
f"GPTQ Act Order: {args.gptq_act_order} - "
f"GPFQ Act Order: {args.gpfq_act_order} - "
f"GPFQ Accumulator Bit Width: {args.accumulator_bit_width} - "
f"Learned Round: {args.learned_round} - "
f"Weight narrow range: {args.weight_narrow_range} - "
f"Bias bit width: {args.bias_bit_width} - "
Expand Down

0 comments on commit 36727ae

Please sign in to comment.