diff --git a/prompt-service/src/unstract/prompt_service/__init__.py b/prompt-service/src/unstract/prompt_service/__init__.py index e69de29bb..ffb749e2f 100644 --- a/prompt-service/src/unstract/prompt_service/__init__.py +++ b/prompt-service/src/unstract/prompt_service/__init__.py @@ -0,0 +1,39 @@ +from collections.abc import Generator +from contextlib import contextmanager +from os import environ as env +from typing import Any + +from playhouse.postgres_ext import PostgresqlExtDatabase + + +def get_env_or_die(env_key: str) -> str: + env_value = env.get(env_key) + if not env_value: + raise ValueError(f"Env variable {env_key} is required") + return env_value + + +# Load required environment variables +db_host = get_env_or_die("PG_BE_HOST") +db_port = get_env_or_die("PG_BE_PORT") +db_user = get_env_or_die("PG_BE_USERNAME") +db_pass = get_env_or_die("PG_BE_PASSWORD") +db_name = get_env_or_die("PG_BE_DATABASE") + +be_db = PostgresqlExtDatabase( + db_name, + user=db_user, + password=db_pass, + host=db_host, + port=db_port, +) + + +@contextmanager +def db_context() -> Generator[PostgresqlExtDatabase, Any, None]: + try: + be_db.connect() + yield be_db + finally: + if not be_db.is_closed(): + be_db.close() diff --git a/prompt-service/src/unstract/prompt_service/authentication_middleware.py b/prompt-service/src/unstract/prompt_service/authentication_middleware.py index fbf03c37b..c5bd45042 100644 --- a/prompt-service/src/unstract/prompt_service/authentication_middleware.py +++ b/prompt-service/src/unstract/prompt_service/authentication_middleware.py @@ -1,7 +1,7 @@ from typing import Any, Optional from flask import Request, current_app -from unstract.prompt_service.config import db +from unstract.prompt_service import be_db, db_context from unstract.prompt_service.constants import DBTableV2, FeatureFlag from unstract.flags.feature_flag import check_feature_flag_status @@ -22,9 +22,10 @@ def validate_bearer_token(token: Optional[str]) -> bool: platform_key_table = "account_platformkey" query = f"SELECT * FROM {platform_key_table} WHERE key = '{token}'" - cursor = db.execute_sql(query) - result_row = cursor.fetchone() - cursor.close() + with db_context(): + cursor = be_db.execute_sql(query) + result_row = cursor.fetchone() + cursor.close() if not result_row or len(result_row) == 0: current_app.logger.error( f"Authentication failed. bearer token not found {token}" @@ -84,8 +85,9 @@ def get_account_from_bearer_token(token: Optional[str]) -> str: @staticmethod def execute_query(query: str) -> Any: - cursor = db.execute_sql(query) - result_row = cursor.fetchone() + with db_context(): + cursor = be_db.execute_sql(query) + result_row = cursor.fetchone() cursor.close() if not result_row or len(result_row) == 0: return None diff --git a/prompt-service/src/unstract/prompt_service/config.py b/prompt-service/src/unstract/prompt_service/config.py index 5d6a2ef15..b9a0caf7d 100644 --- a/prompt-service/src/unstract/prompt_service/config.py +++ b/prompt-service/src/unstract/prompt_service/config.py @@ -4,7 +4,6 @@ from dotenv import load_dotenv from flask import Flask -from peewee import PostgresqlDatabase from unstract.prompt_service.constants import LogLevel load_dotenv() @@ -28,15 +27,6 @@ } ) -db = PostgresqlDatabase(None) - - -def get_env_or_die(env_key: str) -> str: - env_value = env.get(env_key) - if not env_value: - raise ValueError(f"Env variable {env_key} is required") - return env_value - def create_app() -> Flask: app = Flask("prompt-service") @@ -48,19 +38,4 @@ def create_app() -> Flask: else: app.logger.setLevel(logging.WARNING) - # Load required environment variables - db_host = get_env_or_die("PG_BE_HOST") - db_port = get_env_or_die("PG_BE_PORT") - db_user = get_env_or_die("PG_BE_USERNAME") - db_pass = get_env_or_die("PG_BE_PASSWORD") - db_name = get_env_or_die("PG_BE_DATABASE") - - # Initialize and connect to the database - db.init( - database=db_name, - user=db_user, - password=db_pass, - host=db_host, - port=db_port, - ) return app diff --git a/prompt-service/src/unstract/prompt_service/helper.py b/prompt-service/src/unstract/prompt_service/helper.py index ae37cc408..4c51df19e 100644 --- a/prompt-service/src/unstract/prompt_service/helper.py +++ b/prompt-service/src/unstract/prompt_service/helper.py @@ -7,8 +7,8 @@ from dotenv import load_dotenv from flask import Flask, current_app, json +from unstract.prompt_service import be_db, db_context from unstract.prompt_service.authentication_middleware import AuthenticationMiddleware -from unstract.prompt_service.config import db from unstract.prompt_service.constants import PromptServiceContants as PSKeys from unstract.prompt_service.exceptions import APIError, RateLimitError from unstract.sdk.exceptions import RateLimitError as SdkRateLimitError @@ -130,12 +130,13 @@ def query_usage_metadata(token: str, metadata: dict[str, Any]) -> dict[str, Any] """ logger: Logger = current_app.logger try: - with db.atomic(): - logger.info( - "Querying usage metadata for org_id: %s, run_id: %s", org_id, run_id - ) - cursor = db.execute_sql(query, (run_id,)) - results: list[tuple] = cursor.fetchall() + with db_context(): + with be_db.atomic(): + logger.info( + "Querying usage metadata for org_id: %s, run_id: %s", org_id, run_id + ) + cursor = be_db.execute_sql(query, (run_id,)) + results: list[tuple] = cursor.fetchall() # Process results as needed for row in results: key, item = _get_key_and_item(row) diff --git a/prompt-service/src/unstract/prompt_service/main.py b/prompt-service/src/unstract/prompt_service/main.py index b92fb0076..d80d53781 100644 --- a/prompt-service/src/unstract/prompt_service/main.py +++ b/prompt-service/src/unstract/prompt_service/main.py @@ -7,7 +7,7 @@ from flask import json, jsonify, request from llama_index.core.vector_stores import ExactMatchFilter, MetadataFilters from unstract.prompt_service.authentication_middleware import AuthenticationMiddleware -from unstract.prompt_service.config import create_app, db +from unstract.prompt_service.config import create_app from unstract.prompt_service.constants import PromptServiceContants as PSKeys from unstract.prompt_service.constants import RunLevel from unstract.prompt_service.exceptions import APIError, ErrorResponse, NoPayloadError @@ -48,12 +48,6 @@ plugin_loader(app) -@app.before_request -def before_request() -> None: - if db.is_closed(): - db.connect(reuse_if_open=True) - - def _publish_log( log_events_id: str, component: dict[str, str],