-
Notifications
You must be signed in to change notification settings - Fork 76
/
Copy pathtext2sql.py
81 lines (66 loc) · 2.38 KB
/
text2sql.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import argparse
import json
import os
import re
import sqlite3
import time
import pandas as pd
from lotus.models import OpenAIModel
def run_row(query_row):
text2sql_prompt = query_row["text2sql_prompt"]
db_name = query_row["DB used"]
raw_answer = lm([[{"role": "user", "content": text2sql_prompt}]])[0]
sql_statements = re.findall(r"```sql\n(.*?)\n```", raw_answer, re.DOTALL)
if not sql_statements:
sql_statements = re.findall(r"```\n(.*?)\n```", raw_answer, re.DOTALL)
if not sql_statements:
sql_statements = [raw_answer]
last_sql_statement = sql_statements[-1]
try:
try:
answer = eval(query_row["Answer"])
except Exception:
answer = query_row["Answer"]
conn = sqlite3.connect(f"../dev_folder/dev_databases/{db_name}/{db_name}.sqlite")
cursor = conn.cursor()
cursor.execute(last_sql_statement)
raw_results = cursor.fetchall()
predictions = [res[0] for res in raw_results]
if not isinstance(answer, list):
predictions = predictions[0] if predictions else None
return {
"query_id": query_row["Query ID"],
"sql_statement": last_sql_statement,
"prediction": predictions,
"answer": answer,
}
except Exception as e:
return {
"error": f"Error running SQL statement: {last_sql_statement}\n{e}",
"query_id": query_row["Query ID"],
}
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--df_path", default="../tag_queries.csv", type=str)
parser.add_argument("--output_dir", type=str)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
queries_df = pd.read_csv(args.df_path)
lm = OpenAIModel(
model="meta-llama/Meta-Llama-3.1-70B-Instruct",
api_base="http://localhost:8000/v1",
provider="vllm",
max_tokens=512,
)
if args.output_dir:
os.makedirs(args.output_dir, exist_ok=True)
for _, query_row in queries_df.iterrows():
tic = time.time()
output = run_row(query_row)
latency = time.time() - tic
output["latency"] = latency
print(output)
if args.output_dir:
with open(os.path.join(args.output_dir, f"query_{query_row['Query ID']}.json"), "w+") as f:
json.dump(output, f)