Skip to content

Commit

Permalink
Fix (ptq): remove flexml option
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 6, 2023
1 parent 422b632 commit 159dca5
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from brevitas import config
from brevitas import torch_version
from brevitas.graph.quantize import preprocess_for_quantize
from brevitas.graph.target.flexml import preprocess_for_flexml_quantize
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_bias_correction
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gpfq
Expand Down Expand Up @@ -194,21 +193,10 @@ def ptq_torchvision_models(args):
model = get_torchvision_model(config_namespace.model_name)

# Preprocess the model for quantization
if config_namespace.target_backend == 'flexml':
# Flexml requires static shapes, thus representative input is passed in
img_shape = model_config['center_crop_shape']
model = preprocess_for_flexml_quantize(
model,
torch.ones(1, 3, img_shape, img_shape),
equalize_iters=config_namespace.graph_eq_iterations,
equalize_merge_bias=config_namespace.graph_eq_merge_bias)
elif config_namespace.target_backend == 'fx' or config_namespace.target_backend == 'layerwise':
model = preprocess_for_quantize(
model,
equalize_iters=config_namespace.graph_eq_iterations,
equalize_merge_bias=config_namespace.graph_eq_merge_bias)
else:
raise RuntimeError(f"{config_namespace.target_backend} backend not supported.")
model = preprocess_for_quantize(
model,
equalize_iters=config_namespace.graph_eq_iterations,
equalize_merge_bias=config_namespace.graph_eq_merge_bias)

if config_namespace.act_equalization is not None:
print("Applying activation equalization:")
Expand Down Expand Up @@ -297,11 +285,6 @@ def ptq_torchvision_models(args):

def validate_config(config_namespace):
is_valid = True
# Flexml supports only per-tensor scale factors, power of two scale factors
if config_namespace.target_backend == 'flexml' and (
config_namespace.weight_quant_granularity == 'per_channel' or
config_namespace.scale_factor_type == 'float_scale'):
is_valid = False
# Merge bias can be enabled only when graph equalization is enabled
if config_namespace.graph_eq_iterations == 0 and config_namespace.graph_eq_merge_bias:
is_valid = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from brevitas.graph.gptq import gptq_mode
from brevitas.graph.quantize import layerwise_quantize
from brevitas.graph.quantize import quantize
from brevitas.graph.target.flexml import quantize_flexml
from brevitas.inject import value
import brevitas.nn as qnn
from brevitas.quant.experimental.float import Fp8e4m3Act
Expand Down Expand Up @@ -52,7 +51,7 @@
from brevitas_examples.imagenet_classification.ptq.learned_round_utils import save_inp_out_data
from brevitas_examples.imagenet_classification.ptq.learned_round_utils import split_layers

QUANTIZE_MAP = {'layerwise': layerwise_quantize, 'fx': quantize, 'flexml': quantize_flexml}
QUANTIZE_MAP = {'layerwise': layerwise_quantize, 'fx': quantize}

BIAS_BIT_WIDTH_MAP = {32: Int32Bias, 16: Int16Bias, None: None}

Expand Down Expand Up @@ -238,7 +237,7 @@ def layerwise_bit_width_fn_weight(module):
**act_bit_width_dict)

if backend != 'layerwise':
# Fx and flexml backend requires three mappings for quantization
# Fx backend requires three mappings for quantization
quantize_kwargs = {
'compute_layer_map': quant_layer_map,
'quant_act_map': quant_act_map,
Expand Down
26 changes: 6 additions & 20 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

from brevitas.export import export_onnx_qcdq
from brevitas.export import export_torch_qcdq
from brevitas.graph.equalize import activation_equalization_mode
from brevitas.graph.quantize import preprocess_for_quantize
from brevitas.graph.target.flexml import preprocess_for_flexml_quantize
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_act_equalization
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_bias_correction
from brevitas_examples.imagenet_classification.ptq.ptq_common import apply_gpfq
Expand Down Expand Up @@ -75,7 +73,7 @@
parser.add_argument(
'--target-backend',
default='fx',
choices=['fx', 'layerwise', 'flexml'],
choices=['fx', 'layerwise'],
help='Backend to target for quantization (default: fx)')
parser.add_argument(
'--scale-factor-type',
Expand Down Expand Up @@ -306,23 +304,11 @@ def main():
model = get_torchvision_model(args.model_name)

# Preprocess the model for quantization
if args.target_backend == 'flexml':
# flexml requires static shapes, pass a representative input in
img_shape = model_config['center_crop_shape']
model = preprocess_for_flexml_quantize(
model,
torch.ones(1, 3, img_shape, img_shape),
equalize_iters=args.graph_eq_iterations,
equalize_merge_bias=args.graph_eq_merge_bias,
merge_bn=not args.calibrate_bn)
elif args.target_backend == 'fx' or args.target_backend == 'layerwise':
model = preprocess_for_quantize(
model,
equalize_iters=args.graph_eq_iterations,
equalize_merge_bias=args.graph_eq_merge_bias,
merge_bn=not args.calibrate_bn)
else:
raise RuntimeError(f"{args.target_backend} backend not supported.")
model = preprocess_for_quantize(
model,
equalize_iters=args.graph_eq_iterations,
equalize_merge_bias=args.graph_eq_merge_bias,
merge_bn=not args.calibrate_bn)

if args.act_equalization is not None:
print("Applying activation equalization:")
Expand Down

0 comments on commit 159dca5

Please sign in to comment.