Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/db connection #691

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions platform-service/src/unstract/platform_service/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any

from playhouse.postgres_ext import PostgresqlExtDatabase
from unstract.platform_service.env import Env

be_db = PostgresqlExtDatabase(
Env.PG_BE_DATABASE,
user=Env.PG_BE_USERNAME,
password=Env.PG_BE_PASSWORD,
host=Env.PG_BE_HOST,
port=Env.PG_BE_PORT,
)


@contextmanager
def db_context() -> Generator[PostgresqlExtDatabase, Any, None]:
try:
be_db.connect()
yield be_db
finally:
if not be_db.is_closed():
be_db.close()
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from flask import Blueprint, Request
from flask import current_app as app
from flask import jsonify, make_response, request
from peewee import PostgresqlDatabase
from unstract.platform_service import be_db, db_context
from unstract.platform_service.constants import DBTable, DBTableV2, FeatureFlag
from unstract.platform_service.env import Env
from unstract.platform_service.exceptions import APIError
Expand All @@ -20,15 +20,6 @@

from unstract.flags.feature_flag import check_feature_flag_status

be_db = PostgresqlDatabase(
Env.PG_BE_DATABASE,
user=Env.PG_BE_USERNAME,
password=Env.PG_BE_PASSWORD,
host=Env.PG_BE_HOST,
port=Env.PG_BE_PORT,
)
be_db.init(Env.PG_BE_DATABASE)

platform_bp = Blueprint("platform", __name__)


Expand Down Expand Up @@ -85,9 +76,10 @@ def get_organization_from_bearer_token(token: str) -> tuple[Optional[int], str]:


def execute_query(query: str, params: tuple = ()) -> Any:
cursor = be_db.execute_sql(query, params)
result_row = cursor.fetchone()
cursor.close()
with db_context():
cursor = be_db.execute_sql(query, params)
result_row = cursor.fetchone()
cursor.close()
if not result_row or len(result_row) == 0:
return None
return result_row[0]
Expand All @@ -104,10 +96,11 @@ def validate_bearer_token(token: Optional[str]) -> bool:
else:
platform_key_table = "account_platformkey"

query = f"SELECT * FROM {platform_key_table} WHERE key = '{token}'"
cursor = be_db.execute_sql(query)
result_row = cursor.fetchone()
cursor.close()
with db_context():
query = f"SELECT * FROM {platform_key_table} WHERE key = '{token}'"
cursor = be_db.execute_sql(query)
result_row = cursor.fetchone()
cursor.close()
if not result_row or len(result_row) == 0:
app.logger.error(f"Authentication failed. bearer token not found {token}")
return False
Expand Down Expand Up @@ -174,9 +167,10 @@ def page_usage() -> Any:
)

try:
with be_db.atomic() as transaction:
be_db.execute_sql(query, params)
transaction.commit()
with db_context():
with be_db.atomic() as transaction:
be_db.execute_sql(query, params)
transaction.commit()
app.logger.info("Entry created with id %s for %s", usage_id, org_id)
result["status"] = "OK"
result["unique_id"] = usage_id
Expand Down Expand Up @@ -294,9 +288,11 @@ def usage() -> Any:
current_time,
)
try:
with be_db.atomic() as transaction:
be_db.execute_sql(query, params)
transaction.commit()

with db_context():
with be_db.atomic() as transaction:
be_db.execute_sql(query, params)
transaction.commit()
app.logger.info("Entry created with id %s for %s", usage_id, org_id)
result["status"] = "OK"
result["unique_id"] = usage_id
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Optional

from peewee import PostgresqlDatabase
from playhouse.postgres_ext import PostgresqlExtDatabase
from unstract.platform_service import be_db, db_context
from unstract.platform_service.constants import DBTableV2, FeatureFlag
from unstract.platform_service.exceptions import APIError

