-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathllms.py
18 lines (15 loc) · 797 Bytes
/
llms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#define LLMs
from langchain_openai import ChatOpenAI
from langchain_groq import ChatGroq
from langchain_openai import OpenAIEmbeddings
import os
def load_llm(llm_name):
if llm_name=='openai':
llm = ChatOpenAI(model_name="gpt-4o", openai_api_key=os.environ["OPENAI_API_KEY"], temperature = 0.1, streaming=True) # type: ignore
if llm_name=='groq':
llm = ChatGroq(temperature=0.0, groq_api_key=os.environ["GROQ_API_KEY"], model_name="llama3-70b-8192" ) #temperature = 0.1 mixtral-8x7b-32768 llama3-70b-8192
if llm_name=="local":
llm = ChatOpenAI(model="llama3gradlt", base_url="http://localhost:11434/v1", temperature = 0.1)
return llm
def load_embedding():
return OpenAIEmbeddings(model="text-embedding-3-small", api_key=os.environ["OPENAI_API_KEY"])