Skip to content

Commit

Permalink
v2 changes of ProfileManager,core and studio (#471)
Browse files Browse the repository at this point in the history
* v2 changes of ProfileManager,core and studio

* tool_id in prompt model changed to cascade as discussed in PR

---------

Co-authored-by: Hari John Kuriakose <[email protected]>
  • Loading branch information
muhammad-ali-e and hari-kuriakose authored Jul 18, 2024
1 parent 5035622 commit 9b075c0
Show file tree
Hide file tree
Showing 34 changed files with 2,990 additions and 0 deletions.
Empty file.
5 changes: 5 additions & 0 deletions backend/prompt_studio/prompt_profile_manager_v2/admin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from django.contrib import admin

from .models import ProfileManager

admin.site.register(ProfileManager)
5 changes: 5 additions & 0 deletions backend/prompt_studio/prompt_profile_manager_v2/apps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from django.apps import AppConfig


class ProfileManager(AppConfig):
name = "prompt_studio.prompt_profile_manager_v2"
18 changes: 18 additions & 0 deletions backend/prompt_studio/prompt_profile_manager_v2/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
class ProfileManagerKeys:
CREATED_BY = "created_by"
TOOL_ID = "tool_id"
PROMPTS = "prompts"
ADAPTER_NAME = "adapter_name"
LLM = "llm"
VECTOR_STORE = "vector_store"
EMBEDDING_MODEL = "embedding_model"
X2TEXT = "x2text"
PROMPT_STUDIO_TOOL = "prompt_studio_tool"
MAX_PROFILE_COUNT = 4


class ProfileManagerErrors:
SERIALIZATION_FAILED = "Data Serialization Failed."
PROFILE_NAME_EXISTS = "A profile with this name already exists."
DUPLICATE_API = "It appears that a duplicate call may have been made."
PLATFORM_ERROR = "Seems an error occured in Platform Service."
6 changes: 6 additions & 0 deletions backend/prompt_studio/prompt_profile_manager_v2/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from rest_framework.exceptions import APIException


class PlatformServiceError(APIException):
status_code = 400
default_detail = "Seems an error occured in Platform Service."
115 changes: 115 additions & 0 deletions backend/prompt_studio/prompt_profile_manager_v2/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import uuid

from account_v2.models import User
from adapter_processor_v2.models import AdapterInstance
from django.db import models
from prompt_studio.prompt_studio_core_v2.exceptions import DefaultProfileError
from prompt_studio.prompt_studio_core_v2.models import CustomTool
from utils.models.base_model import BaseModel


class ProfileManager(BaseModel):
"""Model to store the LLM Triad management details for Prompt."""

class RetrievalStrategy(models.TextChoices):
SIMPLE = "simple", "Simple retrieval"
SUBQUESTION = "subquestion", "Subquestion retrieval"

profile_id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
profile_name = models.TextField(blank=False)
vector_store = models.ForeignKey(
AdapterInstance,
db_comment="Field to store the chosen vector store.",
blank=False,
null=False,
on_delete=models.PROTECT,
related_name="profiles_vector_store",
)
embedding_model = models.ForeignKey(
AdapterInstance,
blank=False,
null=False,
on_delete=models.PROTECT,
related_name="profiles_embedding_model",
)
llm = models.ForeignKey(
AdapterInstance,
db_comment="Field to store the LLM chosen by the user",
blank=False,
null=False,
on_delete=models.PROTECT,
related_name="profiles_llm",
)
x2text = models.ForeignKey(
AdapterInstance,
db_comment="Field to store the X2Text Adapter chosen by the user",
blank=False,
null=False,
on_delete=models.PROTECT,
related_name="profiles_x2text",
)
chunk_size = models.IntegerField(null=True, blank=True)
chunk_overlap = models.IntegerField(null=True, blank=True)
reindex = models.BooleanField(default=False)
retrieval_strategy = models.TextField(
choices=RetrievalStrategy.choices,
blank=True,
db_comment="Field to store the retrieval strategy for prompts",
)
similarity_top_k = models.IntegerField(
blank=True,
null=True,
db_comment="Field to store number of top embeddings to take into context", # noqa: E501
)
section = models.TextField(
blank=True, null=True, db_comment="Field to store limit to section"
)
created_by = models.ForeignKey(
User,
on_delete=models.SET_NULL,
related_name="profile_managers_created",
null=True,
blank=True,
editable=False,
)
modified_by = models.ForeignKey(
User,
on_delete=models.SET_NULL,
related_name="profile_managers_modified",
null=True,
blank=True,
editable=False,
)

prompt_studio_tool = models.ForeignKey(
CustomTool, on_delete=models.CASCADE, null=True, related_name="profile_managers"
)
is_default = models.BooleanField(
default=False,
db_comment="Default LLM Profile used in prompt",
)

is_summarize_llm = models.BooleanField(
default=False,
db_comment="Default LLM Profile used for summarizing",
)

class Meta:
verbose_name = "Profile Manager"
verbose_name_plural = "Profile Managers"
db_table = "profile_manager_v2"
constraints = [
models.UniqueConstraint(
fields=["prompt_studio_tool", "profile_name"],
name="unique_prompt_studio_tool_profile_name_index",
),
]

@staticmethod
def get_default_llm_profile(tool: CustomTool) -> "ProfileManager":
try:
return ProfileManager.objects.get( # type: ignore
prompt_studio_tool=tool, is_default=True
)
except ProfileManager.DoesNotExist:
raise DefaultProfileError
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from prompt_studio.prompt_profile_manager_v2.models import ProfileManager


class ProfileManagerHelper:

@classmethod
def get_profile_manager(cls, profile_manager_id: str) -> ProfileManager:
try:
return ProfileManager.objects.get(profile_id=profile_manager_id)
except ProfileManager.DoesNotExist:
raise ValueError("ProfileManager does not exist.")
53 changes: 53 additions & 0 deletions backend/prompt_studio/prompt_profile_manager_v2/serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import logging

from adapter_processor_v2.adapter_processor import AdapterProcessor
from prompt_studio.prompt_profile_manager_v2.constants import ProfileManagerKeys
from prompt_studio.prompt_studio_core_v2.exceptions import MaxProfilesReachedError

from backend.serializers import AuditSerializer

from .models import ProfileManager

logger = logging.getLogger(__name__)


class ProfileManagerSerializer(AuditSerializer):
class Meta:
model = ProfileManager
fields = "__all__"

def to_representation(self, instance): # type: ignore
rep: dict[str, str] = super().to_representation(instance)
llm = rep[ProfileManagerKeys.LLM]
embedding = rep[ProfileManagerKeys.EMBEDDING_MODEL]
vector_db = rep[ProfileManagerKeys.VECTOR_STORE]
x2text = rep[ProfileManagerKeys.X2TEXT]
if llm:
rep[ProfileManagerKeys.LLM] = AdapterProcessor.get_adapter_instance_by_id(
llm
)
if embedding:
rep[ProfileManagerKeys.EMBEDDING_MODEL] = (
AdapterProcessor.get_adapter_instance_by_id(embedding)
)
if vector_db:
rep[ProfileManagerKeys.VECTOR_STORE] = (
AdapterProcessor.get_adapter_instance_by_id(vector_db)
)
if x2text:
rep[ProfileManagerKeys.X2TEXT] = (
AdapterProcessor.get_adapter_instance_by_id(x2text)
)
return rep

def validate(self, data):
prompt_studio_tool = data.get(ProfileManagerKeys.PROMPT_STUDIO_TOOL)

profile_count = ProfileManager.objects.filter(
prompt_studio_tool=prompt_studio_tool
).count()

if profile_count >= ProfileManagerKeys.MAX_PROFILE_COUNT:
raise MaxProfilesReachedError()

return data
24 changes: 24 additions & 0 deletions backend/prompt_studio/prompt_profile_manager_v2/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from django.urls import path
from rest_framework.urlpatterns import format_suffix_patterns

from .views import ProfileManagerView

profile_manager_detail = ProfileManagerView.as_view(
{
"get": "retrieve",
"put": "update",
"patch": "partial_update",
"delete": "destroy",
}
)


urlpatterns = format_suffix_patterns(
[
path(
"profile-manager/<uuid:pk>/",
profile_manager_detail,
name="profile-manager-detail",
),
]
)
50 changes: 50 additions & 0 deletions backend/prompt_studio/prompt_profile_manager_v2/views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Any, Optional

from account_v2.custom_exceptions import DuplicateData
from django.db import IntegrityError
from django.db.models import QuerySet
from django.http import HttpRequest
from permissions.permission import IsOwner
from prompt_studio.prompt_profile_manager_v2.constants import (
ProfileManagerErrors,
ProfileManagerKeys,
)
from prompt_studio.prompt_profile_manager_v2.serializers import ProfileManagerSerializer
from rest_framework import status, viewsets
from rest_framework.response import Response
from rest_framework.versioning import URLPathVersioning
from utils.filtering import FilterHelper

from .models import ProfileManager


class ProfileManagerView(viewsets.ModelViewSet):
"""Viewset to handle all Custom tool related operations."""

versioning_class = URLPathVersioning
permission_classes = [IsOwner]
serializer_class = ProfileManagerSerializer

def get_queryset(self) -> Optional[QuerySet]:
filter_args = FilterHelper.build_filter_args(
self.request,
ProfileManagerKeys.CREATED_BY,
)
if filter_args:
queryset = ProfileManager.objects.filter(**filter_args)
else:
queryset = ProfileManager.objects.all()
return queryset

def create(
self, request: HttpRequest, *args: tuple[Any], **kwargs: dict[str, Any]
) -> Response:
serializer: ProfileManagerSerializer = self.get_serializer(data=request.data)
# Overriding default exception behaviour
# TO DO : Handle model related exceptions.
serializer.is_valid(raise_exception=True)
try:
self.perform_create(serializer)
except IntegrityError:
raise DuplicateData(ProfileManagerErrors.PROFILE_NAME_EXISTS)
return Response(serializer.data, status=status.HTTP_201_CREATED)
Empty file.
5 changes: 5 additions & 0 deletions backend/prompt_studio/prompt_studio_core_v2/admin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from django.contrib import admin

from .models import CustomTool

admin.site.register(CustomTool)
5 changes: 5 additions & 0 deletions backend/prompt_studio/prompt_studio_core_v2/apps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from django.apps import AppConfig


class CustomTool(AppConfig):
name = "prompt_studio.prompt_studio_core_v2"
Loading

0 comments on commit 9b075c0

Please sign in to comment.