Expand All @@ -10,7 +11,7 @@
class AdapterInstanceRequestHelper:
@staticmethod
def get_adapter_instance_from_db(
db_instance: PostgresqlDatabase,
db_instance: PostgresqlExtDatabase,
organization_id: str,
adapter_instance_id: str,
organization_uid: Optional[int] = None,
Expand Down Expand Up @@ -38,11 +39,12 @@ def get_adapter_instance_from_db(
f'"{organization_id}".adapter_adapterinstance x '
f"WHERE id='{adapter_instance_id}'"
)
cursor = db_instance.execute_sql(query)
result_row = cursor.fetchone()
if not result_row:
raise APIError(message="Adapter not found", code=404)
columns = [desc[0] for desc in cursor.description]
data_dict: dict[str, Any] = dict(zip(columns, result_row))
cursor.close()
with db_context():
cursor = be_db.execute_sql(query)
result_row = cursor.fetchone()
if not result_row:
raise APIError(message="Adapter not found", code=404)
columns = [desc[0] for desc in cursor.description]
data_dict: dict[str, Any] = dict(zip(columns, result_row))
cursor.close()
return data_dict
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any

from peewee import PostgresqlDatabase
from unstract.platform_service import be_db, db_context
from unstract.platform_service.constants import DBTableV2, FeatureFlag
from unstract.platform_service.exceptions import APIError

Expand Down Expand Up @@ -38,11 +39,12 @@ def get_prompt_instance_from_db(
f'"{organization_id}".prompt_studio_registry_promptstudioregistry x'
f" WHERE prompt_registry_id='{prompt_registry_id}'"
)
cursor = db_instance.execute_sql(query)
result_row = cursor.fetchone()
if not result_row:
raise APIError(message="Custom Tool not found", code=404)
columns = [desc[0] for desc in cursor.description]
data_dict: dict[str, Any] = dict(zip(columns, result_row))
cursor.close()
with db_context():
cursor = be_db.execute_sql(query)
result_row = cursor.fetchone()
if not result_row:
raise APIError(message="Custom Tool not found", code=404)
columns = [desc[0] for desc in cursor.description]
data_dict: dict[str, Any] = dict(zip(columns, result_row))
cursor.close()
return data_dict
7 changes: 0 additions & 7 deletions platform-service/src/unstract/platform_service/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from dotenv import load_dotenv
from flask import Flask
from unstract.platform_service.controller import api
from unstract.platform_service.controller.platform import be_db

load_dotenv()

Expand Down Expand Up @@ -37,12 +36,6 @@ def create_app() -> Flask:
app = create_app()


@app.before_request
def before_request() -> None:
if be_db.is_closed():
be_db.connect(reuse_if_open=True)


if __name__ == "__main__":
# Start the server
app.run(host="0.0.0.0", port=3001, load_dotenv=True)
39 changes: 39 additions & 0 deletions prompt-service/src/unstract/prompt_service/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from collections.abc import Generator
from contextlib import contextmanager
from os import environ as env
from typing import Any

from playhouse.postgres_ext import PostgresqlExtDatabase


def get_env_or_die(env_key: str) -> str:
env_value = env.get(env_key)
if not env_value:
raise ValueError(f"Env variable {env_key} is required")
return env_value


# Load required environment variables
db_host = get_env_or_die("PG_BE_HOST")
db_port = get_env_or_die("PG_BE_PORT")
db_user = get_env_or_die("PG_BE_USERNAME")
db_pass = get_env_or_die("PG_BE_PASSWORD")
db_name = get_env_or_die("PG_BE_DATABASE")

be_db = PostgresqlExtDatabase(
db_name,
user=db_user,
password=db_pass,
host=db_host,
port=db_port,
)


@contextmanager
def db_context() -> Generator[PostgresqlExtDatabase, Any, None]:
try:
be_db.connect()
yield be_db
finally:
if not be_db.is_closed():
be_db.close()
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Optional

from flask import Request, current_app
from unstract.prompt_service.config import db
from unstract.prompt_service import be_db, db_context
from unstract.prompt_service.constants import DBTableV2, FeatureFlag

from unstract.flags.feature_flag import check_feature_flag_status
Expand All @@ -22,9 +22,10 @@ def validate_bearer_token(token: Optional[str]) -> bool:
platform_key_table = "account_platformkey"

query = f"SELECT * FROM {platform_key_table} WHERE key = '{token}'"
cursor = db.execute_sql(query)
result_row = cursor.fetchone()
cursor.close()
with db_context():
cursor = be_db.execute_sql(query)
result_row = cursor.fetchone()
cursor.close()
if not result_row or len(result_row) == 0:
current_app.logger.error(
f"Authentication failed. bearer token not found {token}"
Expand Down Expand Up @@ -84,8 +85,9 @@ def get_account_from_bearer_token(token: Optional[str]) -> str:

@staticmethod
def execute_query(query: str) -> Any:
cursor = db.execute_sql(query)
result_row = cursor.fetchone()
with db_context():
cursor = be_db.execute_sql(query)
result_row = cursor.fetchone()
cursor.close()
if not result_row or len(result_row) == 0:
return None
Expand Down
25 changes: 0 additions & 25 deletions prompt-service/src/unstract/prompt_service/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from dotenv import load_dotenv
from flask import Flask
from peewee import PostgresqlDatabase
from unstract.prompt_service.constants import LogLevel

load_dotenv()
Expand All @@ -28,15 +27,6 @@
}
)

