Skip to content

Commit

Permalink
Merge branch 'main' into fix/avoid-conditionally-hiding-the-mark-hori…
Browse files Browse the repository at this point in the history
…zontal-lines-field
  • Loading branch information
tahierhussain authored Jan 24, 2025
2 parents 674c041 + 01f0e65 commit ab26646
Show file tree
Hide file tree
Showing 88 changed files with 2,657 additions and 1,359 deletions.
27 changes: 27 additions & 0 deletions backend/adapter_processor_v2/adapter_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
InValidAdapterId,
TestAdapterError,
)
from cryptography.fernet import Fernet
from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist
from platform_settings_v2.platform_auth_service import PlatformAuthenticationService
Expand All @@ -23,6 +24,11 @@

logger = logging.getLogger(__name__)

try:
from plugins.subscription.time_trials.subscription_adapter import add_unstract_key
except ImportError:
add_unstract_key = None


class AdapterProcessor:
@staticmethod
Expand Down Expand Up @@ -91,6 +97,12 @@ def test_adapter(adapter_id: str, adapter_metadata: dict[str, Any]) -> bool:
adapter_class = Adapterkit().get_adapter_class_by_adapter_id(adapter_id)

if adapter_metadata.pop(AdapterKeys.ADAPTER_TYPE) == AdapterKeys.X2TEXT:

if (
adapter_metadata.get(AdapterKeys.PLATFORM_PROVIDED_UNSTRACT_KEY)
and add_unstract_key
):
adapter_metadata = add_unstract_key(adapter_metadata)
adapter_metadata[X2TextConstants.X2TEXT_HOST] = settings.X2TEXT_HOST
adapter_metadata[X2TextConstants.X2TEXT_PORT] = settings.X2TEXT_PORT
platform_key = PlatformAuthenticationService.get_active_platform_key()
Expand All @@ -106,6 +118,21 @@ def test_adapter(adapter_id: str, adapter_metadata: dict[str, Any]) -> bool:
e, adapter_name=adapter_metadata[AdapterKeys.ADAPTER_NAME]
)

@staticmethod
def update_adapter_metadata(adapter_metadata_b: Any) -> Any:
if add_unstract_key:
encryption_secret: str = settings.ENCRYPTION_KEY
f: Fernet = Fernet(encryption_secret.encode("utf-8"))

adapter_metadata = json.loads(
f.decrypt(bytes(adapter_metadata_b).decode("utf-8"))
)
adapter_metadata = add_unstract_key(adapter_metadata)

adapter_metadata_b = f.encrypt(json.dumps(adapter_metadata).encode("utf-8"))
return adapter_metadata_b
return adapter_metadata_b

