-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into chore/modify_launch_script
- Loading branch information
Showing
10 changed files
with
414 additions
and
146 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
# inference/litellm/server.py | ||
import logging | ||
import traceback | ||
from typing import Optional | ||
import os | ||
import httpx | ||
from dotenv import load_dotenv | ||
from fastapi import APIRouter, HTTPException, Query | ||
from node.schemas import ChatCompletionRequest, CompletionRequest, EmbeddingsRequest | ||
|
||
load_dotenv() | ||
|
||
# Group all endpoints under "inference" in the Swagger docs | ||
router = APIRouter(prefix="/inference", tags=["inference"]) | ||
|
||
LITELLM_HTTP_TIMEOUT = 60 * 5 | ||
LITELLM_MASTER_KEY = os.environ.get("LITELLM_MASTER_KEY") | ||
if not LITELLM_MASTER_KEY: | ||
raise Exception("Missing LITELLM_MASTER_KEY for authentication") | ||
LITELLM_URL = "http://litellm:4000" if os.getenv("LAUNCH_DOCKER") == "true" else "http://localhost:4000" | ||
|
||
|
||
@router.get("/models", summary="List Models") | ||
async def models_endpoint(return_wildcard_routes: Optional[bool] = Query(False, alias="return_wildcard_routes")): | ||
logging.info("Received models list request") | ||
try: | ||
params = {"return_wildcard_routes": return_wildcard_routes} | ||
async with httpx.AsyncClient(timeout=LITELLM_HTTP_TIMEOUT) as client: | ||
response = await client.get( | ||
f"{LITELLM_URL}/models", | ||
params=params, | ||
headers={"Authorization": f"Bearer {LITELLM_MASTER_KEY}"} | ||
) | ||
logging.info(f"LiteLLM models response: {response.json()}") | ||
return response.json() | ||
except httpx.ReadTimeout: | ||
logging.error("Request to LiteLLM timed out") | ||
raise HTTPException(status_code=504, detail="Request to LiteLLM timed out") | ||
except Exception as e: | ||
logging.error(f"Error in models endpoint: {str(e)}") | ||
logging.error(f"Full traceback: {traceback.format_exc()}") | ||
raise HTTPException(status_code=500, detail=str(e)) | ||
|
||
|
||
@router.post("/chat/completions", summary="Chat Completion") | ||
async def chat_completions_endpoint( | ||
request_body: ChatCompletionRequest, | ||
model: Optional[str] = Query(None, description="Model") | ||
): | ||
""" | ||
Chat Completion endpoint following the OpenAI Chat API specification. | ||
""" | ||
logging.info("Received chat completions request") | ||
payload = request_body.model_dump(exclude_none=True) | ||
if model is not None: | ||
payload["model"] = model | ||
try: | ||
async with httpx.AsyncClient(timeout=LITELLM_HTTP_TIMEOUT) as client: | ||
response = await client.post( | ||
f"{LITELLM_URL}/chat/completions", | ||
json=payload, | ||
headers={"Authorization": f"Bearer {LITELLM_MASTER_KEY}"} | ||
) | ||
logging.info(f"LiteLLM response: {response.json()}") | ||
return response.json() | ||
except httpx.ReadTimeout: | ||
logging.error("Request to LiteLLM timed out") | ||
raise HTTPException(status_code=504, detail="Request to LiteLLM timed out") | ||
except Exception as e: | ||
logging.error(f"Error in chat completions endpoint: {str(e)}") | ||
logging.error(f"Full traceback: {traceback.format_exc()}") | ||
raise HTTPException(status_code=500, detail=str(e)) | ||
|
||
|
||
@router.post("/completions", summary="Completion") | ||
async def completions_endpoint( | ||
request_body: CompletionRequest, | ||
model: Optional[str] = Query(None, description="Model") | ||
): | ||
logging.info("Received completions request") | ||
payload = request_body.model_dump(exclude_none=True) | ||
if model is not None: | ||
payload["model"] = model | ||
try: | ||
async with httpx.AsyncClient(timeout=LITELLM_HTTP_TIMEOUT) as client: | ||
response = await client.post( | ||
f"{LITELLM_URL}/completions", | ||
json=payload, | ||
headers={"Authorization": f"Bearer {LITELLM_MASTER_KEY}"} | ||
) | ||
logging.info(f"LiteLLM completions response: {response.json()}") | ||
return response.json() | ||
except httpx.ReadTimeout: | ||
logging.error("Request to LiteLLM timed out") | ||
raise HTTPException(status_code=504, detail="Request to LiteLLM timed out") | ||
except Exception as e: | ||
logging.error(f"Error in completions endpoint: {str(e)}") | ||
logging.error(f"Full traceback: {traceback.format_exc()}") | ||
raise HTTPException(status_code=500, detail=str(e)) | ||
|
||
|
||
@router.post("/embeddings", summary="Embeddings") | ||
async def embeddings_endpoint( | ||
request_body: EmbeddingsRequest, | ||
model: Optional[str] = Query(None, description="Model") | ||
): | ||
logging.info("Received embeddings request") | ||
payload = request_body.model_dump(exclude_none=True) | ||
if model is not None: | ||
payload["model"] = model | ||
try: | ||
async with httpx.AsyncClient(timeout=LITELLM_HTTP_TIMEOUT) as client: | ||
response = await client.post( | ||
f"{LITELLM_URL}/embeddings", | ||
json=payload, | ||
headers={"Authorization": f"Bearer {LITELLM_MASTER_KEY}"} | ||
) | ||
logging.info(f"LiteLLM embeddings response: {response.json()}") | ||
return response.json() | ||
except httpx.ReadTimeout: | ||
logging.error("Request to LiteLLM timed out") | ||
raise HTTPException(status_code=504, detail="Request to LiteLLM timed out") | ||
except Exception as e: | ||
logging.error(f"Error in embeddings endpoint: {str(e)}") | ||
logging.error(f"Full traceback: {traceback.format_exc()}") | ||
raise HTTPException(status_code=500, detail=str(e)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.