Skip to content

Commit

Permalink
Export improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
volcacius committed Jul 17, 2023
1 parent 5fe532f commit b6ed36e
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/brevitas_examples/llm/llm_quant/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from brevitas.export.manager import BaseManager
from brevitas.nn import QuantLinear
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector
from brevitas_examples.llm.llm_quant.mlir_custom_mm import * # noqa


class WeightBlockQuantHandlerBase(BaseHandler, ABC):
Expand All @@ -30,6 +29,7 @@ def __init__(self):
self.scale = None
self.zero_point = None
self.bit_width = None
self.dtype = None

def scaling_impl(self, proxy_module):
return proxy_module.tensor_quant.scaling_impl
Expand Down Expand Up @@ -77,11 +77,12 @@ def prepare_for_export(self, module):
quant_layer = module.tracked_module_list[0]
quant_weight = quant_layer.quant_weight()
self.int_weight = quant_weight.int().detach()
self.scale = self.export_scale(module, self.bit_width)
self.dtype = quant_weight.value.dtype
self.scale = self.export_scale(module, self.bit_width).detach()
self.expanded_scaling_shape = self.scaling_impl(module).expanded_scaling_shape
self.reshaped_scaling_shape = self.scaling_impl(module).reshaped_scaling_shape
if (quant_weight.zero_point != 0.).any():
self.zero_point = self.export_zero_point(module, self.scale, self.bit_width)
self.zero_point = self.export_zero_point(module, self.scale, self.bit_width).detach()
self.expanded_zero_point_shape = self.zero_point_impl(module).expanded_zero_point_shape
self.reshaped_zero_point_shape = self.zero_point_impl(module).reshaped_zero_point_shape

Expand All @@ -94,7 +95,8 @@ def forward(self, x):
zero_point = self.zero_point.expand(self.expanded_zero_point_shape).contiguous()
# contiguous above is to avoid the reshape below being mapped to a unsafe view
zero_point = zero_point.view(self.reshaped_zero_point_shape)
int_weight = int_weight - zero_point
# avoid unsigned subtraction
int_weight = int_weight.to(self.dtype) - zero_point.to(self.dtype)
else:
zero_point = torch.zeros_like(scale)
quant_weight = int_weight * scale
Expand Down

0 comments on commit b6ed36e

Please sign in to comment.