Skip to content

Commit

Permalink
Activations (nn): Add QuantIdentity
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Apr 7, 2020
1 parent a730ec5 commit 9b57b61
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
2 changes: 1 addition & 1 deletion brevitas/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .quant_accumulator import ClampQuantAccumulator, TruncQuantAccumulator
from .quant_activation import QuantReLU, QuantSigmoid, QuantTanh, QuantHardTanh
from .quant_activation import QuantReLU, QuantSigmoid, QuantTanh, QuantHardTanh, QuantIdentity
from .quant_avg_pool import QuantAvgPool2d
from .quant_linear import QuantLinear
from .quant_conv import QuantConv2d, PaddingType
Expand Down
54 changes: 54 additions & 0 deletions brevitas/nn/quant_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,3 +284,57 @@ def __init__(self,
scaling_stats_op=scaling_stats_op,
scaling_stats_buffer_momentum=scaling_stats_buffer_momentum,
scaling_stats_permute_dims=scaling_stats_permute_dims)


class QuantIdentity(QuantActivation):

def __init__(self,
bit_width: int,
min_val: float = -1.0,
max_val: float = 1.0,
narrow_range: bool = False,
quant_type: QuantType = QuantType.FP,
float_to_int_impl_type: FloatToIntImplType = FloatToIntImplType.ROUND,
scaling_impl_type: ScalingImplType = ScalingImplType.PARAMETER,
scaling_override: Optional[Module] = None,
scaling_per_channel: bool = False,
scaling_stats_sigma: float = 3.0,
scaling_stats_op: StatsOp = StatsOp.MEAN_LEARN_SIGMA_STD,
scaling_stats_buffer_momentum: float = 0.1,
scaling_stats_permute_dims: Tuple = (1, 0, 2, 3),
per_channel_broadcastable_shape: Optional[Tuple[int, ...]] = None,
min_overall_bit_width: Optional[int] = 2,
max_overall_bit_width: Optional[int] = None,
bit_width_impl_override: Union[BitWidthParameter] = None,
bit_width_impl_type: BitWidthImplType = BitWidthImplType.CONST,
restrict_bit_width_type: RestrictValueType = RestrictValueType.INT,
restrict_scaling_type: RestrictValueType = RestrictValueType.LOG_FP,
scaling_min_val: Optional[float] = SCALING_MIN_VAL,
override_pretrained_bit_width: bool = False,
return_quant_tensor: bool = False):
super(QuantIdentity, self).__init__(return_quant_tensor=return_quant_tensor)
activation_impl = Identity()
self.act_quant_proxy = ActivationQuantProxy(activation_impl=activation_impl,
bit_width=bit_width,
signed=True,
narrow_range=narrow_range,
scaling_override=scaling_override,
min_val=min_val,
max_val=max_val,
quant_type=quant_type,
float_to_int_impl_type=float_to_int_impl_type,
scaling_impl_type=scaling_impl_type,
scaling_per_channel=scaling_per_channel,
scaling_min_val=scaling_min_val,
per_channel_broadcastable_shape=per_channel_broadcastable_shape,
min_overall_bit_width=min_overall_bit_width,
max_overall_bit_width=max_overall_bit_width,
bit_width_impl_override=bit_width_impl_override,
bit_width_impl_type=bit_width_impl_type,
restrict_bit_width_type=restrict_bit_width_type,
restrict_scaling_type=restrict_scaling_type,
override_pretrained_bit_width=override_pretrained_bit_width,
scaling_stats_sigma=scaling_stats_sigma,
scaling_stats_op=scaling_stats_op,
scaling_stats_buffer_momentum=scaling_stats_buffer_momentum,
scaling_stats_permute_dims=scaling_stats_permute_dims)

0 comments on commit 9b57b61

Please sign in to comment.