Skip to content

Commit

Permalink
Merge branch 'main' into feat/MultiTenancyV2-AdapterAndConnectorProcess
Browse files Browse the repository at this point in the history
  • Loading branch information
hari-kuriakose authored Jul 18, 2024
2 parents 5f6e0a3 + 72134f0 commit 4d58415
Show file tree
Hide file tree
Showing 24 changed files with 373 additions and 203 deletions.
7 changes: 7 additions & 0 deletions backend/adapter_processor/adapter_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
InternalServiceError,
InValidAdapterId,
TestAdapterError,
TestAdapterInputError,
)
from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist
Expand Down Expand Up @@ -97,6 +98,12 @@ def test_adapter(adapter_id: str, adapter_metadata: dict[str, Any]) -> bool:
test_result: bool = adapter_instance.test_connection()
logger.info(f"{adapter_id} test result: {test_result}")
return test_result
# HACK: Remove after error is explicitly handled in VertexAI adapter
except json.JSONDecodeError:
raise TestAdapterInputError(
"Credentials is not a valid service account JSON, "
"please provide a valid JSON."
)
except AdapterError as e:
raise TestAdapterError(str(e))

Expand Down
9 changes: 6 additions & 3 deletions backend/adapter_processor/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from rest_framework.exceptions import APIException

from backend.exceptions import UnstractBaseException


class IdIsMandatory(APIException):
status_code = 400
Expand Down Expand Up @@ -46,11 +44,16 @@ class UniqueConstraintViolation(APIException):
default_detail = "Unique constraint violated"


class TestAdapterError(UnstractBaseException):
class TestAdapterError(APIException):
status_code = 500
default_detail = "Error while testing adapter"


class TestAdapterInputError(APIException):
status_code = 400
default_detail = "Error while testing adapter, please check the configuration."


class DeleteAdapterInUseError(APIException):
status_code = 409

Expand Down
18 changes: 7 additions & 11 deletions backend/adapter_processor/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,17 +115,13 @@ def test(self, request: Request) -> Response:
adapter_metadata[AdapterKeys.ADAPTER_TYPE] = serializer.validated_data.get(
AdapterKeys.ADAPTER_TYPE
)
try:
test_result = AdapterProcessor.test_adapter(
adapter_id=adapter_id, adapter_metadata=adapter_metadata
)
return Response(
{AdapterKeys.IS_VALID: test_result},
status=status.HTTP_200_OK,
)
except Exception as e:
logger.error(f"Error testing adapter : {str(e)}")
raise e
test_result = AdapterProcessor.test_adapter(
adapter_id=adapter_id, adapter_metadata=adapter_metadata
)
return Response(
{AdapterKeys.IS_VALID: test_result},
status=status.HTTP_200_OK,
)


class AdapterInstanceViewSet(ModelViewSet):
Expand Down
30 changes: 15 additions & 15 deletions backend/connector_processor/connector_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from connector_processor.constants import ConnectorKeys
from connector_processor.exceptions import (
InternalServiceError,
InValidConnectorId,
InValidConnectorMode,
OAuthTimeOut,
Expand Down Expand Up @@ -53,26 +52,27 @@ def get_json_schema(connector_id: str) -> dict:
updated_connectors = fetch_connectors_by_key_value(
ConnectorKeys.ID, connector_id
)
if len(updated_connectors) != 0:
connector = updated_connectors[0]
schema_details[ConnectorKeys.OAUTH] = connector.get(ConnectorKeys.OAUTH)
schema_details[ConnectorKeys.SOCIAL_AUTH_URL] = connector.get(
ConnectorKeys.SOCIAL_AUTH_URL
)
try:
schema_details[ConnectorKeys.JSON_SCHEMA] = json.loads(
connector.get(ConnectorKeys.JSON_SCHEMA)
)
except Exception as exc:
logger.error(f"Error occurred while parsing JSON Schema: {exc}")
raise InternalServiceError()
else:
if len(updated_connectors) == 0:
logger.error(
f"Invalid connector Id : {connector_id} "
f"while fetching "
f"JSON Schema"
)
raise InValidConnectorId()

connector = updated_connectors[0]
schema_details[ConnectorKeys.OAUTH] = connector.get(ConnectorKeys.OAUTH)
schema_details[ConnectorKeys.SOCIAL_AUTH_URL] = connector.get(
ConnectorKeys.SOCIAL_AUTH_URL
)
try:
schema_details[ConnectorKeys.JSON_SCHEMA] = json.loads(
connector.get(ConnectorKeys.JSON_SCHEMA)
)
except Exception as exc:
logger.error(f"Error occurred decoding JSON for {connector_id}: {exc}")
raise exc

return schema_details

@staticmethod
Expand Down
13 changes: 0 additions & 13 deletions backend/prompt_studio/prompt_profile_manager/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from adapter_processor.adapter_processor import AdapterProcessor
from prompt_studio.prompt_profile_manager.constants import ProfileManagerKeys
from prompt_studio.prompt_studio_core.exceptions import MaxProfilesReachedError

from backend.serializers import AuditSerializer

Expand Down Expand Up @@ -39,15 +38,3 @@ def to_representation(self, instance): # type: ignore
AdapterProcessor.get_adapter_instance_by_id(x2text)
)
return rep

