Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Prompt versioning #485

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
7 changes: 7 additions & 0 deletions backend/backend/pagination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from rest_framework.pagination import PageNumberPagination


class DefaultPagination(PageNumberPagination):
page_size = 10
page_size_query_param = "page_size"
max_page_size = 100
2 changes: 2 additions & 0 deletions backend/backend/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ def get_required_setting(
"prompt_studio.prompt_studio_output_manager",
"prompt_studio.prompt_studio_document_manager",
"prompt_studio.prompt_studio_index_manager",
"prompt_studio.prompt_version_manager",
"prompt_studio.tag_manager",
"usage",
)

Expand Down
8 changes: 8 additions & 0 deletions backend/backend/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@
UrlPathConstants.PROMPT_STUDIO,
include("prompt_studio.prompt_studio_index_manager.urls"),
),
path(
UrlPathConstants.PROMPT_STUDIO,
include("prompt_studio.prompt_version_manager.urls"),
),
path(
UrlPathConstants.PROMPT_STUDIO,
include("prompt_studio.tag_manager.urls"),
),
]


Expand Down
13 changes: 13 additions & 0 deletions backend/prompt/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,16 @@ def create(self, request: Any) -> Response:
return Response(
serializer.data, status=status.HTTP_201_CREATED, headers=headers
)

def destroy(self, request, *args, **kwargs):
instance = self.get_object()
if instance.tool.tag:
# Perform soft delete if tool.tag_id is present
if not instance.checked_in:
return Response(status=status.HTTP_404_NOT_FOUND)
instance.checked_in = False
instance.save()
return Response(status=status.HTTP_204_NO_CONTENT)
else:
instance.delete()
return Response(status=status.HTTP_204_NO_CONTENT)
8 changes: 8 additions & 0 deletions backend/prompt_studio/prompt_profile_manager/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ class Meta:
model = ProfileManager
fields = "__all__"

def __init__(self, *args, **kwargs):
fields = kwargs.pop("fields", None)
super().__init__(*args, **kwargs)
if fields:
self.Meta.fields = fields
else:
self.Meta.fields = "__all__"

def to_representation(self, instance): # type: ignore
rep: dict[str, str] = super().to_representation(instance)
llm = rep[ProfileManagerKeys.LLM]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Generated by Django 4.2.1 on 2024-07-14 08:57

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("prompt_studio", "0006_alter_toolstudioprompt_prompt_key_and_more"),
]

operations = [
migrations.RemoveField(
model_name="toolstudioprompt",
name="assert_prompt",
),
migrations.RemoveField(
model_name="toolstudioprompt",
name="assertion_failure_prompt",
),
migrations.RemoveField(
model_name="toolstudioprompt",
name="is_assert",
),
migrations.AddField(
model_name="toolstudioprompt",
name="checked_in",
field=models.BooleanField(
db_comment="Currently checked-in prompt", default=True
),
),
migrations.AddField(
model_name="toolstudioprompt",
name="loaded_version",
field=models.CharField(
db_comment="Current loaded version of prompt",
default="v1",
max_length=10,
),
),
]
20 changes: 6 additions & 14 deletions backend/prompt_studio/prompt_studio/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,6 @@ class Mode(models.TextChoices):
blank=True,
)
output = models.TextField(blank=True)
# TODO: Remove below 3 fields related to assertion
assert_prompt = models.TextField(
blank=True,
null=True,
db_comment="Field to store the asserted prompt",
unique=False,
)
assertion_failure_prompt = models.TextField(
blank=True,
null=True,
db_comment="Field to store the prompt key",
unique=False,
)
is_assert = models.BooleanField(default=False)
active = models.BooleanField(default=True, null=False, blank=False)
output_metadata = models.JSONField(
db_column="output_metadata",
Expand All @@ -102,6 +88,12 @@ class Mode(models.TextChoices):
blank=True,
editable=False,
)
loaded_version = models.CharField(
max_length=10, default="v1", db_comment="Current loaded version of prompt"
)
checked_in = models.BooleanField(
default=True, db_comment="Currently checked-in prompt"
)
# Eval settings for the prompt
# NOTE:
# - Field name format is eval_<metric_type>_<metric_name>
Expand Down
25 changes: 25 additions & 0 deletions backend/prompt_studio/prompt_studio/serializers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Any

from prompt_studio.prompt_version_manager.helper import PromptVersionHelper
from rest_framework import serializers

