Skip to content

Commit

Permalink
Consolidating Memory tests under client-sdk
Browse files Browse the repository at this point in the history
Summary:
Part of #651

Requirements
* add more integration tests in tests/client-sdk covering functionalities in llama-stack-apps

Porting tests from
* llama_stack/providers/tests/memory/test_memory.py

Ensuring we cover some basic functions
* MemoryResource src/llama_stack_client/resources/memory.py
* MemoryBanksResource src/llama_stack_client/resources/memory_banks.py

Test Plan:
Run against the stack as lib
```
LLAMA_STACK_CONFIG=tests/client-sdk/memory/resources/run.yaml pytest tests/client-sdk/memory -v

tests/client-sdk/memory/test_memory.py::test_memory_bank_list PASSED                                                                                                                     [ 20%]
tests/client-sdk/memory/test_memory.py::test_memory_bank_register PASSED                                                                                                                 [ 40%]
tests/client-sdk/memory/test_memory.py::test_memory_bank_unregister PASSED                                                                                                               [ 60%]
tests/client-sdk/memory/test_memory.py::test_memory_bank_insert_inline_and_query PASSED                                                                                                  [ 80%]
tests/client-sdk/memory/test_memory.py::test_memory_bank_insert_from_url_and_query PASSED                                                                                                [100%]
```


Run agianst the local server
```
LLAMA_STACK_BASE_URL=http://localhost:5000 pytest tests/client-sdk/memory -v


tests/client-sdk/memory/test_memory.py::test_memory_bank_list PASSED                                                                                                                     [ 20%]
tests/client-sdk/memory/test_memory.py::test_memory_bank_register PASSED                                                                                                                 [ 40%]
tests/client-sdk/memory/test_memory.py::test_memory_bank_unregister PASSED                                                                                                               [ 60%]
tests/client-sdk/memory/test_memory.py::test_memory_bank_insert_inline_and_query PASSED                                                                                                  [ 80%]
tests/client-sdk/memory/test_memory.py::test_memory_bank_insert_from_url_and_query PASSED                                                                                                [100%]
```
  • Loading branch information
vladimirivic committed Jan 1, 2025
1 parent a6c206e commit ff09600
Show file tree
Hide file tree
Showing 3 changed files with 339 additions and 15 deletions.
108 changes: 108 additions & 0 deletions tests/client-sdk/memory/resources/run.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
version: '2'
image_name: ollama
docker_image: null
conda_env: ollama
apis:
- agents
- datasetio
- eval
- inference
- memory
- safety
- scoring
- telemetry
providers:
inference:
- provider_id: ollama
provider_type: remote::ollama
config:
url: ${env.OLLAMA_URL:http://localhost:11434}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
memory:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
- provider_id: code-scanner
provider_type: inline::code-scanner
config: {}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config: {}
- provider_id: localfs
provider_type: inline::localfs
config: {}
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: ollama
provider_model_id: null
model_type: llm
- metadata: {}
model_id: ${env.SAFETY_MODEL}
provider_id: ollama
provider_model_id: null
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields:
- params: null
shield_id: ${env.SAFETY_MODEL}
provider_id: llama-guard
provider_shield_id: null
- params: null
shield_id: CodeScanner
provider_id: code-scanner
provider_shield_id: null
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []
28 changes: 28 additions & 0 deletions tests/client-sdk/memory/run_tests.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Using Llama Stack as Library
```
LLAMA_STACK_CONFIG=tests/client-sdk/memory/resources/run.yaml pytest tests/client-sdk/memory -v
# Alternatively, you can use distribution template names e.g. "ollama", "together", "vllm"
LLAMA_STACK_CONFIG=ollama pytest tests/client-sdk/memory -v
```

# Using local Llama Stack server instance
```
# Export Llama Stack naming vars
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
export INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct"
# Export Ollama naming vars
export OLLAMA_INFERENCE_MODEL="llama3.2:3b-instruct-fp16"
export OLLAMA_SAFETY_MODEL="llama-guard3:1b"
# Start Ollama instance
ollama run $OLLAMA_INFERENCE_MODEL --keepalive 60m
ollama run $OLLAMA_SAFETY_MODEL --keepalive 60m
# Start the Llama Stack server
llama stack run ./llama_stack/templates/ollama/run-with-safety.yaml
# Run the tests
LLAMA_STACK_BASE_URL=http://localhost:5000 pytest tests/client-sdk/memory -v
```
218 changes: 203 additions & 15 deletions tests/client-sdk/memory/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,199 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import random

import pytest

from llama_stack_client.types.memory_insert_params import Document
from llama_stack.apis.memory import MemoryBankDocument


def test_memory_bank(llama_stack_client):
providers = llama_stack_client.providers.list()
if "memory" not in providers:
pytest.skip("No memory provider available")
@pytest.fixture(scope="function")
def empty_memory_bank_registry(llama_stack_client):
memory_banks = [
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
]
for memory_bank_id in memory_banks:
llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id)


