Skip to content

Commit

Permalink
fixed unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Nanbo Liu committed Dec 26, 2023
1 parent bd8c293 commit a2307d2
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 5 deletions.
7 changes: 4 additions & 3 deletions runtimes/huggingface/mlserver_huggingface/codecs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ def encode_request(cls, payload: Dict[str, Any], **kwargs) -> InferenceRequest:
@classmethod
def decode_request(cls, request: InferenceRequest) -> Dict[str, Any]:
"""
Decode Inference requst into dictionary
extra Inference kwargs can be kept in 'InferenceRequest.parameters.extra'
Decode Inference request into dictionary
extra Inference kwargs are extracted from 'InferenceRequest.parameters.extra'
"""
values = {}
field_codecs = cls._find_decode_codecs(request)
Expand All @@ -193,7 +193,8 @@ def decode_request(cls, request: InferenceRequest) -> Dict[str, Any]:
values.update(extra)
else:
logging.warn(
"Extra inference kwargs should be kept in a dictionary."
f"Extra parameters is provided with value '{extra}' and type '{type(extra)}' \n"
"Extra parameters cannot be parsed, expected a dictionary."
)
return values

Expand Down
58 changes: 57 additions & 1 deletion runtimes/huggingface/tests/test_codecs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest

import logging
from mlserver.types import (
InferenceRequest,
InferenceResponse,
Expand Down Expand Up @@ -53,6 +53,62 @@ def test_decode_request(inference_request, expected):
assert payload == expected


@pytest.mark.parametrize(
"inference_request, expected_payload, expected_log_msg",
[
(
InferenceRequest(
parameters=Parameters(content_type="str", extra="foo3"),
inputs=[
RequestInput(
name="foo",
datatype="BYTES",
data=["bar1", "bar2"],
shape=[2, 1],
),
RequestInput(
name="foo2", datatype="BYTES", data=["var1"], shape=[1, 1]
),
],
),
{"foo": ["bar1", "bar2"]},
logging.warn(
f"Extra parameters is provided with value: 'foo3' and type '<class 'str'> \n"
"Extra parameters cannot be parsed, expected a dictionary."
),
),
(
InferenceRequest(
parameters=Parameters(content_type="str", extra=123),
inputs=[
RequestInput(
name="foo",
datatype="BYTES",
data=["bar1", "bar2"],
shape=[2, 1],
),
RequestInput(
name="foo2", datatype="BYTES", data=["var1"], shape=[1, 1]
),
],
),
{"foo": ["bar1", "bar2"]},
logging.warn(
f"Extra parameters is provided with value '123' and type '<class 'int'> \n"
"Extra parameters cannot be parsed, expected a dictionary."
),
),
],
)
def test_decode_request_with_invalid_parameter_extra(
inference_request, expected_payload, expected_log_msg, caplog
):
caplog.set_level(logging.WARN)
payload = HuggingfaceRequestCodec.decode_request(inference_request)
assert payload == expected_payload
assert expected_log_msg in caplog.text


@pytest.mark.parametrize(
"payload, use_bytes, expected",
[
Expand Down
5 changes: 4 additions & 1 deletion runtimes/huggingface/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mlserver_huggingface.common import load_pipeline_from_settings
from mlserver.types import InferenceRequest, RequestInput
from mlserver.types.dataplane import Parameters
from mlserver_huggingface.codecs.base import MultiInputRequestCodec


@pytest.mark.parametrize(
Expand Down Expand Up @@ -252,7 +253,9 @@ async def test_pipeline_uses_inference_kwargs(
tokenizer = runtime._model.tokenizer

prediction = await runtime.predict(payload)
generated_text = json.loads(prediction.outputs[0].data[0])["generated_text"]
generated_text = MultiInputRequestCodec.decode_response(prediction)["output"][0][
"generated_text"
]
assert isinstance(generated_text, str)
tokenized_generated_text = tokenizer.tokenize(generated_text)
num_predicted_tokens = len(tokenized_generated_text)
Expand Down

0 comments on commit a2307d2

Please sign in to comment.