from backend.serializers import AuditSerializer
Expand All @@ -6,9 +9,31 @@


class ToolStudioPromptSerializer(AuditSerializer):

def update(self, instance: Any, validated_data: dict[str, Any]) -> Any:
request = self.context.get("request")
if request and instance.prompt_type == "PROMPT":
# Create a new instance from the existing instance
prompt_instance = ToolStudioPrompt(
tool_id=instance.tool_id,
prompt_id=instance.prompt_id,
prompt_key=instance.prompt_key,
prompt=instance.prompt,
enforce_type=instance.enforce_type,
profile_manager=instance.profile_manager,
)
# Iterate over validated_data and set those keys in prompt_instance
for key, value in validated_data.items():
setattr(prompt_instance, key, value)
validated_data["loaded_version"] = PromptVersionHelper.get_prompt_version(
prompt_instance
)
return super().update(instance, validated_data)

class Meta:
model = ToolStudioPrompt
fields = "__all__"
read_only_fields = ["loaded_version"]


class ToolStudioIndexSerializer(serializers.Serializer):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Generated by Django 4.2.1 on 2024-07-14 08:57

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("prompt_studio_core", "0013_customtool_enable_highlight"),
]

operations = [
migrations.AddField(
model_name="customtool",
name="tag_id",
field=models.TextField(
blank=True,
db_comment="Currently checked-in tag id",
default=None,
null=True,
),
),
]
6 changes: 6 additions & 0 deletions backend/prompt_studio/prompt_studio_core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ class CustomTool(BaseModel):
enable_highlight = models.BooleanField(
db_comment="Flag to enable or disable document highlighting", default=False
)
tag_id = models.TextField(
null=True,
blank=True,
default=None,
db_comment="Currently checked-in tag id",
)

# Introduced field to establish M2M relation between users and custom_tool.
# This will introduce intermediary table which relates both the models.
Expand Down
35 changes: 25 additions & 10 deletions backend/prompt_studio/prompt_studio_core/prompt_studio_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from prompt_studio.prompt_studio_output_manager.output_manager_helper import (
OutputManagerHelper,
)
from prompt_studio.prompt_version_manager.helper import PromptVersionHelper
from unstract.sdk.constants import LogLevel
from unstract.sdk.exceptions import IndexingError, SdkError
from unstract.sdk.index import Index
Expand Down Expand Up @@ -255,8 +256,8 @@ def get_select_fields() -> dict[str, Any]:
return response

@staticmethod
def _fetch_prompt_from_id(id: str) -> ToolStudioPrompt:
"""Internal function used to fetch prompt from ID.
def fetch_prompt_from_id(id: str) -> ToolStudioPrompt:
"""Method used to fetch prompt from ID.

Args:
id (_type_): UUID of the prompt
Expand All @@ -268,17 +269,26 @@ def _fetch_prompt_from_id(id: str) -> ToolStudioPrompt:
return prompt_instance

@staticmethod
def fetch_prompt_from_tool(tool_id: str) -> list[ToolStudioPrompt]:
"""Internal function used to fetch mapped prompts from ToolID.
def fetch_prompt_from_tool(
tool_id: str, include_notes: bool = True, checked_in_only: bool = True
) -> list[ToolStudioPrompt]:
"""Function used to fetch mapped prompts from ToolID.

Args:
tool_id (_type_): UUID of the tool
tool_id (str): UUID of the tool
include_notes (bool): Whether to include notes
checked_in_only (bool): Whether to only include checked_in prompts

