Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 16, 2024
1 parent b08f120 commit ca97643
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 42 deletions.
5 changes: 3 additions & 2 deletions src/brevitas/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from abc import ABCMeta
from abc import abstractmethod
import warnings

from brevitas import config

Expand All @@ -29,8 +30,8 @@ def export_mode(self):
@export_mode.setter
def export_mode(self, value):
if value and config.JIT_ENABLED:
raise RuntimeError(
"Export mode with BREVITAS_JIT is currently not supported. Save the model' "
warnings.warn(
"Export mode with BREVITAS_JIT might fail. If so, save the model' "
"state_dict to a .pth, load it back with BREVITAS_JIT=0, and call export.")
if value and self.training:
raise RuntimeError("Can't enter export mode during training, only during inference")
Expand Down
101 changes: 61 additions & 40 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from typing import Tuple

import torch
from torch import Tensor
import torch.nn as nn

from brevitas import is_dynamo_compiling
from brevitas.function.ops import max_float
Expand All @@ -28,19 +30,19 @@

class InferenceHandler(torch.nn.Module, ABC):

def attach_debug_info(self, module):
def attach_debug_info(self, module: nn.Module):
pass

@abstractmethod
def prepare_for_export(self, module):
def prepare_for_export(self, module: nn.Module):
pass

@abstractmethod
def quantize(self, x):
def quantize(self, x: Tensor):
pass

@abstractmethod
def dequantize(self, x):
def dequantize(self, x: Tensor):
pass


Expand All @@ -52,24 +54,21 @@ def __init__(self):
self.register_buffer('scale', torch.ones(1))
self.register_buffer('zero_point', torch.ones(0))

def attach_debug_info(self, module):
pass

def prepare_for_export(self, module):
def prepare_for_export(self, module: nn.Module):
if module.is_quant_enabled:
self.scale = module.scale()
self.zero_point = module.zero_point().to(self.scale.device)
self.bit_width = module.bit_width()
self.min_clamp = min_int(module.is_signed, module.is_narrow_range, self.bit_width)
self.max_clamp = max_int(module.is_signed, module.is_narrow_range, self.bit_width)

def quantize(self, x, scale, zero_point):
def quantize(self, x: Tensor, scale: Tensor, zero_point: Tensor) -> Tuple[Tensor]:
return torch.clamp(torch.round(x / scale + zero_point), self.min_clamp, self.max_clamp)

def dequantize(self, x, scale, zero_point):
def dequantize(self, x: Tensor, scale: Tensor, zero_point: Tensor) -> Tensor:
return (x - zero_point) * scale

def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:
return self.dequantize(self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point), self.scale, self.zero_point, self.bit_width


Expand All @@ -80,14 +79,15 @@ def __init__(self):
super().__init__()
self.register_buffer('cached_weight', torch.ones(1))

def prepare_for_export(self, module):
def prepare_for_export(self, module: nn.Module):
super().prepare_for_export(module)
if module.is_quant_enabled:
self.cached_weight = None
super().prepare_for_export(module)
if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only:
self.cached_weight = module._cached_weight.value
else:
self.cached_weight = None

def forward(self, x) -> Tuple[torch.Tensor]:
def forward(self, x: Tensor) -> Tuple[Tensor]:
if self.cached_weight is not None:
x = self.cached_weight
else:
Expand All @@ -99,11 +99,11 @@ def forward(self, x) -> Tuple[torch.Tensor]:
class DynamicIntInferenceHandler(IntInferencetHandler):
handled_layer = DynamicActQuantProxyFromInjector

def prepare_for_export(self, module):
def prepare_for_export(self, module: nn.Module):
if module.is_quant_enabled:
self.module_forward = module.fused_activation_quant_proxy

def forward(self, x, ununsed_scale=None):
def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:
return self.module_forward(x)


Expand All @@ -115,7 +115,7 @@ def prepare_for_export(self, module):
self.module_forward = module.fused_activation_quant_proxy
self.group_dim = module.group_dim

def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:
x, *other = self.module_forward(x)
if is_dynamo_compiling():
start_dim = self.group_dim if self.group_dim != -1 else -2
Expand All @@ -129,12 +129,15 @@ class GroupwiseIntWeightInferenceHandler(IntWeightInferencetHandler):

def prepare_for_export(self, module):
super().prepare_for_export(module)
self.input_view = module.input_view_impl
self.flattened_view = module.apply_input_view
if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only:
self.cached_weight = module._cached_weight.quant_tensor.value_
if module.is_quant_enabled:
self.input_view = module.input_view_impl
self.flattened_view = module.apply_input_view
if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only:
self.cached_weight = module._cached_weight.quant_tensor.value_
else:
self.cached_weight = None

def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
def forward(self, x: Tensor) -> Tuple[Tensor]:
if self.scale.shape != ():
scale = self.input_view(self.scale)
else:
Expand All @@ -156,6 +159,11 @@ def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
class FloatInferencetHandler(InferenceHandler):
handled_layer = (ActFloatQuantProxyFromInjector, BiasQuantProxyFromInjector)

def __init__(self):
super().__init__()
self.register_buffer('scale', torch.ones(1))
self.register_buffer('zero_point', torch.ones(0))

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.scale = module.scale()
Expand All @@ -182,7 +190,7 @@ def prepare_for_export(self, module):
self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias)
self.min_value = torch.tensor(0.) if not module.is_signed else -self.max_value

