Skip to content

Commit

Permalink
transit on stateful seq2seq models
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Jan 16, 2025
1 parent 649d966 commit f02ee0e
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 16 deletions.
16 changes: 7 additions & 9 deletions notebooks/distil-whisper-asr/distil-whisper-asr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,7 @@
" encoder_calibration_data = []\n",
" decoder_calibration_data = []\n",
" ov_model.encoder.request = InferRequestWrapper(ov_model.encoder.request, encoder_calibration_data, apply_caching=True)\n",
" ov_model.decoder_with_past.request = InferRequestWrapper(ov_model.decoder_with_past.request,\n",
" ov_model.decoder.request = InferRequestWrapper(ov_model.decoder.request,\n",
" decoder_calibration_data,\n",
" apply_caching=True)\n",
"\n",
Expand All @@ -996,7 +996,7 @@
" ov_model.generate(input_features)\n",
" finally:\n",
" ov_model.encoder.request = ov_model.encoder.request.request\n",
" ov_model.decoder_with_past.request = ov_model.decoder_with_past.request.request\n",
" ov_model.decoder.request = ov_model.decoder.request.request\n",
"\n",
" return encoder_calibration_data, decoder_calibration_data"
]
Expand Down Expand Up @@ -1146,23 +1146,21 @@
" gc.collect()\n",
"\n",
" print(\"Quantizing decoder with past\")\n",
" quantized_decoder_with_past = nncf.quantize(\n",
" ov_model.decoder_with_past.model,\n",
" quantized_decoder = nncf.quantize(\n",
" ov_model.decoder.model,\n",
" nncf.Dataset(decoder_calibration_data),\n",
" subset_size=len(decoder_calibration_data),\n",
" model_type=nncf.ModelType.TRANSFORMER,\n",
" # Smooth Quant algorithm reduces activation quantization error; optimal alpha value was obtained through grid search\n",
" advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alpha=0.95)\n",
" )\n",
" ov.save_model(quantized_decoder_with_past, quantized_model_path / \"openvino_decoder_with_past_model.xml\")\n",
" del quantized_decoder_with_past\n",
" ov.save_model(quantized_decoder_with_past, quantized_model_path / \"openvino_decoder_model.xml\")\n",
" del quantized_decoder\n",
" del decoder_calibration_data\n",
" gc.collect()\n",
"\n",
" # Copy the config file and the first-step-decoder manually\n",
" shutil.copy(model_path / \"config.json\", quantized_model_path / \"config.json\")\n",
" shutil.copy(model_path / \"openvino_decoder_model.xml\", quantized_model_path / \"openvino_decoder_model.xml\")\n",
" shutil.copy(model_path / \"openvino_decoder_model.bin\", quantized_model_path / \"openvino_decoder_model.bin\")\n",
"\n",
" quantized_ov_model = OVModelForSpeechSeq2Seq.from_pretrained(quantized_model_path, ov_config=ov_config, compile=False)\n",
" quantized_ov_model.to(device.value)\n",
Expand Down Expand Up @@ -1392,7 +1390,7 @@
" whole_infer_times = []\n",
" time_fn(ov_model, \"generate\", whole_infer_times)\n",
" time_fn(ov_model.encoder, \"forward\", encoder_infer_times)\n",
" time_fn(ov_model.decoder_with_past, \"forward\", decoder_with_past_infer_times)\n",
" time_fn(ov_model.decoder, \"forward\", decoder_with_past_infer_times)\n",
"\n",
" ground_truths = []\n",
" predictions = []\n",
Expand Down
4 changes: 2 additions & 2 deletions notebooks/grammar-correction/grammar-correction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,7 @@
"grammar_corrector_pipe_fp32 = grammar_corrector_pipe\n",
"grammar_corrector_pipe_int8 = None\n",
"if to_quantize.value:\n",
" quantized_model_path = Path(\"quantized_decoder_with_past\") / \"openvino_model.xml\"\n",
" quantized_model_path = Path(\"quantized_decodet\") / \"openvino_model.xml\"\n",
" grammar_corrector_pipe_int8 = get_quantized_pipeline(\n",
" grammar_corrector_pipe_fp32,\n",
" grammar_corrector_tokenizer,\n",
Expand Down Expand Up @@ -1063,7 +1063,7 @@
"\n",
"if to_quantize.value:\n",
" model_size_fp32, model_size_int8 = calculate_compression_rate(\n",
" grammar_corrector_dir / \"openvino_decoder_with_past_model.xml\",\n",
" grammar_corrector_dir / \"openvino_decoder_model.xml\",\n",
" quantized_model_path,\n",
" )"
]
Expand Down
10 changes: 5 additions & 5 deletions notebooks/grammar-correction/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

def collect_calibration_data(grammar_corrector_pipe_fp32: Pipeline, calibration_dataset_size: int) -> List[Dict]:
calibration_data = []
ov_decoder = grammar_corrector_pipe_fp32.model.decoder_with_past
ov_decoder = grammar_corrector_pipe_fp32.model.decoder

# Wrap decoder inference for data collection
ov_decoder.request = InferRequestWrapper(ov_decoder.request, calibration_data, apply_caching=True)
Expand Down Expand Up @@ -55,7 +55,7 @@ def quantize(
quantized_model = core.read_model(model=quantized_model_path)
else:
calibration_data = collect_calibration_data(grammar_corrector_pipe_fp32, calibration_dataset_size)
ov_decoder = grammar_corrector_pipe_fp32.model.decoder_with_past
ov_decoder = grammar_corrector_pipe_fp32.model.decoder
quantized_model = nncf.quantize(
ov_decoder.model,
calibration_dataset=nncf.Dataset(calibration_data),
Expand Down Expand Up @@ -93,9 +93,9 @@ def get_quantized_pipeline(

# Load quantized model into grammar correction pipeline
grammar_corrector_model_int8 = OVModelForSeq2SeqLM.from_pretrained(grammar_corrector_dir, device=device)
grammar_corrector_model_int8.decoder_with_past.model = quantized_model
grammar_corrector_model_int8.decoder_with_past.request = None
grammar_corrector_model_int8.decoder_with_past._compile()
grammar_corrector_model_int8.decoder.model = quantized_model
grammar_corrector_model_int8.decoder.request = None
grammar_corrector_model_int8.decoder._compile()
grammar_corrector_pipe_int8 = pipeline(
"text2text-generation", model=grammar_corrector_model_int8, tokenizer=grammar_corrector_tokenizer, device=torch.device("cpu")
)
Expand Down

0 comments on commit f02ee0e

Please sign in to comment.