From 097128a14c6f40c94ac237daaed273ad62a954cc Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Mon, 10 Feb 2025 06:01:57 +0800 Subject: [PATCH] add ut Signed-off-by: Kaihui-intel --- .../transformers/quantization/utils.py | 1 - .../weight_only/test_transfomers.py | 45 +++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/neural_compressor/transformers/quantization/utils.py b/neural_compressor/transformers/quantization/utils.py index 39edc633d8c..f5474e33144 100644 --- a/neural_compressor/transformers/quantization/utils.py +++ b/neural_compressor/transformers/quantization/utils.py @@ -126,7 +126,6 @@ def _replace_linear( current_key_name = [] current_key_name.append(name) is_removed = False - print(isinstance(module, auto_round.export.export_to_itrex.model_wrapper.WeightOnlyLinear)) if ( isinstance(module, torch.nn.Linear) or isinstance(module, INCWeightOnlyLinear) diff --git a/test/3x/torch/quantization/weight_only/test_transfomers.py b/test/3x/torch/quantization/weight_only/test_transfomers.py index 83f6b664da0..f3dd4cae038 100644 --- a/test/3x/torch/quantization/weight_only/test_transfomers.py +++ b/test/3x/torch/quantization/weight_only/test_transfomers.py @@ -10,6 +10,7 @@ from neural_compressor.torch.utils import get_ipex_version from neural_compressor.transformers import ( AutoModelForCausalLM, + Qwen2VLForConditionalGeneration, AutoRoundConfig, AwqConfig, GPTQConfig, @@ -19,6 +20,12 @@ ipex_version = get_ipex_version() +try: + import auto_round + + auto_round_installed = True +except ImportError: + auto_round_installed = False class TestTansformersLikeAPI: def setup_class(self): @@ -30,6 +37,7 @@ def setup_class(self): def teardown_class(self): shutil.rmtree("nc_workspace", ignore_errors=True) shutil.rmtree("transformers_tmp", ignore_errors=True) + shutil.rmtree("transformers_vlm_tmp", ignore_errors=True) def test_quantization_for_llm(self): model_name_or_path = self.model_name_or_path @@ -208,3 +216,40 @@ def test_loading_autoawq_model(self): else: target_text = ["One day, the little girl in the back of my mind will say, “I’m so glad you’"] assert gen_text == target_text, "loading autoawq quantized model failed." + + @pytest.mark.skipif(not auto_round_installed, reason="auto_round module is not installed") + def test_vlm(self): + model_name = "Qwen/Qwen2-VL-2B-Instruct" + from neural_compressor.transformers import Qwen2VLForConditionalGeneration + from neural_compressor.transformers import AutoModelForCausalLM + woq_config = AutoRoundConfig( + bits=4, + group_size=128, + is_vlm=True, + dataset="liuhaotian/llava_conv_58k", + iters=2, + n_samples=5, + seq_len=512, + batch_size=1, + export_format="itrex", + ) + + woq_model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, quantization_config=woq_config, attn_implementation='eager') + + from intel_extension_for_pytorch.nn.modules import WeightOnlyQuantizedLinear + assert isinstance(woq_model.model.layers[0].self_attn.k_proj, WeightOnlyQuantizedLinear), "replacing model failed." + + #save + woq_model.save_pretrained("transformers_vlm_tmp") + + #load + loaded_model = Qwen2VLForConditionalGeneration.from_pretrained("transformers_vlm_tmp") + assert isinstance(loaded_model.model.layers[0].self_attn.k_proj, WeightOnlyQuantizedLinear), "loaing model failed." + + # phi-3-vision-128k-instruct + model_name = "microsoft/Phi-3-vision-128k-instruct" + woq_model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, attn_implementation='eager') + + from intel_extension_for_pytorch.nn.modules import WeightOnlyQuantizedLinear + breakpoint() + assert isinstance(woq_model.model.layers[0].self_attn.o_proj, WeightOnlyQuantizedLinear), "quantizaion failed." \ No newline at end of file