Skip to content

Commit

Permalink
Fix (minifloat): make MaxFloatInfNaN jit compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Feb 16, 2024
1 parent 61ba479 commit 6bae418
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/brevitas/core/function_wrapper/clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,13 @@ def __init__(
if any(map(lambda x: len(x) > mantissa_bit_width, self.__special_values)):
raise RuntimeError('NaN/inf codes need to be the same length as the mantissa.')

# move computation of min for forward pass here so it's jit compatible
self.__min_special_case = min(map(lambda x: int(x, 2), self.__special_values))

@brevitas.jit.script_method
def forward(self):
# idea: take inf and nan values, select the smallest, set max_value to smallest_val - 1
min_special_case = min(map(lambda x: int(x, 2), self.__special_values))
max_value_mantissa = min_special_case - 1
max_value_mantissa = self.__min_special_case - 1

if max_value_mantissa < 0:
# all mantissa values are used, so we need to use decrease exponent values
Expand Down

0 comments on commit 6bae418

Please sign in to comment.