Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v2 changes of connector and connector_auth #461

Merged
merged 6 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
5 changes: 5 additions & 0 deletions backend/connector_auth_v2/admin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from django.contrib import admin

from .models import ConnectorAuth

admin.site.register(ConnectorAuth)
6 changes: 6 additions & 0 deletions backend/connector_auth_v2/apps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from django.apps import AppConfig


class ConnectorAuthConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField"
name = "connector_auth_v2"
18 changes: 18 additions & 0 deletions backend/connector_auth_v2/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
class ConnectorAuthKey:
OAUTH_KEY = "oauth-key"


class SocialAuthConstants:
UID = "uid"
PROVIDER = "provider"
ACCESS_TOKEN = "access_token"
REFRESH_TOKEN = "refresh_token"
TOKEN_TYPE = "token_type"
AUTH_TIME = "auth_time"
EXPIRES = "expires"

REFRESH_AFTER_FORMAT = "%d/%m/%Y %H:%M:%S"
REFRESH_AFTER = "refresh_after" # Timestamp to refresh tokens after

GOOGLE_OAUTH = "google-oauth2"
GOOGLE_TOKEN_EXPIRY_FORMAT = "%d/%m/%Y %H:%M:%S"
31 changes: 31 additions & 0 deletions backend/connector_auth_v2/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Optional

from rest_framework.exceptions import APIException


class CacheMissException(APIException):
status_code = 404
default_detail = "Key doesn't exist."


class EnrichConnectorMetadataException(APIException):
status_code = 500
default_detail = "Connector metadata could not be enriched"


class MissingParamException(APIException):
status_code = 400
default_detail = "Bad request, missing parameter."

def __init__(
self,
code: Optional[str] = None,
param: Optional[str] = None,
) -> None:
detail = f"Bad request, missing parameter: {param}"
super().__init__(detail, code)


class KeyNotConfigured(APIException):
status_code = 500
default_detail = "Key is not configured correctly"
141 changes: 141 additions & 0 deletions backend/connector_auth_v2/models.py
hari-kuriakose marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import logging
import uuid
from typing import Any

from account_v2.models import User
from connector_auth_v2.constants import SocialAuthConstants
from connector_auth_v2.pipeline.google import GoogleAuthHelper
from django.db import models
from django.db.models.query import QuerySet
from rest_framework.request import Request
from social_django.fields import JSONField
from social_django.models import AbstractUserSocialAuth, DjangoStorage
from social_django.strategy import DjangoStrategy

logger = logging.getLogger(__name__)


class ConnectorAuthManager(models.Manager):
def get_queryset(self) -> QuerySet:
queryset = super().get_queryset()
# TODO PAN-83: Decrypt here
# for obj in queryset:
# logger.info(f"Decrypting extra_data: {obj.extra_data}")

return queryset


class ConnectorAuth(AbstractUserSocialAuth):
"""Social Auth association model, stores tokens.
The relation with `account.User` is only for the library to work
and should be NOT be used to access the secrets.
Use the following static methods instead
```
@classmethod
def get_social_auth(cls, provider, id):

@classmethod
def create_social_auth(cls, user, uid, provider):
```
"""

id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
user = models.ForeignKey(
User,
related_name="connector_auths",
on_delete=models.SET_NULL,
null=True,
)

def __str__(self) -> str:
return f"ConnectorAuth(provider: {self.provider}, uid: {self.uid})"

def save(self, *args: Any, **kwargs: Any) -> Any:
# TODO PAN-83: Encrypt here
# logger.info(f"Encrypting extra_data: {self.extra_data}")
return super().save(*args, **kwargs)

def set_extra_data(self, extra_data=None): # type: ignore
ConnectorAuth.check_credential_format(extra_data)
if extra_data[SocialAuthConstants.PROVIDER] == SocialAuthConstants.GOOGLE_OAUTH:
extra_data = GoogleAuthHelper.enrich_connector_metadata(extra_data)
return super().set_extra_data(extra_data)