Returns:
List[ToolStudioPrompt]: List of instance of the model
List[ToolStudioPrompt]: List of instances of the model
"""
filter_args = {"tool_id": tool_id}
if checked_in_only:
filter_args["checked_in"] = True
if not include_notes:
filter_args["prompt_type"] = "PROMPT"
prompt_instances: list[ToolStudioPrompt] = ToolStudioPrompt.objects.filter(
tool_id=tool_id
**filter_args
).order_by(TSPKeys.SEQUENCE_NUMBER)
return prompt_instances

Expand Down Expand Up @@ -425,8 +435,10 @@ def _execute_single_prompt(
run_id,
profile_manager_id,
):
prompt_instance = PromptStudioHelper._fetch_prompt_from_id(id)
prompt_instance = PromptStudioHelper.fetch_prompt_from_id(id)
prompt_name = prompt_instance.prompt_key
# Check and create a new prompt version
PromptVersionHelper.create_prompt_version([prompt_instance])
PromptStudioHelper._publish_log(
{
"tool_id": tool_id,
Expand Down Expand Up @@ -497,8 +509,11 @@ def _execute_single_prompt(
def _execute_prompts_in_single_pass(
doc_path, tool_id, org_id, user_id, document_id, run_id
):
prompts = PromptStudioHelper.fetch_prompt_from_tool(tool_id)
prompts = [prompt for prompt in prompts if prompt.prompt_type != TSPKeys.NOTES]
prompts = PromptStudioHelper.fetch_prompt_from_tool(
tool_id=tool_id, include_notes=False
)
# Check and create a new prompt version for all prompts
PromptVersionHelper.create_prompt_version(prompts)
if not prompts:
logger.error(f"[{tool_id or 'NA'}] No prompts found for id: {id}")
raise NoPromptsFound()
Expand Down
26 changes: 24 additions & 2 deletions backend/prompt_studio/prompt_studio_core/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from prompt_studio.prompt_studio.serializers import ToolStudioPromptSerializer
from prompt_studio.prompt_studio_core.constants import ToolStudioKeys as TSKeys
from prompt_studio.prompt_studio_core.exceptions import DefaultProfileError
from prompt_studio.tag_manager.models import TagManager
from rest_framework import serializers
from utils.FileValidator import FileValidator

Expand All @@ -28,6 +29,7 @@ class CustomToolSerializer(AuditSerializer):
class Meta:
model = CustomTool
fields = "__all__"
read_only_fields = ["tag_id"]

def to_representation(self, instance): # type: ignore
data = super().to_representation(instance)
Expand Down Expand Up @@ -65,8 +67,28 @@ def to_representation(self, instance): # type: ignore
logger.error(f"Error occured while appending prompts {e}")
return data

data["created_by_email"] = instance.created_by.email

tag_id = instance.tag_id
current_tag = None
available_tags = []
if tag_id:
try:
# Fetch all TagManager objects associated with the instance
tag_managers = TagManager.objects.filter(tool=instance)
tag_manager = tag_managers.get(id=tag_id)
current_tag = tag_manager.tag
available_tags = [
{"id": tag_manager.id, "tag": tag_manager.tag}
for tag_manager in tag_managers
]
except tag_manager.DoesNotExist:
pass
data.update(
{
"current_tag": current_tag,
"available_tags": available_tags,
"created_by_email": instance.created_by.email,
}
)
return data


Expand Down
7 changes: 4 additions & 3 deletions backend/prompt_studio/prompt_studio_core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def create(self, request: HttpRequest) -> Response:
)
return Response(serializer.data, status=status.HTTP_201_CREATED)

def perform_create(self, serializer):
return serializer.save()

def perform_destroy(self, instance: CustomTool) -> None:
organization_id = UserSessionUtils.get_organization_id(self.request)
instance.delete(organization_id)
Expand Down Expand Up @@ -169,7 +172,7 @@ def list_profiles(self, request: HttpRequest, pk: Any = None) -> Response:
)

serialized_instances = ProfileManagerSerializer(
profile_manager_instances, many=True
instance=profile_manager_instances, many=True
).data

return Response(serialized_instances)
Expand Down Expand Up @@ -272,7 +275,6 @@ def fetch_response(self, request: HttpRequest, pk: Any = None) -> Response:
if not run_id:
# Generate a run_id
run_id = CommonUtils.generate_uuid()

response: dict[str, Any] = PromptStudioHelper.prompt_responder(
id=id,
tool_id=tool_id,
Expand Down Expand Up @@ -330,7 +332,6 @@ def create_prompt(self, request: HttpRequest, pk: Any = None) -> Response:
serializer = ToolStudioPromptSerializer(data=request.data, context=context)
serializer.is_valid(raise_exception=True)
try:
# serializer.save()
self.perform_create(serializer)
except IntegrityError:
raise DuplicateData(
Expand Down
8 changes: 8 additions & 0 deletions backend/prompt_studio/prompt_studio_registry/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,11 @@ class PromptStudioRegistry(BaseModel):
shared_users = models.ManyToManyField(User, related_name="shared_exported_tools")

objects = PromptStudioRegistryModelManager()

# class Meta:
# constraints = [
# models.UniqueConstraint(
# fields=["prompt_id", "version"],
# name="unique_tool_prompt_version",
# ),
# ]
Loading