Skip to content

Commit

Permalink
Python: Improve agent retrieval by passing necessary kwargs. Add unit…
Browse files Browse the repository at this point in the history
… tests. (microsoft#10116)

### Motivation and Context

The assistant retrieval methods for OpenAI and AzureOpenAI didn't
propagate the api_key or client, if specified. For AzureOpenAI, it's
also important to propagate more settings info.

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

### Description

This PR:
- improves the handling for retrieving an agent by making sure we pass
along specified keyword arguments
- adds unit tests to ensure the new functionality works as expected

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [X] The code builds clean without any errors or warnings
- [X] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [X] All unit tests pass, and I have added new tests where possible
- [X] I didn't break anyone 😄
  • Loading branch information
moonbox3 authored Jan 9, 2025
1 parent f5f38c7 commit 12b6b82
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 6 deletions.
4 changes: 4 additions & 0 deletions python/samples/concepts/agents/assistant_agent_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ async def main():
enable_code_interpreter=True,
)

assistant_id = agent.assistant.id

retrieved_agent: AzureAssistantAgent = await AzureAssistantAgent.retrieve(
id=assistant_id,
kernel=kernel,
Expand All @@ -71,6 +73,8 @@ async def main():
enable_code_interpreter=True,
)

assistant_id = agent.assistant.id

