Skip to content

Commit

Permalink
feat(router.py): support request prioritization for text completion c… (
Browse files Browse the repository at this point in the history
#7540)

* feat(router.py): support request prioritization for text completion calls

* fix(internal_user_endpoints.py): fix sql query to return all keys, including null team id keys on `/user/info`

Fixes #7485

* fix: fix linting errors

* fix: fix linting error

* test(test_router_helper_utils.py): add direct test for '_schedule_factory'

Fixes code qa test
  • Loading branch information
krrishdholakia authored Jan 4, 2025
1 parent f770dd0 commit d43d83f
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 3 deletions.
5 changes: 5 additions & 0 deletions docs/my-website/docs/scheduler.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ Prioritize LLM API requests in high-traffic.
- Priority - The lower the number, the higher the priority:
* e.g. `priority=0` > `priority=2000`

Supported Router endpoints:
- `acompletion` (`/v1/chat/completions` on Proxy)
- `atext_completion` (`/v1/completions` on Proxy)


## Quick Start

```python
Expand Down
30 changes: 28 additions & 2 deletions litellm/llms/openai/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1928,6 +1928,10 @@ async def async_get_assistants(
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
order: Optional[str] = "desc",
limit: Optional[int] = 20,
before: Optional[str] = None,
after: Optional[str] = None,
) -> AsyncCursorPage[Assistant]:
openai_client = self.async_get_openai_client(
api_key=api_key,
Expand All @@ -1937,8 +1941,16 @@ async def async_get_assistants(
organization=organization,
client=client,
)
request_params = {
"order": order,
"limit": limit,
}
if before:
request_params["before"] = before
if after:
request_params["after"] = after

response = await openai_client.beta.assistants.list()
response = await openai_client.beta.assistants.list(**request_params) # type: ignore

return response

Expand Down Expand Up @@ -1981,6 +1993,10 @@ def get_assistants(
organization: Optional[str],
client=None,
aget_assistants=None,
order: Optional[str] = "desc",
limit: Optional[int] = 20,
before: Optional[str] = None,
after: Optional[str] = None,
):
if aget_assistants is not None and aget_assistants is True:
return self.async_get_assistants(
Expand All @@ -2000,7 +2016,17 @@ def get_assistants(
client=client,
)

response = openai_client.beta.assistants.list()
request_params = {
"order": order,
"limit": limit,
}

if before:
request_params["before"] = before
if after:
request_params["after"] = after

response = openai_client.beta.assistants.list(**request_params) # type: ignore

return response

Expand Down
5 changes: 5 additions & 0 deletions litellm/proxy/_new_secret_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ model_list:
litellm_params:
model: openai/gpt-3.5-turbo
api_key: os.environ/OPENAI_API_KEY
- model_name: openai-text-completion
litellm_params:
model: openai/gpt-3.5-turbo
api_key: os.environ/OPENAI_API_KEY
- model_name: chatbot_actions
litellm_params:
model: langfuse/azure/gpt-4o
Expand All @@ -11,5 +15,6 @@ model_list:
tpm: 1000000
prompt_id: "jokes"


litellm_settings:
callbacks: ["otel"]
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ async def _get_user_info_for_proxy_admin():
sql_query = """
SELECT
(SELECT json_agg(t.*) FROM "LiteLLM_TeamTable" t) as teams,
(SELECT json_agg(k.*) FROM "LiteLLM_VerificationToken" k WHERE k.team_id != 'litellm-dashboard') as keys
(SELECT json_agg(k.*) FROM "LiteLLM_VerificationToken" k WHERE k.team_id != 'litellm-dashboard' OR k.team_id IS NULL) as keys
"""
if prisma_client is None:
raise Exception(
Expand All @@ -413,6 +413,8 @@ async def _get_user_info_for_proxy_admin():

results = await prisma_client.db.query_raw(sql_query)

verbose_proxy_logger.debug("results_keys: %s", results)

_keys_in_db: List = results[0]["keys"] or []
# cast all keys to LiteLLM_VerificationToken
keys_in_db = []
Expand Down
70 changes: 70 additions & 0 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,6 +1356,67 @@ async def schedule_acompletion(
llm_provider="openai",
)

async def _schedule_factory(
self,
model: str,
priority: int,
original_function: Callable,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
):
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
### FLOW ITEM ###
_request_id = str(uuid.uuid4())
item = FlowItem(
priority=priority, # 👈 SET PRIORITY FOR REQUEST
request_id=_request_id, # 👈 SET REQUEST ID
model_name=model, # 👈 SAME as 'Router'
)
### [fin] ###

## ADDS REQUEST TO QUEUE ##
await self.scheduler.add_request(request=item)

## POLL QUEUE
end_time = time.time() + self.timeout
curr_time = time.time()
poll_interval = self.scheduler.polling_interval # poll every 3ms
make_request = False

while curr_time < end_time:
_healthy_deployments, _ = await self._async_get_healthy_deployments(
model=model, parent_otel_span=parent_otel_span
)
make_request = await self.scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue
id=item.request_id,
model_name=item.model_name,
health_deployments=_healthy_deployments,
)
if make_request: ## IF TRUE -> MAKE REQUEST
break
else: ## ELSE -> loop till default_timeout
await asyncio.sleep(poll_interval)
curr_time = time.time()

if make_request:
try:
_response = await original_function(*args, **kwargs)
if isinstance(_response._hidden_params, dict):
_response._hidden_params.setdefault("additional_headers", {})
_response._hidden_params["additional_headers"].update(
{"x-litellm-request-prioritization-used": True}
)
return _response
except Exception as e:
setattr(e, "priority", priority)
raise e
else:
raise litellm.Timeout(
message="Request timed out while polling queue",
model=model,
llm_provider="openai",
)

def image_generation(self, prompt: str, model: str, **kwargs):
try:
kwargs["model"] = model
Expand Down Expand Up @@ -1844,10 +1905,19 @@ async def atext_completion(
is_async: Optional[bool] = False,
**kwargs,
):
if kwargs.get("priority", None) is not None:
return await self._schedule_factory(
model=model,
priority=kwargs.pop("priority"),
original_function=self.atext_completion,
args=(model, prompt),
kwargs=kwargs,
)
try:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["original_function"] = self._atext_completion

self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
response = await self.async_function_with_fallbacks(**kwargs)

Expand Down
75 changes: 75 additions & 0 deletions tests/proxy_unit_tests/test_proxy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,3 +1252,78 @@ def test_get_model_group_info():
model_group="openai/tts-1",
)
assert len(model_list) == 1


import pytest
import asyncio
from unittest.mock import AsyncMock, patch
import json


@pytest.fixture
def mock_team_data():
return [
{"team_id": "team1", "team_name": "Test Team 1"},
{"team_id": "team2", "team_name": "Test Team 2"},
]


@pytest.fixture
def mock_key_data():
return [
{"token": "test_token_1", "key_name": "key1", "team_id": None, "spend": 0},
{"token": "test_token_2", "key_name": "key2", "team_id": "team1", "spend": 100},
{
"token": "test_token_3",
"key_name": "key3",
"team_id": "litellm-dashboard",
"spend": 50,
},
]


class MockDb:
def __init__(self, mock_team_data, mock_key_data):
self.mock_team_data = mock_team_data
self.mock_key_data = mock_key_data

async def query_raw(self, query: str, *args):
# Simulate the SQL query response
filtered_keys = [
k
for k in self.mock_key_data
if k["team_id"] != "litellm-dashboard" or k["team_id"] is None
]

return [{"teams": self.mock_team_data, "keys": filtered_keys}]


class MockPrismaClientDB:
def __init__(
self,
mock_team_data,
mock_key_data,
):
self.db = MockDb(mock_team_data, mock_key_data)


@pytest.mark.asyncio
async def test_get_user_info_for_proxy_admin(mock_team_data, mock_key_data):
# Patch the prisma_client import
from litellm.proxy._types import UserInfoResponse

with patch(
"litellm.proxy.proxy_server.prisma_client",
MockPrismaClientDB(mock_team_data, mock_key_data),
):

from litellm.proxy.management_endpoints.internal_user_endpoints import (
_get_user_info_for_proxy_admin,
)

# Execute the function
result = await _get_user_info_for_proxy_admin()

# Verify the result structure
assert isinstance(result, UserInfoResponse)
assert len(result.keys) == 2
43 changes: 43 additions & 0 deletions tests/router_unit_tests/test_router_helper_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,49 @@ async def test_router_schedule_acompletion(model_list):
assert response["choices"][0]["message"]["content"] == "I'm fine, thank you!"


@pytest.mark.asyncio
async def test_router_schedule_atext_completion(model_list):
"""Test if the 'schedule_atext_completion' function is working correctly"""
from litellm.types.utils import TextCompletionResponse

router = Router(model_list=model_list)
with patch.object(
router, "_atext_completion", AsyncMock()
) as mock_atext_completion:
mock_atext_completion.return_value = TextCompletionResponse()
response = await router.atext_completion(
model="gpt-3.5-turbo",
prompt="Hello, how are you?",
priority=1,
)
mock_atext_completion.assert_awaited_once()
assert "priority" not in mock_atext_completion.call_args.kwargs


@pytest.mark.asyncio
async def test_router_schedule_factory(model_list):
"""Test if the 'schedule_atext_completion' function is working correctly"""
from litellm.types.utils import TextCompletionResponse

router = Router(model_list=model_list)
with patch.object(
router, "_atext_completion", AsyncMock()
) as mock_atext_completion:
mock_atext_completion.return_value = TextCompletionResponse()
response = await router._schedule_factory(
model="gpt-3.5-turbo",
args=(
"gpt-3.5-turbo",
"Hello, how are you?",
),
priority=1,
kwargs={},
original_function=router.atext_completion,
)
mock_atext_completion.assert_awaited_once()
assert "priority" not in mock_atext_completion.call_args.kwargs


@pytest.mark.asyncio
async def test_router_arealtime(model_list):
"""Test if the '_arealtime' function is working correctly"""
Expand Down

0 comments on commit d43d83f

Please sign in to comment.