def quantize(self, x, scale, zero_point):
def quantize(self, x: Tensor, scale: Tensor, zero_point: Tensor) -> Tuple[Tensor]:
# Compute masks
inf_mask = x.isinf()
p_max_val_mask = x > self.max_value
Expand All @@ -200,24 +208,29 @@ def quantize(self, x, scale, zero_point):

return x

def dequantize(self, x, scale, zero_point):
def dequantize(self, x: Tensor, scale: Tensor, zero_point: Tensor) -> Tensor:
return (x - zero_point) * scale

def forward(self, x) -> Tuple[torch.Tensor]:
def forward(self, x: Tensor) -> Tuple[Tensor]:
return self.dequantize(self.quantize(x, self.scale, self.zero_point), self.scale, self.zero_point), self.scale, self.zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values


class FloatWeightInferencetHandler(FloatInferencetHandler):
handled_layer = WeightFloatQuantProxyFromInjector

def __init__(self):
super().__init__()
self.register_buffer('cached_weight', torch.ones(1))

def prepare_for_export(self, module):
super().prepare_for_export(module)
if module.is_quant_enabled:
self.cached_weight = None
super().prepare_for_export(module)
if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only:
self.cached_weight = module._cached_weight.value
else:
self.cached_weight = None

def forward(self, x) -> Tuple[torch.Tensor]:
def forward(self, x: Tensor) -> Tuple[Tensor]:
if self.cached_weight is not None:
x = self.cached_weight
else:
Expand All @@ -229,12 +242,12 @@ def forward(self, x) -> Tuple[torch.Tensor]:
class GroupwiseFloatInferenceHandler(FloatInferencetHandler):
handled_layer = GroupwiseActFloatQuantProxyFromInjector

def prepare_for_export(self, module):
def prepare_for_export(self, module: nn.Module):
if module.is_quant_enabled:
self.module_forward = module.fused_activation_quant_proxy
self.group_dim = module.group_dim

def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
def forward(self, x: Tensor) -> Tuple[Tensor]:
x, *other = self.module_forward(x)
if is_dynamo_compiling():
start_dim = self.group_dim if self.group_dim != -1 else -2
Expand All @@ -246,13 +259,17 @@ def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
class GroupwiseFloatWeightInferenceHandler(FloatWeightInferencetHandler):
handled_layer = GroupwiseWeightFloatQuantProxyFromInjector

def prepare_for_export(self, module):
def prepare_for_export(self, module: nn.Module):
super().prepare_for_export(module)
self.input_view = module.input_view_impl
self.flattened_view = module.apply_input_view
if module.is_quant_enabled:
self.input_view = module.input_view_impl
self.flattened_view = module.apply_input_view
if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only:
self.cached_weight = module._cached_weight.quant_tensor.value_
else:
self.cached_weight = None

def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
x = self.input_view(x)
def forward(self, x: Tensor) -> Tuple[Tensor]:
if self.scale.shape != ():
scale = self.input_view(self.scale)
else:
Expand All @@ -261,7 +278,11 @@ def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
zero_point = self.input_view(self.zero_point)
else:
zero_point = self.zero_point
out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point)
if is_dynamo_compiling():
out = self.flattened_view(out)
return out, scale, zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values
if self.cached_weight is not None:
out = self.cached_weight
else:
x = self.input_view(x)
out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point)
if is_dynamo_compiling():
out = self.flattened_view(out)
return out, scale, zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values

0 comments on commit ca97643

Please sign in to comment.