Skip to content

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
  • Loading branch information
infwinston committed Jan 3, 2024
1 parent 8b60a00 commit 14b4ae7
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 24 deletions.
37 changes: 17 additions & 20 deletions fastchat/serve/gradio_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,17 @@

ip_expiration_dict = defaultdict(lambda: 0)

# Information about custom OpenAI compatible API models.
# JSON file format:
# JSON file format of API-based models:
# {
# "vicuna-7b": {
# "model_name": "vicuna-7b-v1.5",
# "api_base": "http://8.8.8.55:5555/v1",
# "api_key": "password"
# "api_key": "password",
# "api_type": "openai", # openai, anthropic, palm, mistral
# "anony_only": false, # whether to show this model in anonymous mode only
# },
# }
openai_compatible_models_info = {}
api_endpoint_info = {}


class State:
Expand Down Expand Up @@ -122,7 +123,7 @@ def get_conv_log_filename():
return name


def get_model_list(controller_url, register_openai_compatible_models):
def get_model_list(controller_url, register_api_endpoint_file):
if controller_url:
ret = requests.post(controller_url + "/refresh_all_workers")
assert ret.status_code == 200
Expand All @@ -132,19 +133,17 @@ def get_model_list(controller_url, register_openai_compatible_models):
models = []

# Add API providers
if register_openai_compatible_models:
global openai_compatible_models_info
openai_compatible_models_info = json.load(
open(register_openai_compatible_models)
)
models += list(openai_compatible_models_info.keys())
if register_api_endpoint_file:
global api_endpoint_info
api_endpoint_info = json.load(open(register_api_endpoint_file))
models += list(api_endpoint_info.keys())

models = list(set(models))
visible_models = models.copy()
for mdl in visible_models:
if mdl not in openai_compatible_models_info:
if mdl not in api_endpoint_info:
continue
mdl_dict = openai_compatible_models_info[mdl]
mdl_dict = api_endpoint_info[mdl]
if mdl_dict["anony_only"]:
visible_models.remove(mdl)

Expand Down Expand Up @@ -181,7 +180,7 @@ def load_demo(url_params, request: gr.Request):
if args.model_list_mode == "reload":
models, all_models = get_model_list(
controller_url,
args.register_openai_compatible_models,
args.register_api_endpoint_file,
)

return load_demo_single(models, url_params)
Expand Down Expand Up @@ -367,9 +366,7 @@ def bot_response(

conv, model_name = state.conv, state.model_name
model_api_dict = (
openai_compatible_models_info[model_name]
if model_name in openai_compatible_models_info
else None
api_endpoint_info[model_name] if model_name in api_endpoint_info else None
)

if model_api_dict is None:
Expand Down Expand Up @@ -846,9 +843,9 @@ def build_demo(models):
help="Shows term of use before loading the demo",
)
parser.add_argument(
"--register-openai-compatible-models",
"--register-api-endpoint-file",
type=str,
help="Register custom OpenAI API compatible models by loading them from a JSON file",
help="Register API-based model endpoints from a JSON file",
)
parser.add_argument(
"--gradio-auth-path",
Expand All @@ -867,7 +864,7 @@ def build_demo(models):
set_global_vars(args.controller_url, args.moderate)
models, all_models = get_model_list(
args.controller_url,
args.register_openai_compatible_models,
args.register_api_endpoint_file,
)

# Set authorization credentials
Expand Down
8 changes: 4 additions & 4 deletions fastchat/serve/gradio_web_server_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def load_demo(url_params, request: gr.Request):
if args.model_list_mode == "reload":
models, all_models = get_model_list(
args.controller_url,
args.register_openai_compatible_models,
args.register_api_endpoint_file,
)

single_updates = load_demo_single(models, url_params)
Expand Down Expand Up @@ -164,9 +164,9 @@ def build_demo(models, elo_results_file, leaderboard_table_file):
help="Shows term of use before loading the demo",
)
parser.add_argument(
"--register-openai-compatible-models",
"--register-api-endpoint-file",
type=str,
help="Register custom OpenAI API compatible models by loading them from a JSON file",
help="Register API-based model endpoints from a JSON file",
)
parser.add_argument(
"--gradio-auth-path",
Expand Down Expand Up @@ -194,7 +194,7 @@ def build_demo(models, elo_results_file, leaderboard_table_file):
set_global_vars_anony(args.moderate)
models, all_models = get_model_list(
args.controller_url,
args.register_openai_compatible_models,
args.register_api_endpoint_file,
)

# Set authorization credentials
Expand Down

0 comments on commit 14b4ae7

Please sign in to comment.