Skip to content

Commit

Permalink
Feat (ptq): support for bfloat16 in evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 6, 2023
1 parent 904e00d commit e2ad283
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 6 deletions.
15 changes: 15 additions & 0 deletions src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def __init__(

@brevitas.jit.script_method
def forward(self, x: Tensor):
dtype = x.dtype
if dtype == torch.bfloat16:
x = x.float()
if self.stats_reduce_dim is None:
# k is 1-indexed, so round away from zero
k = int(math.floor(.01 * self.q * x.numel() + 0.5))
Expand All @@ -73,6 +76,8 @@ def forward(self, x: Tensor):
# k is 1-indexed, so round away from zero
k = int(math.floor(.01 * self.q * dim_slice.numel() + 0.5))
result = x.abs().kthvalue(k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values
if dtype == torch.bfloat16:
result = result.to(torch.bfloat16)
return result


Expand All @@ -94,6 +99,9 @@ def __init__(

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tensor:
dtype = x.dtype
if dtype == torch.bfloat16:
x = x.float()
if self.stats_reduce_dim is None:
# k is 1-indexed, so round away from zero
k = int(math.ceil(.01 * self.q * x.numel()))
Expand All @@ -106,6 +114,8 @@ def forward(self, x: Tensor) -> Tensor:
# k is 1-indexed, so round away from zero
k = int(math.ceil(.01 * self.q * dim_slice.numel()))
result = x.kthvalue(k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values
if dtype == torch.bfloat16:
result = result.to(torch.bfloat16)
result = torch.clamp(result, max=self.zero())
return result

Expand All @@ -130,6 +140,9 @@ def __init__(

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tensor:
dtype = x.dtype
if dtype == torch.bfloat16:
x = x.float()
if self.stats_reduce_dim is None:
low_k = int(math.ceil(.01 * self.low_q * x.numel()))
# k is 1-indexed, so round away from zero
Expand All @@ -150,6 +163,8 @@ def forward(self, x: Tensor) -> Tensor:
low_result = torch.clamp(low_result, max=self.zero())
interval = high_result - low_result
abs_interval = torch.abs(interval)
if dtype == torch.bfloat16:
abs_interval = abs_interval.to(torch.bfloat16)
return abs_interval


Expand Down
17 changes: 13 additions & 4 deletions src/brevitas/quant_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .torch_handler import QUANT_TENSOR_FN_HANDLER

IS_VALID_ATOL = 2e-1
BFLOAT16_IS_VALID_ATOL = 0.5


class QuantTensorBase(NamedTuple):
Expand Down Expand Up @@ -104,8 +105,15 @@ def is_not_none(self):

@property
def _pre_round_int_value(self):
int_value = self.value / self.scale
int_value = int_value + self.zero_point
value = self.value
scale = self.scale
zero_point = self.zero_point
if self.value.dtype == torch.bfloat16:
value = self.value.to(torch.float32)
scale = self.scale.to(torch.float32)
zero_point = self.zero_point.to(torch.float32)
int_value = value / scale
int_value = int_value + zero_point
return int_value

@property
Expand All @@ -114,8 +122,9 @@ def is_valid(self):
with torch.no_grad():
pre_round_int_value = self._pre_round_int_value
rounded_int_value = torch.round(pre_round_int_value)
is_int = torch.isclose(
pre_round_int_value, rounded_int_value, atol=IS_VALID_ATOL).all()
max_abs_diff = torch.max(torch.abs(pre_round_int_value - rounded_int_value))
atol = BFLOAT16_IS_VALID_ATOL if self.value.dtype == torch.bfloat16 else IS_VALID_ATOL
is_int = max_abs_diff < atol
if self.bit_width >= 2:
if self.signed:
is_upper_b = (2.0 ** (self.bit_width - 1) - 1 >= rounded_int_value).all()
Expand Down
18 changes: 16 additions & 2 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

# Ignore warnings about __torch_function__
warnings.filterwarnings("ignore")
ONNX_DTYPE_OPSET = 19

model_names = sorted(
name for name in torchvision.models.__dict__ if name.islower() and not name.startswith("__") and
Expand Down Expand Up @@ -72,6 +73,8 @@
metavar='ARCH',
choices=model_names,
help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet18)')
parser.add_argument(
'--dtype', default='float', choices=['float', 'bfloat16'], help='Data type to use')
parser.add_argument(
'--target-backend',
default='fx',
Expand Down Expand Up @@ -207,6 +210,7 @@
default=3,
type=int,
help='Exponent bit width used with float quantization for activations (default: 3)')
parser.add_argument('--onnx-opset-version', default=None, type=int, help='ONNX opset version')
add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)')
add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)')
add_bool_arg(
Expand All @@ -219,11 +223,19 @@

def main():
args = parser.parse_args()
dtype = getattr(torch, args.dtype)

if args.export_onnx_qcdq:
if dtype == torch.bfloat16 and args.onnx_opset_version < ONNX_DTYPE_OPSET:
raise RuntimeError(
f"ONNX export with bfloat16 requires opset {ONNX_DTYPE_OPSET} or superior")
if args.export_torch_qcdq:
if dtype == torch.bfloat16:
raise RuntimeError("Torch export does not support bfloat16 export")

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

if args.act_quant_calibration_type == 'stats':
act_quant_calib_config = str(args.act_quant_percentile) + 'stats'
else:
Expand Down Expand Up @@ -327,10 +339,12 @@ def main():
if args.act_equalization is not None:
print("Applying activation equalization:")
apply_act_equalization(model, calib_loader, layerwise=args.act_equalization == 'layerwise')
model = model.to(dtype)

# Define the quantized model
quant_model = quantize_model(
model,
dtype=dtype,
backend=args.target_backend,
scale_factor_type=args.scale_factor_type,
bias_bit_width=args.bias_bit_width,
Expand Down Expand Up @@ -400,7 +414,7 @@ def main():
export_name = os.path.join(args.export_dir, config)
if args.export_onnx_qcdq:
export_name = export_name + '.onnx'
export_onnx_qcdq(model, ref_input, export_name)
export_onnx_qcdq(model, ref_input, export_name, opset_version=args.onnx_opset_version)
if args.export_torch_qcdq:
export_name = export_name + '.pt'
export_torch_qcdq(model, ref_input, export_name)
Expand Down

0 comments on commit e2ad283

Please sign in to comment.