diff --git a/backend/adapter_processor_v2/adapter_processor.py b/backend/adapter_processor_v2/adapter_processor.py index 33fccec88..3120b985d 100644 --- a/backend/adapter_processor_v2/adapter_processor.py +++ b/backend/adapter_processor_v2/adapter_processor.py @@ -9,6 +9,7 @@ InValidAdapterId, TestAdapterError, ) +from cryptography.fernet import Fernet from django.conf import settings from django.core.exceptions import ObjectDoesNotExist from platform_settings_v2.platform_auth_service import PlatformAuthenticationService @@ -23,6 +24,11 @@ logger = logging.getLogger(__name__) +try: + from plugins.subscription.time_trials.subscription_adapter import add_unstract_key +except ImportError: + add_unstract_key = None + class AdapterProcessor: @staticmethod @@ -91,6 +97,12 @@ def test_adapter(adapter_id: str, adapter_metadata: dict[str, Any]) -> bool: adapter_class = Adapterkit().get_adapter_class_by_adapter_id(adapter_id) if adapter_metadata.pop(AdapterKeys.ADAPTER_TYPE) == AdapterKeys.X2TEXT: + + if ( + adapter_metadata.get(AdapterKeys.PLATFORM_PROVIDED_UNSTRACT_KEY) + and add_unstract_key + ): + adapter_metadata = add_unstract_key(adapter_metadata) adapter_metadata[X2TextConstants.X2TEXT_HOST] = settings.X2TEXT_HOST adapter_metadata[X2TextConstants.X2TEXT_PORT] = settings.X2TEXT_PORT platform_key = PlatformAuthenticationService.get_active_platform_key() @@ -106,6 +118,21 @@ def test_adapter(adapter_id: str, adapter_metadata: dict[str, Any]) -> bool: e, adapter_name=adapter_metadata[AdapterKeys.ADAPTER_NAME] ) + @staticmethod + def update_adapter_metadata(adapter_metadata_b: Any) -> Any: + if add_unstract_key: + encryption_secret: str = settings.ENCRYPTION_KEY + f: Fernet = Fernet(encryption_secret.encode("utf-8")) + + adapter_metadata = json.loads( + f.decrypt(bytes(adapter_metadata_b).decode("utf-8")) + ) + adapter_metadata = add_unstract_key(adapter_metadata) + + adapter_metadata_b = f.encrypt(json.dumps(adapter_metadata).encode("utf-8")) + return adapter_metadata_b + return adapter_metadata_b + @staticmethod def __fetch_adapters_by_key_value(key: str, value: Any) -> Adapter: """Fetches a list of adapters that have an attribute matching key and diff --git a/backend/adapter_processor_v2/constants.py b/backend/adapter_processor_v2/constants.py index 3b849a72b..ea530e994 100644 --- a/backend/adapter_processor_v2/constants.py +++ b/backend/adapter_processor_v2/constants.py @@ -24,12 +24,13 @@ class AdapterKeys: X2TEXT_DEFAULT = "x2text_default" SHARED_USERS = "shared_users" ADAPTER_NAME_EXISTS = ( - "Configuration with this name already exists within your organisation. " + "Configuration with this name already exists within your organisation." "Please try with a different name." ) ADAPTER_NAME = "adapter_name" ADAPTER_CREATED_BY = "created_by_email" ADAPTER_CONTEXT_WINDOW_SIZE = "context_window_size" + PLATFORM_PROVIDED_UNSTRACT_KEY = "use_platform_provided_unstract_key" class AllowedDomains(Enum): diff --git a/backend/adapter_processor_v2/views.py b/backend/adapter_processor_v2/views.py index 299716d87..e3c1ee482 100644 --- a/backend/adapter_processor_v2/views.py +++ b/backend/adapter_processor_v2/views.py @@ -169,8 +169,30 @@ def get_serializer_class( def create(self, request: Any) -> Response: serializer = self.get_serializer(data=request.data) + + use_platform_unstract_key = False + adapter_metadata = request.data.get(AdapterKeys.ADAPTER_METADATA) + if adapter_metadata and adapter_metadata.get( + AdapterKeys.PLATFORM_PROVIDED_UNSTRACT_KEY, False + ): + use_platform_unstract_key = True + serializer.is_valid(raise_exception=True) try: + adapter_type = serializer.validated_data.get(AdapterKeys.ADAPTER_TYPE) + + if adapter_type == AdapterKeys.X2TEXT and use_platform_unstract_key: + adapter_metadata_b = serializer.validated_data.get( + AdapterKeys.ADAPTER_METADATA_B + ) + adapter_metadata_b = AdapterProcessor.update_adapter_metadata( + adapter_metadata_b + ) + # Update the validated data with the new adapter_metadata + serializer.validated_data[AdapterKeys.ADAPTER_METADATA_B] = ( + adapter_metadata_b + ) + instance = serializer.save() organization_member = OrganizationMemberService.get_user_by_id( request.user.id @@ -185,7 +207,6 @@ def create(self, request: Any) -> Response: organization_member=organization_member ) - adapter_type = serializer.validated_data.get(AdapterKeys.ADAPTER_TYPE) if (adapter_type == AdapterKeys.LLM) and ( not user_default_adapter.default_llm_adapter ): diff --git a/backend/backend/settings/base.py b/backend/backend/settings/base.py index 4aecdc790..c132f9e20 100644 --- a/backend/backend/settings/base.py +++ b/backend/backend/settings/base.py @@ -457,6 +457,9 @@ def get_required_setting( # Whitelisting health check API WHITELISTED_PATHS.append("/health") +# These path will work without organization in request +ORGANIZATION_MIDDLEWARE_WHITELISTED_PATHS = [] + # API Doc Generator Settings # https://drf-yasg.readthedocs.io/en/stable/settings.html REDOC_SETTINGS = { diff --git a/backend/middleware/organization_middleware.py b/backend/middleware/organization_middleware.py index a2b9450ce..848dd2c2c 100644 --- a/backend/middleware/organization_middleware.py +++ b/backend/middleware/organization_middleware.py @@ -12,6 +12,14 @@ def process_request(self, request): # Check if the URL matches the pattern with organization ID match = re.match(pattern, request.path) if match: + # Check if the request path matches any of the whitelisted paths + if any( + re.match(path, request.path) + for path in settings.ORGANIZATION_MIDDLEWARE_WHITELISTED_PATHS + ): + request.path_info = "/" + request.path_info + return + org_id = match.group("org_id") request.organization_id = org_id new_path = re.sub(pattern, "/" + tenant_prefix, request.path_info) diff --git a/frontend/src/hooks/useRequestUrl.js b/frontend/src/hooks/useRequestUrl.js index 1bb202f2f..4fd39164f 100644 --- a/frontend/src/hooks/useRequestUrl.js +++ b/frontend/src/hooks/useRequestUrl.js @@ -6,7 +6,7 @@ const useRequestUrl = () => { const getUrl = (url) => { if (!url) return null; - const baseUrl = `/api/v1/${sessionDetails?.orgId}/`; + const baseUrl = `/api/v1/unstract/${sessionDetails?.orgId}/`; return baseUrl + url.replace(/^\//, ""); };