diff --git a/src/brevitas/core/scaling/pre_scaling.py b/src/brevitas/core/scaling/pre_scaling.py index 6f9a1905a..dd125396d 100644 --- a/src/brevitas/core/scaling/pre_scaling.py +++ b/src/brevitas/core/scaling/pre_scaling.py @@ -14,7 +14,7 @@ from brevitas.core.stats import SCALAR_SHAPE from brevitas.core.stats.stats_wrapper import _Stats from brevitas.function import abs_binary_sign_grad -from brevitas.nn.utils import get_upper_bound_on_l1_norm +from brevitas.function import get_upper_bound_on_l1_norm __all__ = ["ParameterPreScalingWeightNorm", "AccumulatorAwareParameterPreScaling"] diff --git a/src/brevitas/function/ops.py b/src/brevitas/function/ops.py index f68ae9ede..feb20ac38 100644 --- a/src/brevitas/function/ops.py +++ b/src/brevitas/function/ops.py @@ -201,3 +201,17 @@ def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_b device=mantissa_bit_width.device))) max_val = max_mantissa * (2 ** max_exponent) return max_val + + +def get_upper_bound_on_l1_norm( + accumulator_bit_width: int, input_bit_width: Tensor, input_is_signed: bool) -> Tensor: + """Calculate the upper bound on the l1-norm of the weights using the derivations from + `Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow Avoidance` + by I.Colbert, A.Pappalardo, and J.Petri-Koenig.""" + assert input_bit_width is not None, "A2Q relies on input bit-width." + assert input_is_signed is not None, "A2Q relies on input sign." + input_is_signed = float(input_is_signed) # 1. if signed else 0. + max_accumulator_bit_width = accumulator_bit_width # P + max_accumulator_mag = pow(2., max_accumulator_bit_width - 1.) - 1. # 2^{P-1}-1 + max_input_mag_inverse = pow(2., input_is_signed - input_bit_width) + return max_accumulator_mag * max_input_mag_inverse diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index 8d0273856..d4a608156 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -1,21 +1,19 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -from abc import ABC from copy import deepcopy from typing import List, Optional import numpy as np import torch -from torch import Tensor import unfoldNd +from brevitas.function import get_upper_bound_on_l1_norm from brevitas.graph.gpxq import GPxQ from brevitas.graph.gpxq import gpxq_mode from brevitas.graph.gpxq import StopFwdException from brevitas.graph.gpxq import SUPPORTED_CONV_OP import brevitas.nn as qnn -from brevitas.nn.utils import get_upper_bound_on_l1_norm class gpfq_mode(gpxq_mode): diff --git a/src/brevitas/nn/utils.py b/src/brevitas/nn/utils.py index 634f2d751..3e7b423ee 100644 --- a/src/brevitas/nn/utils.py +++ b/src/brevitas/nn/utils.py @@ -127,17 +127,3 @@ def calculate_min_accumulator_bit_width( min_bit_width = alpha + phi(alpha) + 1. min_bit_width = ceil_ste(min_bit_width) return min_bit_width # returns the minimum accumulator that can be used without risk of overflow - - -def get_upper_bound_on_l1_norm( - accumulator_bit_width: int, input_bit_width: Tensor, input_is_signed: bool) -> Tensor: - """Calculate the upper bound on the l1-norm of the weights using the derivations from - `Quantized Neural Networks for Low-Precision Accumulation with Guaranteed Overflow Avoidance` - by I.Colbert, A.Pappalardo, and J.Petri-Koenig.""" - assert input_bit_width is not None, "A2Q relies on input bit-width." - assert input_is_signed is not None, "A2Q relies on input sign." - input_is_signed = float(input_is_signed) # 1. if signed else 0. - max_accumulator_bit_width = accumulator_bit_width # P - max_accumulator_mag = pow(2., max_accumulator_bit_width - 1.) - 1. # 2^{P-1}-1 - max_input_mag_inverse = pow(2., input_is_signed - input_bit_width) - return max_accumulator_mag * max_input_mag_inverse