Skip to content

Commit

Permalink
add documents to turn
Browse files Browse the repository at this point in the history
  • Loading branch information
dineshyv committed Jan 6, 2025
1 parent 827ca40 commit bd5786d
Show file tree
Hide file tree
Showing 9 changed files with 865 additions and 352 deletions.
901 changes: 574 additions & 327 deletions docs/notebooks/Llama_Stack_Building_AI_Applications.ipynb

Large diffs are not rendered by default.

39 changes: 39 additions & 0 deletions docs/resources/llama-stack-spec.html
Original file line number Diff line number Diff line change
Expand Up @@ -3978,6 +3978,41 @@
"stream": {
"type": "boolean"
},
"documents": {
"type": "array",
"items": {
"type": "object",
"properties": {
"content": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/InterleavedContentItem"
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/InterleavedContentItem"
}
},
{
"$ref": "#/components/schemas/URL"
}
]
},
"mime_type": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"content",
"mime_type"
]
}
},
"tools": {
"type": "array",
"items": {
Expand Down Expand Up @@ -6699,6 +6734,9 @@
"gradient_accumulation_steps": {
"type": "integer"
},
"max_validation_steps": {
"type": "integer"
},
"data_config": {
"$ref": "#/components/schemas/DataConfig"
},
Expand All @@ -6718,6 +6756,7 @@
"n_epochs",
"max_steps_per_epoch",
"gradient_accumulation_steps",
"max_validation_steps",
"data_config",
"optimizer_config"
]
Expand Down
22 changes: 22 additions & 0 deletions docs/resources/llama-stack-spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,25 @@ components:
properties:
agent_id:
type: string
documents:
items:
additionalProperties: false
properties:
content:
oneOf:
- type: string
- $ref: '#/components/schemas/InterleavedContentItem'
- items:
$ref: '#/components/schemas/InterleavedContentItem'
type: array
- $ref: '#/components/schemas/URL'
mime_type:
type: string
required:
- content
- mime_type
type: object
type: array
messages:
items:
oneOf:
Expand Down Expand Up @@ -2920,6 +2939,8 @@ components:
type: integer
max_steps_per_epoch:
type: integer
max_validation_steps:
type: integer
n_epochs:
type: integer
optimizer_config:
Expand All @@ -2928,6 +2949,7 @@ components:
- n_epochs
- max_steps_per_epoch
- gradient_accumulation_steps
- max_validation_steps
- data_config
- optimizer_config
type: object
Expand Down
9 changes: 9 additions & 0 deletions llama_stack/apis/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ class Attachment(BaseModel):
mime_type: str


class Document(BaseModel):
content: InterleavedContent | URL
mime_type: str


class StepCommon(BaseModel):
turn_id: str
step_id: str
Expand Down Expand Up @@ -272,6 +277,9 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
]
]

documents: Optional[List[Document]] = None
tools: Optional[List[AgentTool]] = None

stream: Optional[bool] = False


Expand Down Expand Up @@ -308,6 +316,7 @@ async def create_agent_turn(
]
],
stream: Optional[bool] = False,
documents: Optional[List[Document]] = None,
tools: Optional[List[AgentTool]] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...

