-
Notifications
You must be signed in to change notification settings - Fork 258
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
v2 changes of tool_instance and usage (#470)
* v2 changes of tool_instance and usage * Added the TODO comment for urls --------- Co-authored-by: Hari John Kuriakose <[email protected]>
- Loading branch information
1 parent
9b075c0
commit 9354480
Showing
20 changed files
with
1,386 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from django.contrib import admin | ||
|
||
from .models import ToolInstance | ||
|
||
admin.site.register(ToolInstance) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from django.apps import AppConfig | ||
|
||
|
||
class ToolInstanceConfig(AppConfig): | ||
name = "tool_instance_v2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
class ToolInstanceKey: | ||
"""Dict keys for ToolInstance model.""" | ||
|
||
PK = "id" | ||
TOOL_ID = "tool_id" | ||
VERSION = "version" | ||
METADATA = "metadata" | ||
STEP = "step" | ||
STATUS = "status" | ||
WORKFLOW = "workflow" | ||
INPUT = "input" | ||
OUTPUT = "output" | ||
TI_COUNT = "tool_instance_count" | ||
|
||
|
||
class JsonSchemaKey: | ||
"""Dict Keys for Tool's Json schema.""" | ||
|
||
PROPERTIES = "properties" | ||
THEN = "then" | ||
INPUT_FILE_CONNECTOR = "inputFileConnector" | ||
OUTPUT_FILE_CONNECTOR = "outputFileConnector" | ||
OUTPUT_FOLDER = "outputFolder" | ||
ROOT_FOLDER = "rootFolder" | ||
TENANT_ID = "tenant_id" | ||
INPUT_DB_CONNECTOR = "inputDBConnector" | ||
OUTPUT_DB_CONNECTOR = "outputDBConnector" | ||
ENUM = "enum" | ||
PROJECT_DEFAULT = "Project Default" | ||
|
||
|
||
class ToolInstanceErrors: | ||
TOOL_EXISTS = "Tool with this configuration already exists." | ||
DUPLICATE_API = "It appears that a duplicate call may have been made." | ||
|
||
|
||
class ToolKey: | ||
"""Dict keys for a Tool.""" | ||
|
||
NAME = "name" | ||
DESCRIPTION = "description" | ||
ICON = "icon" | ||
FUNCTION_NAME = "function_name" | ||
OUTPUT_TYPE = "output_type" | ||
INPUT_TYPE = "input_type" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from typing import Optional | ||
|
||
from rest_framework.exceptions import APIException | ||
|
||
|
||
class ToolInstanceBaseException(APIException): | ||
def __init__( | ||
self, | ||
detail: Optional[str] = None, | ||
code: Optional[int] = None, | ||
tool_name: Optional[str] = None, | ||
) -> None: | ||
detail = detail or self.default_detail | ||
if tool_name is not None: | ||
detail = f"{detail} Tool: {tool_name}" | ||
super().__init__(detail, code) | ||
|
||
|
||
class ToolFunctionIsMandatory(ToolInstanceBaseException): | ||
status_code = 400 | ||
default_detail = "Tool function is mandatory." | ||
|
||
|
||
class ToolDoesNotExist(ToolInstanceBaseException): | ||
status_code = 400 | ||
default_detail = "Tool doesn't exist." | ||
|
||
|
||
class FetchToolListFailed(ToolInstanceBaseException): | ||
status_code = 400 | ||
default_detail = "Failed to fetch tool list." | ||
|
||
|
||
class ToolInstantiationError(ToolInstanceBaseException): | ||
status_code = 500 | ||
default_detail = "Error instantiating tool." | ||
|
||
|
||
class BadRequestException(ToolInstanceBaseException): | ||
status_code = 400 | ||
default_detail = "Invalid input." | ||
|
||
|
||
class ToolSettingValidationError(APIException): | ||
status_code = 400 | ||
default_detail = "Error while validating tool's setting." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import uuid | ||
|
||
from account_v2.models import User | ||
from connector_v2.models import ConnectorInstance | ||
from django.db import models | ||
from django.db.models import QuerySet | ||
from utils.models.base_model import BaseModel | ||
from workflow_manager.workflow_v2.models.workflow import Workflow | ||
|
||
TOOL_ID_LENGTH = 64 | ||
TOOL_VERSION_LENGTH = 16 | ||
TOOL_STATUS_LENGTH = 32 | ||
|
||
|
||
class ToolInstanceManager(models.Manager): | ||
def get_instances_for_workflow( | ||
self, workflow: uuid.UUID | ||
) -> QuerySet["ToolInstance"]: | ||
return self.filter(workflow=workflow) | ||
|
||
|
||
class ToolInstance(BaseModel): | ||
class Status(models.TextChoices): | ||
PENDING = "PENDING", "Settings Not Configured" | ||
READY = "READY", "Ready to Start" | ||
INITIATED = "INITIATED", "Initialization in Progress" | ||
COMPLETED = "COMPLETED", "Process Completed" | ||
ERROR = "ERROR", "Error Encountered" | ||
|
||
workflow = models.ForeignKey( | ||
Workflow, | ||
on_delete=models.CASCADE, | ||
related_name="tool_instances", | ||
null=False, | ||
blank=False, | ||
) | ||
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) | ||
tool_id = models.CharField( | ||
max_length=TOOL_ID_LENGTH, | ||
db_comment="Function name of the tool being used", | ||
) | ||
input = models.JSONField(null=True, db_comment="Provisional WF input to a tool") | ||
output = models.JSONField(null=True, db_comment="Provisional WF output to a tool") | ||
version = models.CharField(max_length=TOOL_VERSION_LENGTH) | ||
metadata = models.JSONField(db_comment="Stores config for a tool") | ||
step = models.IntegerField() | ||
# TODO: Make as an enum supporting fixed values once we have clarity | ||
status = models.CharField(max_length=TOOL_STATUS_LENGTH, default="Ready to start") | ||
created_by = models.ForeignKey( | ||
User, | ||
on_delete=models.SET_NULL, | ||
related_name="tool_instances_created", | ||
null=True, | ||
blank=True, | ||
) | ||
modified_by = models.ForeignKey( | ||
User, | ||
on_delete=models.SET_NULL, | ||
related_name="tool_instances_modified", | ||
null=True, | ||
blank=True, | ||
) | ||
# Added these connectors separately | ||
# for file and db for scalability | ||
input_file_connector = models.ForeignKey( | ||
ConnectorInstance, | ||
on_delete=models.SET_NULL, | ||
related_name="input_file_connectors", | ||
null=True, | ||
blank=True, | ||
) | ||
output_file_connector = models.ForeignKey( | ||
ConnectorInstance, | ||
on_delete=models.SET_NULL, | ||
related_name="output_file_connectors", | ||
null=True, | ||
blank=True, | ||
) | ||
input_db_connector = models.ForeignKey( | ||
ConnectorInstance, | ||
on_delete=models.SET_NULL, | ||
related_name="input_db_connectors", | ||
null=True, | ||
blank=True, | ||
) | ||
output_db_connector = models.ForeignKey( | ||
ConnectorInstance, | ||
on_delete=models.SET_NULL, | ||
related_name="output_db_connectors", | ||
null=True, | ||
blank=True, | ||
) | ||
|
||
objects = ToolInstanceManager() | ||
|
||
class Meta: | ||
verbose_name = "Tool Instance" | ||
verbose_name_plural = "Tool Instances" | ||
db_table = "tool_instance_v2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import logging | ||
import uuid | ||
from typing import Any | ||
|
||
from prompt_studio.prompt_studio_registry_v2.constants import PromptStudioRegistryKeys | ||
from rest_framework.serializers import ListField, Serializer, UUIDField, ValidationError | ||
from tool_instance_v2.constants import ToolInstanceKey as TIKey | ||
from tool_instance_v2.constants import ToolKey | ||
from tool_instance_v2.exceptions import ToolDoesNotExist | ||
from tool_instance_v2.models import ToolInstance | ||
from tool_instance_v2.tool_instance_helper import ToolInstanceHelper | ||
from tool_instance_v2.tool_processor import ToolProcessor | ||
from unstract.tool_registry.dto import Tool | ||
from workflow_manager.workflow_v2.constants import WorkflowKey | ||
from workflow_manager.workflow_v2.models.workflow import Workflow | ||
|
||
from backend.constants import RequestKey | ||
from backend.serializers import AuditSerializer | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ToolInstanceSerializer(AuditSerializer): | ||
workflow_id = UUIDField(write_only=True) | ||
|
||
class Meta: | ||
model = ToolInstance | ||
fields = "__all__" | ||
extra_kwargs = { | ||
TIKey.WORKFLOW: { | ||
"required": False, | ||
}, | ||
TIKey.VERSION: { | ||
"required": False, | ||
}, | ||
TIKey.METADATA: { | ||
"required": False, | ||
}, | ||
TIKey.STEP: { | ||
"required": False, | ||
}, | ||
} | ||
|
||
def to_representation(self, instance: ToolInstance) -> dict[str, str]: | ||
rep: dict[str, Any] = super().to_representation(instance) | ||
tool_function = rep.get(TIKey.TOOL_ID) | ||
|
||
if tool_function is None: | ||
raise ToolDoesNotExist() | ||
try: | ||
tool: Tool = ToolProcessor.get_tool_by_uid(tool_function) | ||
except ToolDoesNotExist: | ||
return rep | ||
rep[ToolKey.ICON] = tool.icon | ||
rep[ToolKey.NAME] = tool.properties.display_name | ||
# Need to Change it into better method | ||
if self.context.get(RequestKey.REQUEST): | ||
metadata = ToolInstanceHelper.get_altered_metadata(instance) | ||
if metadata: | ||
rep[TIKey.METADATA] = metadata | ||
return rep | ||
|
||
def create(self, validated_data: dict[str, Any]) -> Any: | ||
workflow_id = validated_data.pop(WorkflowKey.WF_ID) | ||
try: | ||
workflow = Workflow.objects.get(pk=workflow_id) | ||
except Workflow.DoesNotExist: | ||
raise ValidationError(f"Workflow with ID {workflow_id} does not exist.") | ||
validated_data[TIKey.WORKFLOW] = workflow | ||
|
||
tool_uid = validated_data.get(TIKey.TOOL_ID) | ||
if not tool_uid: | ||
raise ToolDoesNotExist() | ||
|
||
tool: Tool = ToolProcessor.get_tool_by_uid(tool_uid=tool_uid) | ||
# TODO: Handle other fields once tools SDK is out | ||
validated_data[TIKey.PK] = uuid.uuid4() | ||
# TODO: Use version from tool props | ||
validated_data[TIKey.VERSION] = "" | ||
validated_data[TIKey.METADATA] = { | ||
# TODO: Review and remove tool instance ID | ||
WorkflowKey.WF_TOOL_INSTANCE_ID: str(validated_data[TIKey.PK]), | ||
PromptStudioRegistryKeys.PROMPT_REGISTRY_ID: str(tool_uid), | ||
**ToolProcessor.get_default_settings(tool), | ||
} | ||
if TIKey.STEP not in validated_data: | ||
validated_data[TIKey.STEP] = workflow.tool_instances.count() + 1 | ||
# Workflow will get activated on adding tools to workflow | ||
if not workflow.is_active: | ||
workflow.is_active = True | ||
workflow.save() | ||
return super().create(validated_data) | ||
|
||
|
||
class ToolInstanceReorderSerializer(Serializer): | ||
workflow_id = UUIDField() | ||
tool_instances = ListField(child=UUIDField()) | ||
|
||
def validate(self, data: dict[str, Any]) -> dict[str, Any]: | ||
workflow_id = data.get(WorkflowKey.WF_ID) | ||
tool_instances = data.get(WorkflowKey.WF_TOOL_INSTANCES, []) | ||
|
||
# Check if the workflow exists | ||
try: | ||
workflow = Workflow.objects.get(pk=workflow_id) | ||
except Workflow.DoesNotExist: | ||
raise ValidationError(f"Workflow with ID {workflow_id} does not exist.") | ||
|
||
# Check if the number of tool instances matches the actual count | ||
tool_instance_count = workflow.tool_instances.count() | ||
if len(tool_instances) != tool_instance_count: | ||
msg = ( | ||
f"Incorrect number of tool instances passed: " | ||
f"{len(tool_instances)}, expected: {tool_instance_count}" | ||
) | ||
logger.error(msg) | ||
raise ValidationError(detail=msg) | ||
|
||
# Check if each tool instance exists in the workflow | ||
existing_tool_instance_ids = workflow.tool_instances.values_list( | ||
"id", flat=True | ||
) | ||
for tool_instance_id in tool_instances: | ||
if tool_instance_id not in existing_tool_instance_ids: | ||
raise ValidationError( | ||
"One or more tool instances do not exist in the workflow." | ||
) | ||
|
||
return data |
Oops, something went wrong.