Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[trunc/quant_avg_pool] Update Trunc and QuantAveragePool to match how Brevitas Ops work #170

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions src/qonnx/custom_op/general/trunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,29 +31,36 @@

from qonnx.core.datatype import DataType
from qonnx.custom_op.base import CustomOp
from qonnx.custom_op.general.quant import resolve_rounding_mode
from qonnx.custom_op.general.quant import max_int, min_int, resolve_rounding_mode


def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode):
def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode):
# Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR

# Scaling
y = inp_tensor / scale
y = y + zeropt
# Rounding
y = np.round(y)
# Truncate
trunc_bit_width = input_bit_width - output_bit_width
trunc_scale = 2.0**trunc_bit_width
# Rescale
trunc_scale = 2 ** np.round(
np.log2(output_scale / scale)
) # Trunc scale should be a power-of-two - ensure that is the case
y = y / trunc_scale

# To int
# Clamping
min_int_val = min_int(signed, narrow, output_bit_width)
max_int_val = max_int(signed, narrow, output_bit_width)
y = np.where(y > max_int_val, max_int_val.astype(y.dtype), y)
y = np.where(y < min_int_val, min_int_val.astype(y.dtype), y)
# To int (truncate)
rounding_fx = resolve_rounding_mode(rounding_mode)
y = rounding_fx(y)

# Rescale
y = y - zeropt
y = y * scale
output_zeropt = zeropt / trunc_scale # Rescale zero-point
y = y - output_zeropt
y = y * output_scale

return y

Expand All @@ -73,6 +80,13 @@ def get_nodeattr_types(self):
return {
# The rounding mode, which is used for the trunc function
"rounding_mode": ("s", True, "FLOOR"),
"narrow": ("i", False, 0, {0, 1}),
"signed": ("i", False, 1, {0, 1}),
"output_scale": (
"f",
False,
-1.0,
), # Invalid scale signifies that it needs to be computed from input/output bit_width
}

def make_shape_compatible_op(self, model):
Expand All @@ -93,8 +107,14 @@ def execute_node(self, context, graph):
output_bit_width = context[node.input[4]]
# save attributes
rounding_mode = self.get_nodeattr("rounding_mode")
narrow = self.get_nodeattr("narrow")
signed = self.get_nodeattr("signed")
output_scale = self.get_nodeattr("output_scale")
output_scale = 2 ** (input_bit_width - output_bit_width) if output_scale <= 0.0 else output_scale
# calculate output
ret = trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode)
ret = trunc(
inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode
)
# set context according to output name
context[node.output[0]] = ret

Expand Down