Skip to content

Commit

Permalink
updated setup
Browse files Browse the repository at this point in the history
  • Loading branch information
eavanvalkenburg committed Nov 26, 2024
1 parent a8a83eb commit 92a7140
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 120 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/python-integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ jobs:
- 50051:50051
cosmosdb:
image: mcr.microsoft.com/cosmosdb/linux/azure-cosmos-emulator:vnext-preview
options: >-
--protocol
https
ports:
- 8081:8081
- 1234:1234
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,23 @@ async def test_list_collection_names(
data_model_type: type,
):
"""Test list collection names."""
store = stores["azure_cosmos_db_no_sql"]

assert await store.list_collection_names() == []
async with stores["azure_cosmos_db_no_sql"] as store:
assert await store.list_collection_names() == []

collection_name = "list_collection_names"
collection = store.get_collection(collection_name, data_model_type)
await collection.create_collection()
collection_name = "list_collection_names"
collection = store.get_collection(collection_name, data_model_type)
await collection.create_collection()

collection_names = await store.list_collection_names()
assert collection_name in collection_names
collection_names = await store.list_collection_names()
assert collection_name in collection_names

await collection.delete_collection()
assert await collection.does_collection_exist() is False
collection_names = await store.list_collection_names()
assert collection_name not in collection_names
await collection.delete_collection()
assert await collection.does_collection_exist() is False
collection_names = await store.list_collection_names()
assert collection_name not in collection_names

# Deleting the collection doesn't remove it from the vector_record_collections list in the store
assert collection_name in store.vector_record_collections
# Deleting the collection doesn't remove it from the vector_record_collections list in the store
assert collection_name in store.vector_record_collections

