From 1474588045a105f170cf925265f557f83757b455 Mon Sep 17 00:00:00 2001 From: Hyeon-Seo Yun Date: Mon, 28 Nov 2022 18:00:54 -0500 Subject: [PATCH 1/2] fixed wrong runtime shape inference for BatchNorm1dToQuantScaleBias --- src/brevitas/nn/quant_bn.py | 2 +- src/brevitas/nn/quant_scale_bias.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/brevitas/nn/quant_bn.py b/src/brevitas/nn/quant_bn.py index 98c8c9c65..8a5ff930d 100644 --- a/src/brevitas/nn/quant_bn.py +++ b/src/brevitas/nn/quant_bn.py @@ -97,7 +97,7 @@ def __init__( super(BatchNorm1dToQuantScaleBias, self).__init__( num_features, bias=True, - runtime_shape=(1, -1, 1), + runtime_shape=(1, -1), weight_quant=weight_quant, bias_quant=bias_quant, input_quant=input_quant, diff --git a/src/brevitas/nn/quant_scale_bias.py b/src/brevitas/nn/quant_scale_bias.py index d2ebbf474..d50d1e3ff 100644 --- a/src/brevitas/nn/quant_scale_bias.py +++ b/src/brevitas/nn/quant_scale_bias.py @@ -80,8 +80,9 @@ def __init__( input_quant: Optional[ActQuantType] = None, output_quant: Optional[ActQuantType] = None, return_quant_tensor: bool = False, + runtime_shape=(1, -1, 1, 1), **kwargs) -> None: - ScaleBias.__init__(self, num_features, bias) + ScaleBias.__init__(self, num_features, bias, runtime_shape=runtime_shape) QuantWBIOL.__init__( self, weight_quant=weight_quant, From abc5ae40c511c291a1dc4300ec312f2e5f05b583 Mon Sep 17 00:00:00 2001 From: Hyeon-Seo Yun Date: Tue, 6 Dec 2022 15:58:18 -0500 Subject: [PATCH 2/2] reverted the default runtime_shape of BatchNorm1dToQuantScaleBias from (1,-1) to (1, -1, 1), as requested by the maintainer of brevitas branch --- src/brevitas/nn/quant_bn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/nn/quant_bn.py b/src/brevitas/nn/quant_bn.py index 8a5ff930d..98c8c9c65 100644 --- a/src/brevitas/nn/quant_bn.py +++ b/src/brevitas/nn/quant_bn.py @@ -97,7 +97,7 @@ def __init__( super(BatchNorm1dToQuantScaleBias, self).__init__( num_features, bias=True, - runtime_shape=(1, -1), + runtime_shape=(1, -1, 1), weight_quant=weight_quant, bias_quant=bias_quant, input_quant=input_quant,