Expand Down
134 changes: 130 additions & 4 deletions llama_stack/providers/inline/agents/meta_reference/agent_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,18 @@
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnStartPayload,
Attachment,
Document,
InferenceStep,
ShieldCallStep,
StepType,
ToolExecutionStep,
Turn,
)
from llama_stack.apis.common.content_types import TextContentItem, URL
from llama_stack.apis.common.content_types import (
InterleavedContent,
TextContentItem,
URL,
)
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
CompletionMessage,
Expand All @@ -55,8 +60,8 @@
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.memory import Memory, MemoryBankDocument
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.providers.utils.kvstore import KVStore
Expand Down Expand Up @@ -190,6 +195,7 @@ async def create_and_execute_turn(
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
documents=request.documents,
tools_for_turn=request.tools,
):
if isinstance(chunk, CompletionMessage):
Expand Down Expand Up @@ -240,6 +246,7 @@ async def run(
input_messages: List[Message],
sampling_params: SamplingParams,
stream: bool = False,
documents: Optional[List[Document]] = None,
tools_for_turn: Optional[List[AgentTool]] = None,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
Expand All @@ -257,7 +264,13 @@ async def run(
yield res

async for res in self._run(
session_id, turn_id, input_messages, sampling_params, stream, tools_for_turn
session_id,
turn_id,
input_messages,
sampling_params,
stream,
documents,
tools_for_turn,
):
if isinstance(res, bool):
return
Expand Down Expand Up @@ -352,6 +365,7 @@ async def _run(
input_messages: List[Message],
sampling_params: SamplingParams,
stream: bool = False,
documents: Optional[List[Document]] = None,
tools_for_turn: Optional[List[AgentTool]] = None,
) -> AsyncGenerator:
tool_args = {}
Expand All @@ -361,6 +375,7 @@ async def _run(
tool_args[tool.name] = tool.args

tool_defs = await self._get_tool_defs(tools_for_turn)
await self.handle_documents(session_id, documents, input_messages, tool_defs)
if "memory" in tool_defs and len(input_messages) > 0:
with tracing.span("memory_tool") as span:
step_id = str(uuid.uuid4())
Expand All @@ -378,6 +393,11 @@ async def _run(
"query": input_messages[-1],
**extra_args,
}

session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
if session_info.memory_bank_id:
args["memory_bank_id"] = session_info.memory_bank_id
serialized_args = tracing.serialize_value(args)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
Expand Down Expand Up @@ -732,6 +752,112 @@ async def _get_tool_defs(

return ret

async def handle_documents(
self,
session_id: str,
documents: List[Document],
input_messages: List[Message],
tool_defs: Dict[str, ToolDefinition],
) -> None:
memory_tool = tool_defs.get("memory", None)
code_interpreter_tool = tool_defs.get("code_interpreter", None)
if documents:
content_items = [
d for d in documents if isinstance(d.content, InterleavedContent)
]
url_items = [d for d in documents if isinstance(d.content, URL)]
pattern = re.compile("^(https?://|file://|data:)")
url_items = [
URL(uri=a.content) for a in url_items if pattern.match(a.content)
]
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
if code_interpreter_tool:
for c in content_items:
temp_file_path = os.path.join(
self.tempdir, f"{make_random_string()}.txt"
)
with open(temp_file_path, "w") as temp_file:
temp_file.write(c.content)
url_items.append(URL(uri=f"file://{temp_file_path}"))

if memory_tool and code_interpreter_tool:
# if both memory and code_interpreter are available, we download the URLs
# and attach the data to the last message.
msg = await attachment_message(self.tempdir, url_items)
input_messages.append(msg)
# Since memory is present, add all the data to the memory bank
await self.add_to_session_memory_bank(session_id, documents)
elif code_interpreter_tool:
# if only code_interpreter is available, we download the URLs to a tempdir
# and attach the path to them as a message to inference with the
# assumption that the model invokes the code_interpreter tool with the path
msg = await attachment_message(self.tempdir, url_items)
input_messages.append(msg)
elif memory_tool:
# if only memory is available, we load the data from the URLs and content items to the memory bank
await self.add_to_session_memory_bank(session_id, documents)
else:
# if no memory or code_interpreter tool is available,
# we try to load the data from the URLs and content items as a message to inference
# and add it to the last message's context
input_messages[-1].context = content_items + load_data_from_urls(
url_items
)

async def _ensure_memory_bank(self, session_id: str) -> str:
session_info = await self.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")

if session_info.memory_bank_id is None:
bank_id = f"memory_bank_{session_id}"
await self.memory_banks_api.register_memory_bank(
memory_bank_id=bank_id,
params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
),
)
await self.storage.add_memory_bank_to_session(session_id, bank_id)
else:
bank_id = session_info.memory_bank_id

return bank_id

async def add_to_session_memory_bank(
self, session_id: str, data: List[Document]
) -> None:
bank_id = await self._ensure_memory_bank(session_id)
documents = [
MemoryBankDocument(
document_id=str(uuid.uuid4()),
content=a.content,
mime_type=a.mime_type,
metadata={},
)
for a in data
]
await self.memory_api.insert_documents(
bank_id=bank_id,
documents=documents,
)


async def load_data_from_urls(urls: List[URL]) -> List[str]:
data = []
for url in urls:
uri = url.uri
if uri.startswith("file://"):
filepath = uri[len("file://") :]
with open(filepath, "r") as f:
data.append(f.read())
elif uri.startswith("http"):
async with httpx.AsyncClient() as client:
r = await client.get(uri)
resp = r.text
data.append(resp)
return data


async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
content = []
Expand Down
3 changes: 3 additions & 0 deletions llama_stack/providers/inline/agents/meta_reference/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
AgentStepResponse,
AgentTool,
AgentTurnCreateRequest,
Document,
Session,
Turn,
)
Expand Down Expand Up @@ -147,6 +148,7 @@ async def create_agent_turn(
]
],
tools: Optional[List[AgentTool]] = None,
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
Expand All @@ -155,6 +157,7 @@ async def create_agent_turn(
messages=messages,
stream=True,
tools=tools,
documents=documents,
)
if stream:
return self._create_agent_turn_streaming(request)
Expand Down
12 changes: 12 additions & 0 deletions llama_stack/providers/inline/agents/meta_reference/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
class AgentSessionInfo(BaseModel):
session_id: str
session_name: str
memory_bank_id: Optional[str] = None
started_at: datetime


Expand Down Expand Up @@ -51,6 +52,17 @@ async def get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]:

return AgentSessionInfo(**json.loads(value))

async def add_memory_bank_to_session(self, session_id: str, bank_id: str):
session_info = await self.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")

session_info.memory_bank_id = bank_id
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)

async def add_turn_to_session(self, session_id: str, turn: Turn):
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
Expand Down
Loading

0 comments on commit bd5786d

Please sign in to comment.