Skip to content

Commit

Permalink
v2 changes of connector and connector_auth
Browse files Browse the repository at this point in the history
  • Loading branch information
muhammad-ali-e committed Jul 8, 2024
1 parent 48826c4 commit 3ac0893
Show file tree
Hide file tree
Showing 25 changed files with 1,667 additions and 0 deletions.
Empty file.
5 changes: 5 additions & 0 deletions backend/connector_auth_v2/admin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from django.contrib import admin

from .models import ConnectorAuth

admin.site.register(ConnectorAuth)
6 changes: 6 additions & 0 deletions backend/connector_auth_v2/apps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from django.apps import AppConfig


class ConnectorAuthConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField"
name = "connector_auth_v2"
18 changes: 18 additions & 0 deletions backend/connector_auth_v2/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
class ConnectorAuthKey:
OAUTH_KEY = "oauth-key"


class SocialAuthConstants:
UID = "uid"
PROVIDER = "provider"
ACCESS_TOKEN = "access_token"
REFRESH_TOKEN = "refresh_token"
TOKEN_TYPE = "token_type"
AUTH_TIME = "auth_time"
EXPIRES = "expires"

REFRESH_AFTER_FORMAT = "%d/%m/%Y %H:%M:%S"
REFRESH_AFTER = "refresh_after" # Timestamp to refresh tokens after

GOOGLE_OAUTH = "google-oauth2"
GOOGLE_TOKEN_EXPIRY_FORMAT = "%d/%m/%Y %H:%M:%S"
31 changes: 31 additions & 0 deletions backend/connector_auth_v2/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Optional

from rest_framework.exceptions import APIException


class CacheMissException(APIException):
status_code = 404
default_detail = "Key doesn't exist."


class EnrichConnectorMetadataException(APIException):
status_code = 500
default_detail = "Connector metadata could not be enriched"


class MissingParamException(APIException):
status_code = 400
default_detail = "Bad request, missing parameter."

def __init__(
self,
code: Optional[str] = None,
param: Optional[str] = None,
) -> None:
detail = f"Bad request, missing parameter: {param}"
super().__init__(detail, code)


class KeyNotConfigured(APIException):
status_code = 500
default_detail = "Key is not configured correctly"
141 changes: 141 additions & 0 deletions backend/connector_auth_v2/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import logging
import uuid
from typing import Any

from account_v2.models import User
from connector_auth_v2.constants import SocialAuthConstants
from connector_auth_v2.pipeline.google import GoogleAuthHelper
from django.db import models
from django.db.models.query import QuerySet
from rest_framework.request import Request
from social_django.fields import JSONField
from social_django.models import AbstractUserSocialAuth, DjangoStorage
from social_django.strategy import DjangoStrategy

logger = logging.getLogger(__name__)


class ConnectorAuthManager(models.Manager):
def get_queryset(self) -> QuerySet:
queryset = super().get_queryset()
# TODO PAN-83: Decrypt here
# for obj in queryset:
# logger.info(f"Decrypting extra_data: {obj.extra_data}")

return queryset


class ConnectorAuth(AbstractUserSocialAuth):
"""Social Auth association model, stores tokens.
The relation with `account.User` is only for the library to work
and should be NOT be used to access the secrets.
Use the following static methods instead
```
@classmethod
def get_social_auth(cls, provider, id):
@classmethod
def create_social_auth(cls, user, uid, provider):
```
"""

id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
user = models.ForeignKey(
User,
related_name="connector_auths",
on_delete=models.SET_NULL,
null=True,
)

def __str__(self) -> str:
return f"ConnectorAuth(provider: {self.provider}, uid: {self.uid})"

def save(self, *args: Any, **kwargs: Any) -> Any:
# TODO PAN-83: Encrypt here
# logger.info(f"Encrypting extra_data: {self.extra_data}")
return super().save(*args, **kwargs)

def set_extra_data(self, extra_data=None): # type: ignore
ConnectorAuth.check_credential_format(extra_data)
if extra_data[SocialAuthConstants.PROVIDER] == SocialAuthConstants.GOOGLE_OAUTH:
extra_data = GoogleAuthHelper.enrich_connector_metadata(extra_data)
return super().set_extra_data(extra_data)

