Skip to content

Commit

Permalink
Merge pull request #20 from TJor-L/feature
Browse files Browse the repository at this point in the history
Feature
  • Loading branch information
TJor-L authored Sep 8, 2023
2 parents ad4c10a + 2d303e9 commit 9a63fe1
Show file tree
Hide file tree
Showing 6 changed files with 2,114 additions and 89 deletions.
24 changes: 24 additions & 0 deletions databaseMod/wineSample/dropWineData.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# -*- coding: utf-8 -*-
# File : dropWineData.py
# Time : 2023/9/7 17:56
# Author : Dijkstra Liu
# Email : [email protected]
#
#     /> —— フ
#      | `_  _ l
#      ノ ミ_xノ
#      /     |
#     /  ヽ  ノ
#     │  | | \
#  / ̄|   | | |
# | ( ̄ヽ__ヽ_)__)
#  \_つ
#
# Description:
from databaseMod.milvusDB import MilvusDB


mvs_db = MilvusDB()
mvs_db.db_name = 'wine'
mvs_db.collection = 'wine_data'
mvs_db.drop_collection_by_name('wine_data')
12 changes: 9 additions & 3 deletions databaseMod/wineSample/searchWine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from databaseMod.milvusDB import MilvusDB


mvs_db = MilvusDB()
mvs_db.db_name = 'wine'
mvs_db.collection = 'wine_data'
Expand All @@ -11,8 +10,15 @@

res = mvs_db.conduct_vector_similar_search(query="sweet red wine with price under 50", limit=5,
output_fields=output_fields)
print(res)

entity_strings = []
index = 1
for search_res in res:
for hit in search_res:
print(hit)
entity = hit.entity.to_dict()["entity"]
entity_str = "* Product " + str(index) + ": " + ', '.join(f"{key}: {value}" for key, value in entity.items())
entity_strings.append(entity_str)
index = index + 1

result_str = '\n'.join(entity_strings)
print(result_str)
1,750 changes: 1,750 additions & 0 deletions development.log

Large diffs are not rendered by default.

123 changes: 112 additions & 11 deletions modules/actions/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,46 @@
from langchain.schema import AgentAction, AgentFinish, OutputParserException
import re
import const

from databaseMod.milvusDB import MilvusDB
from flask import current_app
from modules.factories.connection_string_factory import create_connection_string
from modules.factories.database_factory import create_database
from modules.factories.embedding_factory import create_embedding
from modules.factories.model_factory import create_model

from langchain import PromptTemplate
from utils.logging import LOGGER

# TODO: try to connect to MYSQL to acquire history message based on database. The test implementation is let all the user share one common history message.
chat_history = ChatMessageHistory() # Change the memory location to save all message from users

mvs_db = MilvusDB()
mvs_db.db_name = 'wine'
mvs_db.collection = 'wine_data'
output_fields = [field.name for field in mvs_db.collection.schema.fields if
field.name not in {'id', 'wine_info_embed'}]


def wine_search(query: str) -> str:
"""A wine search tool, use it when you need to search products from your company"""

res = mvs_db.conduct_vector_similar_search(query=query, limit=5,
output_fields=output_fields)

entity_strings = []
index = 1
for search_res in res:
for hit in search_res:
entity = hit.entity.to_dict()["entity"]
entity_str = "* Product " + str(index) + ": " + ', '.join(
f"{key}: {value}" for key, value in entity.items())
entity_strings.append(entity_str)
index = index + 1

result_str = '\n'.join(entity_strings)
LOGGER.info("Search Tool found:{}".format(result_str))
return result_str


class ChatBase:
def __init__(self, model=None, in_memory=None, chats_history=None, number=None):
Expand All @@ -40,10 +69,82 @@ def convert_message(self, history):

def keep_memory_message(self, num, with_memory):
memory = ConversationBufferWindowMemory(
k=current_app.config.get("OPENAI_BUFFER_TOP_K", num) if with_memory else 0, chat_memory=chat_history, memory_key="chat_history", input_key="input")
k=current_app.config.get("OPENAI_BUFFER_TOP_K", num) if with_memory else 0, chat_memory=chat_history,
memory_key="chat_history", input_key="input")
LOGGER.info("Memory object created.")
return memory

