diff --git a/runtimes/huggingface/tests/test_common.py b/runtimes/huggingface/tests/test_common.py index 29cf09226..98e897fbc 100644 --- a/runtimes/huggingface/tests/test_common.py +++ b/runtimes/huggingface/tests/test_common.py @@ -252,9 +252,8 @@ async def test_pipeline_uses_inference_kwargs( tokenizer = runtime._model.tokenizer prediction = await runtime.predict(payload) - generated_text = MultiInputRequestCodec.decode_response(prediction)["output"][0][ - "generated_text" - ] + decoded_prediction = MultiInputRequestCodec.decode_response(prediction) + generated_text = decoded_prediction.get("output")[0]["generated_text"] assert isinstance(generated_text, str) tokenized_generated_text = tokenizer.tokenize(generated_text) num_predicted_tokens = len(tokenized_generated_text)