Skip to content

Commit

Permalink
[FIX] Fixes for Prompt studio Indexing and tool runs (#143)
Browse files Browse the repository at this point in the history
* Refactoring changed file names

* Roll version

* Update src/unstract/sdk/utils/tool_utils.py

Co-authored-by: Chandrasekharan M <[email protected]>
Signed-off-by: Gayathri <[email protected]>

* Update tests/test_fs_permanent.py

Co-authored-by: Chandrasekharan M <[email protected]>
Signed-off-by: Gayathri <[email protected]>

* Update tests/test_fs_permanent.py

Co-authored-by: Chandrasekharan M <[email protected]>
Signed-off-by: Gayathri <[email protected]>

* Update src/unstract/sdk/utils/tool_utils.py

Co-authored-by: Chandrasekharan M <[email protected]>
Signed-off-by: Gayathri <[email protected]>

* Address review comments

* Add support for passing length to mime_type

* Add recursive and fix mypy issue

* CHange test case with new behavior to return FileNotFound in read()

* Remove typing kwargs.

* Resolve mypy issues

* Resolve mypy issues

* Remove unwanted conditionals/vars

* Remove pandoc and tessaract.

* Details of provider added to error message

* fixed enum conditional matching value

* Include EnvHelper in __init__

* Rename error handler

* Upgrade version

* Expose StorageType outside

* Resolve circular dependency issue

* Resolve circular dependency issue

* Indexing fixes + clean up

* Add deprecation warnings

* Add deprecation warnings

---------

Signed-off-by: Gayathri <[email protected]>
Co-authored-by: Chandrasekharan M <[email protected]>
  • Loading branch information
1 parent eaccd55 commit 2675e2b
Show file tree
Hide file tree
Showing 15 changed files with 91 additions and 45 deletions.
2 changes: 1 addition & 1 deletion src/unstract/sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.54.0rc11"
__version__ = "0.54.0rc12"


def get_sdk_version():
Expand Down
15 changes: 15 additions & 0 deletions src/unstract/sdk/adapters/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import warnings
from pathlib import Path

import filetype
Expand Down Expand Up @@ -76,6 +77,13 @@ def get_file_mime_type(
Returns:
str: MIME type of the file
"""
# Adding the following DeprecationWarning manually as the package "deprecated"
# does not support deprecation on static methods.
warnings.warn(
"`get_file_mime_type` is deprecated. "
"Use `FileStorage mime_type()` instead.",
DeprecationWarning,
)
sample_contents = fs.read(path=input_file, mode="rb", length=100)
input_file_mime = magic.from_buffer(sample_contents, mime=True)
return input_file_mime
Expand All @@ -93,6 +101,13 @@ def guess_extention(
Returns:
str: File extention
"""
# Adding the following DeprecationWarning manually as the package "deprecated"
# does not support deprecation on static methods.
warnings.warn(
"`guess_extention` is deprecated. "
"Use `FileStorage guess_extension()` instead.",
DeprecationWarning,
)
input_file_extention = ""
sample_contents = fs.read(path=input_file_path, mode="rb", length=100)
if sample_contents:
Expand Down
2 changes: 1 addition & 1 deletion src/unstract/sdk/adapters/x2text/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def process_document(
if not local_storage.exists(input_file_path):
fs.download(from_path=input_file_path, to_path=input_file_path)
with open(input_file_path, "rb") as input_f:
mime_type = AdapterUtils.get_file_mime_type(input_file=input_file_path)
mime_type = local_storage.mime_type(input_file=input_file_path)
files = {"file": (input_file_path, input_f, mime_type)}
response = UnstructuredHelper.make_request(
unstructured_adapter_config=unstructured_adapter_config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from llama_parse import LlamaParse

from unstract.sdk.adapters.exceptions import AdapterError
from unstract.sdk.adapters.utils import AdapterUtils
from unstract.sdk.adapters.x2text.dto import TextExtractionResult
from unstract.sdk.adapters.x2text.llama_parse.src.constants import LlamaParseConfig
from unstract.sdk.adapters.x2text.x2text_adapter import X2TextAdapter
Expand Down Expand Up @@ -62,9 +61,7 @@ def _call_parser(
file_extension = pathlib.Path(input_file_path).suffix
if not file_extension:
try:
input_file_extension = AdapterUtils.guess_extention(
input_file_path, fs
)
input_file_extension = fs.guess_extension(input_file_path)
input_file_path_copy = input_file_path
input_file_path = ".".join(
(input_file_path_copy, input_file_extension)
Expand Down
5 changes: 3 additions & 2 deletions src/unstract/sdk/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ class ToolEnv:
DATA_DIR = "TOOL_DATA_DIR"
EXECUTION_BY_TOOL = "EXECUTION_BY_TOOL"
EXECUTION_DATA_DIR = "EXECUTION_DATA_DIR"
WORKFLOW_EXECUTION_FS_PROVIDER = "WORKFLOW_EXECUTION_FS_PROVIDER"
WORKFLOW_EXECUTION_FS_CREDENTIAL = "WORKFLOW_EXECUTION_FS_CREDENTIAL"
WORKFLOW_EXECUTION_FILE_STORAGE_CREDENTIALS = (
"WORKFLOW_EXECUTION_FILE_STORAGE_CREDENTIALS"
)


class ConnectorKeys:
Expand Down
3 changes: 2 additions & 1 deletion src/unstract/sdk/file_storage/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
class FileOperationParams:
READ_ENTIRE_LENGTH = -1
MIME_TYPE_DEFAULT_READ_LENGTH = 100
EXTENSION_DEFAULT_READ_LENGTH = 100
DEFAULT_ENCODING = "utf-8"


Expand All @@ -15,7 +16,7 @@ class FileSeekPosition:

class StorageType(Enum):
PERMANENT = "permanent"
TEMPORARY = "temporary"
SHARED_TEMPORARY = "shared_temporary"


class CredentialKeyword:
Expand Down
2 changes: 1 addition & 1 deletion src/unstract/sdk/file_storage/env_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_storage(storage_type: StorageType, env_name: str) -> FileStorage:
credentials = file_storage_creds.get(CredentialKeyword.CREDENTIALS, "{}")
if storage_type == StorageType.PERMANENT:
file_storage = PermanentFileStorage(provider=provider, **credentials)
elif storage_type == StorageType.TEMPORARY:
elif storage_type == StorageType.SHARED_TEMPORARY:
file_storage = SharedTemporaryFileStorage(
provider=provider, **credentials
)
Expand Down
22 changes: 22 additions & 0 deletions src/unstract/sdk/file_storage/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from hashlib import sha256
from typing import Any, Union

import filetype
import fsspec
import magic
import yaml
Expand Down Expand Up @@ -361,3 +362,24 @@ def yaml_load(
with self.fs.open(path=path) as f:
data: dict[str, Any] = yaml.safe_load(f)
return data

@skip_local_cache
def guess_extension(self, path: str) -> str:
"""Returns the extension of the file passed.
Args:
path (str): String holding the path
Returns:
str: File extension
"""
file_extension = ""
sample_contents = self.read(
path=path,
mode="rb",
length=FileOperationParams.EXTENSION_DEFAULT_READ_LENGTH,
)
if sample_contents:
file_type = filetype.guess(sample_contents)
file_extension = file_type.EXTENSION
return file_extension
4 changes: 4 additions & 0 deletions src/unstract/sdk/file_storage/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,7 @@ def yaml_load(
path: str,
) -> dict[Any, Any]:
pass

@abstractmethod
def guess_extension(self, path: str) -> str:
pass
2 changes: 1 addition & 1 deletion src/unstract/sdk/file_storage/shared_temporary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(
):
if provider.value not in self.SUPPORTED_FILE_STORAGE_TYPES:
raise FileStorageError(
f"File storage provider is not supported in Permanent mode. "
f"File storage provider is not supported in Shared Temporary mode. "
f"Supported providers: {self.SUPPORTED_FILE_STORAGE_TYPES}"
)
if provider == FileStorageProvider.MINIO:
Expand Down
36 changes: 19 additions & 17 deletions src/unstract/sdk/tool/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
ToolEnv,
ToolExecKey,
)
from unstract.sdk.file_storage import SharedTemporaryFileStorage
from unstract.sdk.exceptions import FileStorageError
from unstract.sdk.file_storage import EnvHelper, StorageType
from unstract.sdk.tool.mixin import ToolConfigHelper
from unstract.sdk.tool.parser import ToolArgsParser
from unstract.sdk.tool.stream import StreamMixin
Expand Down Expand Up @@ -43,27 +44,28 @@ def __init__(self, log_level: LogLevel = LogLevel.INFO) -> None:
self.filestorage_provider = None
self.workflow_filestorage = None
self.execution_dir = None
filestorage_provider_env = os.environ.get(
ToolEnv.WORKFLOW_EXECUTION_FS_PROVIDER
filestorage_env = os.environ.get(
ToolEnv.WORKFLOW_EXECUTION_FILE_STORAGE_CREDENTIALS
)
if filestorage_provider_env:
if filestorage_env:
self.execution_dir = Path(self.get_env_or_die(ToolEnv.EXECUTION_DATA_DIR))
self.filestorage_provider = ToolUtils.get_filestorage_provider(
var_name=ToolEnv.WORKFLOW_EXECUTION_FS_PROVIDER
)

try:
self.filestorage_credentials = ToolUtils.get_filestorage_credentials(
ToolEnv.WORKFLOW_EXECUTION_FS_CREDENTIAL
self.workflow_filestorage = EnvHelper.get_storage(
StorageType.SHARED_TEMPORARY,
ToolEnv.WORKFLOW_EXECUTION_FILE_STORAGE_CREDENTIALS,
)
except KeyError as e:
self.stream_error_and_exit(
f"Required credentials is missing in the env: {str(e)}"
)
except json.JSONDecodeError:
raise ValueError(
"File storage credentials are not set properly. "
"Please check your settings."
except FileStorageError as e:
self.stream_error_and_exit(
"Error while initialising storage: %s",
e,
stack_info=True,
exc_info=True,
)
self.workflow_filestorage = SharedTemporaryFileStorage(
provider=self.filestorage_provider,
**self.filestorage_credentials,
)

@classmethod
def from_tool_args(cls, args: list[str]) -> "BaseTool":
Expand Down
5 changes: 2 additions & 3 deletions src/unstract/sdk/tool/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,8 @@ def _validate_file_type(self, input_file: Path) -> None:
)
allowed_mimes.append(EXT_MIME_MAP[ext])
if self.tool.workflow_filestorage:
input_file_mime = ToolUtils.get_file_mime_type(
input_file=input_file, fs=self.tool.workflow_filestorage
)
tool_fs = self.tool.workflow_filestorage
input_file_mime = tool_fs.mime_type(input_file=input_file)
else:
input_file_mime = ToolUtils.get_file_mime_type(input_file=input_file)
self.tool.stream_log(f"Input file MIME: {input_file_mime}")
Expand Down
7 changes: 7 additions & 0 deletions src/unstract/sdk/utils/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ def get_file_mime_type(
Returns:
str: MIME type of the file
"""
# Adding the following DeprecationWarning manually as the package "deprecated"
# does not support deprecation on static methods.
warnings.warn(
"`get_file_mime_type` is deprecated. "
"Use `FileStorage mime_type()` instead.",
DeprecationWarning,
)
input_file_mime = ""
sample_contents = fs.read(path=input_file, mode="rb", length=100)
input_file_mime = magic.from_buffer(sample_contents, mime=True)
Expand Down
24 changes: 11 additions & 13 deletions src/unstract/sdk/x2txt.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,20 @@ def _get_x2text(self) -> X2TextAdapter:
][Common.ADAPTER]
x2text_metadata = x2text_config.get(Common.ADAPTER_METADATA)
# Add x2text service host, port and platform_service_key
x2text_metadata[X2TextConstants.X2TEXT_HOST] = (
self._tool.get_env_or_die(X2TextConstants.X2TEXT_HOST)
)
x2text_metadata[X2TextConstants.X2TEXT_PORT] = (
self._tool.get_env_or_die(X2TextConstants.X2TEXT_PORT)
)
x2text_metadata[
X2TextConstants.X2TEXT_HOST
] = self._tool.get_env_or_die(X2TextConstants.X2TEXT_HOST)
x2text_metadata[
X2TextConstants.X2TEXT_PORT
] = self._tool.get_env_or_die(X2TextConstants.X2TEXT_PORT)

if not SdkHelper.is_public_adapter(
adapter_id=self._adapter_instance_id
):
x2text_metadata[X2TextConstants.PLATFORM_SERVICE_API_KEY] = (
self._tool.get_env_or_die(
X2TextConstants.PLATFORM_SERVICE_API_KEY
)
x2text_metadata[
X2TextConstants.PLATFORM_SERVICE_API_KEY
] = self._tool.get_env_or_die(
X2TextConstants.PLATFORM_SERVICE_API_KEY
)

self._x2text_instance = x2text_adapter(x2text_metadata)
Expand All @@ -94,9 +94,7 @@ def process(
fs: FileStorage = FileStorage(provider=FileStorageProvider.LOCAL),
**kwargs: dict[Any, Any],
) -> TextExtractionResult:
if self._tool.workflow_filestorage:
fs = self._tool.workflow_filestorage
mime_type = ToolUtils.get_file_mime_type(input_file_path, fs)
mime_type = fs.mime_type(input_file_path)
text_extraction_result: TextExtractionResult = None
if mime_type == MimeType.TEXT:
extracted_text = fs.read(path=input_file_path, mode="r", encoding="utf-8")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_file_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ def test_glob(file_storage, folder_path, expected_result):
FileStorageProvider.GCS,
),
(
StorageType.TEMPORARY,
StorageType.SHARED_TEMPORARY,
"TEST_TEMPORARY_STORAGE",
FileStorageProvider.MINIO,
),
Expand Down

0 comments on commit 2675e2b

Please sign in to comment.