From b6ed36e218df13f9994428f467d21553b645e5c4 Mon Sep 17 00:00:00 2001 From: Alessandro Pappalardo Date: Fri, 7 Jul 2023 00:35:44 +0100 Subject: [PATCH] Export improvements --- src/brevitas_examples/llm/llm_quant/export.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/export.py b/src/brevitas_examples/llm/llm_quant/export.py index 782b03030..2edd9f777 100644 --- a/src/brevitas_examples/llm/llm_quant/export.py +++ b/src/brevitas_examples/llm/llm_quant/export.py @@ -18,7 +18,6 @@ from brevitas.export.manager import BaseManager from brevitas.nn import QuantLinear from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector -from brevitas_examples.llm.llm_quant.mlir_custom_mm import * # noqa class WeightBlockQuantHandlerBase(BaseHandler, ABC): @@ -30,6 +29,7 @@ def __init__(self): self.scale = None self.zero_point = None self.bit_width = None + self.dtype = None def scaling_impl(self, proxy_module): return proxy_module.tensor_quant.scaling_impl @@ -77,11 +77,12 @@ def prepare_for_export(self, module): quant_layer = module.tracked_module_list[0] quant_weight = quant_layer.quant_weight() self.int_weight = quant_weight.int().detach() - self.scale = self.export_scale(module, self.bit_width) + self.dtype = quant_weight.value.dtype + self.scale = self.export_scale(module, self.bit_width).detach() self.expanded_scaling_shape = self.scaling_impl(module).expanded_scaling_shape self.reshaped_scaling_shape = self.scaling_impl(module).reshaped_scaling_shape if (quant_weight.zero_point != 0.).any(): - self.zero_point = self.export_zero_point(module, self.scale, self.bit_width) + self.zero_point = self.export_zero_point(module, self.scale, self.bit_width).detach() self.expanded_zero_point_shape = self.zero_point_impl(module).expanded_zero_point_shape self.reshaped_zero_point_shape = self.zero_point_impl(module).reshaped_zero_point_shape @@ -94,7 +95,8 @@ def forward(self, x): zero_point = self.zero_point.expand(self.expanded_zero_point_shape).contiguous() # contiguous above is to avoid the reshape below being mapped to a unsafe view zero_point = zero_point.view(self.reshaped_zero_point_shape) - int_weight = int_weight - zero_point + # avoid unsigned subtraction + int_weight = int_weight.to(self.dtype) - zero_point.to(self.dtype) else: zero_point = torch.zeros_like(scale) quant_weight = int_weight * scale