Skip to content

Commit

Permalink
test (core/float): Enhanced testing of minifloat formats (#1136)
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser authored Dec 20, 2024
1 parent fd01451 commit 09235be
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 5 deletions.
19 changes: 19 additions & 0 deletions src/brevitas/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import copy
from functools import wraps
from typing import List, Optional, Tuple

import torch
Expand Down Expand Up @@ -127,3 +128,21 @@ def is_broadcastable(tensor, other):
else:
return False
return True


def torch_dtype(dtype):

def decorator(fn):

@wraps(fn)
def wrapped_fn(*args, **kwargs):
cur_dtype = torch.get_default_dtype()
try:
torch.set_default_dtype(dtype)
fn(*args, **kwargs)
finally:
torch.set_default_dtype(cur_dtype)

return wrapped_fn

return decorator
35 changes: 35 additions & 0 deletions tests/brevitas/core/test_float_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from hypothesis import given
from hypothesis import settings
import mock
import pytest
import torch
Expand All @@ -15,9 +16,11 @@
from brevitas.core.scaling import FloatScaling
from brevitas.function.ops import max_float
from brevitas.utils.torch_utils import float_internal_scale
from brevitas.utils.torch_utils import torch_dtype
from tests.brevitas.hyp_helper import float_st
from tests.brevitas.hyp_helper import float_tensor_random_shape_st
from tests.brevitas.hyp_helper import random_minifloat_format
from tests.brevitas.hyp_helper import random_minifloat_format_and_value
from tests.marker import jit_disabled_for_mock


Expand Down Expand Up @@ -233,3 +236,35 @@ def test_inner_scale(inp, minifloat_format, scale):
out_nans = out.isnan()
expected_out_nans = expected_out.isnan()
assert torch.equal(out[~out_nans], expected_out[~expected_out_nans])


@given(
minifloat_format_and_value=random_minifloat_format_and_value(
min_bit_width=4, max_bit_with=10, rand_exp_bias=True))
@settings(max_examples=1000)
@jit_disabled_for_mock()
@torch_dtype(torch.float64)
@torch.no_grad()
def test_valid_float_values(minifloat_format_and_value):
minifloat_value, exponent, mantissa, sign, bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format_and_value
scaling_impl = mock.Mock(side_effect=lambda x, y: 1.0)
float_scaling = FloatScaling(None, None, True)
float_clamp = FloatClamp(
tensor_clamp_impl=TensorClamp(),
signed=signed,
inf_values=None,
nan_values=None,
saturating=True)
float_quant = FloatQuant(
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
exponent_bias=exponent_bias,
signed=signed,
input_view_impl=Identity(),
scaling_impl=scaling_impl,
float_scaling_impl=float_scaling,
float_clamp_impl=float_clamp)
inp = torch.tensor(minifloat_value)
quant_value, *_ = float_quant(inp)
assert torch.equal(inp, quant_value)
57 changes: 52 additions & 5 deletions tests/brevitas/hyp_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,16 +227,29 @@ def min_max_tensor_random_shape_st(draw, min_dims=1, max_dims=4, max_size=3, wid


@st.composite
def random_minifloat_format(draw, min_bit_width=MIN_INT_BIT_WIDTH, max_bit_with=MAX_INT_BIT_WIDTH):
def random_minifloat_format(
draw,
min_bit_width=MIN_INT_BIT_WIDTH,
max_bit_with=MAX_INT_BIT_WIDTH,
rand_exp_bias=False,
valid_only=False):
""""
Generate a minifloat format. Returns bit_width, exponent, mantissa, and signed.
"""
# TODO: add support for new minifloat format that comes with FloatQuantTensor
bit_width = draw(st.integers(min_value=min_bit_width, max_value=max_bit_with))
exponent_bit_width = draw(st.integers(min_value=0, max_value=bit_width))
signed = draw(st.booleans())

exponent_bias = 2 ** (exponent_bit_width - 1) - 1
if valid_only:
# Only works if min_bit_width >= 3
signed = draw(st.booleans())
exponent_bit_width = draw(st.integers(min_value=1, max_value=bit_width - 1 - int(signed)))
else:
exponent_bit_width = draw(st.integers(min_value=0, max_value=bit_width))
signed = draw(st.booleans())

if rand_exp_bias:
exponent_bias = draw(st.integers(min_value=-127, max_value=127))
else:
exponent_bias = 2 ** (exponent_bit_width - 1) - 1

# if no budget is left, return
if bit_width == exponent_bit_width:
Expand All @@ -246,3 +259,37 @@ def random_minifloat_format(draw, min_bit_width=MIN_INT_BIT_WIDTH, max_bit_with=
mantissa_bit_width = bit_width - exponent_bit_width - int(signed)

return bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias


@st.composite
def random_valid_minifloat(
draw, bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias):
""""
Generate a random floating-point value that can be represented in the specified minifloat format.
"""
# Sanity-check that the format is valid
assert bit_width == exponent_bit_width + mantissa_bit_width + int(signed)
# Generate int values of the minifloat components
sign = draw(st.integers(min_value=0, max_value=int(signed)))
mantissa = draw(st.integers(min_value=0, max_value=int(2 ** mantissa_bit_width - 1)))
exponent = draw(st.integers(min_value=0, max_value=int(2 ** exponent_bit_width - 1)))
# Scale mantissa between 0-1
mantissa_fixed = mantissa / 2 ** mantissa_bit_width
# Add 1 unless denormalised
mantissa_fixed += 0. if exponent == 0 else 1.
# Adjust exponent if denormalised, otherwise leave it unchanged
exponent_value = 1 if exponent == 0 else exponent
valid_minifloat = ((-1.) ** sign) * (mantissa_fixed * 2 ** (exponent_value - exponent_bias))
return valid_minifloat, exponent, mantissa, sign


@st.composite
def random_minifloat_format_and_value(
draw,
min_bit_width=MIN_INT_BIT_WIDTH,
max_bit_with=MAX_INT_BIT_WIDTH,
rand_exp_bias=False,
valid_format_only=True):
bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = draw(random_minifloat_format(min_bit_width=min_bit_width, max_bit_with=max_bit_with, rand_exp_bias=rand_exp_bias, valid_only=valid_format_only))
valid_minifloat, exponent, mantissa, sign = draw(random_valid_minifloat(bit_width=bit_width, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, signed=signed, exponent_bias=exponent_bias))
return valid_minifloat, exponent, mantissa, sign, bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias

0 comments on commit 09235be

Please sign in to comment.