Skip to content

Commit

Permalink
Python: fix for add_function with conflicts (#6437)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->
Small change in the logic and detection and log.warning for when the
prompt and template and template_format are not the same, when using
kernel.add_function

Fixes #6412 

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄

Co-authored-by: Evan Mattson <[email protected]>
  • Loading branch information
eavanvalkenburg and moonbox3 authored May 29, 2024
1 parent 5d25f6a commit 7cb651a
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 12 deletions.
19 changes: 15 additions & 4 deletions python/semantic_kernel/functions/kernel_function_from_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,8 @@ def __init__(
template_format: TEMPLATE_FORMAT_TYPES = KERNEL_TEMPLATE_FORMAT_NAME,
prompt_template: PromptTemplateBase | None = None,
prompt_template_config: PromptTemplateConfig | None = None,
prompt_execution_settings: None | (
PromptExecutionSettings | list[PromptExecutionSettings] | dict[str, PromptExecutionSettings]
) = None,
prompt_execution_settings: None
| (PromptExecutionSettings | list[PromptExecutionSettings] | dict[str, PromptExecutionSettings]) = None,
) -> None:
"""Initializes a new instance of the KernelFunctionFromPrompt class.
Expand All @@ -85,6 +84,16 @@ def __init__(
through prompt_template_config or in the prompt_template."
)

if prompt and prompt_template_config and prompt_template_config.template != prompt:
logger.warning(
f"Prompt ({prompt}) and PromptTemplateConfig ({prompt_template_config.template}) both supplied, "
"using the template in PromptTemplateConfig, ignoring prompt."
)
if template_format and prompt_template_config and prompt_template_config.template_format != template_format:
logger.warning(
f"Template ({template_format}) and PromptTemplateConfig ({prompt_template_config.template_format}) "
"both supplied, using the template format in PromptTemplateConfig, ignoring template."
)
if not prompt_template:
if not prompt_template_config:
# prompt must be there if prompt_template and prompt_template_config is not supplied
Expand All @@ -94,7 +103,9 @@ def __init__(
template=prompt,
template_format=template_format,
)
prompt_template = TEMPLATE_FORMAT_MAP[template_format](prompt_template_config=prompt_template_config) # type: ignore
prompt_template = TEMPLATE_FORMAT_MAP[prompt_template_config.template_format](
prompt_template_config=prompt_template_config
) # type: ignore

try:
metadata = KernelFunctionMetadata(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def render(self, kernel: "Kernel", arguments: Optional["KernelArguments"]
arguments = KernelArguments()

arguments = self._get_trusted_arguments(arguments)
allow_unsafe_function_output = self._get_allow_unsafe_function_output()
allow_unsafe_function_output = self._get_allow_dangerously_set_function_output()
helpers: dict[str, Callable[..., Any]] = {}
for plugin in kernel.plugins.values():
helpers.update(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ async def render(self, kernel: "Kernel", arguments: Optional["KernelArguments"]
arguments = KernelArguments()

arguments = self._get_trusted_arguments(arguments)
allow_unsafe_function_output = self._get_allow_unsafe_function_output()
allow_unsafe_function_output = self._get_allow_dangerously_set_function_output()
helpers: dict[str, Callable[..., Any]] = {}
helpers.update(JINJA2_SYSTEM_HELPERS)
for plugin in kernel.plugins.values():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async def render_blocks(self, blocks: list[Block], kernel: "Kernel", arguments:
rendered_blocks: list[str] = []

arguments = self._get_trusted_arguments(arguments)
allow_unsafe_function_output = self._get_allow_unsafe_function_output()
allow_unsafe_function_output = self._get_allow_dangerously_set_function_output()
for block in blocks:
if isinstance(block, TextRenderer):
rendered_blocks.append(block.render(kernel, arguments))
Expand Down
10 changes: 5 additions & 5 deletions python/semantic_kernel/prompt_template/prompt_template_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,17 @@ def _get_trusted_arguments(
new_args[name] = value
return new_args

def _get_allow_unsafe_function_output(self) -> bool:
"""Get the allow_unsafe_function_output flag.
def _get_allow_dangerously_set_function_output(self) -> bool:
"""Get the allow_dangerously_set_content flag.
If the prompt template allows unsafe content, then we do not encode the function output,
unless explicitly allowed by the prompt template config
"""
allow_unsafe_function_output = self.allow_dangerously_set_content
allow_dangerously_set_content = self.allow_dangerously_set_content
if self.prompt_template_config.allow_dangerously_set_content:
allow_unsafe_function_output = True
return allow_unsafe_function_output
allow_dangerously_set_content = True
return allow_dangerously_set_content

def _should_escape(self, name: str, input_variables: list["InputVariable"]) -> bool:
"""Check if the variable should be escaped.
Expand Down
35 changes: 35 additions & 0 deletions python/tests/unit/kernel/test_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from semantic_kernel.functions.kernel_arguments import KernelArguments
from semantic_kernel.functions.kernel_function_decorator import kernel_function
from semantic_kernel.functions.kernel_plugin import KernelPlugin
from semantic_kernel.prompt_template.kernel_prompt_template import KernelPromptTemplate
from semantic_kernel.prompt_template.prompt_template_config import PromptTemplateConfig
from semantic_kernel.services.ai_service_client_base import AIServiceClientBase
from semantic_kernel.services.ai_service_selector import AIServiceSelector

Expand Down Expand Up @@ -229,6 +231,39 @@ def test_add_function_not_provided(kernel: Kernel):
kernel.add_function(function_name="TestFunction", plugin_name="TestPlugin")


def test_add_function_from_prompt_different_values(kernel: Kernel):
template = """
Write a short story about two Corgis on an adventure.
The story must be:
- G rated
- Have a positive message
- No sexism, racism or other bias/bigotry
- Be exactly {{$paragraph_count}} paragraphs long
- Be written in this language: {{$language}}
- The two names of the corgis are {{GenerateNames.generate_names}}
"""
prompt = "test"

kernel.add_function(
prompt=prompt,
function_name="TestFunction",
plugin_name="TestPlugin",
description="Write a short story.",
template_format="handlebars",
prompt_template_config=PromptTemplateConfig(
template=template,
),
execution_settings=PromptExecutionSettings(
extension_data={"max_tokens": 500, "temperature": 0.5, "top_p": 0.5}
),
)
func = kernel.get_function("TestPlugin", "TestFunction")
assert func.name == "TestFunction"
assert func.description == "Write a short story."
assert isinstance(func.prompt_template, KernelPromptTemplate)
assert len(func.parameters) == 2


def test_add_functions(kernel: Kernel):
@kernel_function(name="func1")
def func1(arg1: str) -> str:
Expand Down

0 comments on commit 7cb651a

Please sign in to comment.