Skip to content

Commit

Permalink
Text mode test case assertion updated to be fuzzy with threshold 0.99
Browse files Browse the repository at this point in the history
  • Loading branch information
chandrasekharan-zipstack committed Nov 1, 2024
1 parent 2625c1d commit 768257c
Showing 1 changed file with 19 additions and 79 deletions.
98 changes: 19 additions & 79 deletions tests/client_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import logging
import os
import unittest
from difflib import SequenceMatcher, unified_diff
from pathlib import Path

import pytest
import requests

from unstract.llmwhisperer import LLMWhispererClient

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -40,93 +36,37 @@ def test_get_usage_info(client):
)
def test_whisper(client, data_dir, processing_mode, output_mode, input_file):
file_path = os.path.join(data_dir, input_file)
response = client.whisper(
whisper_result = client.whisper(
processing_mode=processing_mode,
output_mode=output_mode,
file_path=file_path,
timeout=200,
)
logger.debug(response)
logger.debug(whisper_result)

exp_basename = f"{Path(input_file).stem}.{processing_mode}.{output_mode}.txt"
exp_file = os.path.join(data_dir, "expected", exp_basename)
with open(exp_file, encoding="utf-8") as f:
exp = f.read()

assert isinstance(response, dict)
assert response["status_code"] == 200

# For text based processing, perform a strict match
if processing_mode == "text" and output_mode == "text":
assert response["extracted_text"] == exp
# For OCR based processing, perform a fuzzy match
else:
extracted_text = response["extracted_text"]
similarity = SequenceMatcher(None, extracted_text, exp).ratio()
threshold = 0.97

if similarity < threshold:
diff = "\n".join(
unified_diff(exp.splitlines(), extracted_text.splitlines(), fromfile="Expected", tofile="Extracted")
)
pytest.fail(f"Texts are not similar enough: {similarity * 100:.2f}% similarity. Diff:\n{diff}")
assert_extracted_text(exp_file, whisper_result, processing_mode, output_mode)


# TODO: Review and port to pytest based tests
class TestLLMWhispererClient(unittest.TestCase):
@unittest.skip("Skipping test_whisper")
def test_whisper(self):
client = LLMWhispererClient()
# response = client.whisper(
# url="https://storage.googleapis.com/pandora-static/samples/bill.jpg.pdf"
# )
response = client.whisper(
file_path="test_data/restaurant_invoice_photo.pdf",
timeout=200,
store_metadata_for_highlighting=True,
)
print(response)
# self.assertIsInstance(response, dict)
def assert_extracted_text(file_path, whisper_result, mode, output_mode):
with open(file_path, encoding="utf-8") as f:
exp = f.read()

# @unittest.skip("Skipping test_whisper")
def test_whisper_stream(self):
client = LLMWhispererClient()
download_url = "https://storage.googleapis.com/pandora-static/samples/bill.jpg.pdf"
# Create a stream of download_url and pass it to whisper
response_download = requests.get(download_url, stream=True)
response_download.raise_for_status()
response = client.whisper(
stream=response_download.iter_content(chunk_size=1024),
timeout=200,
store_metadata_for_highlighting=True,
)
print(response)
# self.assertIsInstance(response, dict)
assert isinstance(whisper_result, dict)
assert whisper_result["status_code"] == 200

@unittest.skip("Skipping test_whisper_status")
def test_whisper_status(self):
client = LLMWhispererClient()
response = client.whisper_status(whisper_hash="7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a")
logger.info(response)
self.assertIsInstance(response, dict)
# For OCR based processing
threshold = 0.97

@unittest.skip("Skipping test_whisper_retrieve")
def test_whisper_retrieve(self):
client = LLMWhispererClient()
response = client.whisper_retrieve(whisper_hash="7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a")
logger.info(response)
self.assertIsInstance(response, dict)
# For text based processing
if mode == "native_text" and output_mode == "text":
threshold = 0.99
extracted_text = whisper_result["extracted_text"]
similarity = SequenceMatcher(None, extracted_text, exp).ratio()

@unittest.skip("Skipping test_whisper_highlight_data")
def test_whisper_highlight_data(self):
client = LLMWhispererClient()
response = client.highlight_data(
whisper_hash="9924d865|5f1d285a7cf18d203de7af1a1abb0a3a",
search_text="Indiranagar",
if similarity < threshold:
diff = "\n".join(
unified_diff(exp.splitlines(), extracted_text.splitlines(), fromfile="Expected", tofile="Extracted")
)
logger.info(response)
self.assertIsInstance(response, dict)


if __name__ == "__main__":
unittest.main()
pytest.fail(f"Texts are not similar enough: {similarity * 100:.2f}% similarity. Diff:\n{diff}")

0 comments on commit 768257c

Please sign in to comment.