@pytest.fixture(scope="function")
def single_entry_memory_bank_registry(llama_stack_client, empty_memory_bank_registry):
memory_bank_id = f"test_bank_{random.randint(1000, 9999)}"
llama_stack_client.memory_banks.register(
memory_bank_id=memory_bank_id,
params={
"memory_bank_type": "vector",
"embedding_model": "all-MiniLM-L6-v2",
"chunk_size_in_tokens": 512,
"overlap_size_in_tokens": 64,
},
provider_id="faiss",
)
memory_banks = [
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
]
return memory_banks


@pytest.fixture(scope="session")
def sample_documents():
return [
MemoryBankDocument(
document_id="test-doc-1",
content="Python is a high-level programming language.",
metadata={"category": "programming", "difficulty": "beginner"},
),
MemoryBankDocument(
document_id="test-doc-2",
content="Machine learning is a subset of artificial intelligence.",
metadata={"category": "AI", "difficulty": "advanced"},
),
MemoryBankDocument(
document_id="test-doc-3",
content="Data structures are fundamental to computer science.",
metadata={"category": "computer science", "difficulty": "intermediate"},
),
MemoryBankDocument(
document_id="test-doc-4",
content="Neural networks are inspired by biological neural networks.",
metadata={"category": "AI", "difficulty": "advanced"},
),
]


def assert_valid_response(response):
assert len(response.chunks) > 0
assert len(response.scores) > 0
assert len(response.chunks) == len(response.scores)
for chunk in response.chunks:
assert isinstance(chunk.content, str)
assert chunk.document_id is not None


def test_memory_bank_retrieve(llama_stack_client, empty_memory_bank_registry):
# Register a memory bank first
memory_bank_id = f"test_bank_{random.randint(1000, 9999)}"
llama_stack_client.memory_banks.register(
memory_bank_id=memory_bank_id,
params={
"memory_bank_type": "vector",
"embedding_model": "all-MiniLM-L6-v2",
"chunk_size_in_tokens": 512,
"overlap_size_in_tokens": 64,
},
provider_id="faiss",
)

# Retrieve the memory bank and validate its properties
response = llama_stack_client.memory_banks.retrieve(memory_bank_id=memory_bank_id)
assert response is not None
assert response.identifier == memory_bank_id
assert response.type == "memory_bank"
assert response.memory_bank_type == "vector"
assert response.embedding_model == "all-MiniLM-L6-v2"
assert response.chunk_size_in_tokens == 512
assert response.overlap_size_in_tokens == 64
assert response.provider_id == "faiss"
assert response.provider_resource_id == memory_bank_id


def test_memory_bank_list(llama_stack_client, empty_memory_bank_registry):
memory_banks_after_register = [
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
]
assert len(memory_banks_after_register) == 0


