Skip to content

Commit

Permalink
sglang llm model
Browse files Browse the repository at this point in the history
  • Loading branch information
luv-bansal committed Jan 8, 2025
1 parent ec775ac commit 263878f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import itertools
import os
import sys
import threading
from typing import Iterator

from clarifai.runners.models.model_runner import ModelRunner
Expand Down Expand Up @@ -67,10 +65,13 @@ def __init__(self,

def start_server(self, python_executable, checkpoints):
try:
self.process = execute_shell_command(f"{python_executable} -m sglang.launch_server --model-path {checkpoints} --dtype {self.dtype} --tensor-parallel-size {self.tensor_parallel_size} --quantization {self.quantization} --mem-fraction-static {self.mem_fraction_static} --context-length {self.context_length} --port {self.port} --host localhost")
self.process = execute_shell_command(
f"{python_executable} -m sglang.launch_server --model-path {checkpoints} --dtype {self.dtype} --tensor-parallel-size {self.tensor_parallel_size} --quantization {self.quantization} --mem-fraction-static {self.mem_fraction_static} --context-length {self.context_length} --port {self.port} --host localhost"
)
wait_for_server(f'http://localhost:{self.port}')
except Exception as e:
if self.process:
logger.error("Terminating the sglang server process.")
terminate_process(self.process)
raise RuntimeError("Failed to start sglang server: " + str(e))

Expand All @@ -86,7 +87,7 @@ def load_model(self):
self.mem_fraction_static = 0.9
self.tensor_parallel_size = 1
self.dtype = "float16"
self.port = 4675
self.port = 8761
self.context_length = 4096
self.quantization = "awq"

Expand All @@ -105,10 +106,11 @@ def load_model(self):
python_executable = sys.executable

# if checkpoints section is in config.yaml file then checkpoints will be downloaded at this path during model upload time.
checkpoints = os.path.join(os.path.dirname(__file__), "checkpoints")
# checkpoints = os.path.join(os.path.dirname(__file__), "checkpoints")
checkpoints = "casperhansen/llama-3.3-70b-instruct-awq"

try:
# Start the sglang server
# Start the sglang server
self.server_manager.start_server(python_executable, checkpoints)
except Exception as e:
logger.error(f"Error starting sglang server: {e}")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Config file for the VLLM runner

model:
id: "sglang-llama-3.3-70b-instruct"
user_id: "luv_2261"
app_id: "test-upload"
id: "sglang-llama-3_3-70b-instruct"
user_id: "user_id"
app_id: "app_id"
model_type_id: "text-to-text"

build_info:
Expand All @@ -14,9 +14,9 @@ inference_compute_info:
cpu_memory: "16Gi"
num_accelerators: 1
accelerator_type: ["NVIDIA-L40S"]
accelerator_memory: "46Gi"
accelerator_memory: "40Gi"

checkpoints:
type: "huggingface"
repo_id: "casperhansen/llama-3.3-70b-instruct-awq"
hf_token: "token"
# checkpoints:
# type: "huggingface"
# repo_id: "casperhansen/llama-3.3-70b-instruct-awq"
# hf_token: "token"

0 comments on commit 263878f

Please sign in to comment.