Skip to content

Commit

Permalink
Merge branch 'main' into chore/modify_launch_script
Browse files Browse the repository at this point in the history
  • Loading branch information
enricorotundo committed Feb 17, 2025
2 parents 9c9fdce + 88206a2 commit 8f6322f
Show file tree
Hide file tree
Showing 10 changed files with 414 additions and 146 deletions.
11 changes: 7 additions & 4 deletions Dockerfile-node
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@ SHELL ["/bin/bash", "-c"]

# install deps
RUN apt-get update
RUN apt-get install curl gcc curl git -y
RUN apt-get install gcc curl git libpq-dev -y

# add conda install to path; use base environment
ENV PATH="/opt/conda/bin:${PATH}"
RUN conda create -n node python=3.12
RUN conda create -y -n node python=3.12
RUN echo "source activate node" > /root/.bashrc
ENV PATH="/opt/conda/envs/node/bin:$PATH"
RUN echo

# install postgres (required for poetry to build psycopg2 from source)
RUN conda install -y conda-forge::postgresql=17.2
ENV LDFLAGS="-L/opt/conda/lib"
ENV CPPFLAGS="-I/opt/conda/include"

# set up poetry
ENV PATH="/root/.local/share/pypoetry/venv/bin/:${PATH}"
Expand All @@ -40,4 +44,3 @@ CMD set -x && \
((poetry run celery -A node.worker.main:app worker --loglevel=info | tee /dev/stdout) & ) && \
((poetry run python -m node.server.server --communication-protocol http --port 7001) &) && \
poetry run python -m node.server.server --communication-protocol ws --port 7002

10 changes: 7 additions & 3 deletions Dockerfile-node-dev
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@ SHELL ["/bin/bash", "-c"]

# install deps
RUN apt-get update
RUN apt-get install curl gcc curl git -y
RUN apt-get install gcc curl git libpq-dev -y

# add conda install to path; use base environment
ENV PATH="/opt/conda/bin:${PATH}"
RUN conda create -n node python=3.12
RUN conda create -y -n node python=3.12
RUN echo "source activate node" > /root/.bashrc
ENV PATH="/opt/conda/envs/node/bin:$PATH"
RUN echo

# install postgres (required for poetry to build psycopg2 from source)
RUN conda install -y conda-forge::postgresql=17.2
ENV LDFLAGS="-L/opt/conda/lib"
ENV CPPFLAGS="-I/opt/conda/include"

# set up poetry
ENV PATH="/root/.local/share/pypoetry/venv/bin/:${PATH}"
Expand Down
42 changes: 42 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,47 @@ restart-celery:
sudo systemctl status --no-pager celeryworker.service) & \
fi

restart-hub:
@echo "Checking LOCAL_HUB setting..."
@LOCAL_HUB=$$(if [ "$$(uname)" = "Darwin" ]; then \
grep -E '^LOCAL_HUB=' .env | cut -d'=' -f2 | tr -d '"' | tr -d ' '; \
else \
grep -oP 'LOCAL_HUB=\K.*' .env | tr -d '"' | tr -d ' '; \
fi); \
echo "LOCAL_HUB is $$LOCAL_HUB"; \
if [ "$$LOCAL_HUB" = "true" ]; then \
echo "LOCAL_HUB is True, cleaning and restarting hub..."; \
port=$$(grep HUB_DB_SURREAL_PORT .env | cut -d'=' -f2 | tr -d '"' | tr -d ' '); \
echo "Port is $$port"; \
pids=$$(lsof -ti:$$port || true); \
if [ -n "$$pids" ]; then \
echo "Killing process(es): $$pids"; \
kill -9 $$pids || true; \
fi; \
echo "Waiting for port $$port to be free..."; \
for i in $$(seq 1 10); do \
if ! lsof -ti:$$port 2>/dev/null | grep -q .; then \
echo "Port $$port is now free"; \
break; \
fi; \
if [ $$i -eq 10 ]; then \
echo "Failed to free port $$port after 10 seconds"; \
exit 1; \
fi; \
echo "Still waiting... ($$i/10)"; \
sleep 1; \
done; \
if [ -d "node/storage/hub/hub.db" ]; then \
rm -rf node/storage/hub/hub.db; \
echo "Hub.db removed"; \
fi; \
PYTHONPATH="$$(pwd)" poetry run python node/storage/hub/init_hub.py; \
PYTHONPATH="$$(pwd)" poetry run python node/storage/hub/init_hub.py --user; \
echo "Hub restarted successfully."; \
else \
echo "LOCAL_HUB is False, skipping hub restart."; \
fi