db = PostgresqlDatabase(None)


def get_env_or_die(env_key: str) -> str:
env_value = env.get(env_key)
if not env_value:
raise ValueError(f"Env variable {env_key} is required")
return env_value


def create_app() -> Flask:
app = Flask("prompt-service")
Expand All @@ -48,19 +38,4 @@ def create_app() -> Flask:
else:
app.logger.setLevel(logging.WARNING)

# Load required environment variables
db_host = get_env_or_die("PG_BE_HOST")
db_port = get_env_or_die("PG_BE_PORT")
db_user = get_env_or_die("PG_BE_USERNAME")
db_pass = get_env_or_die("PG_BE_PASSWORD")
db_name = get_env_or_die("PG_BE_DATABASE")

# Initialize and connect to the database
db.init(
database=db_name,
user=db_user,
password=db_pass,
host=db_host,
port=db_port,
)
return app
15 changes: 8 additions & 7 deletions prompt-service/src/unstract/prompt_service/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from dotenv import load_dotenv
from flask import Flask, current_app, json
from unstract.prompt_service import be_db, db_context
from unstract.prompt_service.authentication_middleware import AuthenticationMiddleware
from unstract.prompt_service.config import db
from unstract.prompt_service.constants import PromptServiceContants as PSKeys
from unstract.prompt_service.exceptions import APIError, RateLimitError
from unstract.sdk.exceptions import RateLimitError as SdkRateLimitError
Expand Down Expand Up @@ -130,12 +130,13 @@ def query_usage_metadata(token: str, metadata: dict[str, Any]) -> dict[str, Any]
"""
logger: Logger = current_app.logger
try:
with db.atomic():
logger.info(
"Querying usage metadata for org_id: %s, run_id: %s", org_id, run_id
)
cursor = db.execute_sql(query, (run_id,))
results: list[tuple] = cursor.fetchall()
with db_context():
with be_db.atomic():
logger.info(
"Querying usage metadata for org_id: %s, run_id: %s", org_id, run_id
)
cursor = be_db.execute_sql(query, (run_id,))
results: list[tuple] = cursor.fetchall()
# Process results as needed
for row in results:
key, item = _get_key_and_item(row)
Expand Down
8 changes: 1 addition & 7 deletions prompt-service/src/unstract/prompt_service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from flask import json, jsonify, request
from llama_index.core.vector_stores import ExactMatchFilter, MetadataFilters
from unstract.prompt_service.authentication_middleware import AuthenticationMiddleware
from unstract.prompt_service.config import create_app, db
from unstract.prompt_service.config import create_app
from unstract.prompt_service.constants import PromptServiceContants as PSKeys
from unstract.prompt_service.constants import RunLevel
from unstract.prompt_service.exceptions import APIError, ErrorResponse, NoPayloadError
Expand Down Expand Up @@ -48,12 +48,6 @@
plugin_loader(app)


@app.before_request
def before_request() -> None:
if db.is_closed():
db.connect(reuse_if_open=True)


def _publish_log(
log_events_id: str,
component: dict[str, str],
Expand Down
Loading