diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index ca2baea17..51abdc1a6 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -53,9 +53,9 @@ def __init__( @brevitas.jit.script_method def forward( self, ignored: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: + stats = self.parameter_list_stats() if threshold is None: threshold = torch.ones(1).type_as(stats) - stats = self.parameter_list_stats() return self.stats_scaling_impl(stats, threshold) diff --git a/tests/brevitas/core/test_binary_quant.py b/tests/brevitas/core/test_binary_quant.py index 7a7df3c44..4f82e4815 100644 --- a/tests/brevitas/core/test_binary_quant.py +++ b/tests/brevitas/core/test_binary_quant.py @@ -31,7 +31,7 @@ def test_binary_quant(self, binary_quant_impl_all, inp, scale_init): scaling_impl = mock.Mock(return_value=scale_init) binary_quant = binary_quant_impl_all(scaling_impl) output, scale, zp, bit_width = binary_quant(inp) - scaling_impl.assert_called_once_with(inp, torch.tensor(1.).type_as(inp)) + scaling_impl.assert_called_once_with(inp) assert is_binary_output_value_correct(scale, output) assert is_binary_output_sign_correct(inp, output) assert (scale == scale_init).all() @@ -81,4 +81,4 @@ def test_output_zero_point(self, binary_quant_all, inp): @given(inp=float_tensor_random_shape_st()) def test_output_scale(self, binary_quant_all, scaling_impl_all, inp): _, scale, _, _ = binary_quant_all(inp) - assert_allclose(scale, scaling_impl_all(inp, torch.tensor(1.).type_as(inp))) + assert_allclose(scale, scaling_impl_all(inp)) diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 16b8a4b5f..b4f2cc89e 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -110,7 +110,7 @@ def test_float_to_quant_float(inp, minifloat_format): @jit_disabled_for_mock() def test_scaling_impls_called_once(inp, minifloat_format): bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format - scaling_impl = mock.Mock(side_effect=lambda x: 1.) + scaling_impl = mock.Mock(side_effect=lambda x, y: 1.) float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: 1.) if exponent_bit_width == 0 or mantissa_bit_width == 0: with pytest.raises(RuntimeError): @@ -160,7 +160,7 @@ def test_inner_scale(inp, minifloat_format, scale): bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format # set scaling_impl to scale and float_scaling_impl to 1 to use the same scale as we are here float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: 1.) - scaling_impl = mock.Mock(side_effect=lambda x: scale) + scaling_impl = mock.Mock(side_effect=lambda x, y: scale) if exponent_bit_width == 0 or mantissa_bit_width == 0: with pytest.raises(RuntimeError): float_quant = FloatQuant(