Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat!: Enhance Reporting with CSV Export and Add Error Message to Reports #11

Merged
merged 3 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ The script uses a local SQLite database (`file_processing.db`) with the followin
- `total_embedding_tokens` (INTEGER): Total tokens used for embeddings.
- `total_llm_cost` (REAL): Total cost incurred for LLM operations.
- `total_llm_tokens` (INTEGER): Total tokens used for LLM operations.
- `error_message` (TEXT): Details of errors if `execution_status` is `ERROR`; otherwise NULL.
- `updated_at` (TEXT): Last updated timestamp
- `created_at` (TEXT): Creation timestamp

Expand Down Expand Up @@ -66,6 +67,7 @@ This will display detailed usage information.
- `--print_report`: Print a detailed report of all processed files at the end.
- `--exclude_metadata`: Exclude metadata on tokens consumed and the context passed to LLMs for prompt studio exported tools in the result for each file.
- `--no_verify`: Disable SSL certificate verification. (By default, SSL verification is enabled.)
- `--csv_report`: Path to export the detailed report as a CSV file.

## Usage Examples

Expand Down
95 changes: 85 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
import time
import textwrap
import csv
from dataclasses import dataclass
from datetime import datetime
from functools import partial
Expand Down Expand Up @@ -35,6 +36,7 @@ class Arguments:
skip_unprocessed: bool = False
log_level: str = "INFO"
print_report: bool = False
csv_report: str = ""
include_metadata: bool = True
verify: bool = True

Expand All @@ -58,6 +60,7 @@ def init_db():
total_embedding_tokens INTEGER,
total_llm_cost REAL,
total_llm_tokens INTEGER,
error_message TEXT,
updated_at TEXT,
created_at TEXT
)"""
Expand All @@ -73,6 +76,7 @@ def init_db():
"total_embedding_tokens": "INTEGER",
"total_llm_cost": "REAL",
"total_llm_tokens": "INTEGER",
"error_message": "TEXT",
}

# Add missing columns
Expand Down Expand Up @@ -126,10 +130,14 @@ def update_db(
total_embedding_tokens = None
total_llm_cost = None
total_llm_tokens = None
error_message = None

if result is not None:
total_embedding_cost, total_llm_cost, total_embedding_tokens, total_llm_tokens = calculate_cost_and_tokens(result)

if execution_status == "ERROR":
error_message = extract_error_message(result)

conn = sqlite3.connect(DB_NAME)
conn.set_trace_callback(
lambda x: (
Expand All @@ -142,8 +150,8 @@ def update_db(
now = datetime.now().isoformat()
c.execute(
"""
INSERT OR REPLACE INTO file_status (file_name, execution_status, result, time_taken, status_code, status_api_endpoint, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, updated_at, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, COALESCE((SELECT created_at FROM file_status WHERE file_name = ?), ?))
INSERT OR REPLACE INTO file_status (file_name, execution_status, result, time_taken, status_code, status_api_endpoint, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, error_message, updated_at, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, COALESCE((SELECT created_at FROM file_status WHERE file_name = ?), ?))
""",
(
file_name,
Expand All @@ -156,6 +164,7 @@ def update_db(
total_embedding_tokens,
total_llm_cost,
total_llm_tokens,
error_message,
now,
file_name,
now,
Expand Down Expand Up @@ -211,6 +220,17 @@ def calculate_cost_and_tokens(result):

return total_embedding_cost, total_llm_cost, total_embedding_tokens, total_llm_tokens

# Exract error message from the result JSON
def extract_error_message(result):
result_data = json.loads(result)
# Check for error in extraction_result
extraction_result = result_data.get("extraction_result", [])
if extraction_result and isinstance(extraction_result, list):
for item in extraction_result:
if "error" in item and item["error"]:
return item["error"]
# Fallback to the parent error
return result_data.get("error", "No error message found")

# Print final summary with count of each status and average time using a single SQL query
def print_summary():
Expand Down Expand Up @@ -243,7 +263,7 @@ def print_report():
# Fetch required fields, including total_cost and total_tokens
c.execute(
"""
SELECT file_name, execution_status, time_taken, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens
SELECT file_name, execution_status, time_taken, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, error_message
FROM file_status
"""
)
Expand All @@ -254,23 +274,69 @@ def print_report():
print("\nDetailed Report:")
if report_data:
# Tabulate the data with column headers
headers = ["File Name", "Execution Status", "Time Elapsed (seconds)", "Total Embedding Cost", "Total Embedding Tokens", "Total LLM Cost", "Total LLM Tokens"]
headers = [
textwrap.fill(header, width=20)
for header in [
"File Name",
"Execution Status",
"Time Elapsed (seconds)",
"Total Embedding Cost",
"Total Embedding Tokens",
"Total LLM Cost",
"Total LLM Tokens",
"Error Message"
]
]


# Wrap text in each column to a specific width (e.g., 30 characters for file names and 20 for others) and return None if the value is NULL
formatted_data = []
# Wrap text in each column to a specific width (e.g., 30 characters for file names and 20 for others) and return None if the value is NULL
for row in report_data:
formatted_row = [
"None" if cell is None else
textwrap.fill(str(cell), width=30) if isinstance(cell, str) else
f"{cell:.8f}" if isinstance(cell, float) else cell
for cell in row
]
formatted_data.append(formatted_row)
textwrap.fill(str(cell), width=30) if isinstance(cell, str) else
cell if idx == 2 else f"{cell:.8f}" if isinstance(cell, float) else cell
for idx, cell in enumerate(row)
]
formatted_data.append(formatted_row)

print(tabulate(formatted_data, headers=headers, tablefmt="pretty"))
else:
print("No records found in the database.")

def export_report_to_csv(output_path):
conn = sqlite3.connect(DB_NAME)
c = conn.cursor()

c.execute(
"""
SELECT file_name, execution_status, time_taken, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, error_message
FROM file_status
"""
)
report_data = c.fetchall()
conn.close()

if not report_data:
print("No data available to export.")
return

# Define the headers
headers = [
"File Name", "Execution Status", "Time Elapsed (seconds)",
"Total Embedding Cost", "Total Embedding Tokens",
"Total LLM Cost", "Total LLM Tokens", "Error Message"
]

try:
with open(output_path, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(headers) # Write headers
writer.writerows(report_data) # Write data rows
print(f"CSV successfully exported to {output_path}")
except Exception as e:
print(f"Error exporting to CSV: {e}")


def get_status_endpoint(file_path, client, args: Arguments):
"""Returns status_endpoint, status and response (if available)"""
Expand Down Expand Up @@ -523,6 +589,12 @@ def main():
help="Disable SSL certificate verification.",
)

parser.add_argument(
'--csv_report',
dest="csv_report",
type=str,
help='Path to export the detailed report as a CSV file',
)

args = Arguments(**vars(parser.parse_args()))

Expand All @@ -543,6 +615,9 @@ def main():
"Elapsed time calculation of a file which was resumed"
" from pending state will not be correct"
)

if args.csv_report:
export_report_to_csv(args.csv_report)


if __name__ == "__main__":
Expand Down