def refresh_token(self, strategy, *args, **kwargs): # type: ignore
"""Override of Python Social Auth (PSA)'s refresh_token functionality
to store uid, provider."""
token = self.extra_data.get("refresh_token") or self.extra_data.get(
"access_token"
)
backend = self.get_backend_instance(strategy)
if token and backend and hasattr(backend, "refresh_token"):
response = backend.refresh_token(token, *args, **kwargs)
extra_data = backend.extra_data(self, self.uid, response, self.extra_data)
extra_data[SocialAuthConstants.PROVIDER] = backend.name
extra_data[SocialAuthConstants.UID] = self.uid
if self.set_extra_data(extra_data): # type: ignore
self.save()

def get_and_refresh_tokens(self, request: Request = None) -> tuple[JSONField, bool]:
"""Uses Social Auth's ability to refresh tokens if necessary.
Returns:
Tuple[JSONField, bool]: JSONField of connector metadata
and flag indicating if tokens were refreshed
"""
# To avoid circular dependency error on import
from social_django.utils import load_strategy

refreshed_token = False
strategy: DjangoStrategy = load_strategy(request=request)
existing_access_token = self.access_token
new_access_token = self.get_access_token(strategy)
if new_access_token != existing_access_token:
refreshed_token = True
related_connector_instances = self.connectorinstance_set.all()
for connector_instance in related_connector_instances:
connector_instance.connector_metadata = self.extra_data
connector_instance.save()
logger.info(
f"Refreshed access token for connector {connector_instance.id}, "
f"provider: {self.provider}, uid: {self.uid}"
)

return self.extra_data, refreshed_token

@staticmethod
def check_credential_format(
oauth_credentials: dict[str, str], raise_exception: bool = True
) -> bool:
if (
SocialAuthConstants.PROVIDER in oauth_credentials
and SocialAuthConstants.UID in oauth_credentials
):
return True
else:
if raise_exception:
raise ValueError(
"Auth credential should have provider, uid and connector guid"
)
return False

objects = ConnectorAuthManager()

class Meta:
app_label = "connector_auth_v2"
verbose_name = "Connector Auth"
verbose_name_plural = "Connector Auths"
db_table = "connector_auth_v2"
constraints = [
models.UniqueConstraint(
fields=[
"provider",
"uid",
],
name="unique_provider_uid_index",
),
]


class ConnectorDjangoStorage(DjangoStorage):
user = ConnectorAuth
111 changes: 111 additions & 0 deletions backend/connector_auth_v2/pipeline/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import logging
from typing import Any, Optional

from account_v2.models import User
from connector_auth_v2.constants import ConnectorAuthKey, SocialAuthConstants
from connector_auth_v2.models import ConnectorAuth
from connector_auth_v2.pipeline.google import GoogleAuthHelper
from django.conf import settings
from django.core.cache import cache
from rest_framework.exceptions import PermissionDenied
from social_core.backends.oauth import BaseOAuth2

logger = logging.getLogger(__name__)


def check_user_exists(backend: BaseOAuth2, user: User, **kwargs: Any) -> dict[str, str]:
"""Checks if user is authenticated (will be handled in auth middleware,
present as a fail safe)
Args:
user (account.User): User model
Raises:
PermissionDenied: Unauthorized user
Returns:
dict: Carrying response details for auth pipeline
"""
if not user:
raise PermissionDenied(backend)
return {**kwargs}


def cache_oauth_creds(
backend: BaseOAuth2,
details: dict[str, str],
response: dict[str, str],
uid: str,
user: User,
*args: Any,
**kwargs: Any,
) -> dict[str, str]:
"""Used to cache the extra data JSON in redis against a key.
This contains the access and refresh token along with details
regarding expiry, uid (unique ID given by provider) and provider.
"""
cache_key = kwargs.get("cache_key") or backend.strategy.session_get(
settings.SOCIAL_AUTH_FIELDS_STORED_IN_SESSION[0],
ConnectorAuthKey.OAUTH_KEY,
)
extra_data = backend.extra_data(user, uid, response, details, *args, **kwargs)
extra_data[SocialAuthConstants.PROVIDER] = backend.name
extra_data[SocialAuthConstants.UID] = uid

