Skip to content

Commit

Permalink
Fix (minifloat): correct minifloat computation and tests (#1067)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Oct 26, 2024
1 parent 79c9ca6 commit 06af14b
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 5 deletions.
13 changes: 10 additions & 3 deletions src/brevitas/quant_tensor/float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,9 @@ def _pre_round_float_value(self):
scale = self.scale.type(torch.float32)
minifloat_value = value / scale
fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width
eps = torch.finfo(scale.dtype).tiny
int_scale = float_internal_scale(
self.value, self.mantissa_bit_width, fp_internal_scale, self.eps)
minifloat_value, self.mantissa_bit_width, fp_internal_scale, eps)
minifloat_value = minifloat_value / int_scale
return minifloat_value

Expand Down Expand Up @@ -137,11 +138,17 @@ def device(self):
def minifloat(self, float_datatype=True):
# TODO: Check if OCP and cast to proper data-type if matching
assert float_datatype, "Minifloat quant returns only higher precision dtype"

if self.is_valid:
value = self.value
scale = self.scale
if self.scale.dtype == torch.bfloat16:
value = self.value.type(torch.float32)
scale = self.scale.type(torch.float32)
minifloat_value = value / scale
fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width
eps = torch.finfo(scale.dtype).tiny
int_scale = float_internal_scale(
self.value, self.mantissa_bit_width, fp_internal_scale, self.eps)
minifloat_value, self.mantissa_bit_width, fp_internal_scale, eps)
float_value = torch.round(self._pre_round_float_value) * int_scale
return float_value.type(self.scale.dtype)
else:
Expand Down
13 changes: 11 additions & 2 deletions src/brevitas/quant_tensor/groupwise_float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ def _pre_round_float_value(self):
scale = scale.type(torch.float32)
minifloat_value = value / scale
fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width
int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale)
eps = torch.finfo(scale.dtype).tiny
int_scale = float_internal_scale(
minifloat_value, self.mantissa_bit_width, fp_internal_scale, eps)
minifloat_value = minifloat_value / int_scale
return minifloat_value

Expand Down Expand Up @@ -179,8 +181,15 @@ def minifloat(self, float_datatype=True):
assert float_datatype, "Minifloat quant returns only higher precision dtype"

if self.is_valid:
value, scale, zp = self.expand()
if self.scale.dtype == torch.bfloat16:
value = value.type(torch.float32)
scale = scale.type(torch.float32)
minifloat_value = value / scale
fp_internal_scale = 1. - self.exponent_bias - self.mantissa_bit_width
int_scale = float_internal_scale(self.value, self.mantissa_bit_width, fp_internal_scale)
eps = torch.finfo(scale.dtype).tiny
int_scale = float_internal_scale(
minifloat_value, self.mantissa_bit_width, fp_internal_scale, eps)
float_value = torch.round(self._pre_round_float_value) * int_scale
return float_value.type(self.scale.dtype)
else:
Expand Down
19 changes: 19 additions & 0 deletions tests/brevitas/quant_tensor/test_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@

from packaging import version
import pytest
import pytest_cases
import torch

from brevitas import torch_version
from brevitas.nn import QuantIdentity
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPActPerTensorFloat
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act
from brevitas.quant_tensor import FloatQuantTensor
from brevitas.quant_tensor import IntQuantTensor

Expand Down Expand Up @@ -119,3 +122,19 @@ def test_quant_tensor_view():
assert torch.allclose(a.view(2, -1), b.view(2, -1), atol=0.01)
assert torch.allclose(a.view(16, -1), b.view(16, -1), atol=0.01)
assert torch.allclose(a.view(8, 2), b.view(8, 2), atol=0.01)


QUANT_CLASS = {'fp8': Fp8e4m3ActPerTensorFloat, 'mxfp8': MXFloat8e4m3Act}


@pytest_cases.parametrize('quant_class_key_vale', QUANT_CLASS.items())
def test_minifloat(quant_class_key_vale):
key, quant_class = quant_class_key_vale

x = torch.randn((1, 32))
q = QuantIdentity(quant_class, group_dim=-1, return_quant_tensor=True)
q.eval()

qx = q(x)
# Check that minifloat doesn't raise error
qx.minifloat()

0 comments on commit 06af14b

Please sign in to comment.