def validate(self, data):
prompt_studio_tool = data.get(ProfileManagerKeys.PROMPT_STUDIO_TOOL)

profile_count = ProfileManager.objects.filter(
prompt_studio_tool=prompt_studio_tool
).count()

if profile_count >= ProfileManagerKeys.MAX_PROFILE_COUNT:
raise MaxProfilesReachedError()

return data
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,11 @@ def _execute_prompts_in_single_pass(
doc_path, tool_id, org_id, user_id, document_id, run_id
):
prompts = PromptStudioHelper.fetch_prompt_from_tool(tool_id)
prompts = [prompt for prompt in prompts if prompt.prompt_type != TSPKeys.NOTES]
prompts = [
prompt
for prompt in prompts
if prompt.prompt_type != TSPKeys.NOTES and prompt.active
]
if not prompts:
logger.error(f"[{tool_id or 'NA'}] No prompts found for id: {id}")
raise NoPromptsFound()
Expand Down
16 changes: 15 additions & 1 deletion backend/prompt_studio/prompt_studio_core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from file_management.file_management_helper import FileManagerHelper
from permissions.permission import IsOwner, IsOwnerOrSharedUser
from prompt_studio.processor_loader import ProcessorConfig, load_plugins
from prompt_studio.prompt_profile_manager.constants import ProfileManagerErrors
from prompt_studio.prompt_profile_manager.constants import (
ProfileManagerErrors,
ProfileManagerKeys,
)
from prompt_studio.prompt_profile_manager.models import ProfileManager
from prompt_studio.prompt_profile_manager.serializers import ProfileManagerSerializer
from prompt_studio.prompt_studio.constants import ToolStudioPromptErrors
Expand All @@ -26,6 +29,7 @@
)
from prompt_studio.prompt_studio_core.exceptions import (
IndexingAPIError,
MaxProfilesReachedError,
ToolDeleteError,
)
from prompt_studio.prompt_studio_core.prompt_studio_helper import PromptStudioHelper
Expand Down Expand Up @@ -345,6 +349,16 @@ def create_profile_manager(self, request: HttpRequest, pk: Any = None) -> Respon
serializer = ProfileManagerSerializer(data=request.data, context=context)

serializer.is_valid(raise_exception=True)
# Check for the maximum number of profiles constraint
prompt_studio_tool = serializer.validated_data[
ProfileManagerKeys.PROMPT_STUDIO_TOOL
]
profile_count = ProfileManager.objects.filter(
prompt_studio_tool=prompt_studio_tool
).count()

if profile_count >= ProfileManagerKeys.MAX_PROFILE_COUNT:
raise MaxProfilesReachedError()
try:
self.perform_create(serializer)
except IntegrityError:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,8 @@ class PromptStudioOutputManagerKeys:
DOCUMENT_MANAGER = "document_manager"
IS_SINGLE_PASS_EXTRACT = "is_single_pass_extract"
NOTES = "NOTES"