@pytest.mark.asyncio
async def test_collection_not_created(
Expand All @@ -58,27 +57,27 @@ async def test_collection_not_created(
data_record: dict[str, Any],
):
"""Test get without collection."""
store = stores["azure_cosmos_db_no_sql"]
collection_name = "collection_not_created"
collection = store.get_collection(collection_name, data_model_type)
async with stores["azure_cosmos_db_no_sql"] as store:
collection_name = "collection_not_created"
collection = store.get_collection(collection_name, data_model_type)

assert await collection.does_collection_exist() is False
assert await collection.does_collection_exist() is False

with pytest.raises(
MemoryConnectorException, match="The collection does not exist yet. Create the collection first."
):
await collection.upsert(data_model_type(**data_record))
with pytest.raises(
MemoryConnectorException, match="The collection does not exist yet. Create the collection first."
):
await collection.upsert(data_model_type(**data_record))

with pytest.raises(
MemoryConnectorException, match="The collection does not exist yet. Create the collection first."
):
await collection.get(data_record["id"])
with pytest.raises(
MemoryConnectorException, match="The collection does not exist yet. Create the collection first."
):
await collection.get(data_record["id"])

with pytest.raises(MemoryConnectorException):
await collection.delete(data_record["id"])
with pytest.raises(MemoryConnectorException):
await collection.delete(data_record["id"])

with pytest.raises(MemoryConnectorException, match="Container could not be deleted."):
await collection.delete_collection()
with pytest.raises(MemoryConnectorException, match="Container could not be deleted."):
await collection.delete_collection()

@pytest.mark.asyncio
async def test_custom_partition_key(
Expand All @@ -88,33 +87,35 @@ async def test_custom_partition_key(
data_record: dict[str, Any],
):
"""Test custom partition key."""
store = stores["azure_cosmos_db_no_sql"]
collection_name = "custom_partition_key"
collection = store.get_collection(
collection_name,
data_model_type,
partition_key=PartitionKey(path="/product_type"),
)

composite_key = AzureCosmosDBNoSQLCompositeKey(key=data_record["id"], partition_key=data_record["product_type"])

# Upsert
await collection.create_collection()
await collection.upsert(data_model_type(**data_record))

# Verify
record = await collection.get(composite_key)
assert record is not None
assert isinstance(record, data_model_type)

# Remove
await collection.delete(composite_key)
record = await collection.get(composite_key)
assert record is None

# Remove collection
await collection.delete_collection()
assert await collection.does_collection_exist() is False
async with stores["azure_cosmos_db_no_sql"] as store:
collection_name = "custom_partition_key"
collection = store.get_collection(
collection_name,
data_model_type,
partition_key=PartitionKey(path="/product_type"),
)

composite_key = AzureCosmosDBNoSQLCompositeKey(
key=data_record["id"], partition_key=data_record["product_type"]
)

# Upsert
await collection.create_collection()
await collection.upsert(data_model_type(**data_record))

# Verify
record = await collection.get(composite_key)
assert record is not None
assert isinstance(record, data_model_type)

# Remove
await collection.delete(composite_key)
record = await collection.get(composite_key)
assert record is None

# Remove collection
await collection.delete_collection()
assert await collection.does_collection_exist() is False

@pytest.mark.asyncio
async def test_get_include_vector(
Expand All @@ -124,28 +125,28 @@ async def test_get_include_vector(
data_record: dict[str, Any],
):
"""Test get with include_vector."""
store = stores["azure_cosmos_db_no_sql"]
collection_name = "get_include_vector"
collection = store.get_collection(collection_name, data_model_type)
async with stores["azure_cosmos_db_no_sql"] as store:
collection_name = "get_include_vector"
collection = store.get_collection(collection_name, data_model_type)

# Upsert
await collection.create_collection()
await collection.upsert(data_model_type(**data_record))
# Upsert
await collection.create_collection()
await collection.upsert(data_model_type(**data_record))

# Verify
record = await collection.get(data_record["id"], include_vectors=True)
assert record is not None
assert isinstance(record, data_model_type)
assert record.vector == data_record["vector"]
# Verify
record = await collection.get(data_record["id"], include_vectors=True)
assert record is not None
assert isinstance(record, data_model_type)
assert record.vector == data_record["vector"]

# Remove
await collection.delete(data_record["id"])
record = await collection.get(data_record["id"])
assert record is None
# Remove
await collection.delete(data_record["id"])
record = await collection.get(data_record["id"])
assert record is None

# Remove collection
await collection.delete_collection()
assert await collection.does_collection_exist() is False
# Remove collection
await collection.delete_collection()
assert await collection.does_collection_exist() is False

@pytest.mark.asyncio
async def test_get_not_include_vector(
Expand All @@ -155,28 +156,28 @@ async def test_get_not_include_vector(
data_record: dict[str, Any],
):
"""Test get with include_vector."""
store = stores["azure_cosmos_db_no_sql"]
collection_name = "get_not_include_vector"
collection = store.get_collection(collection_name, data_model_type)
async with stores["azure_cosmos_db_no_sql"] as store:
collection_name = "get_not_include_vector"
collection = store.get_collection(collection_name, data_model_type)

# Upsert
await collection.create_collection()
await collection.upsert(data_model_type(**data_record))
# Upsert
await collection.create_collection()
await collection.upsert(data_model_type(**data_record))

# Verify
record = await collection.get(data_record["id"], include_vectors=False)
assert record is not None
assert isinstance(record, data_model_type)
assert record.vector is None
# Verify
record = await collection.get(data_record["id"], include_vectors=False)
assert record is not None
assert isinstance(record, data_model_type)
assert record.vector is None

# Remove
await collection.delete(data_record["id"])
record = await collection.get(data_record["id"])
assert record is None
# Remove
await collection.delete(data_record["id"])
record = await collection.get(data_record["id"])
assert record is None

# Remove collection
await collection.delete_collection()
assert await collection.does_collection_exist() is False
# Remove collection
await collection.delete_collection()
assert await collection.does_collection_exist() is False

@pytest.mark.asyncio
async def test_collection_with_key_as_key_field(
Expand All @@ -186,29 +187,29 @@ async def test_collection_with_key_as_key_field(
data_record_with_key_as_key_field: dict[str, Any],
):
"""Test collection with key as key field."""
store = stores["azure_cosmos_db_no_sql"]
collection_name = "collection_with_key_as_key_field"
collection = store.get_collection(collection_name, data_model_type_with_key_as_key_field)

# Upsert
await collection.create_collection()
result = await collection.upsert(data_model_type_with_key_as_key_field(**data_record_with_key_as_key_field))
assert data_record_with_key_as_key_field["key"] == result

# Verify
record = await collection.get(data_record_with_key_as_key_field["key"])
assert record is not None
assert isinstance(record, data_model_type_with_key_as_key_field)
assert record.key == data_record_with_key_as_key_field["key"]

# Remove
await collection.delete(data_record_with_key_as_key_field["key"])
record = await collection.get(data_record_with_key_as_key_field["key"])
assert record is None

# Remove collection
await collection.delete_collection()
assert await collection.does_collection_exist() is False
async with stores["azure_cosmos_db_no_sql"] as store:
collection_name = "collection_with_key_as_key_field"
collection = store.get_collection(collection_name, data_model_type_with_key_as_key_field)

# Upsert
await collection.create_collection()
result = await collection.upsert(data_model_type_with_key_as_key_field(**data_record_with_key_as_key_field))
assert data_record_with_key_as_key_field["key"] == result

# Verify
record = await collection.get(data_record_with_key_as_key_field["key"])
assert record is not None
assert isinstance(record, data_model_type_with_key_as_key_field)
assert record.key == data_record_with_key_as_key_field["key"]

# Remove
await collection.delete(data_record_with_key_as_key_field["key"])
record = await collection.get(data_record_with_key_as_key_field["key"])
assert record is None

# Remove collection
await collection.delete_collection()
assert await collection.does_collection_exist() is False

@pytest.mark.asyncio
async def test_custom_client(
Expand All @@ -219,13 +220,14 @@ async def test_custom_client(
url = os.environ.get("AZURE_COSMOS_DB_NO_SQL_URL")
key = os.environ.get("AZURE_COSMOS_DB_NO_SQL_KEY")

async with CosmosClient(url, key) as custom_client:
store = AzureCosmosDBNoSQLStore(
async with (
CosmosClient(url, key) as custom_client,
AzureCosmosDBNoSQLStore(
database_name="test_database",
cosmos_client=custom_client,
create_database=True,
)

) as store,
):
assert await store.list_collection_names() == []

collection_name = "list_collection_names"
Expand Down

0 comments on commit 92a7140

Please sign in to comment.