From a2307d23420cccd6d5494995967a18eb9ba249f4 Mon Sep 17 00:00:00 2001 From: Nanbo Liu Date: Tue, 26 Dec 2023 15:56:39 +0000 Subject: [PATCH] fixed unit tests --- .../mlserver_huggingface/codecs/base.py | 7 ++- runtimes/huggingface/tests/test_codecs.py | 58 ++++++++++++++++++- runtimes/huggingface/tests/test_common.py | 5 +- 3 files changed, 65 insertions(+), 5 deletions(-) diff --git a/runtimes/huggingface/mlserver_huggingface/codecs/base.py b/runtimes/huggingface/mlserver_huggingface/codecs/base.py index 3122fde43..a2ea56cfe 100644 --- a/runtimes/huggingface/mlserver_huggingface/codecs/base.py +++ b/runtimes/huggingface/mlserver_huggingface/codecs/base.py @@ -172,8 +172,8 @@ def encode_request(cls, payload: Dict[str, Any], **kwargs) -> InferenceRequest: @classmethod def decode_request(cls, request: InferenceRequest) -> Dict[str, Any]: """ - Decode Inference requst into dictionary - extra Inference kwargs can be kept in 'InferenceRequest.parameters.extra' + Decode Inference request into dictionary + extra Inference kwargs are extracted from 'InferenceRequest.parameters.extra' """ values = {} field_codecs = cls._find_decode_codecs(request) @@ -193,7 +193,8 @@ def decode_request(cls, request: InferenceRequest) -> Dict[str, Any]: values.update(extra) else: logging.warn( - "Extra inference kwargs should be kept in a dictionary." + f"Extra parameters is provided with value '{extra}' and type '{type(extra)}' \n" + "Extra parameters cannot be parsed, expected a dictionary." ) return values diff --git a/runtimes/huggingface/tests/test_codecs.py b/runtimes/huggingface/tests/test_codecs.py index f17904f65..6ea44f4f9 100644 --- a/runtimes/huggingface/tests/test_codecs.py +++ b/runtimes/huggingface/tests/test_codecs.py @@ -1,5 +1,5 @@ import pytest - +import logging from mlserver.types import ( InferenceRequest, InferenceResponse, @@ -53,6 +53,62 @@ def test_decode_request(inference_request, expected): assert payload == expected +@pytest.mark.parametrize( + "inference_request, expected_payload, expected_log_msg", + [ + ( + InferenceRequest( + parameters=Parameters(content_type="str", extra="foo3"), + inputs=[ + RequestInput( + name="foo", + datatype="BYTES", + data=["bar1", "bar2"], + shape=[2, 1], + ), + RequestInput( + name="foo2", datatype="BYTES", data=["var1"], shape=[1, 1] + ), + ], + ), + {"foo": ["bar1", "bar2"]}, + logging.warn( + f"Extra parameters is provided with value: 'foo3' and type ' \n" + "Extra parameters cannot be parsed, expected a dictionary." + ), + ), + ( + InferenceRequest( + parameters=Parameters(content_type="str", extra=123), + inputs=[ + RequestInput( + name="foo", + datatype="BYTES", + data=["bar1", "bar2"], + shape=[2, 1], + ), + RequestInput( + name="foo2", datatype="BYTES", data=["var1"], shape=[1, 1] + ), + ], + ), + {"foo": ["bar1", "bar2"]}, + logging.warn( + f"Extra parameters is provided with value '123' and type ' \n" + "Extra parameters cannot be parsed, expected a dictionary." + ), + ), + ], +) +def test_decode_request_with_invalid_parameter_extra( + inference_request, expected_payload, expected_log_msg, caplog +): + caplog.set_level(logging.WARN) + payload = HuggingfaceRequestCodec.decode_request(inference_request) + assert payload == expected_payload + assert expected_log_msg in caplog.text + + @pytest.mark.parametrize( "payload, use_bytes, expected", [ diff --git a/runtimes/huggingface/tests/test_common.py b/runtimes/huggingface/tests/test_common.py index c2909526b..ab1c96a5a 100644 --- a/runtimes/huggingface/tests/test_common.py +++ b/runtimes/huggingface/tests/test_common.py @@ -16,6 +16,7 @@ from mlserver_huggingface.common import load_pipeline_from_settings from mlserver.types import InferenceRequest, RequestInput from mlserver.types.dataplane import Parameters +from mlserver_huggingface.codecs.base import MultiInputRequestCodec @pytest.mark.parametrize( @@ -252,7 +253,9 @@ async def test_pipeline_uses_inference_kwargs( tokenizer = runtime._model.tokenizer prediction = await runtime.predict(payload) - generated_text = json.loads(prediction.outputs[0].data[0])["generated_text"] + generated_text = MultiInputRequestCodec.decode_response(prediction)["output"][0][ + "generated_text" + ] assert isinstance(generated_text, str) tokenized_generated_text = tokenizer.tokenize(generated_text) num_predicted_tokens = len(tokenized_generated_text)