class PromptOutputManagerErrorMessage:
TOOL_VALIDATION = "tool_id parameter is required"
TOOL_NOT_FOUND = "Tool not found"
8 changes: 8 additions & 0 deletions backend/prompt_studio/prompt_studio_output_manager/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,17 @@
from .views import PromptStudioOutputView

prompt_doc_list = PromptStudioOutputView.as_view({"get": "list"})
get_output_for_tool_default = PromptStudioOutputView.as_view(
{"get": "get_output_for_tool_default"}
)

urlpatterns = format_suffix_patterns(
[
path("prompt-output/", prompt_doc_list, name="prompt-doc-list"),
path(
"prompt-output/prompt-default-profile/",
get_output_for_tool_default,
name="prompt-default-profile-outputs",
),
]
)
58 changes: 56 additions & 2 deletions backend/prompt_studio/prompt_studio_output_manager/views.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import logging
from typing import Optional
from typing import Any, Optional

from django.core.exceptions import ObjectDoesNotExist
from django.db.models import QuerySet
from django.http import HttpRequest
from prompt_studio.prompt_studio.models import ToolStudioPrompt
from prompt_studio.prompt_studio_output_manager.constants import (
PromptOutputManagerErrorMessage,
PromptStudioOutputManagerKeys,
)
from prompt_studio.prompt_studio_output_manager.serializers import (
PromptStudioOutputSerializer,
)
from rest_framework import viewsets
from rest_framework import status, viewsets
from rest_framework.exceptions import APIException
from rest_framework.response import Response
from rest_framework.versioning import URLPathVersioning
from utils.common_utils import CommonUtils
from utils.filtering import FilterHelper
Expand Down Expand Up @@ -49,3 +55,51 @@ def get_queryset(self) -> Optional[QuerySet]:
queryset = PromptStudioOutputManager.objects.filter(**filter_args)

return queryset

def get_output_for_tool_default(self, request: HttpRequest) -> Response:
# Get the tool_id from request parameters
# Get the tool_id from request parameters
tool_id = request.GET.get("tool_id")
document_manager_id = request.GET.get("document_manager")
tool_validation_message = PromptOutputManagerErrorMessage.TOOL_VALIDATION
tool_not_found = PromptOutputManagerErrorMessage.TOOL_NOT_FOUND
if not tool_id:
raise APIException(detail=tool_validation_message, code=400)

try:
# Fetch ToolStudioPrompt records based on tool_id
tool_studio_prompts = ToolStudioPrompt.objects.filter(tool_id=tool_id)
except ObjectDoesNotExist:
raise APIException(detail=tool_not_found, code=400)

# Initialize the result dictionary
result: dict[str, Any] = {}

# Iterate over ToolStudioPrompt records
for tool_prompt in tool_studio_prompts:
prompt_id = str(tool_prompt.prompt_id)
profile_manager_id = str(tool_prompt.profile_manager.profile_id)

# If profile_manager is not set, skip this record
if not profile_manager_id:
result[tool_prompt.prompt_key] = ""
continue

try:
queryset = PromptStudioOutputManager.objects.filter(
prompt_id=prompt_id,
profile_manager=profile_manager_id,
is_single_pass_extract=False,
document_manager_id=document_manager_id,
)

if not queryset.exists():
result[tool_prompt.prompt_key] = ""
continue

for output in queryset:
result[tool_prompt.prompt_key] = output.output
except ObjectDoesNotExist:
result[tool_prompt.prompt_key] = ""

