Skip to content

Commit

Permalink
Decoupled PerChannel/PerTensor quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 12, 2024
1 parent f58f64b commit 2609941
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from brevitas.inject.enum import StatsOp
from brevitas.proxy import DecoupledWeightQuantProxyFromInjector
from brevitas.proxy import DecoupledWeightQuantWithInputProxyFromInjector
from brevitas.quant.solver.common import SolveStatsReduceDimFromEnum
from brevitas.quant.solver.common import SolveScalingStatsInputViewShapeImplFromEnum, SolveStatsReduceDimFromEnum
from brevitas.quant.solver.parameter import SolveInputViewImpl
from brevitas.quant.solver.parameter import SolveParameterScalingShape
from brevitas.quant.solver.weight import SolveWeightScalingPerOutputChannelShapeFromModule
Expand Down Expand Up @@ -334,8 +334,21 @@ class WeightPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum,
stats_reduce_dim = SCALING_STATS_REDUCE_DIM
scaling_per_output_type = ScalingPerOutputType.CHANNEL

class PerChannelL2Norm(ExtendedInjector):
stats_reduce_dim = SCALING_STATS_REDUCE_DIM
normalize_stats_impl = L2Norm

class PerChannelPreNorm(ExtendedInjector):

pre_scaling_impl = ParameterPreScalingWeightNorm
scaling_stats_input_view_shape_impl = OverOutputChannelView
scaling_impl = (this<<1).scaling_impl
normalize_stats_impl = (this<<1).normalize_stats_impl
tracked_parameter_list = (this<<1).tracked_parameter_list
pre_scaling_shape = (this<<1).pre_scaling_shape

class WeightNormPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum,
SolveScalingStatsInputViewShapeImplFromEnum,
SolveWeightScalingStatsInputDimsFromModule,
SolveWeightScalingPerOutputChannelShapeFromModule,
SolveParameterScalingShape,
Expand All @@ -359,6 +372,8 @@ def scaling_init(scaling_init_impl, bit_width):
scales = scaling_init_impl.parameter_list_stats() / (pow(2., bit_width - 1.) - 1.)
return scales

per_channel_pre_norm = PerChannelPreNorm

proxy_class = DecoupledWeightQuantProxyFromInjector
tensor_quant = DecoupledRescalingIntQuant
decoupled_int_quant = DecoupledIntQuant
Expand All @@ -367,22 +382,25 @@ def scaling_init(scaling_init_impl, bit_width):
scaling_init_impl = StatsFromParameterScaling
restrict_scaling_impl = LogFloatRestrictValue
scaling_stats_impl = AbsMax
pre_scaling_impl = ParameterPreScalingWeightNorm
restrict_pre_scaling_impl = LogFloatRestrictValue
normalize_stats_impl = L2Norm
scaling_per_output_type = ScalingPerOutputType.CHANNEL
pre_scaling_shape = this.scaling_shape # TODO: decouple pre_scaling_shape from scaling_shape
normalize_stats_impl = PerChannelL2Norm.normalize_stats_impl
scaling_per_output_type = ScalingPerOutputType.TENSOR
pre_scaling_shape = this.scaling_per_output_channel_shape
int_scaling_impl = SingleArgStatelessBuffer(1.)
zero_point_impl = ZeroZeroPoint
pre_zero_point_impl = ZeroZeroPoint
bit_width_impl = BitWidthConst
narrow_range = True
signed = True
scaling_stats_input_view_shape_impl = OverOutputChannelView
stats_reduce_dim = SCALING_STATS_REDUCE_DIM
scaling_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_TENSOR
stats_reduce_dim = None
scaling_min_val = 1e-10
pre_scaling_min_val = 1e-10

@value
def pre_scaling_impl():
return this.per_channel_pre_norm.pre_scaling_impl


class AccumulatorAwareWeightQuant(WeightNormPerChannelFloatDecoupled):
"""Experimental accumulator-aware weight quantizer based on `Quantized Neural Networks
Expand Down

0 comments on commit 2609941

Please sign in to comment.