def refresh_token(self, strategy, *args, **kwargs): # type: ignore
"""Override of Python Social Auth (PSA)'s refresh_token functionality
to store uid, provider."""
token = self.extra_data.get("refresh_token") or self.extra_data.get(
"access_token"
)
backend = self.get_backend_instance(strategy)
if token and backend and hasattr(backend, "refresh_token"):
response = backend.refresh_token(token, *args, **kwargs)
extra_data = backend.extra_data(self, self.uid, response, self.extra_data)
extra_data[SocialAuthConstants.PROVIDER] = backend.name
extra_data[SocialAuthConstants.UID] = self.uid
if self.set_extra_data(extra_data): # type: ignore
self.save()

def get_and_refresh_tokens(self, request: Request = None) -> tuple[JSONField, bool]:
"""Uses Social Auth's ability to refresh tokens if necessary.

Returns:
Tuple[JSONField, bool]: JSONField of connector metadata
and flag indicating if tokens were refreshed
"""
# To avoid circular dependency error on import
from social_django.utils import load_strategy

refreshed_token = False
strategy: DjangoStrategy = load_strategy(request=request)
existing_access_token = self.access_token
new_access_token = self.get_access_token(strategy)
if new_access_token != existing_access_token:
refreshed_token = True
related_connector_instances = self.connectorinstance_set.all()
for connector_instance in related_connector_instances:
connector_instance.connector_metadata = self.extra_data
connector_instance.save()
logger.info(
f"Refreshed access token for connector {connector_instance.id}, "
f"provider: {self.provider}, uid: {self.uid}"
)

return self.extra_data, refreshed_token

@staticmethod
def check_credential_format(
oauth_credentials: dict[str, str], raise_exception: bool = True
) -> bool:
if (
SocialAuthConstants.PROVIDER in oauth_credentials
and SocialAuthConstants.UID in oauth_credentials
):
return True
else:
if raise_exception:
raise ValueError(
"Auth credential should have provider, uid and connector guid"
)
return False

objects = ConnectorAuthManager()

class Meta:
app_label = "connector_auth_v2"
verbose_name = "Connector Auth"
verbose_name_plural = "Connector Auths"
db_table = "connector_auth_v2"
constraints = [
models.UniqueConstraint(
fields=[
"provider",
"uid",
],
name="unique_provider_uid_index",
),
]


class ConnectorDjangoStorage(DjangoStorage):
user = ConnectorAuth
111 changes: 111 additions & 0 deletions backend/connector_auth_v2/pipeline/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import logging
from typing import Any, Optional

from account_v2.models import User
from connector_auth_v2.constants import ConnectorAuthKey, SocialAuthConstants
from connector_auth_v2.models import ConnectorAuth
from connector_auth_v2.pipeline.google import GoogleAuthHelper
from django.conf import settings
from django.core.cache import cache
from rest_framework.exceptions import PermissionDenied
from social_core.backends.oauth import BaseOAuth2

logger = logging.getLogger(__name__)


def check_user_exists(backend: BaseOAuth2, user: User, **kwargs: Any) -> dict[str, str]:
"""Checks if user is authenticated (will be handled in auth middleware,
present as a fail safe)

Args:
user (account.User): User model

Raises:
PermissionDenied: Unauthorized user

Returns:
dict: Carrying response details for auth pipeline
"""
if not user:
raise PermissionDenied(backend)
return {**kwargs}


def cache_oauth_creds(
backend: BaseOAuth2,
details: dict[str, str],
response: dict[str, str],
uid: str,
user: User,
*args: Any,
**kwargs: Any,
) -> dict[str, str]:
"""Used to cache the extra data JSON in redis against a key.

This contains the access and refresh token along with details
regarding expiry, uid (unique ID given by provider) and provider.
"""
cache_key = kwargs.get("cache_key") or backend.strategy.session_get(
settings.SOCIAL_AUTH_FIELDS_STORED_IN_SESSION[0],
ConnectorAuthKey.OAUTH_KEY,
)
extra_data = backend.extra_data(user, uid, response, details, *args, **kwargs)
extra_data[SocialAuthConstants.PROVIDER] = backend.name
extra_data[SocialAuthConstants.UID] = uid

