From fc77d70c405b06496892aadee2f732838ea3eee5 Mon Sep 17 00:00:00 2001 From: johnyrahul Date: Fri, 13 Sep 2024 23:53:34 +0530 Subject: [PATCH 1/2] Changes for connection retry in case of db server restart --- .../src/unstract/platform_service/__init__.py | 24 +++++++++++ .../platform_service/controller/platform.py | 42 +++++++++---------- .../helper/adapter_instance.py | 20 +++++---- .../platform_service/helper/prompt_studio.py | 16 +++---- .../src/unstract/platform_service/run.py | 7 ---- 5 files changed, 63 insertions(+), 46 deletions(-) diff --git a/platform-service/src/unstract/platform_service/__init__.py b/platform-service/src/unstract/platform_service/__init__.py index e69de29bb..ca06894e2 100644 --- a/platform-service/src/unstract/platform_service/__init__.py +++ b/platform-service/src/unstract/platform_service/__init__.py @@ -0,0 +1,24 @@ +from collections.abc import Generator +from contextlib import contextmanager +from typing import Any + +from playhouse.postgres_ext import PostgresqlExtDatabase +from unstract.platform_service.env import Env + +be_db = PostgresqlExtDatabase( + Env.PG_BE_DATABASE, + user=Env.PG_BE_USERNAME, + password=Env.PG_BE_PASSWORD, + host=Env.PG_BE_HOST, + port=Env.PG_BE_PORT, +) + + +@contextmanager +def db_context() -> Generator[PostgresqlExtDatabase, Any, None]: + try: + be_db.connect() + yield be_db + finally: + if not be_db.is_closed(): + be_db.close() diff --git a/platform-service/src/unstract/platform_service/controller/platform.py b/platform-service/src/unstract/platform_service/controller/platform.py index 9d51243ce..66d4cb9b5 100644 --- a/platform-service/src/unstract/platform_service/controller/platform.py +++ b/platform-service/src/unstract/platform_service/controller/platform.py @@ -8,7 +8,7 @@ from flask import Blueprint, Request from flask import current_app as app from flask import jsonify, make_response, request -from peewee import PostgresqlDatabase +from unstract.platform_service import be_db, db_context from unstract.platform_service.constants import DBTable, DBTableV2, FeatureFlag from unstract.platform_service.env import Env from unstract.platform_service.exceptions import APIError @@ -20,15 +20,6 @@ from unstract.flags.feature_flag import check_feature_flag_status -be_db = PostgresqlDatabase( - Env.PG_BE_DATABASE, - user=Env.PG_BE_USERNAME, - password=Env.PG_BE_PASSWORD, - host=Env.PG_BE_HOST, - port=Env.PG_BE_PORT, -) -be_db.init(Env.PG_BE_DATABASE) - platform_bp = Blueprint("platform", __name__) @@ -85,9 +76,10 @@ def get_organization_from_bearer_token(token: str) -> tuple[Optional[int], str]: def execute_query(query: str, params: tuple = ()) -> Any: - cursor = be_db.execute_sql(query, params) - result_row = cursor.fetchone() - cursor.close() + with db_context(): + cursor = be_db.execute_sql(query, params) + result_row = cursor.fetchone() + cursor.close() if not result_row or len(result_row) == 0: return None return result_row[0] @@ -104,10 +96,11 @@ def validate_bearer_token(token: Optional[str]) -> bool: else: platform_key_table = "account_platformkey" - query = f"SELECT * FROM {platform_key_table} WHERE key = '{token}'" - cursor = be_db.execute_sql(query) - result_row = cursor.fetchone() - cursor.close() + with db_context(): + query = f"SELECT * FROM {platform_key_table} WHERE key = '{token}'" + cursor = be_db.execute_sql(query) + result_row = cursor.fetchone() + cursor.close() if not result_row or len(result_row) == 0: app.logger.error(f"Authentication failed. bearer token not found {token}") return False @@ -174,9 +167,10 @@ def page_usage() -> Any: ) try: - with be_db.atomic() as transaction: - be_db.execute_sql(query, params) - transaction.commit() + with db_context(): + with be_db.atomic() as transaction: + be_db.execute_sql(query, params) + transaction.commit() app.logger.info("Entry created with id %s for %s", usage_id, org_id) result["status"] = "OK" result["unique_id"] = usage_id @@ -294,9 +288,11 @@ def usage() -> Any: current_time, ) try: - with be_db.atomic() as transaction: - be_db.execute_sql(query, params) - transaction.commit() + + with db_context(): + with be_db.atomic() as transaction: + be_db.execute_sql(query, params) + transaction.commit() app.logger.info("Entry created with id %s for %s", usage_id, org_id) result["status"] = "OK" result["unique_id"] = usage_id diff --git a/platform-service/src/unstract/platform_service/helper/adapter_instance.py b/platform-service/src/unstract/platform_service/helper/adapter_instance.py index c0d03dbc1..abab86072 100644 --- a/platform-service/src/unstract/platform_service/helper/adapter_instance.py +++ b/platform-service/src/unstract/platform_service/helper/adapter_instance.py @@ -1,6 +1,7 @@ from typing import Any, Optional -from peewee import PostgresqlDatabase +from playhouse.postgres_ext import PostgresqlExtDatabase +from unstract.platform_service import be_db, db_context from unstract.platform_service.constants import DBTableV2, FeatureFlag from unstract.platform_service.exceptions import APIError @@ -10,7 +11,7 @@ class AdapterInstanceRequestHelper: @staticmethod def get_adapter_instance_from_db( - db_instance: PostgresqlDatabase, + db_instance: PostgresqlExtDatabase, organization_id: str, adapter_instance_id: str, organization_uid: Optional[int] = None, @@ -38,11 +39,12 @@ def get_adapter_instance_from_db( f'"{organization_id}".adapter_adapterinstance x ' f"WHERE id='{adapter_instance_id}'" ) - cursor = db_instance.execute_sql(query) - result_row = cursor.fetchone() - if not result_row: - raise APIError(message="Adapter not found", code=404) - columns = [desc[0] for desc in cursor.description] - data_dict: dict[str, Any] = dict(zip(columns, result_row)) - cursor.close() + with db_context(): + cursor = be_db.execute_sql(query) + result_row = cursor.fetchone() + if not result_row: + raise APIError(message="Adapter not found", code=404) + columns = [desc[0] for desc in cursor.description] + data_dict: dict[str, Any] = dict(zip(columns, result_row)) + cursor.close() return data_dict diff --git a/platform-service/src/unstract/platform_service/helper/prompt_studio.py b/platform-service/src/unstract/platform_service/helper/prompt_studio.py index 6801792c1..acc14cccd 100644 --- a/platform-service/src/unstract/platform_service/helper/prompt_studio.py +++ b/platform-service/src/unstract/platform_service/helper/prompt_studio.py @@ -1,6 +1,7 @@ from typing import Any from peewee import PostgresqlDatabase +from unstract.platform_service import be_db, db_context from unstract.platform_service.constants import DBTableV2, FeatureFlag from unstract.platform_service.exceptions import APIError @@ -38,11 +39,12 @@ def get_prompt_instance_from_db( f'"{organization_id}".prompt_studio_registry_promptstudioregistry x' f" WHERE prompt_registry_id='{prompt_registry_id}'" ) - cursor = db_instance.execute_sql(query) - result_row = cursor.fetchone() - if not result_row: - raise APIError(message="Custom Tool not found", code=404) - columns = [desc[0] for desc in cursor.description] - data_dict: dict[str, Any] = dict(zip(columns, result_row)) - cursor.close() + with db_context(): + cursor = be_db.execute_sql(query) + result_row = cursor.fetchone() + if not result_row: + raise APIError(message="Custom Tool not found", code=404) + columns = [desc[0] for desc in cursor.description] + data_dict: dict[str, Any] = dict(zip(columns, result_row)) + cursor.close() return data_dict diff --git a/platform-service/src/unstract/platform_service/run.py b/platform-service/src/unstract/platform_service/run.py index ae7f9e1a4..9cf9ecbb0 100644 --- a/platform-service/src/unstract/platform_service/run.py +++ b/platform-service/src/unstract/platform_service/run.py @@ -3,7 +3,6 @@ from dotenv import load_dotenv from flask import Flask from unstract.platform_service.controller import api -from unstract.platform_service.controller.platform import be_db load_dotenv() @@ -37,12 +36,6 @@ def create_app() -> Flask: app = create_app() -@app.before_request -def before_request() -> None: - if be_db.is_closed(): - be_db.connect(reuse_if_open=True) - - if __name__ == "__main__": # Start the server app.run(host="0.0.0.0", port=3001, load_dotenv=True) From 12b0a4ab44887f01f7a8159759683b2484bc189b Mon Sep 17 00:00:00 2001 From: johnyrahul Date: Sat, 14 Sep 2024 00:12:57 +0530 Subject: [PATCH 2/2] Changes for connection retry in case of db server restart fo prompt service --- .../src/unstract/prompt_service/__init__.py | 39 +++++++++++++++++++ .../authentication_middleware.py | 14 ++++--- .../src/unstract/prompt_service/config.py | 25 ------------ .../src/unstract/prompt_service/helper.py | 15 +++---- .../src/unstract/prompt_service/main.py | 8 +--- 5 files changed, 56 insertions(+), 45 deletions(-) diff --git a/prompt-service/src/unstract/prompt_service/__init__.py b/prompt-service/src/unstract/prompt_service/__init__.py index e69de29bb..ffb749e2f 100644 --- a/prompt-service/src/unstract/prompt_service/__init__.py +++ b/prompt-service/src/unstract/prompt_service/__init__.py @@ -0,0 +1,39 @@ +from collections.abc import Generator +from contextlib import contextmanager +from os import environ as env +from typing import Any + +from playhouse.postgres_ext import PostgresqlExtDatabase + + +def get_env_or_die(env_key: str) -> str: + env_value = env.get(env_key) + if not env_value: + raise ValueError(f"Env variable {env_key} is required") + return env_value + + +# Load required environment variables +db_host = get_env_or_die("PG_BE_HOST") +db_port = get_env_or_die("PG_BE_PORT") +db_user = get_env_or_die("PG_BE_USERNAME") +db_pass = get_env_or_die("PG_BE_PASSWORD") +db_name = get_env_or_die("PG_BE_DATABASE") + +be_db = PostgresqlExtDatabase( + db_name, + user=db_user, + password=db_pass, + host=db_host, + port=db_port, +) + + +@contextmanager +def db_context() -> Generator[PostgresqlExtDatabase, Any, None]: + try: + be_db.connect() + yield be_db + finally: + if not be_db.is_closed(): + be_db.close() diff --git a/prompt-service/src/unstract/prompt_service/authentication_middleware.py b/prompt-service/src/unstract/prompt_service/authentication_middleware.py index fbf03c37b..c5bd45042 100644 --- a/prompt-service/src/unstract/prompt_service/authentication_middleware.py +++ b/prompt-service/src/unstract/prompt_service/authentication_middleware.py @@ -1,7 +1,7 @@ from typing import Any, Optional from flask import Request, current_app -from unstract.prompt_service.config import db +from unstract.prompt_service import be_db, db_context from unstract.prompt_service.constants import DBTableV2, FeatureFlag from unstract.flags.feature_flag import check_feature_flag_status @@ -22,9 +22,10 @@ def validate_bearer_token(token: Optional[str]) -> bool: platform_key_table = "account_platformkey" query = f"SELECT * FROM {platform_key_table} WHERE key = '{token}'" - cursor = db.execute_sql(query) - result_row = cursor.fetchone() - cursor.close() + with db_context(): + cursor = be_db.execute_sql(query) + result_row = cursor.fetchone() + cursor.close() if not result_row or len(result_row) == 0: current_app.logger.error( f"Authentication failed. bearer token not found {token}" @@ -84,8 +85,9 @@ def get_account_from_bearer_token(token: Optional[str]) -> str: @staticmethod def execute_query(query: str) -> Any: - cursor = db.execute_sql(query) - result_row = cursor.fetchone() + with db_context(): + cursor = be_db.execute_sql(query) + result_row = cursor.fetchone() cursor.close() if not result_row or len(result_row) == 0: return None diff --git a/prompt-service/src/unstract/prompt_service/config.py b/prompt-service/src/unstract/prompt_service/config.py index 5d6a2ef15..b9a0caf7d 100644 --- a/prompt-service/src/unstract/prompt_service/config.py +++ b/prompt-service/src/unstract/prompt_service/config.py @@ -4,7 +4,6 @@ from dotenv import load_dotenv from flask import Flask -from peewee import PostgresqlDatabase from unstract.prompt_service.constants import LogLevel load_dotenv() @@ -28,15 +27,6 @@ } ) -db = PostgresqlDatabase(None) - - -def get_env_or_die(env_key: str) -> str: - env_value = env.get(env_key) - if not env_value: - raise ValueError(f"Env variable {env_key} is required") - return env_value - def create_app() -> Flask: app = Flask("prompt-service") @@ -48,19 +38,4 @@ def create_app() -> Flask: else: app.logger.setLevel(logging.WARNING) - # Load required environment variables - db_host = get_env_or_die("PG_BE_HOST") - db_port = get_env_or_die("PG_BE_PORT") - db_user = get_env_or_die("PG_BE_USERNAME") - db_pass = get_env_or_die("PG_BE_PASSWORD") - db_name = get_env_or_die("PG_BE_DATABASE") - - # Initialize and connect to the database - db.init( - database=db_name, - user=db_user, - password=db_pass, - host=db_host, - port=db_port, - ) return app diff --git a/prompt-service/src/unstract/prompt_service/helper.py b/prompt-service/src/unstract/prompt_service/helper.py index ae37cc408..4c51df19e 100644 --- a/prompt-service/src/unstract/prompt_service/helper.py +++ b/prompt-service/src/unstract/prompt_service/helper.py @@ -7,8 +7,8 @@ from dotenv import load_dotenv from flask import Flask, current_app, json +from unstract.prompt_service import be_db, db_context from unstract.prompt_service.authentication_middleware import AuthenticationMiddleware -from unstract.prompt_service.config import db from unstract.prompt_service.constants import PromptServiceContants as PSKeys from unstract.prompt_service.exceptions import APIError, RateLimitError from unstract.sdk.exceptions import RateLimitError as SdkRateLimitError @@ -130,12 +130,13 @@ def query_usage_metadata(token: str, metadata: dict[str, Any]) -> dict[str, Any] """ logger: Logger = current_app.logger try: - with db.atomic(): - logger.info( - "Querying usage metadata for org_id: %s, run_id: %s", org_id, run_id - ) - cursor = db.execute_sql(query, (run_id,)) - results: list[tuple] = cursor.fetchall() + with db_context(): + with be_db.atomic(): + logger.info( + "Querying usage metadata for org_id: %s, run_id: %s", org_id, run_id + ) + cursor = be_db.execute_sql(query, (run_id,)) + results: list[tuple] = cursor.fetchall() # Process results as needed for row in results: key, item = _get_key_and_item(row) diff --git a/prompt-service/src/unstract/prompt_service/main.py b/prompt-service/src/unstract/prompt_service/main.py index b92fb0076..d80d53781 100644 --- a/prompt-service/src/unstract/prompt_service/main.py +++ b/prompt-service/src/unstract/prompt_service/main.py @@ -7,7 +7,7 @@ from flask import json, jsonify, request from llama_index.core.vector_stores import ExactMatchFilter, MetadataFilters from unstract.prompt_service.authentication_middleware import AuthenticationMiddleware -from unstract.prompt_service.config import create_app, db +from unstract.prompt_service.config import create_app from unstract.prompt_service.constants import PromptServiceContants as PSKeys from unstract.prompt_service.constants import RunLevel from unstract.prompt_service.exceptions import APIError, ErrorResponse, NoPayloadError @@ -48,12 +48,6 @@ plugin_loader(app) -@app.before_request -def before_request() -> None: - if db.is_closed(): - db.connect(reuse_if_open=True) - - def _publish_log( log_events_id: str, component: dict[str, str],