def test_memory_bank_register(llama_stack_client, empty_memory_bank_registry):
memory_provider_id = "faiss"
memory_bank_id = f"test_bank_{random.randint(1000, 9999)}"
llama_stack_client.memory_banks.register(
memory_bank_id=memory_bank_id,
params={
"memory_bank_type": "vector",
"embedding_model": "all-MiniLM-L6-v2",
"chunk_size_in_tokens": 512,
"overlap_size_in_tokens": 64,
},
provider_id=memory_provider_id,
)

memory_banks_after_register = [
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
]
assert memory_banks_after_register == [memory_bank_id]


def test_memory_bank_unregister(llama_stack_client, single_entry_memory_bank_registry):
memory_banks = [
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
]
assert len(memory_banks) == 1

memory_bank_id = memory_banks[0]
llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id)

memory_banks = [
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
]
assert len(memory_banks) == 0


def test_memory_bank_insert_inline_and_query(
llama_stack_client, single_entry_memory_bank_registry, sample_documents
):
memory_bank_id = single_entry_memory_bank_registry[0]
llama_stack_client.memory.insert(
bank_id=memory_bank_id,
documents=sample_documents,
)

# Query with a direct match
query1 = "programming language"
response1 = llama_stack_client.memory.query(
bank_id=memory_bank_id,
query=query1,
)
assert_valid_response(response1)
assert any("Python" in chunk.content for chunk in response1.chunks)

# Query with semantic similarity
query2 = "AI and brain-inspired computing"
response2 = llama_stack_client.memory.query(
bank_id=memory_bank_id,
query=query2,
)
assert_valid_response(response2)
assert any("neural networks" in chunk.content.lower() for chunk in response2.chunks)

# Query with limit on number of results (max_chunks=2)
query3 = "computer"
response3 = llama_stack_client.memory.query(
bank_id=memory_bank_id,
query=query3,
params={"max_chunks": 2},
)
assert_valid_response(response3)
assert len(response3.chunks) <= 2

# Query with threshold on similarity score
query4 = "computer"
response4 = llama_stack_client.memory.query(
bank_id=memory_bank_id,
query=query4,
params={"score_threshold": 0.01},
)
assert_valid_response(response4)
assert all(score >= 0.01 for score in response4.scores)


# get memory provider id
def test_memory_bank_insert_from_url_and_query(
llama_stack_client, empty_memory_bank_registry
):
providers = llama_stack_client.providers.list()
assert "memory" in providers
assert len(providers["memory"]) > 0

memory_provider_id = providers["memory"][0].provider_id
Expand All @@ -36,12 +219,13 @@ def test_memory_bank(llama_stack_client):
]
assert memory_bank_id in available_memory_banks

# add documents to memory bank
# URLs of documents to insert
# TODO: Move to test/memory/resources then update the url to
# https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/memory/resources/{url}
urls = [
"memory_optimizations.rst",
"chat.rst",
"llama3.rst",
"datasets.rst",
]
documents = [
Document(
Expand All @@ -58,14 +242,18 @@ def test_memory_bank(llama_stack_client):
documents=documents,
)

# query documents
response = llama_stack_client.memory.query(
# Query for the name of method
response1 = llama_stack_client.memory.query(
bank_id=memory_bank_id,
query="How do I use lora",
query="What's the name of the fine-tunning method used?",
)
assert_valid_response(response1)
assert any("lora" in chunk.content.lower() for chunk in response1.chunks)

assert len(response.chunks) > 0
assert len(response.chunks) == len(response.scores)

contents = [chunk.content for chunk in response.chunks]
assert "lora" in contents[0].lower()
# Query for the name of model
response2 = llama_stack_client.memory.query(
bank_id=memory_bank_id,
query="Which Llama model is mentioned?",
)
assert_valid_response(response1)
assert any("llama2" in chunk.content.lower() for chunk in response2.chunks)

0 comments on commit ff09600

Please sign in to comment.