return Response(result, status=status.HTTP_200_OK)
3 changes: 0 additions & 3 deletions backend/prompt_studio/prompt_studio_registry/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,6 @@ class JsonSchemaKey:
ENABLE_CHALLENGE = "enable_challenge"
CHALLENGE_LLM = "challenge_llm"
ENABLE_SINGLE_PASS_EXTRACTION = "enable_single_pass_extraction"
IMAGE_URL = "image_url"
IMAGE_NAME = "image_name"
IMAGE_TAG = "image_tag"
SUMMARIZE_PROMPT = "summarize_prompt"
SUMMARIZE_AS_SOURCE = "summarize_as_source"
ENABLE_HIGHLIGHT = "enable_highlight"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,25 +119,14 @@ def get_tool_by_prompt_registry_id(
f"ID {prompt_registry_id}: {e} "
)
return None
# The below properties are introduced after 0.20.0
# So defaulting to 0.20.0 if the properties are not found
image_url = prompt_registry_tool.tool_metadata.get(
JsonSchemaKey.IMAGE_URL, "docker:unstract/tool-structure:0.0.20"
)
image_name = prompt_registry_tool.tool_metadata.get(
JsonSchemaKey.IMAGE_NAME, "unstract/tool-structure"
)
image_tag = prompt_registry_tool.tool_metadata.get(
JsonSchemaKey.IMAGE_TAG, "0.0.20"
)
return Tool(
tool_uid=prompt_registry_tool.prompt_registry_id,
properties=Properties.from_dict(prompt_registry_tool.tool_property),
spec=Spec.from_dict(prompt_registry_tool.tool_spec),
icon=prompt_registry_tool.icon,
image_url=image_url,
image_name=image_name,
image_tag=image_tag,
image_url=settings.STRUCTURE_TOOL_IMAGE_URL,
image_name=settings.STRUCTURE_TOOL_IMAGE_NAME,
image_tag=settings.STRUCTURE_TOOL_IMAGE_TAG,
)

@staticmethod
Expand Down Expand Up @@ -176,7 +165,6 @@ def update_or_create_psr_tool(
obj, created = PromptStudioRegistry.objects.update_or_create(
custom_tool=custom_tool,
created_by=custom_tool.created_by,
modified_by=custom_tool.modified_by,
defaults={
"name": custom_tool.tool_name,
"tool_property": properties.to_dict(),
Expand All @@ -190,7 +178,7 @@ def update_or_create_psr_tool(
logger.info(f"PSR {obj.prompt_registry_id} was created")
else:
logger.info(f"PSR {obj.prompt_registry_id} was updated")

obj.modified_by = custom_tool.modified_by
obj.shared_to_org = shared_with_org
if not shared_with_org:
obj.shared_users.clear()
Expand Down Expand Up @@ -242,9 +230,6 @@ def frame_export_json(
export_metadata[JsonSchemaKey.DESCRIPTION] = tool.description
export_metadata[JsonSchemaKey.AUTHOR] = tool.author
export_metadata[JsonSchemaKey.TOOL_ID] = str(tool.tool_id)
export_metadata[JsonSchemaKey.IMAGE_URL] = settings.STRUCTURE_TOOL_IMAGE_URL
export_metadata[JsonSchemaKey.IMAGE_NAME] = settings.STRUCTURE_TOOL_IMAGE_NAME
export_metadata[JsonSchemaKey.IMAGE_TAG] = settings.STRUCTURE_TOOL_IMAGE_TAG

default_llm_profile = ProfileManager.get_default_llm_profile(tool)
challenge_llm_instance: Optional[AdapterInstance] = tool.challenge_llm
Expand Down Expand Up @@ -283,6 +268,8 @@ def frame_export_json(
tool_settings[JsonSchemaKey.ENABLE_HIGHLIGHT] = tool.enable_highlight

for prompt in prompts:
if prompt.prompt_type == JsonSchemaKey.NOTES or not prompt.active:
continue

if not prompt.prompt:
invalidated_prompts.append(prompt.prompt_key)
Expand All @@ -298,8 +285,6 @@ def frame_export_json(
invalidated_outputs.append(prompt.prompt_key)
continue

if prompt.prompt_type == JsonSchemaKey.NOTES:
continue
if not prompt.profile_manager:
prompt.profile_manager = default_llm_profile

Expand Down
Loading

0 comments on commit 4d58415

Please sign in to comment.