From 79d74a975972ac612e7f8db350a378b30e4ff999 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Fri, 17 Jan 2025 15:22:39 -0500 Subject: [PATCH] revert test file --- .../tests/inference/test_text_inference.py | 59 +------------------ 1 file changed, 3 insertions(+), 56 deletions(-) diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 98b8ac56e7..037e99819e 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -32,6 +32,7 @@ UserMessage, ) from llama_stack.apis.models import Model + from .utils import group_chunks @@ -42,28 +43,6 @@ # --env FIREWORKS_API_KEY= -def skip_if_centml_tool_call(provider): - """ - Skip tool-calling tests if the provider is remote::centml, - because CentML currently doesn't generate tool_call responses. - """ - if provider.__provider_spec__.provider_type == "remote::centml": - pytest.skip( - "CentML does not currently return tool calls. Skipping tool-calling test." - ) - - -def skip_if_centml_and_8b(inference_model, inference_impl): - """ - Skip if provider is CentML and the model is 8B. - CentML only supports 'meta-llama/Llama-3.2-3B-Instruct'. - """ - provider = inference_impl.routing_table.get_provider_impl(inference_model) - if provider.__provider_spec__.provider_type == "remote::centml" and "8b" in inference_model.lower( - ): - pytest.skip("CentML does not support Llama-3.1 8B model.") - - def get_expected_stop_reason(model: str): return ( StopReason.end_of_message @@ -111,11 +90,7 @@ class TestInference: # share the same provider instance. @pytest.mark.asyncio(loop_scope="session") async def test_model_list(self, inference_model, inference_stack): - inference_impl, models_impl = inference_stack - - # Skip if 8B + CentML - skip_if_centml_and_8b(inference_model, inference_impl) - + _, models_impl = inference_stack response = await models_impl.list_models() assert isinstance(response, list) assert len(response) >= 1 @@ -133,9 +108,6 @@ async def test_model_list(self, inference_model, inference_stack): async def test_completion(self, inference_model, inference_stack): inference_impl, _ = inference_stack - # Skip if 8B + CentML - skip_if_centml_and_8b(inference_model, inference_impl) - provider = inference_impl.routing_table.get_provider_impl(inference_model) if provider.__provider_spec__.provider_type not in ( "inline::meta-reference", @@ -181,9 +153,6 @@ async def test_completion(self, inference_model, inference_stack): async def test_completion_logprobs(self, inference_model, inference_stack): inference_impl, _ = inference_stack - # Skip if 8B + CentML - skip_if_centml_and_8b(inference_model, inference_impl) - provider = inference_impl.routing_table.get_provider_impl(inference_model) if provider.__provider_spec__.provider_type not in ( # "remote::nvidia", -- provider doesn't provide all logprobs @@ -287,10 +256,6 @@ async def test_chat_completion_non_streaming( self, inference_model, inference_stack, common_params, sample_messages ): inference_impl, _ = inference_stack - - # Skip if 8B + CentML - skip_if_centml_and_8b(inference_model, inference_impl) - response = await inference_impl.chat_completion( model_id=inference_model, messages=sample_messages, @@ -379,10 +344,6 @@ async def test_chat_completion_streaming( self, inference_model, inference_stack, common_params, sample_messages ): inference_impl, _ = inference_stack - - # Skip if 8B + CentML - skip_if_centml_and_8b(inference_model, inference_impl) - response = [ r async for r in await inference_impl.chat_completion( @@ -416,13 +377,6 @@ async def test_chat_completion_with_tool_calling( ): inference_impl, _ = inference_stack provider = inference_impl.routing_table.get_provider_impl(inference_model) - - # Skip if 8B + CentML - skip_if_centml_and_8b(inference_model, inference_impl) - - # Skip if CentML (it doesn't produce tool calls yet) - skip_if_centml_tool_call(provider) - if ( provider.__provider_spec__.provider_type == "remote::groq" and "Llama-3.2" in inference_model @@ -470,13 +424,6 @@ async def test_chat_completion_with_tool_calling_streaming( ): inference_impl, _ = inference_stack provider = inference_impl.routing_table.get_provider_impl(inference_model) - - # Skip if 8B + CentML - skip_if_centml_and_8b(inference_model, inference_impl) - - # Skip if CentML (it doesn't produce tool calls yet) - skip_if_centml_tool_call(provider) - if ( provider.__provider_spec__.provider_type == "remote::groq" and "Llama-3.2" in inference_model @@ -530,7 +477,7 @@ async def test_chat_completion_with_tool_calling_streaming( last = grouped[ChatCompletionResponseEventType.progress][-1] # assert last.event.stop_reason == expected_stop_reason assert last.event.delta.parse_status == ToolCallParseStatus.succeeded - assert last.event.delta.content.type == "tool_call" + assert isinstance(last.event.delta.content, ToolCall) call = last.event.delta.content assert call.tool_name == "get_weather"