diff --git a/brevitas/nn/__init__.py b/brevitas/nn/__init__.py index 768da3dc0..565880fa1 100644 --- a/brevitas/nn/__init__.py +++ b/brevitas/nn/__init__.py @@ -1,5 +1,5 @@ from .quant_accumulator import ClampQuantAccumulator, TruncQuantAccumulator -from .quant_activation import QuantReLU, QuantSigmoid, QuantTanh, QuantHardTanh +from .quant_activation import QuantReLU, QuantSigmoid, QuantTanh, QuantHardTanh, QuantIdentity from .quant_avg_pool import QuantAvgPool2d from .quant_linear import QuantLinear from .quant_conv import QuantConv2d, PaddingType diff --git a/brevitas/nn/quant_activation.py b/brevitas/nn/quant_activation.py index f0ce4e5bc..6b6e88530 100644 --- a/brevitas/nn/quant_activation.py +++ b/brevitas/nn/quant_activation.py @@ -284,3 +284,57 @@ def __init__(self, scaling_stats_op=scaling_stats_op, scaling_stats_buffer_momentum=scaling_stats_buffer_momentum, scaling_stats_permute_dims=scaling_stats_permute_dims) + + +class QuantIdentity(QuantActivation): + + def __init__(self, + bit_width: int, + min_val: float = -1.0, + max_val: float = 1.0, + narrow_range: bool = False, + quant_type: QuantType = QuantType.FP, + float_to_int_impl_type: FloatToIntImplType = FloatToIntImplType.ROUND, + scaling_impl_type: ScalingImplType = ScalingImplType.PARAMETER, + scaling_override: Optional[Module] = None, + scaling_per_channel: bool = False, + scaling_stats_sigma: float = 3.0, + scaling_stats_op: StatsOp = StatsOp.MEAN_LEARN_SIGMA_STD, + scaling_stats_buffer_momentum: float = 0.1, + scaling_stats_permute_dims: Tuple = (1, 0, 2, 3), + per_channel_broadcastable_shape: Optional[Tuple[int, ...]] = None, + min_overall_bit_width: Optional[int] = 2, + max_overall_bit_width: Optional[int] = None, + bit_width_impl_override: Union[BitWidthParameter] = None, + bit_width_impl_type: BitWidthImplType = BitWidthImplType.CONST, + restrict_bit_width_type: RestrictValueType = RestrictValueType.INT, + restrict_scaling_type: RestrictValueType = RestrictValueType.LOG_FP, + scaling_min_val: Optional[float] = SCALING_MIN_VAL, + override_pretrained_bit_width: bool = False, + return_quant_tensor: bool = False): + super(QuantIdentity, self).__init__(return_quant_tensor=return_quant_tensor) + activation_impl = Identity() + self.act_quant_proxy = ActivationQuantProxy(activation_impl=activation_impl, + bit_width=bit_width, + signed=True, + narrow_range=narrow_range, + scaling_override=scaling_override, + min_val=min_val, + max_val=max_val, + quant_type=quant_type, + float_to_int_impl_type=float_to_int_impl_type, + scaling_impl_type=scaling_impl_type, + scaling_per_channel=scaling_per_channel, + scaling_min_val=scaling_min_val, + per_channel_broadcastable_shape=per_channel_broadcastable_shape, + min_overall_bit_width=min_overall_bit_width, + max_overall_bit_width=max_overall_bit_width, + bit_width_impl_override=bit_width_impl_override, + bit_width_impl_type=bit_width_impl_type, + restrict_bit_width_type=restrict_bit_width_type, + restrict_scaling_type=restrict_scaling_type, + override_pretrained_bit_width=override_pretrained_bit_width, + scaling_stats_sigma=scaling_stats_sigma, + scaling_stats_op=scaling_stats_op, + scaling_stats_buffer_momentum=scaling_stats_buffer_momentum, + scaling_stats_permute_dims=scaling_stats_permute_dims)