diff --git a/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py b/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py index cc91ba70bec..bd1865e674c 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py @@ -106,11 +106,11 @@ def get_filter_fn(node_list, fn): def is_target_node_in_candidate_list(match, original_graph, pattern_graph): """Filter the node with target operator in match and check if it is in `node_list`.""" target_node = None - for node in pattern_graph.nodes: + for node in pattern_graph.nodes: # pragma: no cover if node.target == target_op: target_node = node break - if target_node is None: + if target_node is None: # pragma: no cover return False matched_node = match.nodes_map[target_node] return matched_node in node_list @@ -137,7 +137,8 @@ def get_unquantized_node_set(gm: torch.fx.GraphModule): for node in gm.graph.nodes: if meta := getattr(node, "meta"): if quantization_annotation := meta.get(xiq.QUANT_ANNOTATION_KEY): - if quantization_annotation._annotated: + none_annotation = xiq._X86InductorQuantizationAnnotation(_annotated=True) + if quantization_annotation != none_annotation: # pragma: no cover continue unquantized_node_set.add(node) return unquantized_node_set @@ -161,18 +162,18 @@ def _parse_node_candidate_set_from_user_config(config, gm): op_type_configs, op_name_configs = config._get_op_name_op_type_config() op_type_filters = [] op_name_filters = [] - for op_type_name, config in op_type_configs.items(): + for op_type_name, config in op_type_configs.items(): # pragma: no cover op_type = getattr(torch.nn, op_type_name) - if config.act_dtype == "fp16": + if config.act_dtype == "fp16": # pragma: no cover filter = xpq._get_module_type_filter(op_type) op_type_filters.append(filter) for op_name, config in op_name_configs.items(): - if config.act_dtype == "fp16": + if config.act_dtype == "fp16": # pragma: no cover filter = xpq._get_module_name_filter(op_name) op_name_filters.append(filter) node_set_from_user_config = set() all_filters = op_type_filters + op_name_filters - for node in gm.graph.nodes: + for node in gm.graph.nodes: # pragma: no cover if any([filter(node) for filter in all_filters]): node_set_from_user_config.add(node) return node_set_from_user_config diff --git a/neural_compressor/torch/algorithms/pt2e_quant/utility.py b/neural_compressor/torch/algorithms/pt2e_quant/utility.py index 92635db1f70..e4efd62271e 100644 --- a/neural_compressor/torch/algorithms/pt2e_quant/utility.py +++ b/neural_compressor/torch/algorithms/pt2e_quant/utility.py @@ -20,6 +20,8 @@ from torch.ao.quantization.quantizer import QuantizationSpec from torch.ao.quantization.quantizer.x86_inductor_quantizer import QuantizationConfig, X86InductorQuantizer +from neural_compressor.torch.utils import GT_TORCH_VERSION_2_3_2 + def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=False) -> QuantizationSpec: dtype_mapping: Dict[str, torch.dtype] = {"int8": torch.int8, "uint8": torch.uint8} @@ -53,6 +55,9 @@ def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=Fals def _map_inc_config_to_torch_quant_config(inc_config, is_dynamic=False) -> QuantizationConfig: + NOT_QUANT_DTYPES = ["fp32", "fp16", "bf16"] + if inc_config.act_dtype in NOT_QUANT_DTYPES and inc_config.w_dtype in NOT_QUANT_DTYPES: # pragma: no cover + return None default_quant_config = xiq.get_default_x86_inductor_quantization_config(is_dynamic=is_dynamic) input_act_quant_spec = create_quant_spec_from_config( inc_config.act_dtype, inc_config.act_sym, inc_config.act_granularity, inc_config.act_algo, is_dynamic=is_dynamic @@ -75,5 +80,22 @@ def create_xiq_quantizer_from_pt2e_config(config, is_dynamic=False) -> X86Induct # set global global_config = _map_inc_config_to_torch_quant_config(config, is_dynamic) quantizer.set_global(global_config) - # Skip the local config for now (need torch 2.4) + # need torch >= 2.3.2 + if GT_TORCH_VERSION_2_3_2: # pragma: no cover + op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config() + if op_type_config_dict: + for op_type, config in op_type_config_dict.items(): + _nn_module_type = getattr(torch.nn, op_type, None) + if _nn_module_type: + quantizer.set_module_type_qconfig( + _nn_module_type, _map_inc_config_to_torch_quant_config(config, is_dynamic) + ) + _nn_func_type = getattr(torch.nn.functional, op_type, None) + if _nn_func_type: + quantizer.set_function_type_qconfig( + _nn_module_type, _map_inc_config_to_torch_quant_config(config, is_dynamic) + ) + if op_name_config_dict: + for op_name, config in op_name_config_dict.items(): + quantizer.set_module_name_qconfig(op_name, _map_inc_config_to_torch_quant_config(config, is_dynamic)) return quantizer diff --git a/neural_compressor/torch/utils/environ.py b/neural_compressor/torch/utils/environ.py index 3091aa83d88..0697979996d 100644 --- a/neural_compressor/torch/utils/environ.py +++ b/neural_compressor/torch/utils/environ.py @@ -91,6 +91,9 @@ def get_torch_version(): return version +GT_TORCH_VERSION_2_3_2 = get_torch_version() > Version("2.3.2") + + def get_accelerator(device_name="auto"): global accelerator # update the global accelerator when calling this func from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator diff --git a/test/3x/torch/quantization/test_pt2e_quant.py b/test/3x/torch/quantization/test_pt2e_quant.py index 3857832598a..e2c643f07c6 100644 --- a/test/3x/torch/quantization/test_pt2e_quant.py +++ b/test/3x/torch/quantization/test_pt2e_quant.py @@ -17,7 +17,7 @@ prepare, quantize, ) -from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version +from neural_compressor.torch.utils import GT_TORCH_VERSION_2_3_2, TORCH_VERSION_2_2_2, get_torch_version torch.manual_seed(0) @@ -119,6 +119,42 @@ def calib_fn(model): logger.warning("out shape is %s", out.shape) assert out is not None + @pytest.mark.skipif(not GT_TORCH_VERSION_2_3_2, reason="Requires torch>=2.3.2") + def test_quantize_simple_model_with_set_local(self, force_not_import_ipex): + model, example_inputs = self.build_simple_torch_model_and_example_inputs() + float_model_output = model(*example_inputs) + quant_config = None + + def calib_fn(model): + for i in range(4): + model(*example_inputs) + + quant_config = get_default_static_config() + quant_config.set_local("fc1", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32")) + q_model = quantize(model=model, quant_config=quant_config, run_fn=calib_fn) + + # check the half node + expected_node_occurrence = { + # Only quantize the `fc2` + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + } + expected_node_occurrence = { + torch_test_quant_common.NodeSpec.call_function(k): v for k, v in expected_node_occurrence.items() + } + node_in_graph = self.get_node_in_graph(q_model) + for node, cnt in expected_node_occurrence.items(): + assert node_in_graph.get(node, 0) == cnt, f"Node {node} should occur {cnt} times, but {node_in_graph[node]}" + + from torch._inductor import config + + config.freezing = True + q_model_out = q_model(*example_inputs) + assert torch.allclose(float_model_output, q_model_out, atol=1e-2), "Quantization failed!" + opt_model = torch.compile(q_model) + out = opt_model(*example_inputs) + assert out is not None + @pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0") @pytest.mark.parametrize("is_dynamic", [False, True]) def test_prepare_and_convert_on_simple_model(self, is_dynamic, force_not_import_ipex): @@ -193,9 +229,9 @@ def get_node_in_graph(graph_module): nodes_in_graph[n] += 1 else: nodes_in_graph[n] = 1 - return + return nodes_in_graph - @pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0") + @pytest.mark.skipif(not GT_TORCH_VERSION_2_3_2, reason="Requires torch>=2.3.0") def test_mixed_fp16_and_int8(self, force_not_import_ipex): model, example_inputs = self.build_model_include_conv_and_linear() model = export(model, example_inputs=example_inputs) @@ -221,9 +257,7 @@ def test_mixed_fp16_and_int8(self, force_not_import_ipex): } node_in_graph = self.get_node_in_graph(converted_model) for node, cnt in expected_node_occurrence.items(): - assert ( - expected_node_occurrence.get(node, 0) == cnt - ), f"Node {node} should occur {cnt} times, but {node_in_graph[node]}" + assert node_in_graph.get(node, 0) == cnt, f"Node {node} should occur {cnt} times, but {node_in_graph[node]}" # inference from torch._inductor import config