# Retrieve the agent using the assistant_id
retrieved_agent: OpenAIAssistantAgent = await OpenAIAssistantAgent.retrieve(
id=assistant_id,
Expand Down
19 changes: 17 additions & 2 deletions python/semantic_kernel/agents/open_ai/azure_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,20 @@ async def retrieve(
assistant = await client.beta.assistants.retrieve(id)
assistant_definition = OpenAIAssistantBase._create_open_ai_assistant_definition(assistant)

return AzureAssistantAgent(kernel=kernel, assistant=assistant, **assistant_definition)
return AzureAssistantAgent(
kernel=kernel,
assistant=assistant,
client=client,
ad_token=ad_token,
api_key=api_key,
endpoint=endpoint,
api_version=api_version,
default_headers=default_headers,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
token_endpoint=token_endpoint,
**assistant_definition,
)

@staticmethod
def _setup_client_and_token(
Expand Down Expand Up @@ -393,7 +406,9 @@ def _setup_client_and_token(

# If we still have no credentials, we can't proceed
if not client and not azure_openai_settings.api_key and not ad_token and not ad_token_provider:
raise AgentInitializationException("Please provide either api_key, ad_token or ad_token_provider.")
raise AgentInitializationException(
"Please provide either a client, an api_key, ad_token or ad_token_provider."
)

# Build the client if it's not supplied
if not client:
Expand Down
11 changes: 10 additions & 1 deletion python/semantic_kernel/agents/open_ai/open_ai_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,15 @@ async def retrieve(
)
assistant = await client.beta.assistants.retrieve(id)
assistant_definition = OpenAIAssistantBase._create_open_ai_assistant_definition(assistant)
return OpenAIAssistantAgent(kernel=kernel, assistant=assistant, **assistant_definition)
return OpenAIAssistantAgent(
kernel=kernel,
assistant=assistant,
client=client,
api_key=api_key,
default_headers=default_headers,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
**assistant_definition,
)

# endregion
103 changes: 100 additions & 3 deletions python/tests/unit/agents/test_azure_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ async def test_retrieve_agent_missing_chat_deployment_name_throws(kernel, azure_
@pytest.mark.parametrize("exclude_list", [["AZURE_OPENAI_API_KEY"]], indirect=True)
async def test_retrieve_agent_missing_api_key_throws(kernel, azure_openai_unit_test_env):
with pytest.raises(
AgentInitializationException, match="Please provide either api_key, ad_token or ad_token_provider."
AgentInitializationException, match="Please provide either a client, an api_key, ad_token or ad_token_provider."
):
_ = await AzureAssistantAgent.retrieve(id="test_id", kernel=kernel, env_file_path="test.env")

Expand All @@ -352,7 +352,7 @@ def test_azure_openai_agent_create_missing_deployment_name(azure_openai_unit_tes
@pytest.mark.parametrize("exclude_list", [["AZURE_OPENAI_API_KEY"]], indirect=True)
def test_azure_openai_agent_create_missing_api_key(azure_openai_unit_test_env):
with pytest.raises(
AgentInitializationException, match="Please provide either api_key, ad_token or ad_token_provider."
AgentInitializationException, match="Please provide either a client, an api_key, ad_token or ad_token_provider."
):
AzureAssistantAgent(service_id="test_service", endpoint="https://example.com", env_file_path="test.env")

Expand Down Expand Up @@ -462,7 +462,7 @@ async def test_setup_client_and_token_no_credentials_raises_exception():
mock_settings.token_endpoint = None

with pytest.raises(
AgentInitializationException, match="Please provide either api_key, ad_token or ad_token_provider."
AgentInitializationException, match="Please provide either a client, an api_key, ad_token or ad_token_provider."
):
_ = AzureAssistantAgent._setup_client_and_token(
azure_openai_settings=mock_settings,
Expand All @@ -471,3 +471,100 @@ async def test_setup_client_and_token_no_credentials_raises_exception():
client=None,
default_headers=None,
)


@pytest.mark.parametrize(
"exclude_list, client, api_key, should_raise, expected_exception_msg, should_create_client_call",
[
([], None, "test_api_key", False, None, True),
([], AsyncMock(spec=AsyncAzureOpenAI), None, False, None, False),
(
[],
AsyncMock(spec=AsyncAzureOpenAI),
"test_api_key",
False,
None,
False,
),
(
["AZURE_OPENAI_API_KEY"],
None,
None,
True,
"Please provide either a client, an api_key, ad_token or ad_token_provider.",
False,
),
],
indirect=["exclude_list"],
)
async def test_retrieve_agent_handling_api_key_and_client(
azure_openai_unit_test_env,
exclude_list,
kernel,
client,
api_key,
should_raise,
expected_exception_msg,
should_create_client_call,
):
is_api_key_present = "AZURE_OPENAI_API_KEY" not in exclude_list

with (
patch.object(
AzureAssistantAgent,
"_create_azure_openai_settings",
return_value=MagicMock(
chat_model_id="test_model",
api_key=MagicMock(
get_secret_value=MagicMock(return_value="test_api_key" if is_api_key_present else None)
)
if is_api_key_present
else None,
),
),
patch.object(
AzureAssistantAgent,
"_create_client",
return_value=AsyncMock(spec=AsyncAzureOpenAI),
) as mock_create_client,
patch.object(
OpenAIAssistantBase,
"_create_open_ai_assistant_definition",
return_value={
"ai_model_id": "test_model",
"description": "test_description",
"id": "test_id",
"name": "test_name",
},
) as mock_create_def,
):
if client:
client.beta = MagicMock()
client.beta.assistants = MagicMock()
client.beta.assistants.retrieve = AsyncMock(return_value=MagicMock(spec=Assistant))
else:
mock_client_instance = mock_create_client.return_value
mock_client_instance.beta = MagicMock()
mock_client_instance.beta.assistants = MagicMock()
mock_client_instance.beta.assistants.retrieve = AsyncMock(return_value=MagicMock(spec=Assistant))

if should_raise:
with pytest.raises(AgentInitializationException, match=expected_exception_msg):
await AzureAssistantAgent.retrieve(id="test_id", kernel=kernel, api_key=api_key, client=client)
return

retrieved_agent = await AzureAssistantAgent.retrieve(
id="test_id", kernel=kernel, api_key=api_key, client=client
)

if should_create_client_call:
mock_create_client.assert_called_once()
else:
mock_create_client.assert_not_called()

assert retrieved_agent.ai_model_id == "test_model"
mock_create_def.assert_called_once()
if client:
client.beta.assistants.retrieve.assert_called_once_with("test_id")
else:
mock_client_instance.beta.assistants.retrieve.assert_called_once_with("test_id")
90 changes: 90 additions & 0 deletions python/tests/unit/agents/test_open_ai_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,3 +509,93 @@ async def test_retrieve_agent(kernel, openai_unit_test_env):
}
mock_client_instance.beta.assistants.retrieve.assert_called_once_with("test_id")
mock_create_def.assert_called_once()


@pytest.mark.parametrize(
"exclude_list, client, api_key, should_raise, expected_exception_msg, should_create_client_call",
[
([], None, "test_api_key", False, None, True),
([], AsyncMock(spec=AsyncOpenAI), None, False, None, False),
([], AsyncMock(spec=AsyncOpenAI), "test_api_key", False, None, False),
(
["OPENAI_API_KEY"],
None,
None,
True,
"The OpenAI API key is required, if a client is not provided.",
False,
),
],
indirect=["exclude_list"],
)
async def test_retrieve_agent_handling_api_key_and_client(
openai_unit_test_env,
exclude_list,
kernel,
client,
api_key,
should_raise,
expected_exception_msg,
should_create_client_call,
):
is_api_key_present = "OPENAI_API_KEY" not in exclude_list

with (
patch.object(
OpenAIAssistantAgent,
"_create_open_ai_settings",
return_value=MagicMock(
chat_model_id="test_model",
api_key=MagicMock(
get_secret_value=MagicMock(return_value="test_api_key" if is_api_key_present else None)
)
if is_api_key_present
else None,
),
),
patch.object(
OpenAIAssistantAgent,
"_create_client",
return_value=AsyncMock(spec=AsyncOpenAI),
) as mock_create_client,
patch.object(
OpenAIAssistantBase,
"_create_open_ai_assistant_definition",
return_value={
"ai_model_id": "test_model",
"description": "test_description",
"id": "test_id",
"name": "test_name",
},
) as mock_create_def,
):
if client:
client.beta = MagicMock()
client.beta.assistants = MagicMock()
client.beta.assistants.retrieve = AsyncMock(return_value=MagicMock(spec=Assistant))
else:
mock_client_instance = mock_create_client.return_value
mock_client_instance.beta = MagicMock()
mock_client_instance.beta.assistants = MagicMock()
mock_client_instance.beta.assistants.retrieve = AsyncMock(return_value=MagicMock(spec=Assistant))

if should_raise:
with pytest.raises(AgentInitializationException, match=expected_exception_msg):
await OpenAIAssistantAgent.retrieve(id="test_id", kernel=kernel, api_key=api_key, client=client)
return

retrieved_agent = await OpenAIAssistantAgent.retrieve(
id="test_id", kernel=kernel, api_key=api_key, client=client
)

if should_create_client_call:
mock_create_client.assert_called_once()
else:
mock_create_client.assert_not_called()

assert retrieved_agent.ai_model_id == "test_model"
mock_create_def.assert_called_once()
if client:
client.beta.assistants.retrieve.assert_called_once_with("test_id")
else:
mock_client_instance.beta.assistants.retrieve.assert_called_once_with("test_id")

0 comments on commit 12b6b82

Please sign in to comment.