Skip to content

Commit

Permalink
Fix (solver/trunc): Added a ShiftRoundSaturate quantizer and update t…
Browse files Browse the repository at this point in the history
…runc solver
  • Loading branch information
nickfraser committed Jan 27, 2025
1 parent 7d6c934 commit b121f5d
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 5 deletions.
8 changes: 8 additions & 0 deletions src/brevitas/inject/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,11 @@ class StatsOp(AutoName):
# Typically adopted for asymmetric quantization
MIN_MAX = auto()
PERCENTILE_INTERVAL = auto()


class TruncScalingImplType(AutoName):
"""
"""
MSB = auto()
WRAPPER = auto()
25 changes: 23 additions & 2 deletions src/brevitas/quant/scaled_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
'Int8WeightPerChannelFloatMSE',
'TruncTo8bit',
'RoundTo8bit',
'ShiftRoundSaturateTo8bit',
'Int4WeightPerTensorFloatDecoupled',
'Int8WeightPerChannelFloatDecoupled',
'Uint8ActPerTensorFloatBatchQuant1d',
Expand Down Expand Up @@ -261,7 +262,7 @@ class Uint8ActPerTensorFloatMSE(MSESymmetricScale, Uint8ActPerTensorFloat):

class TruncTo8bit(TruncQuantSolver):
"""
8-bit signed int truncator that preserves the input scale factor and zero-point.
8-bit int truncator that preserves most-significant bits and zero-point.
Examples:
>>> from brevitas.nn import TruncAvgPool2d
Expand All @@ -271,11 +272,12 @@ class TruncTo8bit(TruncQuantSolver):
quant_type = 'int'
bit_width_impl_type = 'const'
float_to_int_impl_type = 'floor'
trunc_scaling_impl_type = 'msb'


class RoundTo8bit(TruncQuantSolver):
"""
8-bit signed int truncator with rounding that preserves the input scale factor and zero-point.
8-bit int truncator with rounding that preserves most-significant bits and zero-point.
Examples:
>>> from brevitas.nn import TruncAvgPool2d
Expand All @@ -285,6 +287,25 @@ class RoundTo8bit(TruncQuantSolver):
quant_type = 'int'
bit_width_impl_type = 'const'
float_to_int_impl_type = 'round'
trunc_scaling_impl_type = 'msb'


class ShiftRoundSaturateTo8bit(TruncQuantSolver,
ParamFromRuntimePercentileScaling,
PerTensorPoTScaling8bit):
"""
8-bit shift-round-saturate quantizer which uses statistics to calculate the amount of truncation
the lest-significant bits and most-significant bits. Zero-point is preserved.
Examples:
>>> from brevitas.nn import TruncAvgPool2d
>>> pool = TruncAvgPool2d(kernel_size=(3, 3), trunc_quant=ShiftRoundSaturateTo8bit)
"""
bit_width = 8
quant_type = 'int'
bit_width_impl_type = 'const'
float_to_int_impl_type = 'round'
trunc_scaling_impl_type = 'wrapper'


class Int4WeightPerTensorFloatDecoupled(WeightPerTensorFloatDecoupledL2Param):
Expand Down
46 changes: 43 additions & 3 deletions src/brevitas/quant/solver/trunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@
# SPDX-License-Identifier: BSD-3-Clause

from brevitas.core.quant import TruncIntQuant
from brevitas.core.scaling import TruncMsbScaling
from brevitas.core.scaling import TruncPowerOfTwoIntScaling
from brevitas.core.scaling import TruncScalingWrapper
from brevitas.inject import ExtendedInjector
from brevitas.inject import value
from brevitas.inject.enum import QuantType
from brevitas.inject.enum import RestrictValueType
from brevitas.inject.enum import TruncScalingImplType
from brevitas.proxy import TruncQuantProxyFromInjector
from brevitas.quant.solver.common import SolveBitWidthImplFromEnum
from brevitas.quant.solver.common import SolveTensorQuantFloatToIntImplFromEnum
from brevitas.quant.solver.act import *
from brevitas.quant.solver.common import *


class SolveTruncTensorQuantFromEnum(ExtendedInjector):
Expand All @@ -26,9 +31,44 @@ def tensor_quant(quant_type):
raise RuntimeError(f'{quant_type} not recognized.')


class SolveTruncScalingImplFromEnum(ExtendedInjector):

@value
def trunc_scaling_impl(trunc_scaling_impl_type="msb"):
if trunc_scaling_impl_type == TruncScalingImplType.MSB:
return TruncMsbScaling
elif trunc_scaling_impl_type == TruncScalingImplType.WRAPPER:
return TruncScalingWrapper
else:
raise RuntimeError(f'trunc_scaling_impl_type={trunc_scaling_impl_type} not recognized.')


class SolveTruncIntScalingImplFromEnum(ExtendedInjector):

@value
def trunc_int_scaling_impl(restrict_scaling_type):
if restrict_scaling_type == RestrictValueType.POWER_OF_TWO:
return TruncPowerOfTwoIntScaling
else:
raise RuntimeError(f'restrict_scaling_type={restrict_scaling_type} not recognized.')


class TruncQuantSolver(SolveBitWidthImplFromEnum,
SolveTensorQuantFloatToIntImplFromEnum,
SolveTruncTensorQuantFromEnum):
SolveActScalingImplFromEnum,
SolveIntScalingImplFromEnum,
SolveScalingStatsOpFromEnum,
SolveRestrictScalingImplFromEnum,
SolveActScalingInitFromEnum,
SolveStatsReduceDimFromEnum,
SolveActScalingShape,
SolveScalingStatsInputViewShapeImplFromEnum,
SolveActScalingPerOutputChannelShape,
SolveUpdateStateDictImplFromEnum,
SolveInputViewImpl,
SolveTruncTensorQuantFromEnum,
SolveTruncScalingImplFromEnum,
SolveTruncIntScalingImplFromEnum):
"""
Translate enum directives to truncation-specific quantization core modules.
It should be placed last in the list of classes a quantizer inherits from,
Expand Down

0 comments on commit b121f5d

Please sign in to comment.