Skip to content

Commit

Permalink
Feat (inference_mode): pickle compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 18, 2024
1 parent c3208cf commit b5fdc91
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/brevitas/export/inference/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ def _override_weight_caching_mode(m: nn.Module, enabled: bool, metadata_only: bo

class quant_inference_mode:

def __init__(self, model, cache_quant_weight=False, enabled=True):
def __init__(self, model, cache_quant_weight=False, delete_injector=False, enabled=True):
self.model = model
self.enabled = enabled
self.delete_injector = delete_injector
self.injector_reference = dict()
self.cache_quant_weight = cache_quant_weight
self.export_manager = InferenceManager
self.hook_list = []
Expand Down Expand Up @@ -74,6 +76,10 @@ def __exit__(self, type, value, traceback):
lambda m: _override_weight_caching_mode(m, enabled=False, metadata_only=False))
InferenceManager.set_export_mode(self.model, enabled=False)
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)
if self.delete_injector:
for m in self.model.modules():
if m in self.injector_reference:
m.quant_injector = self.injector_reference[m]

def hook(self, module, inp, out):
# After one forward pass with caching enabled, we can:
Expand All @@ -85,6 +91,11 @@ def hook(self, module, inp, out):
self.model.apply(InferenceManager.set_export_handler)
InferenceManager.set_export_mode(self.model, enabled=True)
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
if self.delete_injector:
for m in self.model.modules():
if hasattr(m, 'quant_injector'):
self.injector_reference[m] = m.quant_injector
del m.quant_injector


# Inheritance from BaseManager is not techincally needed
Expand Down

0 comments on commit b5fdc91

Please sign in to comment.