From 7e26cbf53f84402aad832db853bfed441fdb6422 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 6 Nov 2023 18:44:08 +0000 Subject: [PATCH] Fix (core/bit_width): fix in _load_from_state_dict --- src/brevitas/core/bit_width/parameter.py | 6 ++++-- tests/brevitas/core/test_bit_width.py | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/brevitas/core/bit_width/parameter.py b/src/brevitas/core/bit_width/parameter.py index 8a54a07df..d5dd63adc 100644 --- a/src/brevitas/core/bit_width/parameter.py +++ b/src/brevitas/core/bit_width/parameter.py @@ -106,7 +106,8 @@ def _load_from_state_dict( del state_dict[bit_width_offset_key] super(BitWidthParameter, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - if config.IGNORE_MISSING_KEYS and bit_width_offset_key in missing_keys: + if (config.IGNORE_MISSING_KEYS or + self.override_pretrained) and bit_width_offset_key in missing_keys: missing_keys.remove(bit_width_offset_key) @@ -147,5 +148,6 @@ def _load_from_state_dict( del state_dict[bit_width_coeff_key] super(RemoveBitwidthParameter, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - if config.IGNORE_MISSING_KEYS and bit_width_coeff_key in missing_keys: + if (config.IGNORE_MISSING_KEYS or + self.override_pretrained) and bit_width_coeff_key in missing_keys: missing_keys.remove(bit_width_coeff_key) diff --git a/tests/brevitas/core/test_bit_width.py b/tests/brevitas/core/test_bit_width.py index e8b1c7879..51883ae18 100644 --- a/tests/brevitas/core/test_bit_width.py +++ b/tests/brevitas/core/test_bit_width.py @@ -142,6 +142,7 @@ def test_load_from_stateful_const( """ if (bit_width_init_two < min_bit_width_init) and not override_pretrained: pytest.xfail('bit_width cannot be smaller than min_bit_width') + override_value = bit_width_parameter.bit_width_offset bit_width_parameter.load_state_dict(bit_width_stateful_const.state_dict()) bit_width_parameter_tensor = bit_width_parameter()