Skip to content

Commit

Permalink
Add new models (Perplexity, gemini) & Separate GPT versions (lm-sys#2856
Browse files Browse the repository at this point in the history
)

Co-authored-by: Wei-Lin Chiang <[email protected]>
  • Loading branch information
2 people authored and zhanghao.smooth committed Jan 26, 2024
1 parent 5fb6249 commit 077aa66
Show file tree
Hide file tree
Showing 11 changed files with 322 additions and 114 deletions.
1 change: 1 addition & 0 deletions fastchat/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION."
INACTIVE_MSG = "THIS SESSION HAS BEEN INACTIVE FOR TOO LONG. PLEASE REFRESH THIS PAGE."
SLOW_MODEL_MSG = "⚠️ Both models will show the responses all at once. Please stay patient as it may take over 30 seconds."
RATE_LIMIT_MSG = "**RATE LIMIT OF THIS MODEL IS REACHED. PLEASE COME BACK LATER OR TRY OTHER MODELS.**"
# Maximum input length
INPUT_CHAR_LEN_LIMIT = int(os.getenv("FASTCHAT_INPUT_CHAR_LEN_LIMIT", 12000))
# Maximum conversation turns
Expand Down
16 changes: 15 additions & 1 deletion fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,10 @@ def to_gradio_chatbot(self):

def to_openai_api_messages(self):
"""Convert the conversation to OpenAI chat completion format."""
ret = [{"role": "system", "content": self.system_message}]
if self.system_message == "":
ret = []
else:
ret = [{"role": "system", "content": self.system_message}]

for i, (_, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
Expand Down Expand Up @@ -689,6 +692,17 @@ def get_conv_template(name: str) -> Conversation:
)
)

# Perplexity AI template
register_conv_template(
Conversation(
name="pplxai",
system_message="Be precise and concise.",
roles=("user", "assistant"),
sep_style=None,
sep=None,
)
)

# Claude default template
register_conv_template(
Conversation(
Expand Down
37 changes: 36 additions & 1 deletion fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,8 +1043,12 @@ class ChatGPTAdapter(BaseModelAdapter):
def match(self, model_path: str):
return model_path in (
"gpt-3.5-turbo",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-1106",
"gpt-4",
"gpt-4-0314",
"gpt-4-0613",
"gpt-4-turbo",
)

Expand All @@ -1068,6 +1072,22 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("chatgpt")


class PplxAIAdapter(BaseModelAdapter):
"""The model adapter for Perplexity AI"""

def match(self, model_path: str):
return model_path in (
"pplx-7b-online",
"pplx-70b-online",
)

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
raise NotImplementedError()

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("pplxai")


class ClaudeAdapter(BaseModelAdapter):
"""The model adapter for Claude"""

Expand Down Expand Up @@ -1107,6 +1127,19 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("bard")


class GeminiAdapter(BaseModelAdapter):
"""The model adapter for Gemini"""

def match(self, model_path: str):
return model_path in ["gemini-pro"]

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
raise NotImplementedError()

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("bard")


class BiLLaAdapter(BaseModelAdapter):
"""The model adapter for Neutralzz/BiLLa-7B-SFT"""

Expand Down Expand Up @@ -1425,7 +1458,7 @@ class MistralAdapter(BaseModelAdapter):
"""The model adapter for Mistral AI models"""

def match(self, model_path: str):
return "mistral" in model_path.lower()
return "mistral" in model_path.lower() or "mixtral" in model_path.lower()

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
model, tokenizer = super().load_model(model_path, from_pretrained_kwargs)
Expand Down Expand Up @@ -2102,6 +2135,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(PhoenixAdapter)
register_model_adapter(BardAdapter)
register_model_adapter(PaLM2Adapter)
register_model_adapter(GeminiAdapter)
register_model_adapter(ChatGPTAdapter)
register_model_adapter(AzureOpenAIAdapter)
register_model_adapter(ClaudeAdapter)
Expand Down Expand Up @@ -2154,6 +2188,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(MicrosoftOrcaAdapter)
register_model_adapter(XdanAdapter)
register_model_adapter(YiAdapter)
register_model_adapter(PplxAIAdapter)
register_model_adapter(DeepseekCoderAdapter)
register_model_adapter(DeepseekChatAdapter)
register_model_adapter(MetaMathAdapter)
Expand Down
Loading

0 comments on commit 077aa66

Please sign in to comment.