diff --git a/src/brevitas/quant/base.py b/src/brevitas/quant/base.py index 48bec58c1..c88e57204 100644 --- a/src/brevitas/quant/base.py +++ b/src/brevitas/quant/base.py @@ -53,7 +53,8 @@ from brevitas.inject.enum import StatsOp from brevitas.proxy import DecoupledWeightQuantProxyFromInjector from brevitas.proxy import DecoupledWeightQuantWithInputProxyFromInjector -from brevitas.quant.solver.common import SolveScalingStatsInputViewShapeImplFromEnum, SolveStatsReduceDimFromEnum +from brevitas.quant.solver.common import SolveScalingStatsInputViewShapeImplFromEnum +from brevitas.quant.solver.common import SolveStatsReduceDimFromEnum from brevitas.quant.solver.parameter import SolveInputViewImpl from brevitas.quant.solver.parameter import SolveParameterScalingShape from brevitas.quant.solver.weight import SolveWeightScalingPerOutputChannelShapeFromModule @@ -334,20 +335,41 @@ class WeightPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum, stats_reduce_dim = SCALING_STATS_REDUCE_DIM scaling_per_output_type = ScalingPerOutputType.CHANNEL + class PerChannelL2Norm(ExtendedInjector): stats_reduce_dim = SCALING_STATS_REDUCE_DIM normalize_stats_impl = L2Norm + class PerChannelPreNorm(ExtendedInjector): pre_scaling_impl = ParameterPreScalingWeightNorm scaling_stats_input_view_shape_impl = OverOutputChannelView - scaling_impl = (this<<1).scaling_impl - normalize_stats_impl = (this<<1).normalize_stats_impl - tracked_parameter_list = (this<<1).tracked_parameter_list - pre_scaling_shape = (this<<1).pre_scaling_shape + scaling_impl = (this << 1).scaling_impl + normalize_stats_impl = (this << 1).normalize_stats_impl + tracked_parameter_list = (this << 1).tracked_parameter_list + pre_scaling_shape = (this << 1).pre_scaling_shape + + +class SolvePostScaleGranularity(ExtendedInjector): + + @value + def scaling_stats_input_view_shape_impl(scaling_per_output_type): + if scaling_per_output_type == ScalingPerOutputType.TENSOR: + return StatsInputViewShapeImpl.OVER_TENSOR + elif scaling_per_output_type == ScalingPerOutputType.CHANNEL: + return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS + + @value + def stats_reduce_dim(scaling_per_output_type): + if scaling_per_output_type == ScalingPerOutputType.TENSOR: + return None + elif scaling_per_output_type == ScalingPerOutputType.CHANNEL: + return SCALING_STATS_REDUCE_DIM + -class WeightNormPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum, +class WeightNormPerChannelFloatDecoupled(SolvePostScaleGranularity, + SolveStatsReduceDimFromEnum, SolveScalingStatsInputViewShapeImplFromEnum, SolveWeightScalingStatsInputDimsFromModule, SolveWeightScalingPerOutputChannelShapeFromModule, @@ -384,7 +406,7 @@ def scaling_init(scaling_init_impl, bit_width): scaling_stats_impl = AbsMax restrict_pre_scaling_impl = LogFloatRestrictValue normalize_stats_impl = PerChannelL2Norm.normalize_stats_impl - scaling_per_output_type = ScalingPerOutputType.TENSOR + scaling_per_output_type = ScalingPerOutputType.CHANNEL pre_scaling_shape = this.scaling_per_output_channel_shape int_scaling_impl = SingleArgStatelessBuffer(1.) zero_point_impl = ZeroZeroPoint @@ -392,12 +414,10 @@ def scaling_init(scaling_init_impl, bit_width): bit_width_impl = BitWidthConst narrow_range = True signed = True - scaling_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_TENSOR - stats_reduce_dim = None scaling_min_val = 1e-10 pre_scaling_min_val = 1e-10 - @value + @value def pre_scaling_impl(): return this.per_channel_pre_norm.pre_scaling_impl diff --git a/tests/brevitas/nn/nn_quantizers_fixture.py b/tests/brevitas/nn/nn_quantizers_fixture.py index a6b1c05af..90b300b13 100644 --- a/tests/brevitas/nn/nn_quantizers_fixture.py +++ b/tests/brevitas/nn/nn_quantizers_fixture.py @@ -11,6 +11,7 @@ from brevitas import torch_version import brevitas.config as config +from brevitas.inject.enum import ScalingPerOutputType from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d from brevitas.nn import QuantConv3d @@ -48,6 +49,11 @@ EMBED_DIM = 9 NUM_HEADS = 3 + +class Int8WeightNormL2PerChannelPerTensorFixedPoint(Int8WeightNormL2PerChannelFixedPoint): + scaling_per_output_type = ScalingPerOutputType.TENSOR + + LSTM_WEIGHT_QUANTIZER = { 'None': None, 'quant_sym': Int8WeightPerTensorFloat, @@ -62,6 +68,7 @@ 'quant_sym': Int8WeightPerTensorFloat, 'quant_asym': ShiftedUint8WeightPerTensorFloat, 'quant_decoupled': Int8WeightNormL2PerChannelFixedPoint, + 'quant_decoupled_per_tensor': Int8WeightNormL2PerChannelPerTensorFixedPoint, 'quant_mx': MXInt8Weight, 'quant_float': Fp8e4m3WeightPerTensorFloat, **A2Q_WBIOL_WEIGHT_QUANTIZER}