if backend.name == SocialAuthConstants.GOOGLE_OAUTH:
extra_data = GoogleAuthHelper.enrich_connector_metadata(extra_data)

cache.set(
cache_key,
extra_data,
int(settings.SOCIAL_AUTH_EXTRA_DATA_EXPIRATION_TIME_IN_SECOND),
)
return {**kwargs}


class ConnectorAuthHelper:
@staticmethod
def get_oauth_creds_from_cache(
cache_key: str, delete_key: bool = True
) -> Optional[dict[str, str]]:
"""Retrieves oauth credentials from the cache.

Args:
cache_key (str): Key to obtain credentials from

Returns:
Optional[dict[str,str]]: Returns credentials. None if it doesn't exist
"""
oauth_creds: dict[str, str] = cache.get(cache_key)
if delete_key:
cache.delete(cache_key)
return oauth_creds

@staticmethod
def get_or_create_connector_auth(
oauth_credentials: dict[str, str], user: User = None # type: ignore
) -> ConnectorAuth:
"""Gets or creates a ConnectorAuth object.

Args:
user (User): Used while creation, can be removed if not required
oauth_credentials (dict[str,str]): Needs to have provider and uid

Returns:
ConnectorAuth: Object for the respective provider/uid
"""
ConnectorAuth.check_credential_format(oauth_credentials)
provider = oauth_credentials[SocialAuthConstants.PROVIDER]
uid = oauth_credentials[SocialAuthConstants.UID]
connector_oauth: ConnectorAuth = ConnectorAuth.get_social_auth(
provider=provider, uid=uid
)
if not connector_oauth:
connector_oauth = ConnectorAuth.create_social_auth(
user, uid=uid, provider=provider
)

# TODO: Remove User's related manager access to ConnectorAuth
connector_oauth.set_extra_data(oauth_credentials) # type: ignore
return connector_oauth
33 changes: 33 additions & 0 deletions backend/connector_auth_v2/pipeline/google.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from datetime import datetime, timedelta

from connector_auth_v2.constants import SocialAuthConstants as AuthConstants
from connector_auth_v2.exceptions import EnrichConnectorMetadataException
from connector_processor.constants import ConnectorKeys

from unstract.connectors.filesystems.google_drive.constants import GDriveConstants


class GoogleAuthHelper:
@staticmethod
def enrich_connector_metadata(kwargs: dict[str, str]) -> dict[str, str]:
token_expiry: datetime = datetime.now()
auth_time = kwargs.get(AuthConstants.AUTH_TIME)
expires = kwargs.get(AuthConstants.EXPIRES)
if auth_time and expires:
reference = datetime.utcfromtimestamp(float(auth_time))
token_expiry = reference + timedelta(seconds=float(expires))
else:
raise EnrichConnectorMetadataException
# Used by GDrive FS, apart from ACCESS_TOKEN and REFRESH_TOKEN
kwargs[GDriveConstants.TOKEN_EXPIRY] = token_expiry.strftime(
AuthConstants.GOOGLE_TOKEN_EXPIRY_FORMAT
)

# Used by Unstract
kwargs[ConnectorKeys.PATH] = (
GDriveConstants.ROOT_PREFIX
) # Acts as a prefix for all paths
kwargs[AuthConstants.REFRESH_AFTER] = token_expiry.strftime(
AuthConstants.REFRESH_AFTER_FORMAT
)
return kwargs
21 changes: 21 additions & 0 deletions backend/connector_auth_v2/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from django.urls import include, path, re_path
from rest_framework.urlpatterns import format_suffix_patterns

from .views import ConnectorAuthViewSet

connector_auth_cache = ConnectorAuthViewSet.as_view(
{
"get": "cache_key",
}
)

urlpatterns = format_suffix_patterns(
[
path("oauth/", include("social_django.urls", namespace="social")),
re_path(
"^oauth/cache-key/(?P<backend>.+)$",
connector_auth_cache,
name="connector-cache",
),
]
)
Loading
Loading