-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #436 from swarmauri/master
🚢 Shipping v0.4.4
- Loading branch information
Showing
16 changed files
with
705 additions
and
132 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,7 +33,7 @@ | |
'scikit-learn', | ||
'gensim', | ||
'textblob', | ||
'spacy==3.7.4', | ||
'spacy', | ||
'pygments', | ||
'gradio', | ||
'websockets', | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
__version__ = "0.4.3" | ||
__version__ = "0.4.4" | ||
__long_desc__ = """ | ||
# Swarmauri SDK | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,66 +1,76 @@ | ||
import json | ||
from typing import List, Literal, Dict | ||
from mistralai import Mistral | ||
from mistralai import Mistral | ||
from swarmauri.core.typing import SubclassUnion | ||
|
||
from swarmauri.standard.messages.base.MessageBase import MessageBase | ||
from swarmauri.standard.messages.concrete.AgentMessage import AgentMessage | ||
from swarmauri.standard.llms.base.LLMBase import LLMBase | ||
|
||
|
||
class MistralModel(LLMBase): | ||
"""Provider resources: https://docs.mistral.ai/getting-started/models/""" | ||
|
||
api_key: str | ||
allowed_models: List[str] = ['open-mistral-7b', | ||
'open-mixtral-8x7b', | ||
'open-mixtral-8x22b', | ||
'mistral-small-latest', | ||
'mistral-medium-latest', | ||
'mistral-large-latest', | ||
'codestral', | ||
'open-mistral-nemo', | ||
'codestral-latest', | ||
'open-codestral-mamba', | ||
allowed_models: List[str] = [ | ||
"open-mistral-7b", | ||
"open-mixtral-8x7b", | ||
"open-mixtral-8x22b", | ||
"mistral-small-latest", | ||
"mistral-medium-latest", | ||
"mistral-large-latest", | ||
"open-mistral-nemo", | ||
"codestral-latest", | ||
"open-codestral-mamba", | ||
] | ||
name: str = "open-mixtral-8x7b" | ||
type: Literal['MistralModel'] = 'MistralModel' | ||
type: Literal["MistralModel"] = "MistralModel" | ||
|
||
def _format_messages(self, messages: List[SubclassUnion[MessageBase]]) -> List[Dict[str, str]]: | ||
message_properties = ['content', 'role'] | ||
formatted_messages = [message.model_dump(include=message_properties, exclude_none=True) for message in messages] | ||
def _format_messages( | ||
self, messages: List[SubclassUnion[MessageBase]] | ||
) -> List[Dict[str, str]]: | ||
message_properties = ["content", "role"] | ||
formatted_messages = [ | ||
message.model_dump(include=message_properties, exclude_none=True) | ||
for message in messages | ||
] | ||
return formatted_messages | ||
|
||
def predict(self, | ||
conversation, | ||
temperature: int = 0.7, | ||
max_tokens: int = 256, | ||
def predict( | ||
self, | ||
conversation, | ||
temperature: int = 0.7, | ||
max_tokens: int = 256, | ||
top_p: int = 1, | ||
enable_json: bool=False, | ||
safe_prompt: bool=False): | ||
|
||
enable_json: bool = False, | ||
safe_prompt: bool = False, | ||
): | ||
|
||
formatted_messages = self._format_messages(conversation.history) | ||
|
||
client = Mistral(api_key=self.api_key) | ||
client = Mistral(api_key=self.api_key) | ||
if enable_json: | ||
response = client.chat.complete( | ||
model=self.name, | ||
messages=formatted_messages, | ||
temperature=temperature, | ||
response_format={ "type": "json_object" }, | ||
response_format={"type": "json_object"}, | ||
max_tokens=max_tokens, | ||
top_p=top_p, | ||
safe_prompt=safe_prompt | ||
safe_prompt=safe_prompt, | ||
) | ||
else: | ||
response = client.chat.complete( | ||
model=self.name, | ||
messages=formatted_messages, | ||
temperature=temperature, | ||
max_tokens=max_tokens, | ||
top_p=top_p, | ||
safe_prompt=safe_prompt | ||
top_p=top_p, | ||
safe_prompt=safe_prompt, | ||
) | ||
|
||
result = json.loads(response.json()) | ||
message_content = result['choices'][0]['message']['content'] | ||
message_content = result["choices"][0]["message"]["content"] | ||
conversation.add_message(AgentMessage(content=message_content)) | ||
|
||
return conversation | ||
return conversation |
Oops, something went wrong.