Skip to content

Commit

Permalink
Merge pull request #17 from Zipstack/update-client_for_status_api
Browse files Browse the repository at this point in the history
Updated the status API check based on the changes at llm whisperer side
  • Loading branch information
jaseemjaskp authored Feb 5, 2025
2 parents 82baf42 + ba6ef48 commit c497fd3
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 69 deletions.
2 changes: 1 addition & 1 deletion src/unstract/llmwhisperer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.0.0"
__version__ = "2.1.0"

from .client import LLMWhispererClient # noqa: F401
from .client_v2 import LLMWhispererClientV2 # noqa: F401
Expand Down
68 changes: 30 additions & 38 deletions src/unstract/llmwhisperer/client_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ class LLMWhispererClientV2:
client's activities and errors.
"""

formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
log_stream_handler = logging.StreamHandler()
log_stream_handler.setFormatter(formatter)
Expand Down Expand Up @@ -108,7 +106,6 @@ def __init__(
self.logger.setLevel(logging.ERROR)
self.logger.setLevel(logging_level)
self.logger.debug("logging_level set to %s", logging_level)

if base_url == "":
self.base_url = os.getenv("LLMWHISPERER_BASE_URL_V2", BASE_URL_V2)
else:
Expand All @@ -121,6 +118,15 @@ def __init__(
self.api_key = api_key

self.headers = {"unstract-key": self.api_key}
# For test purpose
# self.headers = {
# "Subscription-Id": "python-client",
# "Subscription-Name": "python-client",
# "User-Id": "python-client-user",
# "Product-Id": "python-client-product",
# "Product-Name": "python-client-product",
# "Start-Date": "2024-07-09",
# }

def get_usage_info(self) -> dict:
"""Retrieves the usage information of the LLMWhisperer API.
Expand Down Expand Up @@ -283,9 +289,7 @@ def generate():
)
else:
params["url_in_post"] = True
req = requests.Request(
"POST", api_url, params=params, headers=self.headers, data=url
)
req = requests.Request("POST", api_url, params=params, headers=self.headers, data=url)
prepared = req.prepare()
s = requests.Session()
response = s.send(prepared, timeout=wait_timeout, stream=should_stream)
Expand All @@ -310,42 +314,30 @@ def generate():
message["message"] = "Whisper client operation failed"
message["extraction"] = {}
return message
if status["status"] == "accepted":
self.logger.debug(f'Whisper-hash:{whisper_hash} | STATUS: {status["status"]}...')
if status["status"] == "processing":
self.logger.debug(
f"Whisper-hash:{whisper_hash} | STATUS: processing..."
)
elif status["status"] == "delivered":
self.logger.debug(
f"Whisper-hash:{whisper_hash} | STATUS: Already delivered!"
)
raise LLMWhispererClientException(
{
"status_code": -1,
"message": "Whisper operation already delivered",
}
)
elif status["status"] == "unknown":
self.logger.debug(
f"Whisper-hash:{whisper_hash} | STATUS: unknown..."
)
raise LLMWhispererClientException(
{
"status_code": -1,
"message": "Whisper operation status unknown",
}
)
elif status["status"] == "failed":
self.logger.debug(
f"Whisper-hash:{whisper_hash} | STATUS: failed..."
)
self.logger.debug(f"Whisper-hash:{whisper_hash} | STATUS: processing...")

