Skip to content

Commit

Permalink
fix(key_management_endpoints.py): fix user-membership check when crea…
Browse files Browse the repository at this point in the history
…ting team key (#6890)

* fix(key_management_endpoints.py): fix user-membership check when creating team key

* docs: add deprecation notice on original `/v1/messages` endpoint + add better swagger tags on pass-through endpoints

* fix(gemini/): fix image_url handling for gemini

Fixes #6897

* fix(teams.tsx): fix member add when role is 'user'

* fix(team_endpoints.py): /team/member_add

fix adding several new members to team

* test(test_vertex.py): remove redundant test

* test(test_proxy_server.py): fix team member add tests
  • Loading branch information
krrishdholakia authored and ishaan-jaff committed Nov 28, 2024
1 parent c73ce95 commit 0f08577
Show file tree
Hide file tree
Showing 19 changed files with 401 additions and 164 deletions.
23 changes: 23 additions & 0 deletions litellm/llms/prompt_templates/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ChatCompletionAssistantToolCall,
ChatCompletionFunctionMessage,
ChatCompletionImageObject,
ChatCompletionImageUrlObject,
ChatCompletionTextObject,
ChatCompletionToolCallFunctionChunk,
ChatCompletionToolMessage,
Expand Down Expand Up @@ -681,6 +682,27 @@ def construct_tool_use_system_prompt(
return tool_use_system_prompt


def convert_generic_image_chunk_to_openai_image_obj(
image_chunk: GenericImageParsingChunk,
) -> str:
"""
Convert a generic image chunk to an OpenAI image object.
Input:
GenericImageParsingChunk(
type="base64",
media_type="image/jpeg",
data="...",
)
Return:
"data:image/jpeg;base64,{base64_image}"
"""
return "data:{};{},{}".format(
image_chunk["media_type"], image_chunk["type"], image_chunk["data"]
)


def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsingChunk:
"""
Input:
Expand All @@ -706,6 +728,7 @@ def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsing
data=base64_data,
)
except Exception as e:
traceback.print_exc()
if "Error: Unable to fetch image from URL" in str(e):
raise e
raise Exception(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,12 @@ def _transform_request_body(
optional_params = {k: v for k, v in optional_params.items() if k not in remove_keys}

try:
content = _gemini_convert_messages_with_history(messages=messages)
if custom_llm_provider == "gemini":
content = litellm.GoogleAIStudioGeminiConfig._transform_messages(
messages=messages
)
else:
content = litellm.VertexGeminiConfig._transform_messages(messages=messages)
tools: Optional[Tools] = optional_params.pop("tools", None)
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@
HTTPHandler,
get_async_httpx_client,
)
from litellm.llms.prompt_templates.factory import (
convert_generic_image_chunk_to_openai_image_obj,
convert_to_anthropic_image_obj,
)
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionResponseMessage,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
Expand Down Expand Up @@ -78,6 +83,8 @@
)
from ..vertex_llm_base import VertexBase
from .transformation import (
_gemini_convert_messages_with_history,
_process_gemini_image,
async_transform_request_body,
set_headers,
sync_transform_request_body,
Expand Down Expand Up @@ -912,6 +919,10 @@ def _transform_response(

return model_response

@staticmethod
def _transform_messages(messages: List[AllMessageValues]) -> List[ContentType]:
return _gemini_convert_messages_with_history(messages=messages)


class GoogleAIStudioGeminiConfig(
VertexGeminiConfig
Expand Down Expand Up @@ -1015,6 +1026,32 @@ def map_openai_params(
model, non_default_params, optional_params, drop_params
)

@staticmethod
def _transform_messages(messages: List[AllMessageValues]) -> List[ContentType]:
"""
Google AI Studio Gemini does not support image urls in messages.
"""
for message in messages:
_message_content = message.get("content")
if _message_content is not None and isinstance(_message_content, list):
_parts: List[PartType] = []
for element in _message_content:
if element.get("type") == "image_url":
img_element = element
_image_url: Optional[str] = None
if isinstance(img_element.get("image_url"), dict):
_image_url = img_element["image_url"].get("url") # type: ignore
else:
_image_url = img_element.get("image_url") # type: ignore
if _image_url and "https://" in _image_url:
image_obj = convert_to_anthropic_image_obj(_image_url)
img_element["image_url"] = ( # type: ignore
convert_generic_image_chunk_to_openai_image_obj(
image_obj
)
)
return _gemini_convert_messages_with_history(messages=messages)


async def make_call(
client: Optional[AsyncHTTPHandler],
Expand Down
20 changes: 20 additions & 0 deletions litellm/proxy/_new_secret_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,23 @@ model_list:
vertex_ai_project: "adroit-crow-413218"
vertex_ai_location: "us-east5"

router_settings:
routing_strategy: usage-based-routing-v2
#redis_url: "os.environ/REDIS_URL"
redis_host: "os.environ/REDIS_HOST"
redis_port: "os.environ/REDIS_PORT"

litellm_settings:
cache: true
cache_params:
type: redis
host: "os.environ/REDIS_HOST"
port: "os.environ/REDIS_PORT"
namespace: "litellm.caching"
ttl: 600
# key_generation_settings:
# team_key_generation:
# allowed_team_member_roles: ["admin"]
# required_params: ["tags"] # require team admins to set tags for cost-tracking when generating a team key
# personal_key_generation: # maps to 'Default Team' on UI
# allowed_user_roles: ["proxy_admin"]
1 change: 0 additions & 1 deletion litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1982,7 +1982,6 @@ def __init__(self, **data):
# Replace member_data with the single Member object
data["member"] = member
# Call the superclass __init__ method to initialize the object
traceback.print_stack()
super().__init__(**data)


Expand Down
143 changes: 98 additions & 45 deletions litellm/proxy/auth/auth_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,10 @@ async def _cache_management_object(
proxy_logging_obj: Optional[ProxyLogging],
):
await user_api_key_cache.async_set_cache(key=key, value=value)
if proxy_logging_obj is not None:
await proxy_logging_obj.internal_usage_cache.dual_cache.async_set_cache(
key=key, value=value
)


async def _cache_team_object(
Expand Down Expand Up @@ -586,33 +590,71 @@ async def _get_team_db_check(team_id: str, prisma_client: PrismaClient):
)


async def get_team_object(
async def _get_team_object_from_db(team_id: str, prisma_client: PrismaClient):
return await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)


async def _get_team_object_from_user_api_key_cache(
team_id: str,
prisma_client: Optional[PrismaClient],
prisma_client: PrismaClient,
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
check_cache_only: Optional[bool] = None,
last_db_access_time: LimitedSizeOrderedDict,
db_cache_expiry: int,
proxy_logging_obj: Optional[ProxyLogging],
key: str,
) -> LiteLLM_TeamTableCachedObj:
"""
- Check if team id in proxy Team Table
- if valid, return LiteLLM_TeamTable object with defined limits
- if not, then raise an error
"""
if prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
db_access_time_key = key
should_check_db = _should_check_db(
key=db_access_time_key,
last_db_access_time=last_db_access_time,
db_cache_expiry=db_cache_expiry,
)
if should_check_db:
response = await _get_team_db_check(
team_id=team_id, prisma_client=prisma_client
)
else:
response = None

# check if in cache
key = "team_id:{}".format(team_id)
if response is None:
raise Exception

_response = LiteLLM_TeamTableCachedObj(**response.dict())
# save the team object to cache
await _cache_team_object(
team_id=team_id,
team_table=_response,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)

# save to db access time
# save to db access time
_update_last_db_access_time(
key=db_access_time_key,
value=_response,
last_db_access_time=last_db_access_time,
)

return _response


async def _get_team_object_from_cache(
key: str,
proxy_logging_obj: Optional[ProxyLogging],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span],
) -> Optional[LiteLLM_TeamTableCachedObj]:
cached_team_obj: Optional[LiteLLM_TeamTableCachedObj] = None

## CHECK REDIS CACHE ##
if (
proxy_logging_obj is not None
and proxy_logging_obj.internal_usage_cache.dual_cache
):

cached_team_obj = (
await proxy_logging_obj.internal_usage_cache.dual_cache.async_get_cache(
key=key, parent_otel_span=parent_otel_span
Expand All @@ -628,47 +670,58 @@ async def get_team_object(
elif isinstance(cached_team_obj, LiteLLM_TeamTableCachedObj):
return cached_team_obj

if check_cache_only:
return None


async def get_team_object(
team_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
check_cache_only: Optional[bool] = None,
check_db_only: Optional[bool] = None,
) -> LiteLLM_TeamTableCachedObj:
"""
- Check if team id in proxy Team Table
- if valid, return LiteLLM_TeamTable object with defined limits
- if not, then raise an error
"""
if prisma_client is None:
raise Exception(
f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}."
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)

# else, check db
try:
db_access_time_key = "team_id:{}".format(team_id)
should_check_db = _should_check_db(
key=db_access_time_key,
last_db_access_time=last_db_access_time,
db_cache_expiry=db_cache_expiry,
# check if in cache
key = "team_id:{}".format(team_id)

if not check_db_only:
cached_team_obj = await _get_team_object_from_cache(
key=key,
proxy_logging_obj=proxy_logging_obj,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
)
if should_check_db:
response = await _get_team_db_check(
team_id=team_id, prisma_client=prisma_client
)
else:
response = None

if response is None:
raise Exception
if cached_team_obj is not None:
return cached_team_obj

if check_cache_only:
raise Exception(
f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}."
)

_response = LiteLLM_TeamTableCachedObj(**response.dict())
# save the team object to cache
await _cache_team_object(
# else, check db
try:
return await _get_team_object_from_user_api_key_cache(
team_id=team_id,
team_table=_response,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)

# save to db access time
# save to db access time
_update_last_db_access_time(
key=db_access_time_key,
value=_response,
last_db_access_time=last_db_access_time,
db_cache_expiry=db_cache_expiry,
key=key,
)

return _response
except Exception:
raise Exception(
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
Expand Down
Loading

0 comments on commit 0f08577

Please sign in to comment.