# Target to restart all node components in parallel
restart-node:
@echo "Restarting all components in parallel..."
Expand All @@ -136,6 +177,7 @@ restart-node:
@poetry lock
@poetry install
@echo "poetry install done"
@$(MAKE) restart-hub
@$(MAKE) restart-servers & $(MAKE) restart-celery
@wait
@echo "All node components have been restarted."
21 changes: 17 additions & 4 deletions launch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,19 @@ setup_poetry() {
poetry lock

# Install dependencies and create the virtual environment
poetry install
# pass env vars for as psycopg to build
# https://www.psycopg.org/docs/install.html#build-prerequisites
if [ "$os" = "Darwin" ]; then
PATH="/opt/homebrew/opt/postgresql@17/bin:$PATH" && \
LDFLAGS="-L/opt/homebrew/opt/postgresql@17/lib" && \
CPPFLAGS="-I/opt/homebrew/opt/postgresql@17/include" && \
poetry install
else
export PATH="/usr/lib/postgresql/16/bin:$PATH" && \
export LDFLAGS="-L/usr/lib/postgresql/16/lib" && \
export CPPFLAGS="-I/usr/lib/postgresql/16/include" && \
poetry install
fi

# Verify the presence of a .venv folder within the project directory
if [ -d ".venv" ]; then
Expand Down Expand Up @@ -1794,6 +1806,7 @@ else:

# Check LiteLLM
services+=("LiteLLM")
sleep 2
if curl -s http://localhost:4000/health > /dev/null; then
statuses+=("")
logs+=("")
Expand Down Expand Up @@ -2110,6 +2123,7 @@ main() {
install_python312
darwin_install_miniforge
darwin_clean_node
darwin_setup_local_db
setup_poetry
install_surrealdb
check_and_copy_env
Expand All @@ -2118,7 +2132,6 @@ main() {
darwin_start_rabbitmq
check_and_set_private_key
start_hub_surrealdb
darwin_setup_local_db
darwin_start_local_db
darwin_start_servers
darwin_start_celery_worker
Expand All @@ -2128,6 +2141,7 @@ main() {
install_python312
linux_install_miniforge
linux_clean_node
linux_setup_local_db
setup_poetry
install_surrealdb
check_and_copy_env
Expand All @@ -2136,7 +2150,6 @@ main() {
linux_start_rabbitmq
check_and_set_private_key
start_hub_surrealdb
linux_setup_local_db
linux_start_local_db
linux_start_servers
linux_start_celery_worker
Expand All @@ -2151,4 +2164,4 @@ main() {

if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then
main
fi
fi
126 changes: 126 additions & 0 deletions node/inference/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# inference/litellm/server.py
import logging
import traceback
from typing import Optional
import os
import httpx
from dotenv import load_dotenv
from fastapi import APIRouter, HTTPException, Query
from node.schemas import ChatCompletionRequest, CompletionRequest, EmbeddingsRequest

load_dotenv()

# Group all endpoints under "inference" in the Swagger docs
router = APIRouter(prefix="/inference", tags=["inference"])

LITELLM_HTTP_TIMEOUT = 60 * 5
LITELLM_MASTER_KEY = os.environ.get("LITELLM_MASTER_KEY")
if not LITELLM_MASTER_KEY:
raise Exception("Missing LITELLM_MASTER_KEY for authentication")
LITELLM_URL = "http://litellm:4000" if os.getenv("LAUNCH_DOCKER") == "true" else "http://localhost:4000"


@router.get("/models", summary="List Models")
async def models_endpoint(return_wildcard_routes: Optional[bool] = Query(False, alias="return_wildcard_routes")):
logging.info("Received models list request")
try:
params = {"return_wildcard_routes": return_wildcard_routes}
async with httpx.AsyncClient(timeout=LITELLM_HTTP_TIMEOUT) as client:
response = await client.get(
f"{LITELLM_URL}/models",
params=params,
headers={"Authorization": f"Bearer {LITELLM_MASTER_KEY}"}
)
logging.info(f"LiteLLM models response: {response.json()}")
return response.json()
except httpx.ReadTimeout:
logging.error("Request to LiteLLM timed out")
raise HTTPException(status_code=504, detail="Request to LiteLLM timed out")
except Exception as e:
logging.error(f"Error in models endpoint: {str(e)}")
logging.error(f"Full traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=str(e))


@router.post("/chat/completions", summary="Chat Completion")
async def chat_completions_endpoint(
request_body: ChatCompletionRequest,
model: Optional[str] = Query(None, description="Model")
):
"""
Chat Completion endpoint following the OpenAI Chat API specification.
"""
logging.info("Received chat completions request")
payload = request_body.model_dump(exclude_none=True)
if model is not None:
payload["model"] = model
try:
async with httpx.AsyncClient(timeout=LITELLM_HTTP_TIMEOUT) as client:
response = await client.post(
f"{LITELLM_URL}/chat/completions",
json=payload,
headers={"Authorization": f"Bearer {LITELLM_MASTER_KEY}"}
)
logging.info(f"LiteLLM response: {response.json()}")
return response.json()
except httpx.ReadTimeout:
logging.error("Request to LiteLLM timed out")
raise HTTPException(status_code=504, detail="Request to LiteLLM timed out")
except Exception as e:
logging.error(f"Error in chat completions endpoint: {str(e)}")
logging.error(f"Full traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=str(e))


@router.post("/completions", summary="Completion")
async def completions_endpoint(
request_body: CompletionRequest,
model: Optional[str] = Query(None, description="Model")
):
logging.info("Received completions request")
payload = request_body.model_dump(exclude_none=True)
if model is not None:
payload["model"] = model
try:
async with httpx.AsyncClient(timeout=LITELLM_HTTP_TIMEOUT) as client:
response = await client.post(
f"{LITELLM_URL}/completions",
json=payload,
headers={"Authorization": f"Bearer {LITELLM_MASTER_KEY}"}
)
logging.info(f"LiteLLM completions response: {response.json()}")
return response.json()
except httpx.ReadTimeout:
logging.error("Request to LiteLLM timed out")
raise HTTPException(status_code=504, detail="Request to LiteLLM timed out")
except Exception as e:
logging.error(f"Error in completions endpoint: {str(e)}")
logging.error(f"Full traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=str(e))


@router.post("/embeddings", summary="Embeddings")
async def embeddings_endpoint(
request_body: EmbeddingsRequest,
model: Optional[str] = Query(None, description="Model")
):
logging.info("Received embeddings request")
payload = request_body.model_dump(exclude_none=True)
if model is not None:
payload["model"] = model
try:
async with httpx.AsyncClient(timeout=LITELLM_HTTP_TIMEOUT) as client:
response = await client.post(
f"{LITELLM_URL}/embeddings",
json=payload,
headers={"Authorization": f"Bearer {LITELLM_MASTER_KEY}"}
)
logging.info(f"LiteLLM embeddings response: {response.json()}")
return response.json()
except httpx.ReadTimeout:
logging.error("Request to LiteLLM timed out")
raise HTTPException(status_code=504, detail="Request to LiteLLM timed out")
except Exception as e:
logging.error(f"Error in embeddings endpoint: {str(e)}")
logging.error(f"Full traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=str(e))
9 changes: 9 additions & 0 deletions node/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,3 +492,12 @@ class ChatCompletionRequest(BaseModel):
tool_choice: Optional[str] = None
parallel_tool_calls: Optional[bool] = None

class CompletionRequest(BaseModel):
model: str
prompt: str
max_tokens: Optional[int] = 50
temperature: Optional[float] = 0.7

class EmbeddingsRequest(BaseModel):
model: str
input: Union[str, List[str]]
28 changes: 2 additions & 26 deletions node/server/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from node.worker.package_worker import run_agent, run_tool, run_environment, run_orchestrator, run_kb, run_memory
from node.client import Node as NodeClient
from node.storage.server import router as storage_router
from node.inference.server import router as inference_router
import os

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -291,35 +292,10 @@ async def get_server_connection(self, server_id: str):
logger.error(f"Error getting server connection: {e}")
raise

####### Inference endpoints #######
@router.post("/inference/chat")
async def chat_endpoint(request: ChatCompletionRequest):
"""
Forward chat completion requests to litellm proxy
"""
logger.info(f"Received chat request: {request}")
try:
async with httpx.AsyncClient(timeout=LITELLM_HTTP_TIMEOUT) as client:
response = await client.post(
f"{LITELLM_URL}/chat/completions",
json=request.model_dump(exclude_none=True),
headers={
"Authorization": f"Bearer {LITELLM_MASTER_KEY}"
}
)
logger.info(f"LiteLLM response: {response.json()}")
return response.json()
except httpx.ReadTimeout:
logger.error("Request to LiteLLM timed out")
raise HTTPException(status_code=504, detail="Request to LiteLLM timed out")
except Exception as e:
logger.error(f"Error in chat endpoint: {str(e)}")
logger.error(f"Full traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=str(e))

# Include the router
self.app.include_router(router)
self.app.include_router(storage_router)
self.app.include_router(inference_router)

self.app.add_middleware(
CORSMiddleware,
Expand Down
Loading

0 comments on commit 8f6322f

Please sign in to comment.