elif status["status"] == "error":
self.logger.debug(f"Whisper-hash:{whisper_hash} | STATUS: failed...")
self.logger.error(f'Whisper-hash:{whisper_hash} | STATUS: failed with {status["message"]}')
message["status_code"] = -1
message["message"] = status["message"]
message["status"] = "error"
message["extraction"] = {}
return message
elif "error" in status["status"]:
# for backward compatabity
self.logger.debug(f"Whisper-hash:{whisper_hash} | STATUS: failed...")
self.logger.error(f'Whisper-hash:{whisper_hash} | STATUS: failed with {status["status"]}')
message["status_code"] = -1
message["message"] = "Whisper operation failed"
message["message"] = status["status"]
message["status"] = "error"
message["extraction"] = {}
return message
elif status["status"] == "processed":
self.logger.debug(
f"Whisper-hash:{whisper_hash} | STATUS: processed!"
)
self.logger.debug(f"Whisper-hash:{whisper_hash} | STATUS: processed!")
resultx = self.whisper_retrieve(whisper_hash=whisper_hash)
if resultx["status_code"] == 200:
message["status_code"] = 200
Expand Down
119 changes: 89 additions & 30 deletions tests/integration/client_v2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def test_get_usage_info(client_v2):
"current_page_count_form",
"current_page_count_high_quality",
"current_page_count_native_text",
"current_page_count_excel",
"daily_quota",
"monthly_quota",
"overage_page_count",
Expand All @@ -44,7 +45,10 @@ def test_get_usage_info(client_v2):
def test_whisper_v2(client_v2, data_dir, output_mode, mode, input_file):
file_path = os.path.join(data_dir, input_file)
whisper_result = client_v2.whisper(
mode=mode, output_mode=output_mode, file_path=file_path, wait_for_completion=True
mode=mode,
output_mode=output_mode,
file_path=file_path,
wait_for_completion=True,
)
logger.debug(f"Result for '{output_mode}', '{mode}', " f"'{input_file}: {whisper_result}")

Expand All @@ -54,24 +58,62 @@ def test_whisper_v2(client_v2, data_dir, output_mode, mode, input_file):
assert_extracted_text(exp_file, whisper_result, mode, output_mode)


@pytest.mark.parametrize(
"output_mode, mode, input_file",
[
("layout_preserving", "high_quality", "test.json"),
],
)
def test_whisper_v2_error(client_v2, data_dir, output_mode, mode, input_file):
file_path = os.path.join(data_dir, input_file)

whisper_result = client_v2.whisper(
mode=mode,
output_mode=output_mode,
file_path=file_path,
wait_for_completion=True,
)
logger.debug(f"Result for '{output_mode}', '{mode}', " f"'{input_file}: {whisper_result}")

assert_error_message(whisper_result)


@pytest.mark.parametrize(
"output_mode, mode, url, input_file, page_count",
[
("layout_preserving", "native_text", "https://unstractpocstorage.blob.core.windows.net/public/Amex.pdf",
"credit_card.pdf", 7),
("layout_preserving", "low_cost", "https://unstractpocstorage.blob.core.windows.net/public/Amex.pdf",
"credit_card.pdf", 7),
("layout_preserving", "high_quality", "https://unstractpocstorage.blob.core.windows.net/public/scanned_bill.pdf",
"restaurant_invoice_photo.pdf", 1),
("layout_preserving", "form", "https://unstractpocstorage.blob.core.windows.net/public/scanned_form.pdf",
"handwritten-form.pdf", 1),
]
(
"layout_preserving",
"native_text",
"https://unstractpocstorage.blob.core.windows.net/public/Amex.pdf",
"credit_card.pdf",
7,
),
(
"layout_preserving",
"low_cost",
"https://unstractpocstorage.blob.core.windows.net/public/Amex.pdf",
"credit_card.pdf",
7,
),
(
"layout_preserving",
"high_quality",
"https://unstractpocstorage.blob.core.windows.net/public/scanned_bill.pdf",
"restaurant_invoice_photo.pdf",
1,
),
(
"layout_preserving",
"form",
"https://unstractpocstorage.blob.core.windows.net/public/scanned_form.pdf",
"handwritten-form.pdf",
1,
),
],
)
def test_whisper_v2_url_in_post(client_v2, data_dir, output_mode, mode, url, input_file, page_count):
usage_before = client_v2.get_usage_info()
whisper_result = client_v2.whisper(
mode=mode, output_mode=output_mode, url=url, wait_for_completion=True
)
whisper_result = client_v2.whisper(mode=mode, output_mode=output_mode, url=url, wait_for_completion=True)
logger.debug(f"Result for '{output_mode}', '{mode}', " f"'{input_file}: {whisper_result}")