@staticmethod
def __fetch_adapters_by_key_value(key: str, value: Any) -> Adapter:
"""Fetches a list of adapters that have an attribute matching key and
Expand Down
3 changes: 2 additions & 1 deletion backend/adapter_processor_v2/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ class AdapterKeys:
X2TEXT_DEFAULT = "x2text_default"
SHARED_USERS = "shared_users"
ADAPTER_NAME_EXISTS = (
"Configuration with this name already exists within your organisation. "
"Configuration with this name already exists within your organisation."
"Please try with a different name."
)
ADAPTER_NAME = "adapter_name"
ADAPTER_CREATED_BY = "created_by_email"
ADAPTER_CONTEXT_WINDOW_SIZE = "context_window_size"
PLATFORM_PROVIDED_UNSTRACT_KEY = "use_platform_provided_unstract_key"


class AllowedDomains(Enum):
Expand Down
23 changes: 22 additions & 1 deletion backend/adapter_processor_v2/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,30 @@ def get_serializer_class(

def create(self, request: Any) -> Response:
serializer = self.get_serializer(data=request.data)

use_platform_unstract_key = False
adapter_metadata = request.data.get(AdapterKeys.ADAPTER_METADATA)
if adapter_metadata and adapter_metadata.get(
AdapterKeys.PLATFORM_PROVIDED_UNSTRACT_KEY, False
):
use_platform_unstract_key = True

serializer.is_valid(raise_exception=True)
try:
adapter_type = serializer.validated_data.get(AdapterKeys.ADAPTER_TYPE)

if adapter_type == AdapterKeys.X2TEXT and use_platform_unstract_key:
adapter_metadata_b = serializer.validated_data.get(
AdapterKeys.ADAPTER_METADATA_B
)
adapter_metadata_b = AdapterProcessor.update_adapter_metadata(
adapter_metadata_b
)
# Update the validated data with the new adapter_metadata
serializer.validated_data[AdapterKeys.ADAPTER_METADATA_B] = (
adapter_metadata_b
)

instance = serializer.save()
organization_member = OrganizationMemberService.get_user_by_id(
request.user.id
Expand All @@ -185,7 +207,6 @@ def create(self, request: Any) -> Response:
organization_member=organization_member
)

adapter_type = serializer.validated_data.get(AdapterKeys.ADAPTER_TYPE)
if (adapter_type == AdapterKeys.LLM) and (
not user_default_adapter.default_llm_adapter
):
Expand Down
2 changes: 2 additions & 0 deletions backend/api_v2/api_deployment_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def post(
include_metadata = serializer.validated_data.get(ApiExecution.INCLUDE_METADATA)
include_metrics = serializer.validated_data.get(ApiExecution.INCLUDE_METRICS)
use_file_history = serializer.validated_data.get(ApiExecution.USE_FILE_HISTORY)
tag_names = serializer.validated_data.get(ApiExecution.TAGS)
if not file_objs or len(file_objs) == 0:
raise InvalidAPIRequest("File shouldn't be empty")
response = DeploymentHelper.execute_workflow(
Expand All @@ -64,6 +65,7 @@ def post(
include_metadata=include_metadata,
include_metrics=include_metrics,
use_file_history=use_file_history,
tag_names=tag_names,
)
if "error" in response and response["error"]:
return Response(
Expand Down
1 change: 1 addition & 0 deletions backend/api_v2/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ class ApiExecution:
INCLUDE_METRICS: str = "include_metrics"
USE_FILE_HISTORY: str = "use_file_history" # Undocumented parameter
EXECUTION_ID: str = "execution_id"
TAGS: str = "tags"
5 changes: 5 additions & 0 deletions backend/api_v2/deployment_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from rest_framework.request import Request
from rest_framework.serializers import Serializer
from rest_framework.utils.serializer_helpers import ReturnDict
from tags.models import Tag
from utils.constants import Account, CeleryQueue
from utils.local_context import StateStore
from workflow_manager.endpoint_v2.destination import DestinationConnector
Expand Down Expand Up @@ -138,6 +139,7 @@ def execute_workflow(
include_metadata: bool = False,
include_metrics: bool = False,
use_file_history: bool = False,
tag_names: list[str] = [],
) -> ReturnDict:
"""Execute workflow by api.
Expand All @@ -147,16 +149,19 @@ def execute_workflow(
file_obj (UploadedFile): input file
use_file_history (bool): Use FileHistory table to return results on already
processed files. Defaults to False
tag_names (list(str)): list of tag names
Returns:
ReturnDict: execution status/ result
"""
workflow_id = api.workflow.id
pipeline_id = api.id
tags = Tag.bulk_get_or_create(tag_names=tag_names)
workflow_execution = WorkflowExecutionServiceHelper.create_workflow_execution(
workflow_id=workflow_id,
pipeline_id=pipeline_id,
mode=WorkflowExecution.Mode.QUEUE,
tags=tags,
)
execution_id = workflow_execution.id

Expand Down
5 changes: 4 additions & 1 deletion backend/api_v2/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Serializer,
ValidationError,
)
from tags.serializers import TagParamsSerializer
from utils.serializer.integrity_error_mixin import IntegrityErrorMixin
from workflow_manager.workflow_v2.exceptions import ExecutionDoesNotExistError
from workflow_manager.workflow_v2.models.execution import WorkflowExecution
Expand Down Expand Up @@ -99,7 +100,7 @@ def to_representation(self, instance: APIKey) -> OrderedDict[str, Any]:
return representation


class ExecutionRequestSerializer(Serializer):
class ExecutionRequestSerializer(TagParamsSerializer):
"""Execution request serializer.
Attributes:
Expand All @@ -110,6 +111,8 @@ class ExecutionRequestSerializer(Serializer):
use_file_history (bool): Flag to use FileHistory to save and retrieve
responses quickly. This is undocumented to the user and can be
helpful for demos.
tags (str): Comma-separated List of tags to associate with the execution.
e.g:'tag1,tag2-name,tag3_name'
"""

timeout = IntegerField(
Expand Down
12 changes: 9 additions & 3 deletions backend/backend/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def get_required_setting(
"django.contrib.messages",
"django.contrib.staticfiles",
"django.contrib.admindocs",
"django_filters",
# Third party apps should go below this line,
"rest_framework",
# Connector OAuth
Expand All @@ -226,8 +227,6 @@ def get_required_setting(
"commands",
# health checks
"health",
)
v2_apps = (
"migrating.v2",
"connector_auth_v2",
"tenant_account_v2",
Expand All @@ -250,8 +249,8 @@ def get_required_setting(
"prompt_studio.prompt_studio_output_manager_v2",
"prompt_studio.prompt_studio_document_manager_v2",
"prompt_studio.prompt_studio_index_manager_v2",
"tags",
)
SHARED_APPS += v2_apps
TENANT_APPS = []

INSTALLED_APPS = list(SHARED_APPS) + [
Expand Down Expand Up @@ -432,6 +431,10 @@ def get_required_setting(
"DEFAULT_PERMISSION_CLASSES": [], # TODO: Update once auth is figured
"TEST_REQUEST_DEFAULT_FORMAT": "json",
"EXCEPTION_HANDLER": "middleware.exception.drf_logging_exc_handler",
"DEFAULT_FILTER_BACKENDS": [
"django_filters.rest_framework.DjangoFilterBackend",
"rest_framework.filters.OrderingFilter",
],
}

# These paths will work without authentication
Expand All @@ -451,6 +454,9 @@ def get_required_setting(
# Whitelisting health check API
WHITELISTED_PATHS.append("/health")

# These path will work without organization in request
ORGANIZATION_MIDDLEWARE_WHITELISTED_PATHS = []

# API Doc Generator Settings
# https://drf-yasg.readthedocs.io/en/stable/settings.html
REDOC_SETTINGS = {
Expand Down
1 change: 1 addition & 0 deletions backend/backend/urls_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,5 @@
UrlPathConstants.PROMPT_STUDIO,
include("prompt_studio.prompt_studio_index_manager_v2.urls"),
),
path("tags/", include("tags.urls")),
]
8 changes: 8 additions & 0 deletions backend/middleware/organization_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ def process_request(self, request):
# Check if the URL matches the pattern with organization ID
match = re.match(pattern, request.path)
if match:
# Check if the request path matches any of the whitelisted paths
if any(
re.match(path, request.path)
for path in settings.ORGANIZATION_MIDDLEWARE_WHITELISTED_PATHS
):
request.path_info = "/" + request.path_info
return

org_id = match.group("org_id")
request.organization_id = org_id
new_path = re.sub(pattern, "/" + tenant_prefix, request.path_info)
Expand Down
Loading

0 comments on commit ab26646

Please sign in to comment.