Skip to content
This repository has been archived by the owner on Mar 30, 2024. It is now read-only.

Commit

Permalink
Reduce memory usage for large models (#61)
Browse files Browse the repository at this point in the history
Signed-off-by: Hung-Han (Henry) Chen <[email protected]>
  • Loading branch information
chenhunghan authored Aug 25, 2023
1 parent 75b4edb commit a0ee699
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 160 deletions.
4 changes: 2 additions & 2 deletions charts/ialacol/Chart.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
apiVersion: v2
appVersion: 0.10.3
appVersion: 0.10.4
description: A Helm chart for ialacol
name: ialacol
type: application
version: 0.10.3
version: 0.10.4
2 changes: 0 additions & 2 deletions charts/ialacol/templates/deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ spec:
value: {{ (.Values.deployment.env).DEFAULT_MODEL_HG_REPO_ID | quote }}
- name: DEFAULT_MODEL_FILE
value: {{ (.Values.deployment.env).DEFAULT_MODEL_FILE | quote }}
- name: DOWNLOAD_DEFAULT_MODEL
value: {{ (.Values.deployment.env).DOWNLOAD_DEFAULT_MODEL | quote }}
- name: LOGGING_LEVEL
value: {{ (.Values.deployment.env).LOGGING_LEVEL | quote }}
- name: TOP_K
Expand Down
53 changes: 8 additions & 45 deletions get_auto_config.py → get_config.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,16 @@
import logging
from ctransformers import Config, AutoConfig
from ctransformers import Config

from request_body import ChatCompletionRequestBody, CompletionRequestBody
from get_env import get_env, get_env_or_none
from get_default_thread import get_default_thread
from get_model_type import get_model_type

LOGGING_LEVEL = get_env("LOGGING_LEVEL", "INFO")

log = logging.getLogger("uvicorn")
try:
log.setLevel(LOGGING_LEVEL)
except ValueError:
log.setLevel("INFO")
from log import log

THREADS = int(get_env("THREADS", str(get_default_thread())))


def get_auto_config(
def get_config(
body: CompletionRequestBody | ChatCompletionRequestBody,
) -> AutoConfig:
) -> Config:
# ggml only, follow ctransformers defaults
TOP_K = int(get_env("TOP_K", "40"))
# OpenAI API defaults https://platform.openai.com/docs/api-reference/chat/create#chat/create-top_p
Expand All @@ -38,10 +29,6 @@ def get_auto_config(
MAX_TOKENS = int(get_env("MAX_TOKENS", "9999999"))
# OpenAI API defaults https://platform.openai.com/docs/api-reference/chat/create#chat/create-stop
STOP = get_env_or_none("STOP")
# ggml only, follow ctransformers defaults
CONTEXT_LENGTH = int(get_env("CONTEXT_LENGTH", "-1"))
# the layers to offloading to the GPU
GPU_LAYERS = int(get_env("GPU_LAYERS", "0"))

log.debug("TOP_K: %s", TOP_K)
log.debug("TOP_P: %s", TOP_P)
Expand All @@ -53,34 +40,20 @@ def get_auto_config(
log.debug("THREADS: %s", THREADS)
log.debug("MAX_TOKENS: %s", MAX_TOKENS)
log.debug("STOP: %s", STOP)
log.debug("CONTEXT_LENGTH: %s", CONTEXT_LENGTH)
log.debug("GPU_LAYERS: %s", GPU_LAYERS)

top_k = body.top_k if body.top_k else TOP_K
top_p = body.top_p if body.top_p else TOP_P
temperature = body.temperature if body.temperature else TEMPERATURE
repetition_penalty = body.repetition_penalty if body.repetition_penalty else REPETITION_PENALTY
repetition_penalty = (
body.repetition_penalty if body.repetition_penalty else REPETITION_PENALTY
)
last_n_tokens = body.last_n_tokens if body.last_n_tokens else LAST_N_TOKENS
seed = body.seed if body.seed else SEED
batch_size = body.batch_size if body.batch_size else BATCH_SIZE
threads = body.threads if body.threads else THREADS
max_new_tokens = body.max_tokens if body.max_tokens else MAX_TOKENS
stop = body.stop if body.stop else STOP

log.info("top_k: %s", top_k)
log.info("top_p: %s", top_p)
log.info("temperature: %s", temperature)
log.info("repetition_penalty: %s", repetition_penalty)
log.info("last_n_tokens: %s", last_n_tokens)
log.info("seed: %s", seed)
log.info("batch_size: %s", batch_size)
log.info("threads: %s", threads)
log.info("max_new_tokens: %s", max_new_tokens)
log.info("stop: %s", stop)

log.info("CONTEXT_LENGTH: %s", CONTEXT_LENGTH)
log.info("GPU_LAYERS: %s", GPU_LAYERS)

config = Config(
top_k=top_k,
top_p=top_p,
Expand All @@ -92,16 +65,6 @@ def get_auto_config(
threads=threads,
max_new_tokens=max_new_tokens,
stop=stop,
context_length=CONTEXT_LENGTH,
gpu_layers=GPU_LAYERS,
)

model_type = get_model_type(body)

log.info("model_type: %s", model_type)

auto_config = AutoConfig(
config=config,
model_type=model_type,
)
return auto_config
return config
28 changes: 0 additions & 28 deletions get_llm.py

This file was deleted.

26 changes: 13 additions & 13 deletions get_model_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,36 @@


def get_model_type(
body: ChatCompletionRequestBody | CompletionRequestBody,
filename: str,
) -> str:
ctransformer_model_type = "llama"
# These are also in "starcoder" format
# https://huggingface.co/TheBloke/WizardCoder-15B-1.0-GGML
# https://huggingface.co/TheBloke/minotaur-15B-GGML
if (
"star" in body.model
or "starchat" in body.model
or "WizardCoder" in body.model
or "minotaur-15" in body.model
"star" in filename
or "starchat" in filename
or "WizardCoder" in filename
or "minotaur-15" in filename
):
ctransformer_model_type = "gpt_bigcode"
if "llama" in body.model:
if "llama" in filename:
ctransformer_model_type = "llama"
if "mpt" in body.model:
if "mpt" in filename:
ctransformer_model_type = "mpt"
if "replit" in body.model:
if "replit" in filename:
ctransformer_model_type = "replit"
if "falcon" in body.model:
if "falcon" in filename:
ctransformer_model_type = "falcon"
if "dolly" in body.model:
if "dolly" in filename:
ctransformer_model_type = "dolly-v2"
if "stablelm" in body.model:
if "stablelm" in filename:
ctransformer_model_type = "gpt_neox"
# matching https://huggingface.co/stabilityai/stablecode-completion-alpha-3b
if "stablecode" in body.model:
if "stablecode" in filename:
ctransformer_model_type = "gpt_neox"
# matching https://huggingface.co/EleutherAI/pythia-70m
if "pythia" in body.model:
if "pythia" in filename:
ctransformer_model_type = "gpt_neox"

MODE_TYPE = get_env("MODE_TYPE", "")
Expand Down
12 changes: 12 additions & 0 deletions log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import logging

from get_env import get_env


LOGGING_LEVEL = get_env("LOGGING_LEVEL", "INFO")

log = logging.getLogger("uvicorn")
try:
log.setLevel(LOGGING_LEVEL)
except ValueError:
log.setLevel("INFO")
Loading

0 comments on commit a0ee699

Please sign in to comment.