Skip to content

Commit

Permalink
Merge branch 'main' into fix/table-extraction-fe
Browse files Browse the repository at this point in the history
  • Loading branch information
vishnuszipstack authored Sep 30, 2024
2 parents 7af6578 + bac1a78 commit ff4656b
Show file tree
Hide file tree
Showing 22 changed files with 162 additions and 64 deletions.
8 changes: 4 additions & 4 deletions backend/api_v2/api_deployment_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,9 @@ def post(
file_objs = request.FILES.getlist(ApiExecution.FILES_FORM_DATA)
serializer = ExecutionRequestSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
timeout = serializer.get_timeout(serializer.validated_data)
include_metadata = (
request.data.get(ApiExecution.INCLUDE_METADATA, "false").lower() == "true"
)
timeout = serializer.validated_data.get(ApiExecution.TIMEOUT_FORM_DATA)
include_metadata = serializer.validated_data.get(ApiExecution.INCLUDE_METADATA)
use_file_history = serializer.validated_data.get(ApiExecution.USE_FILE_HISTORY)
if not file_objs or len(file_objs) == 0:
raise InvalidAPIRequest("File shouldn't be empty")
response = DeploymentHelper.execute_workflow(
Expand All @@ -61,6 +60,7 @@ def post(
file_objs=file_objs,
timeout=timeout,
include_metadata=include_metadata,
use_file_history=use_file_history,
)
if "error" in response and response["error"]:
return Response(
Expand Down
1 change: 1 addition & 0 deletions backend/api_v2/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ class ApiExecution:
FILES_FORM_DATA: str = "files"
TIMEOUT_FORM_DATA: str = "timeout"
INCLUDE_METADATA: str = "include_metadata"
USE_FILE_HISTORY: str = "use_file_history" # Undocumented parameter
5 changes: 5 additions & 0 deletions backend/api_v2/deployment_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,16 @@ def execute_workflow(
file_objs: list[UploadedFile],
timeout: int,
include_metadata: bool = False,
use_file_history: bool = False,
) -> ReturnDict:
"""Execute workflow by api.
Args:
organization_name (str): organization name
api (APIDeployment): api model object
file_obj (UploadedFile): input file
use_file_history (bool): Use FileHistory table to return results on already
processed files. Defaults to False
Returns:
ReturnDict: execution status/ result
Expand All @@ -150,6 +153,7 @@ def execute_workflow(
workflow_id=workflow_id,
execution_id=execution_id,
file_objs=file_objs,
use_file_history=use_file_history,
)
try:
result = WorkflowHelper.execute_workflow_async(
Expand All @@ -159,6 +163,7 @@ def execute_workflow(
timeout=timeout,
execution_id=execution_id,
include_metadata=include_metadata,
use_file_history=use_file_history,
)
result.status_api = DeploymentHelper.construct_status_endpoint(
api_endpoint=api.api_endpoint, execution_id=execution_id
Expand Down
27 changes: 12 additions & 15 deletions backend/api_v2/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from django.core.validators import RegexValidator
from pipeline_v2.models import Pipeline
from rest_framework.serializers import (
BooleanField,
CharField,
IntegerField,
JSONField,
Expand Down Expand Up @@ -80,26 +81,22 @@ def to_representation(self, instance: APIKey) -> OrderedDict[str, Any]:


class ExecutionRequestSerializer(Serializer):
"""Execution request serializer
timeout: 0: maximum value of timeout, -1: async execution
"""Execution request serializer.
Attributes:
timeout (int): Timeout for the API deployment, maximum value can be 300s.
If -1 it corresponds to async execution. Defaults to -1
include_metadata (bool): Flag to include metadata in API response
use_file_history (bool): Flag to use FileHistory to save and retrieve
responses quickly. This is undocumented to the user and can be
helpful for demos.
"""

timeout = IntegerField(
min_value=-1, max_value=ApiExecution.MAXIMUM_TIMEOUT_IN_SEC, default=-1
)

def validate_timeout(self, value: Any) -> int:
if not isinstance(value, int):
raise ValidationError("timeout must be a integer.")
if value == 0:
value = ApiExecution.MAXIMUM_TIMEOUT_IN_SEC
return value

def get_timeout(self, validated_data: dict[str, Union[int, None]]) -> int:
value = validated_data.get(ApiExecution.TIMEOUT_FORM_DATA, -1)
if not isinstance(value, int):
raise ValidationError("timeout must be a integer.")
return value
include_metadata = BooleanField(default=False)
use_file_history = BooleanField(default=False)


class APIDeploymentListSerializer(ModelSerializer):
Expand Down
Empty file.
52 changes: 52 additions & 0 deletions backend/backend/custom_db/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import logging

from django.conf import settings
from django.db.backends.postgresql.base import (
DatabaseWrapper as PostgresDatabaseWrapper,
)

logger = logging.getLogger(__name__)


class DatabaseWrapper(PostgresDatabaseWrapper):
"""Custom DatabaseWrapper to manage PostgreSQL connections and set the
search path."""

def get_new_connection(self, conn_params):
"""Establish a new database connection or reuse an existing one, and
set the search path.
Args:
conn_params: Parameters for the new database connection.
Returns:
connection: The database connection
"""
connection = super().get_new_connection(conn_params)
logger.info(f"DB connection (ID: {id(connection)}) is established or reused.")
self.set_search_path(connection)
return connection

def set_search_path(self, connection):
"""Set the search path for the given database connection.
This ensures that the database queries will look in the specified schema.
Args:
connection: The database connection for which to set the search path.
"""
conn_id = id(connection)
original_autocommit = connection.autocommit
try:
connection.autocommit = True
logger.debug(
f"Setting search_path to {settings.DB_SCHEMA} for DB connection ID "
f"{conn_id}."
)
with connection.cursor() as cursor:
cursor.execute(f"SET search_path TO {settings.DB_SCHEMA}")
logger.debug(
f"Successfully set search_path for DB connection ID {conn_id}."
)
finally:
connection.autocommit = original_autocommit
3 changes: 1 addition & 2 deletions backend/backend/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def get_required_setting(
ROOT_URLCONF = "backend.base_urls"

# DB Configuration
DB_ENGINE = "django.db.backends.postgresql"
DB_ENGINE = "backend.custom_db"

# Models
AUTH_USER_MODEL = "account_v2.User"
Expand All @@ -397,7 +397,6 @@ def get_required_setting(
"PORT": f"{DB_PORT}",
"ATOMIC_REQUESTS": ATOMIC_REQUESTS,
"OPTIONS": {
"options": f"-c search_path={DB_SCHEMA}",
"application_name": os.environ.get("APPLICATION_NAME", ""),
},
}
Expand Down
1 change: 0 additions & 1 deletion backend/entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,4 @@ fi
--log-level debug \
--timeout 600 \
--access-logfile - \
--reload \
backend.wsgi:application
3 changes: 3 additions & 0 deletions backend/prompt_studio/prompt_studio_core_v2/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,12 +529,15 @@ def export_tool(self, request: Request, pk: Any = None) -> Response:
serializer.is_valid(raise_exception=True)
is_shared_with_org: bool = serializer.validated_data.get("is_shared_with_org")
user_ids = set(serializer.validated_data.get("user_id"))
force_export = serializer.validated_data.get("force_export")

PromptStudioRegistryHelper.update_or_create_psr_tool(
custom_tool=custom_tool,
shared_with_org=is_shared_with_org,
user_ids=user_ids,
force_export=force_export,
)

return Response(
{"message": "Custom tool exported sucessfully."},
status=status.HTTP_200_OK,
Expand Down
4 changes: 2 additions & 2 deletions backend/prompt_studio/prompt_studio_registry/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ class EmptyToolExportError(APIException):
status_code = 500
default_detail = (
"Prompt Studio project without prompts cannot be exported. "
"Please ensure there is at least one prompt and "
"it is active before exporting."
"Please ensure there is at least one active prompt "
"that has been run before exporting."
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,9 @@ def frame_export_json(
invalidated_prompts.append(prompt.prompt_key)
continue

if not prompt.profile_manager:
prompt.profile_manager = default_llm_profile

if not force_export:
prompt_output = PromptStudioOutputManager.objects.filter(
tool_id=tool.tool_id,
Expand All @@ -302,9 +305,6 @@ def frame_export_json(
invalidated_outputs.append(prompt.prompt_key)
continue

if not prompt.profile_manager:
prompt.profile_manager = default_llm_profile

vector_db = str(prompt.profile_manager.vector_store.id)
embedding_model = str(prompt.profile_manager.embedding_model.id)
llm = str(prompt.profile_manager.llm.id)
Expand Down
4 changes: 2 additions & 2 deletions backend/prompt_studio/prompt_studio_registry_v2/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ class EmptyToolExportError(APIException):
status_code = 500
default_detail = (
"Prompt Studio project without prompts cannot be exported. "
"Please ensure there is at least one prompt and "
"it is active before exporting."
"Please ensure there is at least one active prompt "
"that has been run before exporting."
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,24 @@ def get_tool_by_prompt_registry_id(

@staticmethod
def update_or_create_psr_tool(
custom_tool: CustomTool, shared_with_org: bool, user_ids: set[int]
custom_tool: CustomTool,
shared_with_org: bool,
user_ids: set[int],
force_export: bool,
) -> PromptStudioRegistry:
"""Updates or creates the PromptStudioRegistry record.
This appears as a separate tool in the workflow and is mapped
1:1 with the `CustomTool`.
Args:
tool_id (str): ID of the custom tool.
custom_tool (CustomTool): The instance of the custom tool to be updated
or created.
shared_with_org (bool): Flag indicating whether the tool is shared with
the organization.
user_ids (set[int]): A set of user IDs to whom the tool is shared.
force_export (bool): Indicates if the export is being forced.
Raises:
ToolSaveError
Expand All @@ -162,7 +171,7 @@ def update_or_create_psr_tool(
tool_id=custom_tool.tool_id
)
metadata = PromptStudioRegistryHelper.frame_export_json(
tool=custom_tool, prompts=prompts
tool=custom_tool, prompts=prompts, force_export=force_export
)

obj: PromptStudioRegistry
Expand Down Expand Up @@ -208,7 +217,9 @@ def update_or_create_psr_tool(

@staticmethod
def frame_export_json(
tool: CustomTool, prompts: list[ToolStudioPrompt]
tool: CustomTool,
prompts: list[ToolStudioPrompt],
force_export: bool,
) -> dict[str, Any]:
export_metadata = {}

Expand Down Expand Up @@ -283,19 +294,19 @@ def frame_export_json(
invalidated_prompts.append(prompt.prompt_key)
continue

prompt_output = PromptStudioOutputManager.objects.filter(
tool_id=tool.tool_id,
prompt_id=prompt.prompt_id,
profile_manager=prompt.profile_manager,
).all()

if not prompt_output:
invalidated_outputs.append(prompt.prompt_key)
continue

if not prompt.profile_manager:
prompt.profile_manager = default_llm_profile

if not force_export:
prompt_output = PromptStudioOutputManager.objects.filter(
tool_id=tool.tool_id,
prompt_id=prompt.prompt_id,
profile_manager=prompt.profile_manager,
).all()
if not prompt_output:
invalidated_outputs.append(prompt.prompt_key)
continue

vector_db = str(prompt.profile_manager.vector_store.id)
embedding_model = str(prompt.profile_manager.embedding_model.id)
llm = str(prompt.profile_manager.llm.id)
Expand Down Expand Up @@ -354,10 +365,12 @@ def frame_export_json(
f"Cannot export tool. Prompt(s): {', '.join(invalidated_prompts)} "
"are empty. Please enter a valid prompt."
)
if invalidated_outputs:
if not force_export and invalidated_outputs:
raise InValidCustomToolError(
f"Cannot export tool. Prompt(s): {', '.join(invalidated_outputs)} "
"were not run. Please run them before exporting."
detail="Cannot export tool. Prompt(s):"
f" {', '.join(invalidated_outputs)}"
" were not run. Please run them before exporting.",
code="warning",
)
export_metadata[JsonSchemaKey.TOOL_SETTINGS] = tool_settings
export_metadata[JsonSchemaKey.OUTPUTS] = outputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ def get_prompt_studio_users(self, obj: PromptStudioRegistry) -> Any:
class ExportToolRequestSerializer(serializers.Serializer):
is_shared_with_org = serializers.BooleanField(default=False)
user_id = serializers.ListField(child=serializers.IntegerField(), required=False)
force_export = serializers.BooleanField(default=False)
2 changes: 1 addition & 1 deletion backend/usage_v2/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@ def get_aggregated_token_count(run_id: str) -> dict:
except Exception as e:
# Handle any other exceptions that might occur during the execution
logger.error(f"An unexpected error occurred for run_id {run_id}: {str(e)}")
raise APIException("An unexpected error occurred")
raise APIException("Error while aggregating token counts")
14 changes: 9 additions & 5 deletions backend/workflow_manager/endpoint_v2/destination.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,9 @@ def handle_output(
file_name: str,
file_hash: FileHash,
workflow: Workflow,
input_file_path: str,
error: Optional[str] = None,
input_file_path: Optional[str] = None,
use_file_history: bool = True,
) -> None:
"""Handle the output based on the connection type."""
connection_type = self.endpoint.connection_type
Expand All @@ -163,9 +164,12 @@ def handle_output(
if connection_type == WorkflowEndpoint.ConnectionType.API:
self._handle_api_result(file_name=file_name, error=error, result=result)
return
file_history = FileHistoryHelper.get_file_history(
workflow=workflow, cache_key=file_hash.file_hash
)

file_history = None
if use_file_history:
file_history = FileHistoryHelper.get_file_history(
workflow=workflow, cache_key=file_hash.file_hash
)
if connection_type == WorkflowEndpoint.ConnectionType.FILESYSTEM:
self.copy_output_to_output_directory()
elif connection_type == WorkflowEndpoint.ConnectionType.DATABASE:
Expand All @@ -188,7 +192,7 @@ def handle_output(
self.execution_service.publish_log(
message=f"File '{file_name}' processed successfully"
)
if not file_history:
if use_file_history and not file_history:
FileHistoryHelper.create_file_history(
cache_key=file_hash.file_hash,
workflow=workflow,
Expand Down
Loading

0 comments on commit ff4656b

Please sign in to comment.