Skip to content

Commit

Permalink
fix in filter and 100% test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
eavanvalkenburg committed Dec 16, 2024
1 parent c7a371e commit 90a7827
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ async def _inner_search(
else:
query_vector = vector
if query_vector is None:
raise VectorSearchExecutionException("Search requires either a vector.")
raise VectorSearchExecutionException("Search requires a vector.")
results = await self.qdrant_client.search(
collection_name=self.collection_name,
query_vector=query_vector,
Expand All @@ -214,7 +214,7 @@ def _get_score_from_result(self, result: ScoredPoint) -> float:
def _create_filter(self, options: VectorSearchOptions) -> Filter:
return Filter(
must=[
FieldCondition(key=filter.field_name, match=MatchAny(any=filter.value))
FieldCondition(key=filter.field_name, match=MatchAny(any=[filter.value]))
for filter in options.filter.filters
]
)
Expand Down
84 changes: 71 additions & 13 deletions python/tests/unit/connectors/memory/qdrant/test_qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@

from pytest import fixture, mark, raises
from qdrant_client.async_qdrant_client import AsyncQdrantClient
from qdrant_client.models import Datatype, Distance, VectorParams
from qdrant_client.models import Datatype, Distance, FieldCondition, Filter, MatchAny, VectorParams

from semantic_kernel.connectors.memory.qdrant.qdrant_collection import QdrantCollection
from semantic_kernel.connectors.memory.qdrant.qdrant_store import QdrantStore
from semantic_kernel.data.record_definition.vector_store_record_fields import VectorStoreRecordVectorField
from semantic_kernel.data.vector_search.vector_search_filter import VectorSearchFilter
from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions
from semantic_kernel.exceptions.memory_connector_exceptions import (
MemoryConnectorException,
MemoryConnectorInitializationError,
VectorStoreModelValidationError,
)
from semantic_kernel.exceptions.search_exceptions import VectorSearchExecutionException

BASE_PATH = "qdrant_client.async_qdrant_client.AsyncQdrantClient"

Expand Down Expand Up @@ -119,9 +121,10 @@ def mock_search():
yield mock_search


def test_vector_store_defaults(vector_store):
assert vector_store.qdrant_client is not None
assert vector_store.qdrant_client._client.rest_uri == "http://localhost:6333"
async def test_vector_store_defaults(vector_store):
async with vector_store:
assert vector_store.qdrant_client is not None
assert vector_store.qdrant_client._client.rest_uri == "http://localhost:6333"


def test_vector_store_with_client():
Expand Down Expand Up @@ -162,18 +165,18 @@ def test_get_collection(vector_store, data_model_definition, qdrant_unit_test_en
assert vector_store.vector_record_collections["test"] == collection


def test_collection_init(data_model_definition, qdrant_unit_test_env):
collection = QdrantCollection(
async def test_collection_init(data_model_definition, qdrant_unit_test_env):
async with QdrantCollection(
data_model_type=dict,
collection_name="test",
data_model_definition=data_model_definition,
env_file_path="test.env",
)
assert collection.collection_name == "test"
assert collection.qdrant_client is not None
assert collection.data_model_type is dict
assert collection.data_model_definition == data_model_definition
assert collection.named_vectors
) as collection:
assert collection.collection_name == "test"
assert collection.qdrant_client is not None
assert collection.data_model_type is dict
assert collection.data_model_definition == data_model_definition
assert collection.named_vectors


def test_collection_init_fail(data_model_definition):
Expand Down Expand Up @@ -275,8 +278,63 @@ async def test_create_index_fail(collection_to_use, request):
await collection.create_collection()


async def test_search(collection):
async def test_search(collection, mock_search):
results = await collection._inner_search(vector=[1.0, 2.0, 3.0], options=VectorSearchOptions(include_vectors=False))
async for result in results.results:
assert result.record["id"] == "id1"
break

assert mock_search.call_count == 1
mock_search.assert_called_with(
collection_name="test",
query_vector=[1.0, 2.0, 3.0],
query_filter=Filter(must=[]),
with_vectors=False,
limit=3,
offset=0,
)


async def test_search_named_vectors(collection, mock_search):
collection.named_vectors = True
results = await collection._inner_search(
vector=[1.0, 2.0, 3.0], options=VectorSearchOptions(vector_field_name="vector", include_vectors=False)
)
async for result in results.results:
assert result.record["id"] == "id1"
break

assert mock_search.call_count == 1
mock_search.assert_called_with(
collection_name="test",
query_vector=("vector", [1.0, 2.0, 3.0]),
query_filter=Filter(must=[]),
with_vectors=False,
limit=3,
offset=0,
)


async def test_search_filter(collection, mock_search):
results = await collection._inner_search(
vector=[1.0, 2.0, 3.0],
options=VectorSearchOptions(include_vectors=False, filter=VectorSearchFilter.equal_to("id", "id1")),
)
async for result in results.results:
assert result.record["id"] == "id1"
break

assert mock_search.call_count == 1
mock_search.assert_called_with(
collection_name="test",
query_vector=[1.0, 2.0, 3.0],
query_filter=Filter(must=[FieldCondition(key="id", match=MatchAny(any=["id1"]))]),
with_vectors=False,
limit=3,
offset=0,
)


async def test_search_fail(collection):
with raises(VectorSearchExecutionException, match="Search requires a vector."):
await collection._inner_search(options=VectorSearchOptions(include_vectors=False))

0 comments on commit 90a7827

Please sign in to comment.