Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 13, 2024
1 parent 2609941 commit 028c352
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 10 deletions.
40 changes: 30 additions & 10 deletions src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@
from brevitas.inject.enum import StatsOp
from brevitas.proxy import DecoupledWeightQuantProxyFromInjector
from brevitas.proxy import DecoupledWeightQuantWithInputProxyFromInjector
from brevitas.quant.solver.common import SolveScalingStatsInputViewShapeImplFromEnum, SolveStatsReduceDimFromEnum
from brevitas.quant.solver.common import SolveScalingStatsInputViewShapeImplFromEnum
from brevitas.quant.solver.common import SolveStatsReduceDimFromEnum
from brevitas.quant.solver.parameter import SolveInputViewImpl
from brevitas.quant.solver.parameter import SolveParameterScalingShape
from brevitas.quant.solver.weight import SolveWeightScalingPerOutputChannelShapeFromModule
Expand Down Expand Up @@ -334,20 +335,41 @@ class WeightPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum,
stats_reduce_dim = SCALING_STATS_REDUCE_DIM
scaling_per_output_type = ScalingPerOutputType.CHANNEL


class PerChannelL2Norm(ExtendedInjector):
stats_reduce_dim = SCALING_STATS_REDUCE_DIM
normalize_stats_impl = L2Norm


class PerChannelPreNorm(ExtendedInjector):

pre_scaling_impl = ParameterPreScalingWeightNorm
scaling_stats_input_view_shape_impl = OverOutputChannelView
scaling_impl = (this<<1).scaling_impl
normalize_stats_impl = (this<<1).normalize_stats_impl
tracked_parameter_list = (this<<1).tracked_parameter_list
pre_scaling_shape = (this<<1).pre_scaling_shape
scaling_impl = (this << 1).scaling_impl
normalize_stats_impl = (this << 1).normalize_stats_impl
tracked_parameter_list = (this << 1).tracked_parameter_list
pre_scaling_shape = (this << 1).pre_scaling_shape


class SolvePostScaleGranularity(ExtendedInjector):

@value
def scaling_stats_input_view_shape_impl(scaling_per_output_type):
if scaling_per_output_type == ScalingPerOutputType.TENSOR:
return StatsInputViewShapeImpl.OVER_TENSOR
elif scaling_per_output_type == ScalingPerOutputType.CHANNEL:
return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS

@value
def stats_reduce_dim(scaling_per_output_type):
if scaling_per_output_type == ScalingPerOutputType.TENSOR:
return None
elif scaling_per_output_type == ScalingPerOutputType.CHANNEL:
return SCALING_STATS_REDUCE_DIM


class WeightNormPerChannelFloatDecoupled(SolveStatsReduceDimFromEnum,
class WeightNormPerChannelFloatDecoupled(SolvePostScaleGranularity,
SolveStatsReduceDimFromEnum,
SolveScalingStatsInputViewShapeImplFromEnum,
SolveWeightScalingStatsInputDimsFromModule,
SolveWeightScalingPerOutputChannelShapeFromModule,
Expand Down Expand Up @@ -384,20 +406,18 @@ def scaling_init(scaling_init_impl, bit_width):
scaling_stats_impl = AbsMax
restrict_pre_scaling_impl = LogFloatRestrictValue
normalize_stats_impl = PerChannelL2Norm.normalize_stats_impl
scaling_per_output_type = ScalingPerOutputType.TENSOR
scaling_per_output_type = ScalingPerOutputType.CHANNEL
pre_scaling_shape = this.scaling_per_output_channel_shape
int_scaling_impl = SingleArgStatelessBuffer(1.)
zero_point_impl = ZeroZeroPoint
pre_zero_point_impl = ZeroZeroPoint
bit_width_impl = BitWidthConst
narrow_range = True
signed = True
scaling_stats_input_view_shape_impl = StatsInputViewShapeImpl.OVER_TENSOR
stats_reduce_dim = None
scaling_min_val = 1e-10
pre_scaling_min_val = 1e-10

@value
@value
def pre_scaling_impl():
return this.per_channel_pre_norm.pre_scaling_impl

Expand Down
7 changes: 7 additions & 0 deletions tests/brevitas/nn/nn_quantizers_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from brevitas import torch_version
import brevitas.config as config
from brevitas.inject.enum import ScalingPerOutputType
from brevitas.nn import QuantConv1d
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantConv3d
Expand Down Expand Up @@ -48,6 +49,11 @@
EMBED_DIM = 9
NUM_HEADS = 3


class Int8WeightNormL2PerChannelPerTensorFixedPoint(Int8WeightNormL2PerChannelFixedPoint):
scaling_per_output_type = ScalingPerOutputType.TENSOR


LSTM_WEIGHT_QUANTIZER = {
'None': None,
'quant_sym': Int8WeightPerTensorFloat,
Expand All @@ -62,6 +68,7 @@
'quant_sym': Int8WeightPerTensorFloat,
'quant_asym': ShiftedUint8WeightPerTensorFloat,
'quant_decoupled': Int8WeightNormL2PerChannelFixedPoint,
'quant_decoupled_per_tensor': Int8WeightNormL2PerChannelPerTensorFixedPoint,
'quant_mx': MXInt8Weight,
'quant_float': Fp8e4m3WeightPerTensorFloat,
**A2Q_WBIOL_WEIGHT_QUANTIZER}
Expand Down

0 comments on commit 028c352

Please sign in to comment.