Skip to content

Commit

Permalink
Merge pull request #6 from Xilinx/clean_dict
Browse files Browse the repository at this point in the history
Cleanup state dict of quantized model
  • Loading branch information
volcacius authored Oct 15, 2019
2 parents 180a9d7 + 5d713c5 commit 5db353c
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 49 deletions.
10 changes: 10 additions & 0 deletions brevitas/core/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
if parameter_key in missing_keys:
missing_keys.remove(parameter_key)

def state_dict(self, destination=None, prefix='', keep_vars=False):
output_dict = super(_ViewParameterWrapper, self).state_dict(destination, prefix, keep_vars)
del output_dict[prefix + 'parameter']
return output_dict


class _ViewCatParameterWrapper(torch.jit.ScriptModule):
__constants__ = ['shape', 'cat_dim']
Expand All @@ -114,6 +119,11 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
if parameter_key in missing_keys:
missing_keys.remove(parameter_key)

def state_dict(self, destination=None, prefix='', keep_vars=False):
output_dict = super(_ViewCatParameterWrapper, self).state_dict(destination, prefix, keep_vars)
del output_dict[prefix + 'parameter']
return output_dict


class AbsMax(torch.jit.ScriptModule):
__constants__ = ['reduce_dim']
Expand Down
22 changes: 4 additions & 18 deletions brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from torch import nn, Tensor


from brevitas.core import ZERO_HW_SENTINEL_NAME, ZERO_HW_SENTINEL_VALUE
from brevitas.core import ZERO_HW_SENTINEL_NAME
from brevitas.core.bit_width import BitWidthConst, BitWidthParameter, BitWidthImplType
from brevitas.core.function_wrapper import TensorClampSte, TensorClamp
from brevitas.core.quant import IdentityQuant
Expand All @@ -62,6 +62,8 @@
from brevitas import config
from brevitas.config import docstrings

from .quant_proxy import QuantProxy

__all__ = ['WeightQuantProxy', 'BiasQuantProxy']


Expand All @@ -78,13 +80,9 @@ def forward(self, weight):
return weight + 0


class ParameterQuantProxy(nn.Module):
class ParameterQuantProxy(QuantProxy):
__metaclass__ = ABCMeta

def __init__(self):
super(ParameterQuantProxy, self).__init__()
self.register_buffer(ZERO_HW_SENTINEL_NAME, torch.tensor(ZERO_HW_SENTINEL_VALUE))