exp_basename = f"{Path(input_file).stem}.{mode}.{output_mode}.txt"
Expand All @@ -83,6 +125,12 @@ def test_whisper_v2_url_in_post(client_v2, data_dir, output_mode, mode, url, inp
verify_usage(usage_before, usage_after, page_count, mode)


def assert_error_message(whisper_result):
assert isinstance(whisper_result, dict)
assert whisper_result["status"] == "error"
assert "error" in whisper_result["message"]


def assert_extracted_text(file_path, whisper_result, mode, output_mode):
with open(file_path, encoding="utf-8") as f:
exp = f.read()
Expand All @@ -91,34 +139,45 @@ def assert_extracted_text(file_path, whisper_result, mode, output_mode):
assert whisper_result["status_code"] == 200

# For OCR based processing
threshold = 0.97
threshold = 0.94

# For text based processing
if mode == "native_text" and output_mode == "text":
threshold = 0.99
elif mode == "low_cost":
threshold = 0.90
extracted_text = whisper_result["extraction"]["result_text"]
similarity = SequenceMatcher(None, extracted_text, exp).ratio()

if similarity < threshold:
diff = "\n".join(
unified_diff(exp.splitlines(), extracted_text.splitlines(), fromfile="Expected", tofile="Extracted")
unified_diff(
exp.splitlines(),
extracted_text.splitlines(),
fromfile="Expected",
tofile="Extracted",
)
)
pytest.fail(f"Texts are not similar enough: {similarity * 100:.2f}% similarity. Diff:\n{diff}")
pytest.fail(f"Diff:\n{diff}.\n Texts are not similar enough: {similarity * 100:.2f}% similarity. ")


def verify_usage(before_extract, after_extract, page_count, mode='form'):
all_modes = ['form', 'high_quality', 'low_cost', 'native_text']
def verify_usage(before_extract, after_extract, page_count, mode="form"):
all_modes = ["form", "high_quality", "low_cost", "native_text"]
all_modes.remove(mode)
assert (after_extract['today_page_count'] == before_extract['today_page_count'] + page_count), \
"today_page_count calculation is wrong"
assert (after_extract['current_page_count'] == before_extract['current_page_count'] + page_count), \
"current_page_count calculation is wrong"
if after_extract['overage_page_count'] > 0:
assert (after_extract['overage_page_count'] == before_extract['overage_page_count'] + page_count), \
"overage_page_count calculation is wrong"
assert (after_extract[f'current_page_count_{mode}'] == before_extract[f'current_page_count_{mode}'] + page_count), \
f"{mode} mode calculation is wrong"
assert (
after_extract["today_page_count"] == before_extract["today_page_count"] + page_count
), "today_page_count calculation is wrong"
assert (
after_extract["current_page_count"] == before_extract["current_page_count"] + page_count
), "current_page_count calculation is wrong"
if after_extract["overage_page_count"] > 0:
assert (
after_extract["overage_page_count"] == before_extract["overage_page_count"] + page_count
), "overage_page_count calculation is wrong"
assert (
after_extract[f"current_page_count_{mode}"] == before_extract[f"current_page_count_{mode}"] + page_count
), f"{mode} mode calculation is wrong"
for i in range(len(all_modes)):
assert (after_extract[f'current_page_count_{all_modes[i]}'] ==
before_extract[f'current_page_count_{all_modes[i]}']), \
f"{all_modes[i]} mode calculation is wrong"
assert (
after_extract[f"current_page_count_{all_modes[i]}"] == before_extract[f"current_page_count_{all_modes[i]}"]
), f"{all_modes[i]} mode calculation is wrong"
1 change: 1 addition & 0 deletions tests/test_data/test.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"test": "HelloWorld"}

0 comments on commit c497fd3

Please sign in to comment.