Skip to content

Commit

Permalink
Updated groupwise int quant tensor and notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Aug 13, 2024
1 parent abcac4e commit c1de55e
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 36 deletions.
22 changes: 11 additions & 11 deletions notebooks/minifloat_mx_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
],
"source": [
"from brevitas.quant.experimental.float_base import Fp8e4m3Mixin\n",
"from brevitas.quant.experimental.mx_quant_ocp import MXFloatWeight\n",
"from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight\n",
"from brevitas.quant.experimental.float_quant_ocp import FpOCPWeightPerTensorFloat, FpOCPActPerTensorFloat\n",
"from brevitas.quant.experimental.mx_quant_ocp import MXFloatAct\n",
"from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act\n",
"import brevitas.nn as qnn\n",
"import torch.nn as nn\n",
"import torch\n",
Expand Down Expand Up @@ -72,12 +72,12 @@
"from brevitas.quant_tensor import GroupwiseFloatQuantTensor\n",
"\n",
"\n",
"class MXFloat8Weight(MXFloatWeight, Fp8e4m3Mixin):\n",
"class MXFloat8Weight(MXFloat8e4m3Weight, Fp8e4m3Mixin):\n",
" # The group dimension for the weights it is automatically identified based on the layer type\n",
" # If a new layer type is used, it can be manually specified\n",
" pass\n",
"\n",
"class MXFloat8Act(MXFloatAct, Fp8e4m3Mixin):\n",
"class MXFloat8Act(MXFloat8e4m3Act, Fp8e4m3Mixin):\n",
" # It is necessary to specify the group dimension for the activation quantization\n",
" group_dim = 1\n",
"\n",
Expand Down Expand Up @@ -105,16 +105,16 @@
"metadata": {},
"outputs": [],
"source": [
"from brevitas.quant.experimental.mx_quant_ocp import MXFloatWeightMSE\n",
"from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3WeightMSE\n",
"from brevitas.quant_tensor import GroupwiseFloatQuantTensor\n",
"\n",
"\n",
"class MXFloat8Weight(MXFloatWeightMSE, Fp8e4m3Mixin):\n",
"class MXFloat8Weight(MXFloat8e4m3WeightMSE, Fp8e4m3Mixin):\n",
" # The group dimension for the weights it is automatically identified based on the layer type\n",
" # If a new layer type is used, it can be manually specified\n",
" pass\n",
"\n",
"class MXFloat8Act(MXFloatAct, Fp8e4m3Mixin):\n",
"class MXFloat8Act(MXFloat8e4m3Act, Fp8e4m3Mixin):\n",
" # It is necessary to specify the group dimension for the activation quantization\n",
" group_dim = 1\n",
"\n",
Expand Down Expand Up @@ -143,18 +143,18 @@
"outputs": [],
"source": [
"from brevitas.quant_tensor import GroupwiseIntQuantTensor\n",
"from brevitas.quant.experimental.mx_quant_ocp import MXIntWeight\n",
"from brevitas.quant.experimental.mx_quant_ocp import MXIntAct\n",
"from brevitas.quant.experimental.mx_quant_ocp import MXInt8Weight\n",
"from brevitas.quant.experimental.mx_quant_ocp import MXInt8Act\n",
"import torch.nn as nn\n",
"import brevitas.nn as qnn\n",
"import torch\n",
"\n",
"class MXFloat8Weight(MXIntWeight):\n",
"class MXFloat8Weight(MXInt8Weight):\n",
" # The group dimension for the weights it is automatically identified based on the layer type\n",
" # If a new layer type is used, it can be manually specified\n",
" bit_width = 8\n",
"\n",
"class MXFloat8Act(MXIntAct):\n",
"class MXFloat8Act(MXInt8Act):\n",
" # It is necessary to specify the group dimension for the activation quantization\n",
" group_dim = 1\n",
" bit_width = 8\n",
Expand Down
72 changes: 47 additions & 25 deletions src/brevitas/quant_tensor/groupwise_int_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch

from brevitas.function.ops_ste import round_ste
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor.base_quant_tensor import GroupwisIntQuantTensorBase
from brevitas.quant_tensor.base_quant_tensor import QuantTensor
Expand Down Expand Up @@ -102,28 +103,41 @@ def zero_point(self):
return new_zp

@property
def _pre_round_float_value(self):
value, scale, zp = self.expand()
def _pre_round_int_value(self):
value = self.value
scale = self.scale
zero_point = self.zero_point
if self.scale.dtype == torch.bfloat16:
value = value.type(torch.float32)
scale = scale.type(torch.float32)
minifloat_value = value / scale
fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width
int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale)
minifloat_value = minifloat_value / int_scale
return minifloat_value
value = self.value.type(torch.float32)
scale = self.scale.type(torch.float32)
zero_point = self.zero_point.type(torch.float32)
int_value = value / scale
int_value = int_value + zero_point
return int_value

@property
def is_valid(self):
with torch.no_grad():
pre_round_minifloat_value = self._pre_round_float_value
rounded_minifloat_value = torch.round(pre_round_minifloat_value)
max_abs_diff = torch.max(torch.abs(pre_round_minifloat_value - rounded_minifloat_value))
pre_round_int_value = self._pre_round_int_value
rounded_int_value = torch.round(pre_round_int_value)
max_abs_diff = torch.max(torch.abs(pre_round_int_value - rounded_int_value))
atol = BFLOAT16_IS_VALID_ATOL if self.value.dtype == torch.bfloat16 else IS_VALID_ATOL
is_minifloat = max_abs_diff < atol
# We are missing the checks about self being contained between max and min value
# given by mantissa, exponent, inf, nan, and saturating
return is_minifloat
is_int = max_abs_diff < atol
if self.bit_width >= 2:
if self.signed:
is_upper_b = (2.0 ** (self.bit_width - 1) - 1 >= rounded_int_value).all()
is_lower_b = (-2.0 ** (self.bit_width - 1) <= rounded_int_value).all()
else:
is_upper_b = (2.0 ** self.bit_width - 1 >= rounded_int_value).all()
is_lower_b = (0. <= rounded_int_value).all()
return (is_int & is_upper_b & is_lower_b).item()
else: # binary case
unique_vals = rounded_int_value.unique(
sorted=False, return_counts=False, return_inverse=False)
is_binary = unique_vals.view(-1).size()[0] == 2
is_signed = (unique_vals < 0.).any().item()
sign_match = is_signed == self.signed
return is_int.item() and is_binary and sign_match

@property
def device(self):
Expand All @@ -139,17 +153,25 @@ def device(self):
raise RuntimeError("Value and metadata are on different devices")
return value_device

def minifloat(self, float_datatype=True):
# TODO: Check if OCP and cast to proper data-type if matching
assert float_datatype, "Minifloat quant returns only higher precision dtype"

def int(self, float_datatype=False):
if self.is_valid:
fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width
int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale)
float_value = torch.round(self._pre_round_float_value) * int_scale
return float_value.type(self.scale.dtype)
int_value = round_ste(self._pre_round_int_value)
if float_datatype:
# Values at 8bit and lower can be represented exactly with float16 and bfloat16
# otherwise (e.g. Int16 bias), we upscale to float32
if self.bit_width <= 8.:
return int_value.type(self.scale.dtype)
else:
return int_value.type(torch.float32)
else:
if self.bit_width <= 8. and self.signed_t.item():
return int_value.to(torch.int8)
elif self.bit_width <= 8. and not self.signed_t.item():
return int_value.to(torch.uint8)
else:
return int_value.to(torch.int32)
else:
raise RuntimeError(f"FloatQuantTensor not valid.")
raise RuntimeError(f"GroupwiseIntQuantTensor not valid.")

@staticmethod
def check_input_type(tensor):
Expand Down

0 comments on commit c1de55e

Please sign in to comment.