Skip to content

Commit

Permalink
Changes for connection retry in case of db server restart fo prompt s…
Browse files Browse the repository at this point in the history
…ervice
  • Loading branch information
johnyrahul committed Sep 13, 2024
1 parent fc77d70 commit 12b0a4a
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 45 deletions.
39 changes: 39 additions & 0 deletions prompt-service/src/unstract/prompt_service/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from collections.abc import Generator
from contextlib import contextmanager
from os import environ as env
from typing import Any

from playhouse.postgres_ext import PostgresqlExtDatabase


def get_env_or_die(env_key: str) -> str:
env_value = env.get(env_key)
if not env_value:
raise ValueError(f"Env variable {env_key} is required")
return env_value


# Load required environment variables
db_host = get_env_or_die("PG_BE_HOST")
db_port = get_env_or_die("PG_BE_PORT")
db_user = get_env_or_die("PG_BE_USERNAME")
db_pass = get_env_or_die("PG_BE_PASSWORD")
db_name = get_env_or_die("PG_BE_DATABASE")

be_db = PostgresqlExtDatabase(
db_name,
user=db_user,
password=db_pass,
host=db_host,
port=db_port,
)


@contextmanager
def db_context() -> Generator[PostgresqlExtDatabase, Any, None]:
try:
be_db.connect()
yield be_db
finally:
if not be_db.is_closed():
be_db.close()
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Optional

from flask import Request, current_app
from unstract.prompt_service.config import db
from unstract.prompt_service import be_db, db_context
from unstract.prompt_service.constants import DBTableV2, FeatureFlag

from unstract.flags.feature_flag import check_feature_flag_status
Expand All @@ -22,9 +22,10 @@ def validate_bearer_token(token: Optional[str]) -> bool:
platform_key_table = "account_platformkey"

query = f"SELECT * FROM {platform_key_table} WHERE key = '{token}'"
cursor = db.execute_sql(query)
result_row = cursor.fetchone()
cursor.close()
with db_context():
cursor = be_db.execute_sql(query)
result_row = cursor.fetchone()
cursor.close()
if not result_row or len(result_row) == 0:
current_app.logger.error(
f"Authentication failed. bearer token not found {token}"
Expand Down Expand Up @@ -84,8 +85,9 @@ def get_account_from_bearer_token(token: Optional[str]) -> str:

@staticmethod
def execute_query(query: str) -> Any:
cursor = db.execute_sql(query)
result_row = cursor.fetchone()
with db_context():
cursor = be_db.execute_sql(query)
result_row = cursor.fetchone()
cursor.close()
if not result_row or len(result_row) == 0:
return None
Expand Down
25 changes: 0 additions & 25 deletions prompt-service/src/unstract/prompt_service/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from dotenv import load_dotenv
from flask import Flask
from peewee import PostgresqlDatabase
from unstract.prompt_service.constants import LogLevel

load_dotenv()
Expand All @@ -28,15 +27,6 @@
}
)

db = PostgresqlDatabase(None)


def get_env_or_die(env_key: str) -> str:
env_value = env.get(env_key)
if not env_value:
raise ValueError(f"Env variable {env_key} is required")
return env_value


def create_app() -> Flask:
app = Flask("prompt-service")
Expand All @@ -48,19 +38,4 @@ def create_app() -> Flask:
else:
app.logger.setLevel(logging.WARNING)

# Load required environment variables
db_host = get_env_or_die("PG_BE_HOST")
db_port = get_env_or_die("PG_BE_PORT")
db_user = get_env_or_die("PG_BE_USERNAME")
db_pass = get_env_or_die("PG_BE_PASSWORD")
db_name = get_env_or_die("PG_BE_DATABASE")

# Initialize and connect to the database
db.init(
database=db_name,
user=db_user,
password=db_pass,
host=db_host,
port=db_port,
)
return app
15 changes: 8 additions & 7 deletions prompt-service/src/unstract/prompt_service/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from dotenv import load_dotenv
from flask import Flask, current_app, json
from unstract.prompt_service import be_db, db_context
from unstract.prompt_service.authentication_middleware import AuthenticationMiddleware
from unstract.prompt_service.config import db
from unstract.prompt_service.constants import PromptServiceContants as PSKeys
from unstract.prompt_service.exceptions import APIError, RateLimitError
from unstract.sdk.exceptions import RateLimitError as SdkRateLimitError
Expand Down Expand Up @@ -130,12 +130,13 @@ def query_usage_metadata(token: str, metadata: dict[str, Any]) -> dict[str, Any]
"""
logger: Logger = current_app.logger
try:
with db.atomic():
logger.info(
"Querying usage metadata for org_id: %s, run_id: %s", org_id, run_id
)
cursor = db.execute_sql(query, (run_id,))
results: list[tuple] = cursor.fetchall()
with db_context():
with be_db.atomic():
logger.info(
"Querying usage metadata for org_id: %s, run_id: %s", org_id, run_id
)
cursor = be_db.execute_sql(query, (run_id,))
results: list[tuple] = cursor.fetchall()
# Process results as needed
for row in results:
key, item = _get_key_and_item(row)
Expand Down
8 changes: 1 addition & 7 deletions prompt-service/src/unstract/prompt_service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from flask import json, jsonify, request
from llama_index.core.vector_stores import ExactMatchFilter, MetadataFilters
from unstract.prompt_service.authentication_middleware import AuthenticationMiddleware
from unstract.prompt_service.config import create_app, db
from unstract.prompt_service.config import create_app
from unstract.prompt_service.constants import PromptServiceContants as PSKeys
from unstract.prompt_service.constants import RunLevel
from unstract.prompt_service.exceptions import APIError, ErrorResponse, NoPayloadError
Expand Down Expand Up @@ -48,12 +48,6 @@
plugin_loader(app)


@app.before_request
def before_request() -> None:
if db.is_closed():
db.connect(reuse_if_open=True)


def _publish_log(
log_events_id: str,
component: dict[str, str],
Expand Down

0 comments on commit 12b0a4a

Please sign in to comment.