-
Notifications
You must be signed in to change notification settings - Fork 289
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
v2 changes of connector and connector_auth
- Loading branch information
1 parent
48826c4
commit 3ac0893
Showing
25 changed files
with
1,667 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from django.contrib import admin | ||
|
||
from .models import ConnectorAuth | ||
|
||
admin.site.register(ConnectorAuth) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from django.apps import AppConfig | ||
|
||
|
||
class ConnectorAuthConfig(AppConfig): | ||
default_auto_field = "django.db.models.BigAutoField" | ||
name = "connector_auth_v2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
class ConnectorAuthKey: | ||
OAUTH_KEY = "oauth-key" | ||
|
||
|
||
class SocialAuthConstants: | ||
UID = "uid" | ||
PROVIDER = "provider" | ||
ACCESS_TOKEN = "access_token" | ||
REFRESH_TOKEN = "refresh_token" | ||
TOKEN_TYPE = "token_type" | ||
AUTH_TIME = "auth_time" | ||
EXPIRES = "expires" | ||
|
||
REFRESH_AFTER_FORMAT = "%d/%m/%Y %H:%M:%S" | ||
REFRESH_AFTER = "refresh_after" # Timestamp to refresh tokens after | ||
|
||
GOOGLE_OAUTH = "google-oauth2" | ||
GOOGLE_TOKEN_EXPIRY_FORMAT = "%d/%m/%Y %H:%M:%S" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from typing import Optional | ||
|
||
from rest_framework.exceptions import APIException | ||
|
||
|
||
class CacheMissException(APIException): | ||
status_code = 404 | ||
default_detail = "Key doesn't exist." | ||
|
||
|
||
class EnrichConnectorMetadataException(APIException): | ||
status_code = 500 | ||
default_detail = "Connector metadata could not be enriched" | ||
|
||
|
||
class MissingParamException(APIException): | ||
status_code = 400 | ||
default_detail = "Bad request, missing parameter." | ||
|
||
def __init__( | ||
self, | ||
code: Optional[str] = None, | ||
param: Optional[str] = None, | ||
) -> None: | ||
detail = f"Bad request, missing parameter: {param}" | ||
super().__init__(detail, code) | ||
|
||
|
||
class KeyNotConfigured(APIException): | ||
status_code = 500 | ||
default_detail = "Key is not configured correctly" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import logging | ||
import uuid | ||
from typing import Any | ||
|
||
from account_v2.models import User | ||
from connector_auth_v2.constants import SocialAuthConstants | ||
from connector_auth_v2.pipeline.google import GoogleAuthHelper | ||
from django.db import models | ||
from django.db.models.query import QuerySet | ||
from rest_framework.request import Request | ||
from social_django.fields import JSONField | ||
from social_django.models import AbstractUserSocialAuth, DjangoStorage | ||
from social_django.strategy import DjangoStrategy | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ConnectorAuthManager(models.Manager): | ||
def get_queryset(self) -> QuerySet: | ||
queryset = super().get_queryset() | ||
# TODO PAN-83: Decrypt here | ||
# for obj in queryset: | ||
# logger.info(f"Decrypting extra_data: {obj.extra_data}") | ||
|
||
return queryset | ||
|
||
|
||
class ConnectorAuth(AbstractUserSocialAuth): | ||
"""Social Auth association model, stores tokens. | ||
The relation with `account.User` is only for the library to work | ||
and should be NOT be used to access the secrets. | ||
Use the following static methods instead | ||
``` | ||
@classmethod | ||
def get_social_auth(cls, provider, id): | ||
@classmethod | ||
def create_social_auth(cls, user, uid, provider): | ||
``` | ||
""" | ||
|
||
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) | ||
user = models.ForeignKey( | ||
User, | ||
related_name="connector_auths", | ||
on_delete=models.SET_NULL, | ||
null=True, | ||
) | ||
|
||
def __str__(self) -> str: | ||
return f"ConnectorAuth(provider: {self.provider}, uid: {self.uid})" | ||
|
||
def save(self, *args: Any, **kwargs: Any) -> Any: | ||
# TODO PAN-83: Encrypt here | ||
# logger.info(f"Encrypting extra_data: {self.extra_data}") | ||
return super().save(*args, **kwargs) | ||
|
||
def set_extra_data(self, extra_data=None): # type: ignore | ||
ConnectorAuth.check_credential_format(extra_data) | ||
if extra_data[SocialAuthConstants.PROVIDER] == SocialAuthConstants.GOOGLE_OAUTH: | ||
extra_data = GoogleAuthHelper.enrich_connector_metadata(extra_data) | ||
return super().set_extra_data(extra_data) | ||
|
||
def refresh_token(self, strategy, *args, **kwargs): # type: ignore | ||
"""Override of Python Social Auth (PSA)'s refresh_token functionality | ||
to store uid, provider.""" | ||
token = self.extra_data.get("refresh_token") or self.extra_data.get( | ||
"access_token" | ||
) | ||
backend = self.get_backend_instance(strategy) | ||
if token and backend and hasattr(backend, "refresh_token"): | ||
response = backend.refresh_token(token, *args, **kwargs) | ||
extra_data = backend.extra_data(self, self.uid, response, self.extra_data) | ||
extra_data[SocialAuthConstants.PROVIDER] = backend.name | ||
extra_data[SocialAuthConstants.UID] = self.uid | ||
if self.set_extra_data(extra_data): # type: ignore | ||
self.save() | ||
|
||
def get_and_refresh_tokens(self, request: Request = None) -> tuple[JSONField, bool]: | ||
"""Uses Social Auth's ability to refresh tokens if necessary. | ||
Returns: | ||
Tuple[JSONField, bool]: JSONField of connector metadata | ||
and flag indicating if tokens were refreshed | ||
""" | ||
# To avoid circular dependency error on import | ||
from social_django.utils import load_strategy | ||
|
||
refreshed_token = False | ||
strategy: DjangoStrategy = load_strategy(request=request) | ||
existing_access_token = self.access_token | ||
new_access_token = self.get_access_token(strategy) | ||
if new_access_token != existing_access_token: | ||
refreshed_token = True | ||
related_connector_instances = self.connectorinstance_set.all() | ||
for connector_instance in related_connector_instances: | ||
connector_instance.connector_metadata = self.extra_data | ||
connector_instance.save() | ||
logger.info( | ||
f"Refreshed access token for connector {connector_instance.id}, " | ||
f"provider: {self.provider}, uid: {self.uid}" | ||
) | ||
|
||
return self.extra_data, refreshed_token | ||
|
||
@staticmethod | ||
def check_credential_format( | ||
oauth_credentials: dict[str, str], raise_exception: bool = True | ||
) -> bool: | ||
if ( | ||
SocialAuthConstants.PROVIDER in oauth_credentials | ||
and SocialAuthConstants.UID in oauth_credentials | ||
): | ||
return True | ||
else: | ||
if raise_exception: | ||
raise ValueError( | ||
"Auth credential should have provider, uid and connector guid" | ||
) | ||
return False | ||
|
||
objects = ConnectorAuthManager() | ||
|
||
class Meta: | ||
app_label = "connector_auth_v2" | ||
verbose_name = "Connector Auth" | ||
verbose_name_plural = "Connector Auths" | ||
db_table = "connector_auth_v2" | ||
constraints = [ | ||
models.UniqueConstraint( | ||
fields=[ | ||
"provider", | ||
"uid", | ||
], | ||
name="unique_provider_uid_index", | ||
), | ||
] | ||
|
||
|
||
class ConnectorDjangoStorage(DjangoStorage): | ||
user = ConnectorAuth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import logging | ||
from typing import Any, Optional | ||
|
||
from account_v2.models import User | ||
from connector_auth_v2.constants import ConnectorAuthKey, SocialAuthConstants | ||
from connector_auth_v2.models import ConnectorAuth | ||
from connector_auth_v2.pipeline.google import GoogleAuthHelper | ||
from django.conf import settings | ||
from django.core.cache import cache | ||
from rest_framework.exceptions import PermissionDenied | ||
from social_core.backends.oauth import BaseOAuth2 | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def check_user_exists(backend: BaseOAuth2, user: User, **kwargs: Any) -> dict[str, str]: | ||
"""Checks if user is authenticated (will be handled in auth middleware, | ||
present as a fail safe) | ||
Args: | ||
user (account.User): User model | ||
Raises: | ||
PermissionDenied: Unauthorized user | ||
Returns: | ||
dict: Carrying response details for auth pipeline | ||
""" | ||
if not user: | ||
raise PermissionDenied(backend) | ||
return {**kwargs} | ||
|
||
|
||
def cache_oauth_creds( | ||
backend: BaseOAuth2, | ||
details: dict[str, str], | ||
response: dict[str, str], | ||
uid: str, | ||
user: User, | ||
*args: Any, | ||
**kwargs: Any, | ||
) -> dict[str, str]: | ||
"""Used to cache the extra data JSON in redis against a key. | ||
This contains the access and refresh token along with details | ||
regarding expiry, uid (unique ID given by provider) and provider. | ||
""" | ||
cache_key = kwargs.get("cache_key") or backend.strategy.session_get( | ||
settings.SOCIAL_AUTH_FIELDS_STORED_IN_SESSION[0], | ||
ConnectorAuthKey.OAUTH_KEY, | ||
) | ||
extra_data = backend.extra_data(user, uid, response, details, *args, **kwargs) | ||
extra_data[SocialAuthConstants.PROVIDER] = backend.name | ||
extra_data[SocialAuthConstants.UID] = uid | ||
|
||
if backend.name == SocialAuthConstants.GOOGLE_OAUTH: | ||
extra_data = GoogleAuthHelper.enrich_connector_metadata(extra_data) | ||
|
||
cache.set( | ||
cache_key, | ||
extra_data, | ||
int(settings.SOCIAL_AUTH_EXTRA_DATA_EXPIRATION_TIME_IN_SECOND), | ||
) | ||
return {**kwargs} | ||
|
||
|
||
class ConnectorAuthHelper: | ||
@staticmethod | ||
def get_oauth_creds_from_cache( | ||
cache_key: str, delete_key: bool = True | ||
) -> Optional[dict[str, str]]: | ||
"""Retrieves oauth credentials from the cache. | ||
Args: | ||
cache_key (str): Key to obtain credentials from | ||
Returns: | ||
Optional[dict[str,str]]: Returns credentials. None if it doesn't exist | ||
""" | ||
oauth_creds: dict[str, str] = cache.get(cache_key) | ||
if delete_key: | ||
cache.delete(cache_key) | ||
return oauth_creds | ||
|
||
@staticmethod | ||
def get_or_create_connector_auth( | ||
oauth_credentials: dict[str, str], user: User = None # type: ignore | ||
) -> ConnectorAuth: | ||
"""Gets or creates a ConnectorAuth object. | ||
Args: | ||
user (User): Used while creation, can be removed if not required | ||
oauth_credentials (dict[str,str]): Needs to have provider and uid | ||
Returns: | ||
ConnectorAuth: Object for the respective provider/uid | ||
""" | ||
ConnectorAuth.check_credential_format(oauth_credentials) | ||
provider = oauth_credentials[SocialAuthConstants.PROVIDER] | ||
uid = oauth_credentials[SocialAuthConstants.UID] | ||
connector_oauth: ConnectorAuth = ConnectorAuth.get_social_auth( | ||
provider=provider, uid=uid | ||
) | ||
if not connector_oauth: | ||
connector_oauth = ConnectorAuth.create_social_auth( | ||
user, uid=uid, provider=provider | ||
) | ||
|
||
# TODO: Remove User's related manager access to ConnectorAuth | ||
connector_oauth.set_extra_data(oauth_credentials) # type: ignore | ||
return connector_oauth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from datetime import datetime, timedelta | ||
|
||
from connector_auth_v2.constants import SocialAuthConstants as AuthConstants | ||
from connector_auth_v2.exceptions import EnrichConnectorMetadataException | ||
from connector_processor.constants import ConnectorKeys | ||
|
||
from unstract.connectors.filesystems.google_drive.constants import GDriveConstants | ||
|
||
|
||
class GoogleAuthHelper: | ||
@staticmethod | ||
def enrich_connector_metadata(kwargs: dict[str, str]) -> dict[str, str]: | ||
token_expiry: datetime = datetime.now() | ||
auth_time = kwargs.get(AuthConstants.AUTH_TIME) | ||
expires = kwargs.get(AuthConstants.EXPIRES) | ||
if auth_time and expires: | ||
reference = datetime.utcfromtimestamp(float(auth_time)) | ||
token_expiry = reference + timedelta(seconds=float(expires)) | ||
else: | ||
raise EnrichConnectorMetadataException | ||
# Used by GDrive FS, apart from ACCESS_TOKEN and REFRESH_TOKEN | ||
kwargs[GDriveConstants.TOKEN_EXPIRY] = token_expiry.strftime( | ||
AuthConstants.GOOGLE_TOKEN_EXPIRY_FORMAT | ||
) | ||
|
||
# Used by Unstract | ||
kwargs[ConnectorKeys.PATH] = ( | ||
GDriveConstants.ROOT_PREFIX | ||
) # Acts as a prefix for all paths | ||
kwargs[AuthConstants.REFRESH_AFTER] = token_expiry.strftime( | ||
AuthConstants.REFRESH_AFTER_FORMAT | ||
) | ||
return kwargs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from django.urls import include, path, re_path | ||
from rest_framework.urlpatterns import format_suffix_patterns | ||
|
||
from .views import ConnectorAuthViewSet | ||
|
||
connector_auth_cache = ConnectorAuthViewSet.as_view( | ||
{ | ||
"get": "cache_key", | ||
} | ||
) | ||
|
||
urlpatterns = format_suffix_patterns( | ||
[ | ||
path("oauth/", include("social_django.urls", namespace="social")), | ||
re_path( | ||
"^oauth/cache-key/(?P<backend>.+)$", | ||
connector_auth_cache, | ||
name="connector-cache", | ||
), | ||
] | ||
) |
Oops, something went wrong.