if backend.name == SocialAuthConstants.GOOGLE_OAUTH:
extra_data = GoogleAuthHelper.enrich_connector_metadata(extra_data)

cache.set(
cache_key,
extra_data,
int(settings.SOCIAL_AUTH_EXTRA_DATA_EXPIRATION_TIME_IN_SECOND),
)
return {**kwargs}


class ConnectorAuthHelper:
@staticmethod
def get_oauth_creds_from_cache(
cache_key: str, delete_key: bool = True
) -> Optional[dict[str, str]]:
"""Retrieves oauth credentials from the cache.
Args:
cache_key (str): Key to obtain credentials from
Returns:
Optional[dict[str,str]]: Returns credentials. None if it doesn't exist
"""
oauth_creds: dict[str, str] = cache.get(cache_key)
if delete_key:
cache.delete(cache_key)
return oauth_creds

@staticmethod
def get_or_create_connector_auth(
oauth_credentials: dict[str, str], user: User = None # type: ignore
) -> ConnectorAuth:
"""Gets or creates a ConnectorAuth object.
Args:
user (User): Used while creation, can be removed if not required
oauth_credentials (dict[str,str]): Needs to have provider and uid
Returns:
ConnectorAuth: Object for the respective provider/uid
"""
ConnectorAuth.check_credential_format(oauth_credentials)
provider = oauth_credentials[SocialAuthConstants.PROVIDER]
uid = oauth_credentials[SocialAuthConstants.UID]
connector_oauth: ConnectorAuth = ConnectorAuth.get_social_auth(
provider=provider, uid=uid
)
if not connector_oauth:
connector_oauth = ConnectorAuth.create_social_auth(
user, uid=uid, provider=provider
)

# TODO: Remove User's related manager access to ConnectorAuth
connector_oauth.set_extra_data(oauth_credentials) # type: ignore
return connector_oauth
33 changes: 33 additions & 0 deletions backend/connector_auth_v2/pipeline/google.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from datetime import datetime, timedelta

from connector_auth_v2.constants import SocialAuthConstants as AuthConstants
from connector_auth_v2.exceptions import EnrichConnectorMetadataException
from connector_processor.constants import ConnectorKeys

from unstract.connectors.filesystems.google_drive.constants import GDriveConstants


class GoogleAuthHelper:
@staticmethod
def enrich_connector_metadata(kwargs: dict[str, str]) -> dict[str, str]:
token_expiry: datetime = datetime.now()
auth_time = kwargs.get(AuthConstants.AUTH_TIME)
expires = kwargs.get(AuthConstants.EXPIRES)
if auth_time and expires:
reference = datetime.utcfromtimestamp(float(auth_time))
token_expiry = reference + timedelta(seconds=float(expires))
else:
raise EnrichConnectorMetadataException
# Used by GDrive FS, apart from ACCESS_TOKEN and REFRESH_TOKEN
kwargs[GDriveConstants.TOKEN_EXPIRY] = token_expiry.strftime(
AuthConstants.GOOGLE_TOKEN_EXPIRY_FORMAT
)

# Used by Unstract
kwargs[ConnectorKeys.PATH] = (
GDriveConstants.ROOT_PREFIX
) # Acts as a prefix for all paths
kwargs[AuthConstants.REFRESH_AFTER] = token_expiry.strftime(
AuthConstants.REFRESH_AFTER_FORMAT
)
return kwargs
21 changes: 21 additions & 0 deletions backend/connector_auth_v2/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from django.urls import include, path, re_path
from rest_framework.urlpatterns import format_suffix_patterns

from .views import ConnectorAuthViewSet

connector_auth_cache = ConnectorAuthViewSet.as_view(
{
"get": "cache_key",
}
)

urlpatterns = format_suffix_patterns(
[
path("oauth/", include("social_django.urls", namespace="social")),
re_path(
"^oauth/cache-key/(?P<backend>.+)$",
connector_auth_cache,
name="connector-cache",
),
]
)
Loading

0 comments on commit 3ac0893

Please sign in to comment.