Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new models (Perplexity, gemini) & Separate GPT versions #2856

Merged
merged 4 commits into from
Dec 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fastchat/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION."
INACTIVE_MSG = "THIS SESSION HAS BEEN INACTIVE FOR TOO LONG. PLEASE REFRESH THIS PAGE."
SLOW_MODEL_MSG = "⚠️ Both models will show the responses all at once. Please stay patient as it may take over 30 seconds."
RATE_LIMIT_MSG = "**RATE LIMIT OF THIS MODEL IS REACHED. PLEASE COME BACK LATER OR TRY OTHER MODELS.**"
# Maximum input length
INPUT_CHAR_LEN_LIMIT = int(os.getenv("FASTCHAT_INPUT_CHAR_LEN_LIMIT", 12000))
# Maximum conversation turns
Expand Down
16 changes: 15 additions & 1 deletion fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,10 @@ def to_gradio_chatbot(self):

def to_openai_api_messages(self):
"""Convert the conversation to OpenAI chat completion format."""
ret = [{"role": "system", "content": self.system_message}]
if self.system_message == "":
ret = []
else:
ret = [{"role": "system", "content": self.system_message}]

for i, (_, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
Expand Down Expand Up @@ -679,6 +682,17 @@ def get_conv_template(name: str) -> Conversation:
)
)

# Perplexity AI template
register_conv_template(
Conversation(
name="pplxai",
system_message="Be precise and concise.",
roles=("user", "assistant"),
sep_style=None,
sep=None,
)
)

# Claude default template
register_conv_template(
Conversation(
Expand Down
37 changes: 36 additions & 1 deletion fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,8 +1038,12 @@ class ChatGPTAdapter(BaseModelAdapter):
def match(self, model_path: str):
return model_path in (
"gpt-3.5-turbo",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-1106",
"gpt-4",
"gpt-4-0314",
"gpt-4-0613",
"gpt-4-turbo",
)

Expand All @@ -1063,6 +1067,22 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("chatgpt")


class PplxAIAdapter(BaseModelAdapter):
"""The model adapter for Perplexity AI"""

def match(self, model_path: str):
return model_path in (
"pplx-7b-online",
"pplx-70b-online",
)

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
raise NotImplementedError()

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("pplxai")


class ClaudeAdapter(BaseModelAdapter):
"""The model adapter for Claude"""

Expand Down Expand Up @@ -1102,6 +1122,19 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("bard")


class GeminiAdapter(BaseModelAdapter):
"""The model adapter for Gemini"""

def match(self, model_path: str):
return model_path in ["gemini-pro"]

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
raise NotImplementedError()

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("bard")


class BiLLaAdapter(BaseModelAdapter):
"""The model adapter for Neutralzz/BiLLa-7B-SFT"""

Expand Down Expand Up @@ -1420,7 +1453,7 @@ class MistralAdapter(BaseModelAdapter):
"""The model adapter for Mistral AI models"""

def match(self, model_path: str):
return "mistral" in model_path.lower()
return "mistral" in model_path.lower() or "mixtral" in model_path.lower()

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
model, tokenizer = super().load_model(model_path, from_pretrained_kwargs)
Expand Down Expand Up @@ -2056,6 +2089,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(PhoenixAdapter)
register_model_adapter(BardAdapter)
register_model_adapter(PaLM2Adapter)
register_model_adapter(GeminiAdapter)
register_model_adapter(ChatGPTAdapter)
register_model_adapter(AzureOpenAIAdapter)
register_model_adapter(ClaudeAdapter)
Expand Down Expand Up @@ -2107,6 +2141,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(MicrosoftOrcaAdapter)
register_model_adapter(XdanAdapter)
register_model_adapter(YiAdapter)
register_model_adapter(PplxAIAdapter)
register_model_adapter(DeepseekCoderAdapter)
register_model_adapter(DeepseekChatAdapter)
register_model_adapter(MetaMathAdapter)
Expand Down
Loading