Skip to content

Commit

Permalink
Merge branch 'main' into feat/UN-1451-pdm-lock-automation
Browse files Browse the repository at this point in the history
  • Loading branch information
kirtimanmishrazipstack authored Jul 18, 2024
2 parents c07ef66 + 64861dc commit 48bcb51
Show file tree
Hide file tree
Showing 36 changed files with 515 additions and 303 deletions.
17 changes: 6 additions & 11 deletions backend/connector_processor/connector_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
InValidConnectorId,
InValidConnectorMode,
OAuthTimeOut,
TestConnectorException,
TestConnectorInputException,
TestConnectorInputError,
)

from unstract.connectors.base import UnstractConnector
Expand Down Expand Up @@ -100,15 +99,15 @@ def get_all_supported_connectors(
return supported_connectors

@staticmethod
def test_connectors(connector_id: str, cred_string: dict[str, Any]) -> bool:
def test_connectors(connector_id: str, credentials: dict[str, Any]) -> bool:
logger.info(f"Testing connector: {connector_id}")
connector: dict[str, Any] = fetch_connectors_by_key_value(
ConnectorKeys.ID, connector_id
)[0]
if connector.get(ConnectorKeys.OAUTH):
try:
oauth_key = cred_string.get(ConnectorAuthKey.OAUTH_KEY)
cred_string = ConnectorAuthHelper.get_oauth_creds_from_cache(
oauth_key = credentials.get(ConnectorAuthKey.OAUTH_KEY)
credentials = ConnectorAuthHelper.get_oauth_creds_from_cache(
cache_key=oauth_key, delete_key=False
)
except Exception as exc:
Expand All @@ -120,17 +119,13 @@ def test_connectors(connector_id: str, cred_string: dict[str, Any]) -> bool:

try:
connector_impl = Connectorkit().get_connector_by_id(
connector_id, cred_string
connector_id, credentials
)
test_result = connector_impl.test_credentials()
logger.info(f"{connector_id} test result: {test_result}")
return test_result
except ConnectorError as e:
logger.error(f"Error while testing {connector_id}: {e}")
raise TestConnectorInputException(core_err=e)
except Exception as e:
logger.error(f"Error while testing {connector_id}: {e}")
raise TestConnectorException
raise TestConnectorInputError(core_err=e)

def get_connector_data_with_key(connector_id: str, key_value: str) -> Any:
"""Generic Function to get connector data with provided key."""
Expand Down
4 changes: 2 additions & 2 deletions backend/connector_processor/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class JSONParseException(APIException):

class OAuthTimeOut(APIException):
status_code = 408
default_detail = "Timed Out. Please re authenticate."
default_detail = "Timed out. Please re-authenticate."


class InternalServiceError(APIException):
Expand All @@ -44,7 +44,7 @@ class TestConnectorException(APIException):
default_detail = "Error while testing connector."


class TestConnectorInputException(UnstractBaseException):
class TestConnectorInputError(UnstractBaseException):
def __init__(self, core_err: ConnectorError) -> None:
super().__init__(detail=core_err.message, core_err=core_err)
self.default_detail = core_err.message
Expand Down
4 changes: 2 additions & 2 deletions backend/connector_processor/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ def test(self, request: Request) -> Response:
"""Tests the connector against the credentials passed."""
serializer: TestConnectorSerializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
connector_id = serializer.validated_data.get(ConnectorKeys.CONNECTOR_ID)
connector_id = serializer.validated_data.get(CIKey.CONNECTOR_ID)
cred_string = serializer.validated_data.get(CIKey.CONNECTOR_METADATA)
test_result = ConnectorProcessor.test_connectors(
connector_id=connector_id, cred_string=cred_string
connector_id=connector_id, credentials=cred_string
)
return Response(
{ConnectorKeys.IS_VALID: test_result},
Expand Down
147 changes: 64 additions & 83 deletions backend/pdm.lock

Large diffs are not rendered by default.

13 changes: 0 additions & 13 deletions backend/prompt_studio/prompt_profile_manager/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from adapter_processor.adapter_processor import AdapterProcessor
from prompt_studio.prompt_profile_manager.constants import ProfileManagerKeys
from prompt_studio.prompt_studio_core.exceptions import MaxProfilesReachedError

from backend.serializers import AuditSerializer

Expand Down Expand Up @@ -39,15 +38,3 @@ def to_representation(self, instance): # type: ignore
AdapterProcessor.get_adapter_instance_by_id(x2text)
)
return rep

def validate(self, data):
prompt_studio_tool = data.get(ProfileManagerKeys.PROMPT_STUDIO_TOOL)

profile_count = ProfileManager.objects.filter(
prompt_studio_tool=prompt_studio_tool
).count()

if profile_count >= ProfileManagerKeys.MAX_PROFILE_COUNT:
raise MaxProfilesReachedError()

return data
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,11 @@ def _execute_prompts_in_single_pass(
doc_path, tool_id, org_id, user_id, document_id, run_id
):
prompts = PromptStudioHelper.fetch_prompt_from_tool(tool_id)
prompts = [prompt for prompt in prompts if prompt.prompt_type != TSPKeys.NOTES]
prompts = [
prompt
for prompt in prompts
if prompt.prompt_type != TSPKeys.NOTES and prompt.active
]
if not prompts:
logger.error(f"[{tool_id or 'NA'}] No prompts found for id: {id}")
raise NoPromptsFound()
Expand Down
16 changes: 15 additions & 1 deletion backend/prompt_studio/prompt_studio_core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from file_management.file_management_helper import FileManagerHelper
from permissions.permission import IsOwner, IsOwnerOrSharedUser
from prompt_studio.processor_loader import ProcessorConfig, load_plugins
from prompt_studio.prompt_profile_manager.constants import ProfileManagerErrors
from prompt_studio.prompt_profile_manager.constants import (
ProfileManagerErrors,
ProfileManagerKeys,
)
from prompt_studio.prompt_profile_manager.models import ProfileManager
from prompt_studio.prompt_profile_manager.serializers import ProfileManagerSerializer
from prompt_studio.prompt_studio.constants import ToolStudioPromptErrors
Expand All @@ -26,6 +29,7 @@
)
from prompt_studio.prompt_studio_core.exceptions import (
IndexingAPIError,
MaxProfilesReachedError,
ToolDeleteError,
)
from prompt_studio.prompt_studio_core.prompt_studio_helper import PromptStudioHelper
Expand Down Expand Up @@ -345,6 +349,16 @@ def create_profile_manager(self, request: HttpRequest, pk: Any = None) -> Respon
serializer = ProfileManagerSerializer(data=request.data, context=context)

serializer.is_valid(raise_exception=True)
# Check for the maximum number of profiles constraint
prompt_studio_tool = serializer.validated_data[
ProfileManagerKeys.PROMPT_STUDIO_TOOL
]
profile_count = ProfileManager.objects.filter(
prompt_studio_tool=prompt_studio_tool
).count()

if profile_count >= ProfileManagerKeys.MAX_PROFILE_COUNT:
raise MaxProfilesReachedError()
try:
self.perform_create(serializer)
except IntegrityError:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,8 @@ class PromptStudioOutputManagerKeys:
DOCUMENT_MANAGER = "document_manager"
IS_SINGLE_PASS_EXTRACT = "is_single_pass_extract"
NOTES = "NOTES"


class PromptOutputManagerErrorMessage:
TOOL_VALIDATION = "tool_id parameter is required"
TOOL_NOT_FOUND = "Tool not found"
8 changes: 8 additions & 0 deletions backend/prompt_studio/prompt_studio_output_manager/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,17 @@
from .views import PromptStudioOutputView

prompt_doc_list = PromptStudioOutputView.as_view({"get": "list"})
get_output_for_tool_default = PromptStudioOutputView.as_view(
{"get": "get_output_for_tool_default"}
)

urlpatterns = format_suffix_patterns(
[
path("prompt-output/", prompt_doc_list, name="prompt-doc-list"),
path(
"prompt-output/prompt-default-profile/",
get_output_for_tool_default,
name="prompt-default-profile-outputs",
),
]
)
58 changes: 56 additions & 2 deletions backend/prompt_studio/prompt_studio_output_manager/views.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import logging
from typing import Optional
from typing import Any, Optional

from django.core.exceptions import ObjectDoesNotExist
from django.db.models import QuerySet
from django.http import HttpRequest
from prompt_studio.prompt_studio.models import ToolStudioPrompt
from prompt_studio.prompt_studio_output_manager.constants import (
PromptOutputManagerErrorMessage,
PromptStudioOutputManagerKeys,
)
from prompt_studio.prompt_studio_output_manager.serializers import (
PromptStudioOutputSerializer,
)
from rest_framework import viewsets
from rest_framework import status, viewsets
from rest_framework.exceptions import APIException
from rest_framework.response import Response
from rest_framework.versioning import URLPathVersioning
from utils.common_utils import CommonUtils
from utils.filtering import FilterHelper
Expand Down Expand Up @@ -49,3 +55,51 @@ def get_queryset(self) -> Optional[QuerySet]:
queryset = PromptStudioOutputManager.objects.filter(**filter_args)

return queryset

def get_output_for_tool_default(self, request: HttpRequest) -> Response:
# Get the tool_id from request parameters
# Get the tool_id from request parameters
tool_id = request.GET.get("tool_id")
document_manager_id = request.GET.get("document_manager")
tool_validation_message = PromptOutputManagerErrorMessage.TOOL_VALIDATION
tool_not_found = PromptOutputManagerErrorMessage.TOOL_NOT_FOUND
if not tool_id:
raise APIException(detail=tool_validation_message, code=400)

try:
# Fetch ToolStudioPrompt records based on tool_id
tool_studio_prompts = ToolStudioPrompt.objects.filter(tool_id=tool_id)
except ObjectDoesNotExist:
raise APIException(detail=tool_not_found, code=400)

# Initialize the result dictionary
result: dict[str, Any] = {}

# Iterate over ToolStudioPrompt records
for tool_prompt in tool_studio_prompts:
prompt_id = str(tool_prompt.prompt_id)
profile_manager_id = str(tool_prompt.profile_manager.profile_id)

# If profile_manager is not set, skip this record
if not profile_manager_id:
result[tool_prompt.prompt_key] = ""
continue

try:
queryset = PromptStudioOutputManager.objects.filter(
prompt_id=prompt_id,
profile_manager=profile_manager_id,
is_single_pass_extract=False,
document_manager_id=document_manager_id,
)

if not queryset.exists():
result[tool_prompt.prompt_key] = ""
continue

for output in queryset:
result[tool_prompt.prompt_key] = output.output
except ObjectDoesNotExist:
result[tool_prompt.prompt_key] = ""

return Response(result, status=status.HTTP_200_OK)
3 changes: 0 additions & 3 deletions backend/prompt_studio/prompt_studio_registry/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,6 @@ class JsonSchemaKey:
ENABLE_CHALLENGE = "enable_challenge"
CHALLENGE_LLM = "challenge_llm"
ENABLE_SINGLE_PASS_EXTRACTION = "enable_single_pass_extraction"
IMAGE_URL = "image_url"
IMAGE_NAME = "image_name"
IMAGE_TAG = "image_tag"
SUMMARIZE_PROMPT = "summarize_prompt"
SUMMARIZE_AS_SOURCE = "summarize_as_source"
ENABLE_HIGHLIGHT = "enable_highlight"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,25 +119,14 @@ def get_tool_by_prompt_registry_id(
f"ID {prompt_registry_id}: {e} "
)
return None
# The below properties are introduced after 0.20.0
# So defaulting to 0.20.0 if the properties are not found
image_url = prompt_registry_tool.tool_metadata.get(
JsonSchemaKey.IMAGE_URL, "docker:unstract/tool-structure:0.0.20"
)
image_name = prompt_registry_tool.tool_metadata.get(
JsonSchemaKey.IMAGE_NAME, "unstract/tool-structure"
)
image_tag = prompt_registry_tool.tool_metadata.get(
JsonSchemaKey.IMAGE_TAG, "0.0.20"
)
return Tool(
tool_uid=prompt_registry_tool.prompt_registry_id,
properties=Properties.from_dict(prompt_registry_tool.tool_property),
spec=Spec.from_dict(prompt_registry_tool.tool_spec),
icon=prompt_registry_tool.icon,
image_url=image_url,
image_name=image_name,
image_tag=image_tag,
image_url=settings.STRUCTURE_TOOL_IMAGE_URL,
image_name=settings.STRUCTURE_TOOL_IMAGE_NAME,
image_tag=settings.STRUCTURE_TOOL_IMAGE_TAG,
)

@staticmethod
Expand Down Expand Up @@ -176,7 +165,6 @@ def update_or_create_psr_tool(
obj, created = PromptStudioRegistry.objects.update_or_create(
custom_tool=custom_tool,
created_by=custom_tool.created_by,
modified_by=custom_tool.modified_by,
defaults={
"name": custom_tool.tool_name,
"tool_property": properties.to_dict(),
Expand All @@ -190,7 +178,7 @@ def update_or_create_psr_tool(
logger.info(f"PSR {obj.prompt_registry_id} was created")
else:
logger.info(f"PSR {obj.prompt_registry_id} was updated")

obj.modified_by = custom_tool.modified_by
obj.shared_to_org = shared_with_org
if not shared_with_org:
obj.shared_users.clear()
Expand Down Expand Up @@ -242,9 +230,6 @@ def frame_export_json(
export_metadata[JsonSchemaKey.DESCRIPTION] = tool.description
export_metadata[JsonSchemaKey.AUTHOR] = tool.author
export_metadata[JsonSchemaKey.TOOL_ID] = str(tool.tool_id)
export_metadata[JsonSchemaKey.IMAGE_URL] = settings.STRUCTURE_TOOL_IMAGE_URL
export_metadata[JsonSchemaKey.IMAGE_NAME] = settings.STRUCTURE_TOOL_IMAGE_NAME
export_metadata[JsonSchemaKey.IMAGE_TAG] = settings.STRUCTURE_TOOL_IMAGE_TAG

default_llm_profile = ProfileManager.get_default_llm_profile(tool)
challenge_llm_instance: Optional[AdapterInstance] = tool.challenge_llm
Expand Down Expand Up @@ -283,6 +268,8 @@ def frame_export_json(
tool_settings[JsonSchemaKey.ENABLE_HIGHLIGHT] = tool.enable_highlight

for prompt in prompts:
if prompt.prompt_type == JsonSchemaKey.NOTES or not prompt.active:
continue

if not prompt.prompt:
invalidated_prompts.append(prompt.prompt_key)
Expand All @@ -298,8 +285,6 @@ def frame_export_json(
invalidated_outputs.append(prompt.prompt_key)
continue

if prompt.prompt_type == JsonSchemaKey.NOTES:
continue
if not prompt.profile_manager:
prompt.profile_manager = default_llm_profile

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Col, Modal, Row, Tabs, Typography } from "antd";
import PropTypes from "prop-types";
import { useState } from "react";
import { useEffect, useState } from "react";

import { ListOfConnectors } from "../list-of-connectors/ListOfConnectors";
import "./ConfigureConnectorModal.css";
Expand Down Expand Up @@ -28,6 +28,13 @@ function ConfigureConnectorModal({
setSelectedItemName,
}) {
const [activeKey, setActiveKey] = useState("1");
useEffect(() => {
if (connectorMetadata) {
setActiveKey("2"); // If connector is already configured
} else {
setActiveKey("1"); // default value
}
}, [open, connectorMetadata]);
const { setPostHogCustomEvent, posthogConnectorEventText } =
usePostHogEvents();
const tabItems = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ function ConfigureFormsLayout({
connDetails={connDetails}
connType={connType}
selectedSourceName={selectedItemName}
formDataConfig={formDataConfig}
/>
)}
</div>
Expand Down
Loading

0 comments on commit 48bcb51

Please sign in to comment.