def chat_with_database(self, query):
template = """
## Roles and Rules
Never forget your name is CognoPal.
You are created by Cogno. You are an AI Assistant Customized for Seamless Global Shopping.
You work as a personal shopping assistant recommending customers with products they might enjoy.
Always answer in the language the prospect asks in.
Keep your responses in short length to retain the user's attention.
Start the conversation by just a greeting and how is the prospect doing without
pitching in your first turn.
Always think about at which conversation stage you are at before answering:
1: Introduction: Start the conversation by introducing yourself.
Be polite and respectful while keeping the tone of the conversation professional.
Your greeting should be welcoming. Always clarify in your greeting the reason why you
are messaging.
2: Value proposition: Briefly explain how your product/service can benefit the prospect.
Focus on the unique selling points and value proposition of your product/service that
sets it apart from competitors.
3: Needs analysis: Ask open-ended questions to uncover the prospect's needs and pain
points. Listen carefully to their responses and take notes.
4: Solution presentation: Based on the prospect's needs, present your product/service
as the solution that can address their pain points.
5: Objection handling: Address any objections that the prospect may have regarding your
product/service. Be prepared to provide evidence or testimonials to support your claims.
## Product information
### Here are some company's products' information:
{products}
**I will provide you with a chat log later, as well as costumer's new input. Please respond to the costumer's new input.**
**Please try to use the product information given above. These products are from our company, please do not recommend products other than those mentioned in the information above and in the chat log.**
"""

products = wine_search(query)

prompt = PromptTemplate(
input_variables=["products"],
template=template,
)
template = prompt.format(products=products) + """
## Chat log:
{chat_history}
Customer: {input}
CognoPal:
"""
LOGGER.info("cur tem".format(template))
LOGGER.info("Database result: {}".format(products))
memory = ConversationBufferWindowMemory(
k=current_app.config.get("OPENAI_BUFFER_TOP_K", 5), chat_memory=chat_history, ai_prefix="CognoPal",human_prefix="Customer",
memory_key="chat_history",
input_key="input")

prompt = PromptTemplate(input_variables=["chat_history", "input"], template=template)

conversation = ConversationChain(
llm=create_model(current_app.config.get("MODEL", const.OPENAI)), verbose=True, memory=memory, prompt=prompt
)

reply = conversation.predict(input=query)
LOGGER.info("Reply generated: {}".format(reply))
history = messages_to_dict(chat_history.messages)
LOGGER.info("Chat function ends.")

return {"reply": reply, "history": history}

def chat(self, query, prompt="", isSearch=False):
tools = self.set_up_tools()
LOGGER.info("get into chat func")
Expand All @@ -69,11 +170,10 @@ def chat(self, query, prompt="", isSearch=False):
with_memory,
len(history)))
try:
# TODO: Add prompt
if isSearch is True:
response = self.search_from_knowledge_base(query=query)
LOGGER.info("DataBase result{}".format(response['reply']))
return response
# if isSearch is True:
# response = self.search_from_knowledge_base(query=query)
# LOGGER.info("DataBase result{}".format(response['reply']))
# return response
# Create a Conversation Chain
# Convert dicts to message object
self.convert_message(history)
Expand All @@ -99,7 +199,9 @@ def chat(self, query, prompt="", isSearch=False):
return {"reply": reply, "history": history}
else:

memory = ConversationBufferWindowMemory(k=5, chat_memory=chat_history, memory_key="chat_history", input_key="input")
products = wine_search(query)
memory = ConversationBufferWindowMemory(k=5, chat_memory=chat_history, memory_key="chat_history",
input_key="input")

output_parser = CustomOutputParser()

Expand All @@ -120,10 +222,9 @@ def chat(self, query, prompt="", isSearch=False):
agent=agent, tools=tools, verbose=True, memory=memory
)

reply = agent_executor.run({'input': query})
reply = agent_executor.run({'input': query, 'products': products})
LOGGER.info("Reply generated: {}".format(reply))


history = messages_to_dict(chat_history.messages)
LOGGER.info("Chat function ends.")

Expand Down Expand Up @@ -188,4 +289,4 @@ def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
action = match.group(1).strip()
action_input = match.group(2)
# Return the action and action input
return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output)
return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output)
Loading

0 comments on commit 9a63fe1

Please sign in to comment.