Skip to content

Commit

Permalink
fix(internal_user_endpoints.py): fix sql query to return all keys, in…
Browse files Browse the repository at this point in the history
…cluding null team id keys on `/user/info`

Fixes #7485
  • Loading branch information
krrishdholakia committed Jan 4, 2025
1 parent 2a813ed commit c5af281
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ async def _get_user_info_for_proxy_admin():
sql_query = """
SELECT
(SELECT json_agg(t.*) FROM "LiteLLM_TeamTable" t) as teams,
(SELECT json_agg(k.*) FROM "LiteLLM_VerificationToken" k WHERE k.team_id != 'litellm-dashboard') as keys
(SELECT json_agg(k.*) FROM "LiteLLM_VerificationToken" k WHERE k.team_id != 'litellm-dashboard' OR k.team_id IS NULL) as keys
"""
if prisma_client is None:
raise Exception(
Expand All @@ -413,6 +413,8 @@ async def _get_user_info_for_proxy_admin():

results = await prisma_client.db.query_raw(sql_query)

verbose_proxy_logger.debug("results_keys: %s", results)

_keys_in_db: List = results[0]["keys"] or []
# cast all keys to LiteLLM_VerificationToken
keys_in_db = []
Expand Down
75 changes: 75 additions & 0 deletions tests/proxy_unit_tests/test_proxy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,3 +1252,78 @@ def test_get_model_group_info():
model_group="openai/tts-1",
)
assert len(model_list) == 1


import pytest
import asyncio
from unittest.mock import AsyncMock, patch
import json


@pytest.fixture
def mock_team_data():
return [
{"team_id": "team1", "team_name": "Test Team 1"},
{"team_id": "team2", "team_name": "Test Team 2"},
]


@pytest.fixture
def mock_key_data():
return [
{"token": "test_token_1", "key_name": "key1", "team_id": None, "spend": 0},
{"token": "test_token_2", "key_name": "key2", "team_id": "team1", "spend": 100},
{
"token": "test_token_3",
"key_name": "key3",
"team_id": "litellm-dashboard",
"spend": 50,
},
]


class MockDb:
def __init__(self, mock_team_data, mock_key_data):
self.mock_team_data = mock_team_data
self.mock_key_data = mock_key_data

async def query_raw(self, query: str, *args):
# Simulate the SQL query response
filtered_keys = [
k
for k in self.mock_key_data
if k["team_id"] != "litellm-dashboard" or k["team_id"] is None
]

return [{"teams": self.mock_team_data, "keys": filtered_keys}]


class MockPrismaClientDB:
def __init__(
self,
mock_team_data,
mock_key_data,
):
self.db = MockDb(mock_team_data, mock_key_data)


@pytest.mark.asyncio
async def test_get_user_info_for_proxy_admin(mock_team_data, mock_key_data):
# Patch the prisma_client import
from litellm.proxy._types import UserInfoResponse

with patch(
"litellm.proxy.proxy_server.prisma_client",
MockPrismaClientDB(mock_team_data, mock_key_data),
):

from litellm.proxy.management_endpoints.internal_user_endpoints import (
_get_user_info_for_proxy_admin,
)

# Execute the function
result = await _get_user_info_for_proxy_admin()

# Verify the result structure
assert isinstance(result, UserInfoResponse)
assert len(result.keys) == 2

0 comments on commit c5af281

Please sign in to comment.