Skip to content

Commit

Permalink
get the bot running SQL code again
Browse files Browse the repository at this point in the history
  • Loading branch information
lostmygithubaccount committed Feb 13, 2024
1 parent d7ae737 commit d2789e0
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 313 deletions.
2 changes: 2 additions & 0 deletions src/ibis_birdbrain/attachments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __repr__(self):
from ibis_birdbrain.attachments.text import (
TextAttachment,
CodeAttachment,
ErrorAttachment,
WebpageAttachment,
)

Expand All @@ -110,5 +111,6 @@ def __repr__(self):
"ChartAttachment",
"TextAttachment",
"CodeAttachment",
"ErrorAttachment",
"WebpageAttachment",
]
21 changes: 21 additions & 0 deletions src/ibis_birdbrain/attachments/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,24 @@ def __str__(self):
**language**: {self.language}
**code**:\n{self.content}"""
)

class ErrorAttachment(TextAttachment):
"""An error attachment."""

content: str

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def encode(self):
...

def decode(self):
...

def __str__(self):
return (
super().__str__()
+ f"""
**error**:\n{self.content}"""
)
147 changes: 78 additions & 69 deletions src/ibis_birdbrain/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Attachments,
TableAttachment,
CodeAttachment,
ErrorAttachment,
)
from ibis_birdbrain.messages import Message, Messages, Email
from ibis_birdbrain.utils.messages import to_message
Expand Down Expand Up @@ -84,8 +85,8 @@ class Flows(Enum):
VISUALIZE = "visualize"


class EDAFlows(Enum):
"""Ibis Birdbrain EDA flows."""
class SQLFlows(Enum):
"""Ibis Birdbrain SQL flows."""

GET_CODE = "get_code"
FIX_CODE = "fix_code"
Expand All @@ -102,7 +103,7 @@ def respond(messages: Messages) -> str:
@marvin.fn
def messages_to_text_query(messages: Messages) -> str:
"""Convert the messages to an English text query.
Returns the English prose that concisely describes the desired query.
"""

Expand All @@ -120,61 +121,6 @@ def text_to_sql(text: str, attachments: Attachments) -> str:
return _text_to_sql(text, attachments).strip().strip(";")


def respond_flow(messages: Messages) -> Messages:
pass


def sql_flow(messages: Messages) -> Messages:
extract_guid_instructions = f"""
Extract relevant attachment GUIDs (ONLY the ATTACHMENT GUIDs) from the messages.
Options include: {messages.attachments()}
"""

rm = Messages()

extract_guid_instructions = inspect.cleandoc(extract_guid_instructions)

guids = marvin.extract(
messages,
str,
instructions=extract_guid_instructions,
)
log.info(f"Extracted GUIDs: {guids}")

# get the attachments
attachments = Attachments()
# TODO: fix this ugliness
for guid in guids:
for message in messages:
if guid in messages[message].attachments:
attachments.append(messages[message].attachments[guid])
log.info(f"Attachments: {attachments}")

# get the text query
text_query = messages_to_text_query(messages)

# convert the text to SQL
sql = text_to_sql(text_query, attachments)
a = CodeAttachment(language="sql", content=sql)

# construct the response message
m = Email(
body=f"SQL attachted for query: {text_query}",
subject="SQL code",
attachments=[a],
)

# append the message to the response messages
rm.append(m)

return rm


def visualize_flow(messages: Messages) -> Messages:
pass


# bot
class Bot:
"""Ibis Birdbrain bot."""
Expand Down Expand Up @@ -253,22 +199,85 @@ def __call__(

match flow:
case Flows.RESPOND:
response = respond(self.messages)
m = Email(
body=response,
subject="response",
to_address=self.user_name,
from_address=self.name,
)
self.messages.append(m)
self.respond_flow()
case Flows.SQL_CODE:
ms = sql_flow(self.messages)
self.sql_flow()
self.respond_flow()
# append the table attachment(s) to the response message
# TODO: fix this ugliness
for m in ms:
self.messages.append(ms[m])
for attachment in self.messages[-2].attachments:
if isinstance(
self.messages[-2].attachments[attachment], TableAttachment
):
self.messages[-1].attachments.append(
self.messages[-2].attachments[attachment]
)
case Flows.VISUALIZE:
pass
case _:
pass

return self.messages[-1]

def sql_flow(self) -> None:
extract_guid_instructions = f"""
Extract relevant attachment GUIDs (ONLY the ATTACHMENT GUIDs) from the messages.
Options include: {self.messages.attachments()}
"""

extract_guid_instructions = inspect.cleandoc(extract_guid_instructions)

guids = marvin.extract(
self.messages,
str,
instructions=extract_guid_instructions,
)
log.info(f"Extracted GUIDs: {guids}")

# get the attachments
attachments = Attachments()
# TODO: fix this ugliness
for guid in guids:
for message in self.messages:
if guid in self.messages[message].attachments:
attachments.append(self.messages[message].attachments[guid])
# log.info(f"Attachments: {attachments}")

# get the text query
text_query = messages_to_text_query(self.messages)

# convert the text to SQL
sql = text_to_sql(text_query, attachments)
a = CodeAttachment(language="sql", content=sql)

# run the SQL
try:
t = self.con.sql(sql)
at = TableAttachment(t)
except Exception as e:
at = ErrorAttachment(e)
log.error(f"SQL error: {e}")

# construct the response message
m = Email(
body=f"SQL attachted for query: {text_query}",
subject="SQL code",
attachments=[a, at],
)

# append the message to the response messages
self.messages.append(m)

def respond_flow(self) -> None:
response = respond(self.messages)
m = Email(
body=response,
subject="response",
to_address=self.user_name,
from_address=self.name,
)
self.messages.append(m)

def visualize_flow(self) -> None:
pass
Loading

0 comments on commit d2789e0

Please sign in to comment.