@property
def tensor_quant(self):
return self._tensor_quant
Expand Down Expand Up @@ -381,9 +379,6 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
super(WeightQuantProxy, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
zero_hw_sentinel_key = prefix + ZERO_HW_SENTINEL_NAME
if zero_hw_sentinel_key in missing_keys:
missing_keys.remove(zero_hw_sentinel_key)
if config.REINIT_WEIGHT_QUANT_ON_LOAD:
self.re_init_tensor_quant()

Expand Down Expand Up @@ -436,13 +431,4 @@ def forward(self,
else:
return x

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
super(BiasQuantProxy, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
zero_hw_sentinel_key = prefix + ZERO_HW_SENTINEL_NAME
if zero_hw_sentinel_key in missing_keys:
missing_keys.remove(zero_hw_sentinel_key)



29 changes: 29 additions & 0 deletions brevitas/proxy/quant_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from abc import ABCMeta

import torch
from torch import nn

from brevitas.core import ZERO_HW_SENTINEL_NAME, ZERO_HW_SENTINEL_VALUE


class QuantProxy(nn.Module):
__metaclass__ = ABCMeta

def __init__(self):
super(QuantProxy, self).__init__()
self.register_buffer(ZERO_HW_SENTINEL_NAME, torch.tensor(ZERO_HW_SENTINEL_VALUE))

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
super(QuantProxy, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
zero_hw_sentinel_key = prefix + ZERO_HW_SENTINEL_NAME
if zero_hw_sentinel_key in missing_keys:
missing_keys.remove(zero_hw_sentinel_key)
if zero_hw_sentinel_key in unexpected_keys: # for retrocompatibility with when it wasn't removed
unexpected_keys.remove(zero_hw_sentinel_key)

def state_dict(self, destination=None, prefix='', keep_vars=False):
output_dict = super(QuantProxy, self).state_dict(destination, prefix, keep_vars)
del output_dict[prefix + ZERO_HW_SENTINEL_NAME]
return output_dict
36 changes: 5 additions & 31 deletions brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
from brevitas.core.scaling import ScalingImplType, StandaloneScaling, IntScaling
from brevitas.core.stats import StatsOp

from .quant_proxy import QuantProxy


class FusedActivationQuantProxy(torch.jit.ScriptModule):

Expand All @@ -72,7 +74,7 @@ def forward(self, x, zero_hw_sentinel):
return x, output_scale, output_bit_width


class ActivationQuantProxy(Module):
class ActivationQuantProxy(QuantProxy):

def __init__(self,
activation_impl: Module,
Expand Down Expand Up @@ -107,8 +109,6 @@ def __init__(self,
if scaling_per_channel and per_channel_broadcastable_shape is None:
raise Exception("Per channel scaling requires to specify number of channels.")

self.register_buffer(ZERO_HW_SENTINEL_NAME, torch.tensor(ZERO_HW_SENTINEL_VALUE))

if scaling_per_channel and not scaling_stats_op == StatsOp.MAX_AVE:
scaling_shape = per_channel_broadcastable_shape
scaling_stats_reduce_dim = 1
Expand Down Expand Up @@ -210,16 +210,8 @@ def forward(self, x):
output, output_scale, output_bit_width = self.fused_activation_quant_proxy(x, self.zero_hw_sentinel)
return output, output_scale, output_bit_width

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
super(ActivationQuantProxy, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
zero_hw_sentinel_key = prefix + ZERO_HW_SENTINEL_NAME
if zero_hw_sentinel_key in missing_keys:
missing_keys.remove(zero_hw_sentinel_key)


class ClampQuantProxy(Module):
class ClampQuantProxy(QuantProxy):

def __init__(self,
signed: bool,
Expand All @@ -232,7 +224,6 @@ def __init__(self,
msb_clamp_bit_width_impl_type: BitWidthImplType,
override_pretrained_bit_width: bool):
super(ClampQuantProxy, self).__init__()
self.register_buffer(ZERO_HW_SENTINEL_NAME, torch.tensor(ZERO_HW_SENTINEL_VALUE))

if quant_type == QuantType.FP:
self.tensor_quant = IdentityPrescaledIntQuant()
Expand Down Expand Up @@ -261,16 +252,8 @@ def forward(self, x, input_scale, input_bit_width):
x, output_scale, output_bit_width = self.tensor_quant(x, input_scale, input_bit_width, self.zero_hw_sentinel)
return x, output_scale, output_bit_width

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
super(ClampQuantProxy, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
zero_hw_sentinel_key = prefix + ZERO_HW_SENTINEL_NAME
if zero_hw_sentinel_key in missing_keys:
missing_keys.remove(zero_hw_sentinel_key)


class TruncQuantProxy(Module):
class TruncQuantProxy(QuantProxy):

def __init__(self,
signed: bool,
Expand All @@ -283,7 +266,6 @@ def __init__(self,
explicit_rescaling: bool,
override_pretrained_bit_width: bool):
super(TruncQuantProxy, self).__init__()
self.register_buffer(ZERO_HW_SENTINEL_NAME, torch.tensor(ZERO_HW_SENTINEL_VALUE))
self.explicit_rescaling = explicit_rescaling

if quant_type == QuantType.FP:
Expand Down Expand Up @@ -318,11 +300,3 @@ def forward(self, x, input_scale, input_bit_width):
output_scale = output_scale / trunc_scale
output_bit_width = input_bit_width - trunc_bit_width
return x, output_scale, output_bit_width

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
super(TruncQuantProxy, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
zero_hw_sentinel_key = prefix + ZERO_HW_SENTINEL_NAME
if zero_hw_sentinel_key in missing_keys:
missing_keys.remove(zero_hw_sentinel_key)

0 comments on commit 5db353c

Please sign in to comment.