diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 5856d0358..55580e7bc 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -2,11 +2,9 @@ name: Unit tests on: push: - branches: - - master + branches: [main] pull_request: - branches: - - master + branches: [main] jobs: build: diff --git a/guidance/_grammar.py b/guidance/_grammar.py index 7a76293fa..6385dcad1 100644 --- a/guidance/_grammar.py +++ b/guidance/_grammar.py @@ -1,77 +1,218 @@ -import parsimonious - -# define the Guidance language grammar -grammar = parsimonious.grammar.Grammar( -r""" -template = template_chunk* -template_chunk = comment / slim_comment / escaped_command / unrelated_escape / command / command_block / content - -comment = comment_start comment_content* comment_end -comment_start = "{{!--" -comment_content = not_comment_end / ~r"[^-]*" -not_comment_end = "-" !"-}}" -comment_end = "--}}" - -slim_comment = slim_comment_start slim_comment_content* slim_comment_end -slim_comment_start = "{{" "~"? "!" -slim_comment_content = not_slim_comment_end / ~r"[^}]*" -not_slim_comment_end = "}" !"}" -slim_comment_end = "}}" - -command = command_start command_content command_end -command_block = command_block_open template (command_block_sep template)* command_block_close -command_block_open = command_start "#" block_command_call command_end -command_block_sep = command_start ("or" / "else") command_end -command_block_close = command_start "/" command_name command_end -command_start = "{{" !"!" "~"? -not_command_start = "{" !"{" -not_command_escape = "\\" !"{{" -command_end = "~"? "}}" -command_contents = ~'[^{]*' -block_command_call = command_name command_args? -command_content = command_call / variable_ref -command_call = command_name command_args -command_args = command_arg_and_ws+ -command_arg_and_ws = ws command_arg -command_arg = named_command_arg / positional_command_arg -positional_command_arg = command_arg_group / literal / variable_ref -named_command_arg = variable_name "=" (literal / variable_ref) -command_arg_group = "(" command_content ")" -ws = ~r'\s+' -command_contentasdf = ~"[a-z 0-9]*"i -command_name = ~r"[a-z][a-z_0-9\.]*"i / "<" / ">" / "==" / "!=" / ">=" / "<=" -variable_ref = not_exact_or not_exact_else ~r"[@a-z][a-z_0-9\.\[\]\"'-]*"i -not_exact_or = ~r"or[@a-z][a-z_0-9\.\[\]\"'-]"i / !"or" -not_exact_else = ~r"else[@a-z][a-z_0-9\.\[\]\"'-]"i / !"else" -variable_name = ~r"[@a-z][a-z_0-9]*"i -contentw = ~r'.*' -content = (not_command_start / not_command_escape / ~r"[^{\\]")+ -unrelated_escape = "\\" !command_start -escaped_command = "\\" command_start ~r"[^}]*" command_end - -literal = string_literal / number_literal / boolean_literal / array_literal / object_literal - -string_literal = ~r'"[^\"]*"' / ~r"'[^\']*'" - -number_literal = ~r"[0-9\.]+" - -boolean_literal = "True" / "False" - -array_literal = empty_array / single_item_array / multi_item_array -empty_array = array_start ws? array_end -single_item_array = array_start ws? array_item ws? array_end -array_sep = ws? "," ws? -multi_item_array = array_start ws? array_item (array_sep array_item)* ws? array_end -array_start = "[" -array_end = "]" +from typing import Any +import pyparsing as pp + +pp.ParserElement.enable_packrat() +# pp.enable_diag(pp.Diagnostics.enable_debug_on_named_expressions) +# pp.autoname_elements() + +program = pp.Forward() +program_chunk = pp.Forward() + +## whitespace ## + +ws = pp.White() +opt_ws = pp.Optional(ws) + + +## comments ## + +# long-form comments {{!-- my comment --}} +command_end = pp.Suppress(opt_ws + "}}") | pp.Suppress(opt_ws + "~}}" + opt_ws) +long_comment_start = pp.Suppress(pp.Literal("{{!--")) +long_comment_end = pp.Suppress(pp.Literal("--") + command_end) +not_long_comment_end = "-" + ~pp.FollowedBy("-}}") + ~pp.FollowedBy("-~}}") +long_comment_content = not_long_comment_end | pp.OneOrMore(pp.CharsNotIn("-")) +long_comment = pp.Group(pp.Combine(long_comment_start + pp.ZeroOrMore(long_comment_content) + long_comment_end))("long_comment").set_name("long_comment") + +# short-form comments {{! my comment }} +comment_start = pp.Suppress("{{" + pp.Optional("~") + "!") +not_comment_end = "}" + ~pp.FollowedBy("}") | "~" + ~pp.FollowedBy("}}") +comment_content = not_comment_end | pp.OneOrMore(pp.CharsNotIn("~}")) +comment = pp.Group(pp.Combine(comment_start + pp.ZeroOrMore(comment_content) + command_end))("comment") + + +## literals ## + +literal = pp.Forward().set_name("literal") + +# basic literals +string_literal = pp.Group(pp.Suppress('"') + pp.ZeroOrMore(pp.CharsNotIn('"')) + pp.Suppress('"') | pp.Suppress("'") + pp.ZeroOrMore(pp.CharsNotIn("'")) + pp.Suppress("'"))("string_literal") +number_literal = pp.Group(pp.Word(pp.srange("[0-9.]")))("number_literal") +boolean_literal = pp.Group("True" | pp.Literal("False"))("boolean_literal") + +# object literal +object_literal = pp.Forward().set_name("object_literal") +object_start = pp.Suppress("{") +object_end = pp.Suppress("}") +empty_object = object_start + object_end +object_item = string_literal + pp.Suppress(":") + literal +single_item_object = object_start + object_item + object_end +object_sep = pp.Suppress(",") +multi_item_object = object_start + object_item + pp.ZeroOrMore(object_sep + object_item) + object_end +object_literal <<= pp.Group(empty_object | single_item_object | multi_item_object)("object_literal") + +# array literal +array_literal = pp.Forward().set_name("array_literal") +array_start = pp.Suppress("[") +array_end = pp.Suppress("]") array_item = literal +empty_array = array_start + array_end +single_item_array = array_start + array_item + array_end +array_sep = pp.Suppress(",") +multi_item_array = array_start + array_item + pp.ZeroOrMore(array_sep + array_item) + array_end +array_literal <<= pp.Group(empty_array | single_item_array | multi_item_array)("array_literal") + +literal <<= string_literal | number_literal | boolean_literal | array_literal | object_literal + + +## infix operators ## + +code_chunk_no_infix = pp.Forward().set_name("code_chunk_no_infix") + +class OpNode: + def __repr__(self): + return "{}({})".format(self.__class__.__name__, self.operator) + def __getitem__(self, item): + return getattr(self, item) + def get_name(self): + return self.name + +class UnOp(OpNode): + def __init__(self, tokens): + self.operator = tokens[0][0] + self.value = tokens[0][1] + self.name = "unary_operator" + +class BinOp(OpNode): + def __init__(self, tokens): + self.operator = tokens[0][1] + self.lhs = tokens[0][0] + self.rhs = tokens[0][2] + self.name = "binary_operator" + +infix_operator_block = pp.infix_notation(code_chunk_no_infix, [ + ('-', 1, pp.OpAssoc.RIGHT), + (pp.one_of('* /'), 2, pp.OpAssoc.LEFT, BinOp), + (pp.one_of('+ -'), 2, pp.OpAssoc.LEFT, BinOp), + (pp.one_of('< > <= >= == != is in'), 2, pp.OpAssoc.LEFT, BinOp), + (pp.one_of('and'), 2, pp.OpAssoc.LEFT, BinOp), + (pp.one_of('or'), 2, pp.OpAssoc.LEFT, BinOp), +]) + + +## commands ## + +code_chunk = pp.Forward().set_name("code_chunk") +not_keyword = ~pp.FollowedBy(pp.Keyword("or") | pp.Keyword("else") | pp.Keyword("elif")) +command_name = pp.Combine(not_keyword + pp.Word(pp.srange("[A-Za-z_]"), pp.srange("[A-Za-z_0-9]"))) +variable_name = pp.Word(pp.srange("[@A-Za-z_]"), pp.srange("[A-Za-z_0-9]")) +variable_ref = not_keyword + pp.Group(pp.Word(pp.srange("[@A-Za-z_]"), pp.srange("[A-Za-z_0-9\.\[\]\"'-]")))("variable_ref").set_name("variable_ref") +keyword = pp.Group(pp.Keyword("break") | pp.Keyword("continue"))("keyword") + +class SavedTextNode: + """A node that saves the text it matches.""" + def __init__(self, s, loc, tokens): + start_pos = tokens[0] + if len(tokens) == 3: + end_pos = tokens[2] + else: + end_pos = loc + self.text = s[start_pos:end_pos] + assert len(tokens[1]) == 1 + self.tokens = tokens[1][0] + def __repr__(self): + return "SavedTextNode({})".format(self.text) + self.tokens.__repr__() + def __getitem__(self, item): + return self.tokens[item] + def __len__(self): + return len(self.tokens) + def get_name(self): + return self.tokens.get_name() + def __contains__(self, item): + return item in self.tokens + def __getattr__(self, name): + return getattr(self.tokens, name) + def __call__(self, *args, **kwds): + return self.tokens(*args, **kwds) +def SavedText(node): + return pp.Located(node).add_parse_action(SavedTextNode) + +# command arguments +command_arg = pp.Forward() +named_command_arg = variable_name + "=" + code_chunk +command_arg <<= pp.Group(named_command_arg)("named_command_arg").set_name("named_command_arg") | pp.Group(code_chunk)("positional_command_arg").set_name("positional_command_arg") + +# whitespace command format {{my_command arg1 arg2=val}} +ws_command_call = pp.Forward().set_name("ws_command_call") +command_arg_and_ws = pp.Suppress(ws) + command_arg +ws_command_args = pp.OneOrMore(command_arg_and_ws) +# note that we have to list out all the operators here because we match before the infix operator checks +ws_command_call <<= command_name("name") + ~pp.FollowedBy(pp.one_of("+ - * / or not is and <= == >= != < >")) + ws_command_args + +# paren command format {{my_command(arg1, arg2=val)}} +paren_command_call = pp.Forward().set_name("paren_command_call") +command_arg_and_comma_ws = pp.Suppress(",") + command_arg +paren_command_args = pp.Optional(command_arg) + pp.ZeroOrMore(command_arg_and_comma_ws) +paren_command_call <<= (command_name("name") + pp.Suppress("(")).leave_whitespace() - paren_command_args + pp.Suppress(")") + +# code chunks +enclosed_code_chunk = pp.Forward().set_name("enclosed_code_chunk") +paren_group = (pp.Suppress("(") - enclosed_code_chunk + pp.Suppress(")")).set_name("paren_group") +enclosed_code_chunk_cant_infix = (pp.Group(ws_command_call)("command_call") | pp.Group(paren_command_call)("command_call") | literal | keyword | variable_ref | paren_group) + ~pp.FollowedBy(pp.one_of("+ - * / or not is and <= == >= != < >")) +enclosed_code_chunk <<= enclosed_code_chunk_cant_infix | infix_operator_block +code_chunk_no_infix <<= (paren_group | pp.Group(paren_command_call)("command_call") | literal | keyword | variable_ref) # used by infix_operator_block +code_chunk_cant_infix = code_chunk_no_infix + ~pp.FollowedBy(pp.one_of("+ - * / or not is and <= == >= != < >")) # don't match infix operators so we can run this before infix_operator_block +code_chunk_cant_infix.set_name("code_chunk_cant_infix") +code_chunk <<= code_chunk_cant_infix | infix_operator_block + +# command/variable +command_start = pp.Suppress("{{" + ~pp.FollowedBy("!") + pp.Optional("~")) +simple_command_start = pp.Suppress("{{" + ~pp.FollowedBy("!") + pp.Optional("~")) + ~pp.FollowedBy(pp.one_of("# / >")) +command = SavedText(pp.Group(simple_command_start + enclosed_code_chunk + command_end)("command")) + +# partial +always_call = pp.Group(paren_command_call | command_name("name") + pp.Optional(ws_command_args)) +partial = pp.Group(pp.Suppress(pp.Combine(command_start + ">")) + always_call("command_call") + command_end)("partial") + +# block command {{#my_command arg1 arg2=val}}...{{/my_command}} +separator = pp.Group(pp.Keyword("or") | pp.Keyword("else") | (pp.Keyword("elif") + ws_command_args))("separator").set_name("separator") +block_command = pp.Forward() +block_command_call = always_call("command_call") +block_command_open = pp.Suppress(pp.Combine(command_start + "#")) + block_command_call + command_end +block_command_sep = (command_start + separator + command_end)("block_command_sep").set_name("block_command_sep") +block_command_close = SavedText(pp.Group(command_start + pp.Suppress("/") + command_name + command_end)("block_command_close").set_name("block_command_close")) +block_command_content = (pp.Group(program)("block_content_chunk") + pp.ZeroOrMore(block_command_sep + pp.Group(program)("block_content_chunk"))).set_name("block_content") +block_command <<= (block_command_open + SavedText(pp.Group(block_command_content)("block_content")) + block_command_close).leave_whitespace() +block_command = SavedText(pp.Group(block_command)("block_command")).set_name("block_command") + +# block partial {{#>my_command arg1 arg2=val}}...{{/my_command}} +block_partial = pp.Forward() +block_partial_call = always_call("command_call") +block_partial_open = pp.Combine(command_start + pp.Suppress("#>")) + block_partial_call + command_end +block_partial_close = command_start + pp.Suppress("/") + command_name + command_end +block_partial <<= block_partial_open + program + pp.Suppress(block_partial_close) +block_partial = SavedText(pp.Group(block_partial)("block_partial")) + +# escaped commands \{{ not a command }} +not_command_end = "}" + ~pp.FollowedBy("}") +escaped_command = SavedText(pp.Group(pp.Suppress("\\") + command_start + pp.Combine(pp.ZeroOrMore(pp.CharsNotIn("}") | not_command_end)) + command_end)("escaped_command")) +unrelated_escape = "\\" + ~pp.FollowedBy(command_start) + + +## content ## + +not_command_start = "{" + ~pp.FollowedBy("{") +not_command_escape = "\\" + ~pp.FollowedBy("{{") +stripped_whitespace = pp.Suppress(pp.Word(" \t\r\n")) + pp.FollowedBy("{{~") +unstripped_whitespace = pp.Word(" \t\r\n") # no need for a negative FollowedBy because stripped_whitespace will match first +content = pp.Group(pp.Combine(pp.OneOrMore(stripped_whitespace | unstripped_whitespace | not_command_start | not_command_escape | pp.CharsNotIn("{\\ \t\r\n"))))("content").set_name("content") + +# keyword_command = SavedText(pp.Group(command_start + keyword + ws_command_args + command_end)("keyword_command")) +# block_content_chunk = long_comment | comment | escaped_command | unrelated_escape | block_partial | block_command | partial | command | content +# block_content <<= pp.ZeroOrMore(block_content_chunk)("program").leave_whitespace() + +## global program ## -object_literal = empty_object / single_item_object / multi_item_object -empty_object = object_start ws? object_end -single_item_object = object_start ws? object_item ws? object_end -object_sep = ws? "," ws? -multi_item_object = object_start ws? object_item (object_sep object_item)* ws? object_end -object_start = "{" -object_end = "}" -object_item = string_literal ws? ":" ws? literal -""") \ No newline at end of file +program_chunk <<= (long_comment | comment | escaped_command | unrelated_escape | block_partial | block_command | partial | command | content).leave_whitespace() +program <<= pp.ZeroOrMore(program_chunk)("program").leave_whitespace().set_name("program") +grammar = (program + pp.StringEnd()).parse_with_tabs() \ No newline at end of file diff --git a/guidance/_program.py b/guidance/_program.py index 93633ee50..f2168eb08 100644 --- a/guidance/_program.py +++ b/guidance/_program.py @@ -4,17 +4,18 @@ import html import uuid import sys -import parsimonious +# import parsimonious import logging import copy import asyncio import pathlib import os import traceback +import importlib import time import datetime import nest_asyncio -from .llms import _openai +# from .llms import _openai from . import _utils from ._program_executor import ProgramExecutor from . import library @@ -163,10 +164,10 @@ def __init__(self, text, llm=None, cache_seed=0, logprobs=None, silent=None, asy self.update_display = DisplayThrottler(self._update_display, self.display_throttle_limit) # see if we are in an ipython environment - try: - from IPython import get_ipython + # check if get_ipython variable exists + if hasattr(__builtins__, "get_ipython"): self._ipython = get_ipython() - except: + else: self._ipython = None # if we are echoing in ipython we assume we can display html @@ -424,7 +425,7 @@ async def execute(self): else: with self.llm.session(asynchronous=True) as llm_session: await self._executor.run(llm_session) - self._text = self._variables["_prefix"] + self._text = self._variables["@raw_prefix"] # delete the executor and so mark the program as not executing self._executor = None @@ -471,7 +472,7 @@ def text(self): @property def marked_text(self): if self._executor is not None: - return self._variables["_prefix"] + return self._variables["@raw_prefix"] else: return self._text @@ -681,7 +682,11 @@ def add_spaces(s): "if": library.if_, "unless": library.unless, "add": library.add, + "BINARY_OPERATOR_+": library.add, "subtract": library.subtract, + "BINARY_OPERATOR_-": library.subtract, + "multiply": library.multiply, + "BINARY_OPERATOR_*": library.multiply, "strip": library.strip, "block": library.block, "set": library.set, @@ -692,11 +697,13 @@ def add_spaces(s): "assistant": library.assistant, "break": library.break_, "equal": library.equal, - "==": library.equal, + "BINARY_OPERATOR_==": library.equal, + "notequal": library.notequal, + "BINARY_OPERATOR_!=": library.notequal, "greater": library.greater, - ">": library.greater, + "BINARY_OPERATOR_>": library.greater, "less": library.less, - "<": library.less, + "BINARY_OPERATOR_<": library.less, "contains": library.contains, "parse": library.parse } diff --git a/guidance/_program_executor.py b/guidance/_program_executor.py index 195495416..52f838b0a 100644 --- a/guidance/_program_executor.py +++ b/guidance/_program_executor.py @@ -5,7 +5,7 @@ import asyncio import warnings import logging -import parsimonious +import pyparsing as pp from ._utils import strip_markers from ._grammar import grammar from ._variable_stack import VariableStack @@ -26,65 +26,75 @@ def __init__(self, program): self._logging = hasattr(self.program.log, "append") # find all the handlebars-style partial inclusion tags and replace them with the partial template - def replace_partial(match): - parts = match.group(1).split(" ", 1) - partial_name = parts[0] + # def replace_partial(match): + # parts = match.group(1).split(" ", 1) + # partial_name = parts[0] - # ,args_string = match.group(1).split(" ", 1) - if partial_name not in program._variables: - raise ValueError("Partial '%s' not given in the keyword args:" % partial_name) - out = "{{#block '"+partial_name+"'" - if len(parts) > 1: - out += " " + parts[1] - out += "}}" + program._variables[partial_name].text + "{{/block}}" - # Update the current program variables using those from the partial, but do not overwrite. - # (Rebuilding the _variables map here would break returning new values to the program variables later, e.g. from gen.) - update_variables = { - k: v - for k, v in program[partial_name]._variables.items() - if k not in program._variables - } - program._variables.update(update_variables) - return out - text = re.sub(r"{{>(.*?)}}", replace_partial, program._text) + # # ,args_string = match.group(1).split(" ", 1) + # if partial_name not in program._variables: + # raise ValueError("Partial '%s' not given in the keyword args:" % partial_name) + # out = "{{#block '"+partial_name+"'" + # if len(parts) > 1: + # out += " " + parts[1] + # out += "}}" + program._variables[partial_name].text + "{{/block}}" + # # Update the current program variables using those from the partial, but do not overwrite. + # # (Rebuilding the _variables map here would break returning new values to the program variables later, e.g. from gen.) + # update_variables = { + # k: v + # for k, v in program[partial_name]._variables.items() + # if k not in program._variables + # } + # program._variables.update(update_variables) + # return out + # text = re.sub(r"{{>(.*?)}}", replace_partial, program._text) # parse the program text try: - self.parse_tree = grammar.parse(text) - except parsimonious.exceptions.ParseError as e: - self._check_for_simple_error(text) - raise e + self.parse_tree = grammar.parse_string(program._text) + except (pp.ParseException, pp.ParseSyntaxException) as e: + initial_str = program._text[max(0, e.loc-40):e.loc] + initial_str = initial_str.split("\n")[-1] # trim off any lines before the error + next_str = program._text[e.loc:e.loc+40] + error_string = str(e) + if next_str.startswith("{{#") or next_str.startswith("{{~#"): + error_string += "\nPerhaps the block command was not correctly closed?" + msg = error_string + "\n\n"+initial_str + # msg += "\033[91m" + program._text[e.loc:e.loc+40] + "\033[0m\n" + msg += program._text[e.loc:e.loc+40] + "\n" + msg += " " * len(initial_str) + "^\n" + + raise SyntaxException(msg, e) from None - def _check_for_simple_error(self, text): - """ Check for a simple errors in the program text, and give nice error messages. - """ - - vars = self.program._variables - - # missing block pound sign - for k in vars: - if getattr(vars[k], "is_block", False): - - # look for block commands that are missing the opening pound sign or closing slash - m = re.search(r"(^|[^\\]){{\s*"+k+"(\s|}|~)", text) - if m is not None: - # get the context around the matching error - start = max(0, m.start()-30) - end = min(len(text), m.end()+30) - context = text[start:end] - if start > 0: - context = "..."+context - if end < len(text): - context = context+"..." - raise ValueError("The guidance program is missing the opening pound (#) sign or closing slash (/) for the block level command `"+k+"` at:\n"+context) from None + # def _check_for_simple_error(self, text): + # """ Check for a simple errors in the program text, and give nice error messages. + # """ + + # vars = self.program._variables + + # # missing block pound sign + # for k in vars: + # if getattr(vars[k], "is_block", False): + + # # look for block commands that are missing the opening pound sign or closing slash + # m = re.search(r"(^|[^\\]){{\s*"+k+"(\s|}|~)", text) + # if m is not None: + # # get the context around the matching error + # start = max(0, m.start()-30) + # end = min(len(text), m.end()+30) + # context = text[start:end] + # if start > 0: + # context = "..."+context + # if end < len(text): + # context = context+"..." + # raise ValueError("The guidance program is missing the opening pound (#) sign or closing slash (/) for the block level command `"+k+"` at:\n"+context) from None - # look for block commands that are missing the closing tag - num_opens = len(re.findall(r"(^|[^\\]){{~?#\s*"+k+"(\s|}|~)", text)) - num_closes = len(re.findall(r"(^|[^\\]){{~?/\s*"+k+"(\s|}|~)", text)) - if num_opens > num_closes: - raise ValueError("The guidance program is missing a closing tag for the block level command `"+k+"`.") from None - if num_opens < num_closes: - raise ValueError("The guidance program is missing an opening tag for the block level command `"+k+"`.") from None + # # look for block commands that are missing the closing tag + # num_opens = len(re.findall(r"(^|[^\\]){{~?#\s*"+k+"(\s|}|~)", text)) + # num_closes = len(re.findall(r"(^|[^\\]){{~?/\s*"+k+"(\s|}|~)", text)) + # if num_opens > num_closes: + # raise ValueError("The guidance program is missing a closing tag for the block level command `"+k+"`.") from None + # if num_opens < num_closes: + # raise ValueError("The guidance program is missing an opening tag for the block level command `"+k+"`.") from None @@ -97,7 +107,7 @@ async def run(self, llm_session): # self.whitespace_control_visit(self.parse_tree) # now execute the program - self.program._variables["_prefix"] = "" + self.program._variables["@raw_prefix"] = "" await self.visit(self.parse_tree, VariableStack([self.program._variables], self)) except Exception as e: print(traceback.format_exc()) @@ -115,12 +125,12 @@ def stop(self): # return text # def whitespace_control_visit(self, node, next_node=None, prev_node=None, parent_node=None, grandparent_node=None): - # if node.expr_name in ('command', 'command_block_open', 'command_block_sep', 'command_block_close'): + # if node_name in ('command', 'block_command_open', 'block_command_sep', 'block_command_close'): # if node.text.startswith("{{~"): - # if prev_node and prev_node.expr_name == "content": + # if prev_node and prev_node_name == "content": # prev_node.text = prev_node.text + "{{!--GSTRIP--}}" # if node.text.endswith("~}}"): - # if next_node and next_node.expr_name == "content": + # if next_node and next_node_name == "content": # next_node.text = "{{!--GSTRIP--}}" + next_node.text # # visit all our children @@ -141,109 +151,139 @@ async def visit(self, node, variable_stack, next_node=None, next_next_node=None, # (note that this flag will be cleared once the loop is ended) if self.caught_stop_iteration: return "" + + node_name = node.get_name() - if node.expr_name == 'variable_name': + if node_name == 'variable_name': return node.text - elif node.expr_name == 'content': - text = node.text - - # check for white space stripping commands - if next_node is not None and next_node.text.startswith("{{~"): - text = text.rstrip() - if prev_node is not None and prev_node.text.endswith("~}}"): - text = text.lstrip() - - variable_stack["_prefix"] += text - return "" + elif node_name == 'content': + variable_stack["@raw_prefix"] += node[0] + return None - elif node.expr_name == 'comment': - variable_stack["_prefix"] += node.text + elif node_name == 'long_comment': + variable_stack["@raw_prefix"] += node.text return "" - elif node.expr_name == 'slim_comment': - variable_stack["_prefix"] += node.text + elif node_name == 'comment': + variable_stack["@raw_prefix"] += node.text return "" + + elif node_name == 'partial': + partial_program = variable_stack[node[0]["name"]] + tree = grammar.parse_string(partial_program._text) + variable_stack.push({k: v for k,v in partial_program.variables().items() if k not in ["llm", "logging"]}) + out = await self.visit(tree, variable_stack) + variable_stack.pop() + return out - elif node.expr_name == 'command_args': - visited_children = [await self.visit(child, variable_stack) for child in node.children] - return visited_children - - elif node.expr_name == 'command_arg_and_ws': - # visited_children = [await self.visit(child) for child in node.children] - return await self.visit(node.children[1], variable_stack) #visited_children[1] - - elif node.expr_name == 'positional_command_arg': - # visited_children = [await self.visit(child) for child in node.children] - return PositionalArgument(await self.visit(node.children[0], variable_stack)) + elif node_name == 'positional_command_arg': + return PositionalArgument(await self.visit(node[0], variable_stack)) - elif node.expr_name == 'named_command_arg': - # visited_children = [await self.visit(child) for child in node.children] - return NamedArgument(await self.visit(node.children[0], variable_stack), await self.visit(node.children[2], variable_stack)) + elif node_name == 'named_command_arg': + return NamedArgument(node[0], await self.visit(node[2], variable_stack)) - elif node.expr_name == 'command_name': + elif node_name == 'command_name': + return node.text + + elif node_name == 'command_name': return node.text - elif node.expr_name == 'escaped_command': - variable_stack["_prefix"] += node.text[1:] + elif node_name == 'escaped_command': + variable_stack["@raw_prefix"] += node.text[1:] return + + elif node_name == 'boolean_literal': + if node[0] == "True": + return True + elif node[0] == "False": + return False + else: + raise Exception("Invalid boolean literal") + + elif node_name == 'number_literal': + if "." in node[0]: + return float(node[0]) + else: + return int(node[0]) + + elif node_name == 'string_literal': + return node[0] + + elif node_name == 'object_literal': + out = {} + for i in range(0, len(node), 2): + key = await self.visit(node[i], variable_stack) + value = await self.visit(node[i + 1], variable_stack) + out[key] = value + return out + + elif node_name == 'array_literal': + return [await self.visit(node[i], variable_stack) for i in range(0, len(node))] - elif node.expr_name == 'literal': + elif node_name == 'literal': try: return ast.literal_eval(node.text) except Exception as e: raise Exception(f"Error parsing literal: {node.text} ({e})") - elif node.expr_name == 'command': + elif node_name == 'command': # if execution is already stopped before we start the command we just keep the command text if not self.executing: - variable_stack["_prefix"] += node.text + variable_stack["@raw_prefix"] += node.text return # mark our position in case we need to rewind # pos = len(self.prefix) # find the command name - command_head = node.children[1].children[0] - if command_head.expr_name == 'variable_ref': - if callable(variable_stack[command_head.children[0].text]): - name = command_head.children[0].text - else: - name = "variable_ref" - elif command_head.expr_name == 'command_call': - name = command_head.children[0].text - else: - raise Exception("Unknown command head type: "+command_head.expr_name) + if "variable_ref" in node: + name = "variable_ref" + elif "keyword" in node: + name = "keyword" + elif "command_call" in node: + name = node["command_call"]["name"] + else: # binary_operator and unary_operator + name = node[0].get_name() # add the start marker escaped_node_text = node.text.replace("$", "$").replace("{", "{").replace("}", "}") - variable_stack["_prefix"] += "{{!--"+f"GMARKER_START_{name}${escaped_node_text}$"+"--}}" + variable_stack["@raw_prefix"] += "{{!--"+f"GMARKER_START_{name}${escaped_node_text}$"+"--}}" # visit our children self.block_content.append([]) - visited_children = [await self.visit(child, variable_stack, next_node, next_next_node, prev_node, node, parent_node) for child in node.children] + visited_children = [await self.visit(child, variable_stack, next_node, next_next_node, prev_node, node, parent_node) for child in node] self.block_content.pop() out = "".join("" if c is None else str(c) for c in visited_children) - variable_stack["_prefix"] += out + "{{!--" + f"GMARKER_END_{name}$$" + "--}}" + variable_stack["@raw_prefix"] += out + "{{!--" + f"GMARKER_END_{name}$$" + "--}}" # if execution became stopped during the command, we append the command text if not self.executing: # self.reset_prefix(pos) - variable_stack["_prefix"] += node.text + variable_stack["@raw_prefix"] += node.text return - elif node.expr_name == 'command_arg_group': + elif node_name == 'paren_group': visited_children = [await self.visit(child, variable_stack) for child in node.children] return visited_children[1] - elif node.expr_name == 'command_call' or node.expr_name == 'variable_ref': - if node.expr_name == 'command_call': - visited_children = [await self.visit(child, variable_stack) for child in node.children] - command_name, args = visited_children + elif node_name == 'command_call' or node_name == 'variable_ref' or node_name == 'binary_operator' or node_name == 'unary_operator' or node_name == 'keyword': + if node_name == 'command_call': + command_name = node["name"] + args = [await self.visit(child, variable_stack) for child in node[1:]] + elif node_name == 'binary_operator': + command_name = "BINARY_OPERATOR_" + node["operator"] + args = [ + PositionalArgument(await self.visit(node["lhs"], variable_stack)), + PositionalArgument(await self.visit(node["rhs"], variable_stack)) + ] + elif node_name == 'unary_operator': + command_name = "UNARY_OPERATOR_" + node["operator"] + args = [PositionalArgument(await self.visit(node["value"], variable_stack))] else: - command_name = node.text + command_name = node[0] args = [] # return_value = "" @@ -251,7 +291,7 @@ async def visit(self, node, variable_stack, next_node=None, next_next_node=None, command_function = variable_stack[command_name] # we convert a variable reference to a function that returns the variable value - if node.expr_name == "variable_ref" and not callable(command_function): + if node_name == "variable_ref": command_value = command_function command_function = lambda: command_value @@ -265,15 +305,15 @@ async def visit(self, node, variable_stack, next_node=None, next_next_node=None, # return_value += "" if s is None else str(s) # If we are a top level command we extend the prefix - top_level = grandparent_node is not None and grandparent_node.expr_name == "command" + top_level = grandparent_node is not None and grandparent_node.get_name() == "command" # partial_output = self.extend_prefix # pass # otherwise we keep track of output locally so we can return it if not top_level: # partial_output = update_return_value - pos = len(variable_stack["_prefix"]) - variable_stack.push({"_prefix": variable_stack["_prefix"], "_no_display": True}) + pos = len(variable_stack["@raw_prefix"]) + variable_stack.push({"@raw_prefix": variable_stack["@raw_prefix"], "_no_display": True}) # create the arguments for the command positional_args = [] @@ -303,10 +343,10 @@ async def visit(self, node, variable_stack, next_node=None, next_next_node=None, "name": command_name, "positional_args": positional_args, "named_args": {k:v for k,v in named_args.items() if k != "_parser_context"}, - "prefix": variable_stack["prefix"], + "@prefix": variable_stack["@prefix"], # "node_id": id(node) }) - pos = len(variable_stack["prefix"]) + pos = len(variable_stack["@prefix"]) try: if inspect.iscoroutinefunction(command_function): await asyncio.sleep(0) # give other coroutines a chance to run @@ -317,11 +357,11 @@ async def visit(self, node, variable_stack, next_node=None, next_next_node=None, command_output = ret.value self.caught_stop_iteration = True if self._logging: - self.program.log.append({"type": "end", "name": command_name, "new_prefix": variable_stack["prefix"][pos:]}) + self.program.log.append({"type": "end", "name": command_name, "new_prefix": variable_stack["@prefix"][pos:]}) # call partial output if the command didn't itself (and we are still executing) if not top_level: - curr_prefix = variable_stack.pop()["_prefix"] # pop the variable stack we pushed earlier becuause we were hidden + curr_prefix = variable_stack.pop()["@raw_prefix"] # pop the variable stack we pushed earlier becuause we were hidden if command_output is not None: return command_output else: @@ -335,7 +375,7 @@ async def visit(self, node, variable_stack, next_node=None, next_next_node=None, return new_content else: if command_output is not None: - variable_stack["_prefix"] += str(command_output) + variable_stack["@raw_prefix"] += str(command_output) return "" else: # if the variable does not exist we just pause execution @@ -348,7 +388,7 @@ async def visit(self, node, variable_stack, next_node=None, next_next_node=None, # # if we are not a top level command we return the output instead of displaying it # if not top_level: - # return_value = variable_stack.pop()["_prefix"][pos:] + # return_value = variable_stack.pop()["@raw_prefix"][pos:] # # see if we got a list of outputs encoded as a string # parts = re.split(r"{{!--GMARKERmany[^}]+}}", return_value) @@ -359,7 +399,7 @@ async def visit(self, node, variable_stack, next_node=None, next_next_node=None, # else: # return "" - elif node.expr_name == 'block_command_call': + elif node_name == 'block_command_call': parts = [await self.visit(child, variable_stack) for child in node.children] if len(parts) > 1: command_name, args = parts @@ -368,43 +408,56 @@ async def visit(self, node, variable_stack, next_node=None, next_next_node=None, args = [] return command_name, args - elif node.expr_name == 'command_block_open': + elif node_name == 'block_command_open': return await self.visit(node.children[2], variable_stack) # visited_children = [await self.visit(child) for child in node.children] # return visited_children[2] - elif node.expr_name == 'command_block': + elif node_name == 'block_command': # if execution is already stopped before we start the command block we just return unchanged if not self.executing: - variable_stack["_prefix"] += node.text + variable_stack["@raw_prefix"] += node.text return "" # create a block content variable - block_content = [node.children[1]] - for child in node.children[2].children: - if child.text == '': - continue - block_content.append(child.children[0]) - block_content.append(child.children[1]) - self.block_content.append(block_content) + # block_content = node["block_content"] + # block_content = [node.children[1]] + # for child in node.children[2].children: + # if child.text == '': + # continue + # block_content.append(child.children[0]) + # block_content.append(child.children[1]) + assert node[1].get_name() == "block_content" # TODO: figure out why node["block_content"] doesn't work (has to do with SavedText messing up the keys) + self.block_content.append(node[1]) # get the command name and arguments - command_name, command_args = await self.visit(node.children[0], variable_stack) - - # make sure we have a matching end command - if not (node.text.endswith("/"+command_name+"}}") or node.text.endswith("/"+command_name+"~}}")): - raise SyntaxError("Guidance command block starting with `"+node.text[:20]+"...` does not end with a matching `{{/"+command_name+"}}` but instead ends with `..."+node.text[-20:]+"!") + call = node["command_call"] + command_name = call["name"] + command_args = [await self.visit(arg, variable_stack) for arg in call[1:]] + + # command_args = [] + # if "positional_command_arg" in call: + # for arg in call["positional_command_arg"]: + # command_args.append(await self.visit(arg, variable_stack)) + # if "named_command_arg" in call: + # for arg in call["named_command_arg"]: + # command_args.append(await self.visit(arg, variable_stack)) + # command_name, command_args = [await self.visit(arg, variable_stack) for arg in node["command_call"]["args"]] + + # make sure we have a matching end command TODO: move this to a parser action + # if not (node.text.endswith("/"+command_name+"}}") or node.text.endswith("/"+command_name+"~}}")): + # raise SyntaxError("Guidance command block starting with `"+node.text[:20]+"...` does not end with a matching `{{/"+command_name+"}}` but instead ends with `..."+node.text[-20:]+"!") # if execution stops while parsing the start command just return unchanged if not self.executing: - variable_stack["_prefix"] += node.text + variable_stack["@raw_prefix"] += node.text return "" # add the start marker escaped_node_text = node.text.replace("$", "$").replace("{", "{").replace("}", "}") start_marker = "{{!--"+f"GMARKER_START_{command_name}${escaped_node_text}$"+"--}}" - variable_stack["_prefix"] += start_marker + variable_stack["@raw_prefix"] += start_marker if command_name in variable_stack: command_function = variable_stack[command_name] @@ -420,16 +473,16 @@ async def visit(self, node, variable_stack, next_node=None, next_next_node=None, sig = inspect.signature(command_function) if "_parser_context" in sig.parameters: named_args["_parser_context"] = { - "parser_prefix": strip_markers(variable_stack["_prefix"]), + "parser_prefix": strip_markers(variable_stack["@raw_prefix"]), "parser": self, "block_content": self.block_content[-1], # "partial_output": self.extend_prefix, "variable_stack": variable_stack, "parser_node": node, - "block_close_node": node.children[-1], + "block_close_node": node[-1], "next_node": next_node, "next_next_node": next_next_node, - "prev_node": node.children[0] + "prev_node": node[0] } # call the optionally asyncronous command @@ -439,44 +492,44 @@ async def visit(self, node, variable_stack, next_node=None, next_next_node=None, "name": command_name, "positional_args": positional_args, "named_args": {k:v for k,v in named_args.items() if k != "_parser_context"}, - "prefix": variable_stack["prefix"], + "@prefix": variable_stack["@prefix"], # "node_id": id(node) }) - pos = len(variable_stack["prefix"]) + pos = len(variable_stack["@prefix"]) if inspect.iscoroutinefunction(command_function): command_output = await command_function(*positional_args, **named_args) else: command_output = command_function(*positional_args, **named_args) if self._logging: - self.program.log.append({"type": "end", "name": command_name, "new_prefix": variable_stack["prefix"][pos:]}) + self.program.log.append({"type": "end", "name": command_name, "new_prefix": variable_stack["@prefix"][pos:]}) # if the command didn't send partial output we do it here if command_output is not None: - variable_stack["_prefix"] += command_output + variable_stack["@raw_prefix"] += command_output # pop off the block content after the command call self.block_content.pop() - variable_stack["_prefix"] += "{{!--" + f"GMARKER_END_{command_name}$$" + "--}}" + variable_stack["@raw_prefix"] += "{{!--" + f"GMARKER_END_{command_name}$$" + "--}}" return else: visited_children = [] - for i, child in enumerate(node.children): - if len(node.children) > i + 1: - inner_next_node = node.children[i + 1] + for i, child in enumerate(node): + if len(node) > i + 1: + inner_next_node = node[i + 1] else: inner_next_node = next_node - if len(node.children) > i + 2: - inner_next_next_node = node.children[i + 2] - elif len(node.children) == i + 2: + if len(node) > i + 2: + inner_next_next_node = node[i + 2] + elif len(node) == i + 2: inner_next_next_node = next_node else: inner_next_next_node = next_next_node if i > 0: - inner_prev_node = node.children[i - 1] + inner_prev_node = node[i - 1] else: inner_prev_node = prev_node visited_children.append(await self.visit(child, variable_stack, inner_next_node, inner_next_next_node, inner_prev_node, node, parent_node)) @@ -488,7 +541,7 @@ async def visit(self, node, variable_stack, next_node=None, next_next_node=None, return "".join("" if c is None else c for c in visited_children) # def get_variable(self, name, default_value=None): - # parts = re.split(r"\.|\[", name) + # parts = re.split(r"\.|\[", name) 40 ms 2048 12B # for variables in reversed(self.variable_stack): # curr_pos = variables # found = True @@ -553,7 +606,7 @@ async def visit(self, node, variable_stack, next_node=None, next_next_node=None, # def extend_prefix(self, text, variable_stack): # if text == "" or text is None: # return - # variable_stack["_prefix"] += str(text) + # variable_stack["@raw_prefix"] += str(text) # self.program.update_display() # def reset_prefix(self, pos): @@ -573,3 +626,7 @@ def __init__(self, name, value): class StopCompletion(Exception): pass +class SyntaxException(Exception): + def __init__(self, msg, pyparsing_exception=None): + super().__init__(msg) + self.pyparsing_exception = pyparsing_exception \ No newline at end of file diff --git a/guidance/_utils.py b/guidance/_utils.py index 408a21f22..9d19bf3de 100644 --- a/guidance/_utils.py +++ b/guidance/_utils.py @@ -83,28 +83,28 @@ def __init__(self, variable_stack, hidden=False): self._variable_stack = variable_stack def __enter__(self): - self._pos = len(self._variable_stack["_prefix"]) + self._pos = len(self._variable_stack["@raw_prefix"]) if self._hidden: - self._variable_stack.push({"_prefix": self._variable_stack["_prefix"]}) + self._variable_stack.push({"@raw_prefix": self._variable_stack["@raw_prefix"]}) return self def __exit__(self, type, value, traceback): if self._hidden: new_content = str(self) self._variable_stack.pop() - self._variable_stack["_prefix"] += "{{!--GHIDDEN:"+new_content.replace("--}}", "--_END_END")+"--}}" + self._variable_stack["@raw_prefix"] += "{{!--GHIDDEN:"+new_content.replace("--}}", "--_END_END")+"--}}" def __str__(self): - return strip_markers(self._variable_stack["_prefix"][self._pos:]) + return strip_markers(self._variable_stack["@raw_prefix"][self._pos:]) def __iadd__(self, other): if other is not None: - self._variable_stack["_prefix"] += other + self._variable_stack["@raw_prefix"] += other return self def inplace_replace(self, old, new): """Replace all instances of old with new in the captured content.""" - self._variable_stack["_prefix"] = self._variable_stack["_prefix"][:self._pos] + self._variable_stack["_prefix"][self._pos:].replace(old, new) + self._variable_stack["@raw_prefix"] = self._variable_stack["@raw_prefix"][:self._pos] + self._variable_stack["@raw_prefix"][self._pos:].replace(old, new) class JupyterComm(): def __init__(self, target_id, ipython_handle, callback=None, on_open=None, mode="register"): diff --git a/guidance/_variable_stack.py b/guidance/_variable_stack.py index 5742ce727..c7a24f2c3 100644 --- a/guidance/_variable_stack.py +++ b/guidance/_variable_stack.py @@ -19,7 +19,7 @@ def pop(self): out = self._stack.pop() # if we are popping a _prefix variable state we need to update the display - if "_prefix" in self._stack[-1]: + if "@raw_prefix" in self._stack[-1]: self._executor.program.update_display() return out @@ -30,8 +30,8 @@ def __getitem__(self, key): def get(self, name, default_value=None): # prefix is a special variable that returns the current prefix without the marker tags - if name == "prefix": - return strip_markers(self.get("_prefix", "")) + if name == "@prefix": + return strip_markers(self.get("@raw_prefix", "")) parts = re.split(r"\.|\[", name) for variables in reversed(self._stack): @@ -58,6 +58,17 @@ def get(self, name, default_value=None): def __contains__(self, name): return self.get(name, _NO_VALUE) != _NO_VALUE + + def __delitem__(self, key): + """Note this only works for simple variables, not nested variables.""" + found = True + for variables in reversed(self._stack): + if key in variables: + del variables[key] + found = True + break + if not found: + raise KeyError(key) def __setitem__(self, key, value): parts = re.split(r"\.|\[", key) @@ -97,7 +108,7 @@ def __setitem__(self, key, value): self._stack[0][key] = value # if we changed the _prefix variable, update the display - if changed and key == "_prefix" and not self.get("_no_display", False): + if changed and key == "@raw_prefix" and not self.get("_no_display", False): self._executor.program.update_display() def copy(self): diff --git a/guidance/library/__init__.py b/guidance/library/__init__.py index f66c6356e..4c159139d 100644 --- a/guidance/library/__init__.py +++ b/guidance/library/__init__.py @@ -6,6 +6,7 @@ from ._if import if_ from ._unless import unless from ._add import add +from ._multiply import multiply from ._select import select from ._each import each from ._geneach import geneach @@ -20,4 +21,5 @@ from ._greater import greater from ._less import less from ._contains import contains -from ._parse import parse \ No newline at end of file +from ._parse import parse +from ._notequal import notequal \ No newline at end of file diff --git a/guidance/library/_await.py b/guidance/library/_await.py index fe3c7b41b..1735e122a 100644 --- a/guidance/library/_await.py +++ b/guidance/library/_await.py @@ -16,11 +16,12 @@ async def await_(name, _parser_context=None): # this will result in a partially completed program that we can then finish # later (by calling it again with the variable we need) parser = _parser_context['parser'] - if name not in parser.program: + variable_stack = _parser_context['variable_stack'] + if name not in variable_stack: parser.executing = False else: - value = parser.program[name] - del parser.program[name] + value = variable_stack[name] + del variable_stack[name] return value # cache = parser.program._await_cache diff --git a/guidance/library/_each.py b/guidance/library/_each.py index 8601d6448..45ad089c6 100644 --- a/guidance/library/_each.py +++ b/guidance/library/_each.py @@ -44,11 +44,11 @@ async def each(list, hidden=False, parallel=False, _parser_context=None): "@first": i == 0, "@last": i == len(list) - 1, "this": item, - "_prefix": variable_stack["_prefix"], # create a local copy of the prefix since we are hidden + "@raw_prefix": variable_stack["@raw_prefix"], # create a local copy of the prefix since we are hidden "_no_display": True }) coroutines.append(parser.visit( - block_content[0], + block_content, variable_stack.copy(), next_node=_parser_context["next_node"], next_next_node=_parser_context["next_next_node"], @@ -78,7 +78,7 @@ async def each(list, hidden=False, parallel=False, _parser_context=None): }) with ContentCapture(variable_stack, hidden) as new_content: new_content += await parser.visit( - block_content[0], + block_content, variable_stack, next_node=_parser_context["next_node"], next_next_node=_parser_context["next_next_node"], diff --git a/guidance/library/_gen.py b/guidance/library/_gen.py index 3ef8eb870..1fd722538 100644 --- a/guidance/library/_gen.py +++ b/guidance/library/_gen.py @@ -69,10 +69,10 @@ async def gen(name=None, stop=None, stop_regex=None, save_stop_text=False, max_t next_next_node = _parser_context["next_next_node"] prev_node = _parser_context["prev_node"] # partial_output = _parser_context["partial_output"] - pos = len(variable_stack["_prefix"]) # save the current position in the prefix + pos = len(variable_stack["@raw_prefix"]) # save the current position in the prefix if hidden: - variable_stack.push({"_prefix": variable_stack["_prefix"]}) + variable_stack.push({"@raw_prefix": variable_stack["@raw_prefix"]}) if list_append: assert name is not None, "You must provide a variable name when using list_append=True" @@ -80,12 +80,10 @@ async def gen(name=None, stop=None, stop_regex=None, save_stop_text=False, max_t # if stop is None then we use the text of the node after the generate command if stop is None: - next_text = next_node.text if next_node is not None else "" - prev_text = prev_node.text if prev_node is not None else "" - if next_next_node and next_next_node.text.startswith("{{~"): - next_text = next_text.lstrip() - if next_next_node and next_text == "": - next_text = next_next_node.text + next_text = getattr(next_node, "text", next_node) if next_node is not None else "" + prev_text = getattr(prev_node, "text", prev_node) if prev_node is not None else "" + if next_next_node and next_text == "": + next_text = getattr(next_next_node, "text", next_next_node) # auto-detect quote stop tokens quote_types = ["'''", '"""', '```', '"', "'", "`"] @@ -96,7 +94,7 @@ async def gen(name=None, stop=None, stop_regex=None, save_stop_text=False, max_t # auto-detect role stop tags if stop is None: - m = re.match(r"^{{~?/(user|assistant|system|role)~?}}.*", next_text) + m = re.match(r"^{{~?/\w*(user|assistant|system|role)\w*~?}}.*", next_text) if m: stop = parser.program.llm.role_end(m.group(1)) @@ -132,7 +130,7 @@ async def gen(name=None, stop=None, stop_regex=None, save_stop_text=False, max_t # save the prompt if requested if save_prompt: - variable_stack[save_prompt] = variable_stack["_prefix"]+prefix + variable_stack[save_prompt] = variable_stack["@raw_prefix"]+prefix if logprobs is None: logprobs = parser.program.logprobs @@ -141,14 +139,14 @@ async def gen(name=None, stop=None, stop_regex=None, save_stop_text=False, max_t # call the LLM gen_obj = await parser.llm_session( - variable_stack["prefix"]+prefix, stop=stop, stop_regex=stop_regex, max_tokens=max_tokens, n=n, pattern=pattern, + variable_stack["@prefix"]+prefix, stop=stop, stop_regex=stop_regex, max_tokens=max_tokens, n=n, pattern=pattern, temperature=temperature, top_p=top_p, logprobs=logprobs, cache_seed=cache_seed, token_healing=token_healing, echo=parser.program.logprobs is not None, stream=stream, caching=parser.program.caching, **llm_kwargs ) if n == 1: generated_value = prefix - variable_stack["_prefix"] += prefix + variable_stack["@raw_prefix"] += prefix logprobs_out = [] if not isinstance(gen_obj, (types.AsyncGeneratorType, types.GeneratorType, list, tuple)): gen_obj = [gen_obj] @@ -170,7 +168,7 @@ async def gen(name=None, stop=None, stop_regex=None, save_stop_text=False, max_t break # log.debug("resp", resp) generated_value += resp["choices"][0]["text"] - variable_stack["_prefix"] += resp["choices"][0]["text"] + variable_stack["@raw_prefix"] += resp["choices"][0]["text"] if logprobs is not None: logprobs_out.extend(resp["choices"][0]["logprobs"]["top_logprobs"]) if list_append: @@ -191,16 +189,16 @@ async def gen(name=None, stop=None, stop_regex=None, save_stop_text=False, max_t if hasattr(gen_obj, 'close'): gen_obj.close() generated_value += suffix - variable_stack["_prefix"] += suffix + variable_stack["@raw_prefix"] += suffix if list_append: variable_stack[name][list_ind] = generated_value elif name is not None: variable_stack[name] = generated_value if hidden: - new_content = variable_stack["_prefix"][pos:] + new_content = variable_stack["@raw_prefix"][pos:] variable_stack.pop() - variable_stack["_prefix"] += "{{!--GHIDDEN:"+new_content.replace("--}}", "--_END_END")+"--}}" + variable_stack["@raw_prefix"] += "{{!--GHIDDEN:"+new_content.replace("--}}", "--_END_END")+"--}}" # stop executing if we were interrupted if parser.should_stop: @@ -229,7 +227,7 @@ async def gen(name=None, stop=None, stop_regex=None, save_stop_text=False, max_t # this just uses the first generated value for completion and the rest as alternatives only used for the variable storage # we mostly support this so that the echo=False hiding behavior does not make multiple outputs more complicated than it needs to be in the UX # if echo: - # variable_stack["_prefix"] += generated_value + # variable_stack["@raw_prefix"] += generated_value id = uuid.uuid4().hex l = len(generated_values) @@ -242,7 +240,7 @@ async def gen(name=None, stop=None, stop_regex=None, save_stop_text=False, max_t out += escape_template_block(value) else: out += value - variable_stack["_prefix"] += out + "--}}{{!--" + f"GMARKERmany_generate_end${id}$" + "--}}" + variable_stack["@raw_prefix"] += out + "--}}{{!--" + f"GMARKERmany_generate_end${id}$" + "--}}" return # return "{{!--GMARKERmany_generate_start$$}}" + "{{!--GMARKERmany_generate$$}}".join([v for v in generated_values]) + "{{!--GMARKERmany_generate_end$$}}" # return "".join([v for v in generated_values]) diff --git a/guidance/library/_geneach.py b/guidance/library/_geneach.py index 20cae538e..c664351ff 100644 --- a/guidance/library/_geneach.py +++ b/guidance/library/_geneach.py @@ -46,7 +46,7 @@ async def geneach(list_name, stop=None, max_iterations=100, min_iterations=0, nu # parser_prefix = _parser_context["parser_prefix"] parser_node = _parser_context["parser_node"] - assert len(block_content) == 1 + # assert len(block_content) == 1 assert not (hidden and single_call), "Cannot use hidden=True and single_call together" assert isinstance(list_name, str), "Must provide a variable name to save the generated list to" assert not hidden or num_iterations is not None, "Cannot use hidden=True and variable length iteration together yet..." @@ -61,8 +61,8 @@ async def geneach(list_name, stop=None, max_iterations=100, min_iterations=0, nu max_iterations = 1e10 # give the list a default name - if list_name is None: - list_name = 'generated_list' + # if list_name is None: + # list_name = 'generated_list' # if stop is None then we use the text of the node after the generate command # if stop is None: @@ -119,7 +119,7 @@ async def geneach(list_name, stop=None, max_iterations=100, min_iterations=0, nu # visit the block content new_content += await parser.visit( - block_content[0], + block_content, variable_stack, next_node=_parser_context["next_node"], next_next_node=_parser_context["next_next_node"], @@ -145,7 +145,7 @@ async def geneach(list_name, stop=None, max_iterations=100, min_iterations=0, nu # we run a quick generation to see if we have reached the end of the list (note the +2 tokens is to help be tolorant to whitespace) if stop is not False and i >= min_iterations and i < max_iterations: try: - gen_obj = await parser.llm_session(variable_stack["prefix"], stop=stop, max_tokens=max_stop_tokens, temperature=0, cache_seed=0) + gen_obj = await parser.llm_session(variable_stack["@prefix"], stop=stop, max_tokens=max_stop_tokens, temperature=0, cache_seed=0) except Exception: raise Exception(f"Error generating stop tokens for geneach loop. Perhaps you are outside of role tags (assistant/user/system)? If you don't want the loop to check for stop tokens, set stop=False or set num_iterations.") if gen_obj["choices"][0]["finish_reason"] == "stop": @@ -157,7 +157,7 @@ async def geneach(list_name, stop=None, max_iterations=100, min_iterations=0, nu pattern = re.sub( r'{{gen [\'"]([^\'"]+)[\'"][^}]*}}', lambda x: r"(?P<"+_escape_group_name(x.group(1))+">.*?)", - block_content[0].text + block_content.text ) # fixed prefixes can be used if we know we have at least one iteration @@ -177,7 +177,7 @@ async def geneach(list_name, stop=None, max_iterations=100, min_iterations=0, nu parser.program.cache_seed += 1 else: cache_seed = 0 - gen_stream = await parser.llm_session(variable_stack["_prefix"]+fixed_prefix, stop=stop, max_tokens=single_call_max_tokens, temperature=single_call_temperature, top_p=single_call_top_p, cache_seed=cache_seed, stream=True) + gen_stream = await parser.llm_session(variable_stack["@raw_prefix"]+fixed_prefix, stop=stop, max_tokens=single_call_max_tokens, temperature=single_call_temperature, top_p=single_call_top_p, cache_seed=cache_seed, stream=True) generated_value = fixed_prefix num_items = 0 data = [] @@ -205,20 +205,20 @@ async def geneach(list_name, stop=None, max_iterations=100, min_iterations=0, nu next_item = d # update the list variable (we do this each time we get a new item so that streaming works) - parser.set_variable(list_name, parser.get_variable(list_name, default_value=[]) + [next_item]) + variable_stack[list_name] = variable_stack.get(list_name, []) + [next_item] # recreate the output string with format markers added item_out = re.sub( r"{{(?!~?gen)(.*?)}}", lambda x: match_dict[_escape_group_name(x.group(1))], - block_content[0].text + block_content.text ) item_out = re.sub( r"{{gen [\'\"]([^\'\"]+)[\'\"][^}]*}}", lambda x: "{{!--GMARKER_START_gen$"+x.group().replace("$", "$").replace("{", "{").replace("}", "}")+"$--}}"+match_dict[_escape_group_name(x.group(1))]+"{{!--GMARKER_END_gen$$--}}", item_out ) - partial_output("{{!--GMARKER_each$$--}}" + item_out) # marker and content of the item + variable_stack["@raw_prefix"] += "{{!--GMARKER_each$$--}}" + item_out # marker and content of the item num_items += 1 # out.append(item_out) @@ -233,7 +233,7 @@ async def geneach(list_name, stop=None, max_iterations=100, min_iterations=0, nu # if we have stopped executing, we need to add the loop to the output so it can be executed later if not parser.executing: - variable_stack["_prefix"] += parser_node.text + variable_stack["@raw_prefix"] += parser_node.text # return "" diff --git a/guidance/library/_if.py b/guidance/library/_if.py index 57a358f90..a4c4d6e63 100644 --- a/guidance/library/_if.py +++ b/guidance/library/_if.py @@ -1,6 +1,6 @@ import re -async def if_(value, invert=False, _parser_context=None): +async def if_(value, *, invert=False, _parser_context=None): ''' Standard if/else statement. Parameters @@ -8,28 +8,34 @@ async def if_(value, invert=False, _parser_context=None): value : bool The value to check. If `True` then the first block will be executed, otherwise the second block (the one after the `{{else}}`) will be executed. - invert : bool + invert : bool [DEPRECATED] If `True` then the value will be inverted before checking. ''' block_content = _parser_context['block_content'] variable_stack = _parser_context['variable_stack'] - assert len(block_content) in [1,3] # we don't support elseif yet... - options = [block_content[0]] - for i in range(1, len(block_content), 2): - assert re.match(r"{{~?else~?}}", block_content[i].text), "Expected else statement" - options.append(block_content[i+1]) - - # if isinstance(value, str): - # value2 = value - # value = value.lower().strip() in ["true", "yes", "on", "t", "y", "ok", "okay"] + parser = _parser_context['parser'] + assert len(block_content) % 2 == 1, "Unexpected number of blocks for `if` command: " + str(len(block_content)) + + # parse the first block if invert: value = not value - if value: - return await _parser_context['parser'].visit(options[0], variable_stack) - elif len(options) > 1: - return await _parser_context['parser'].visit(options[1], variable_stack) - else: - return "" + return await parser.visit(block_content[0], variable_stack) + + # parse the rest of the blocks + for i in range(1, len(block_content), 2): + + # elif block + if block_content[i][0] == "elif": + if parser.visit(block_content[i][1], variable_stack): + return await parser.visit(block_content[i+1], variable_stack) + + # else block + elif block_content[i][0] == "else": + return await parser.visit(block_content[i+1], variable_stack) + + else: + raise ValueError("Unexpected block content separator for `if` command: " + block_content[i].text) + return "" if_.is_block = True \ No newline at end of file diff --git a/guidance/library/_multiply.py b/guidance/library/_multiply.py new file mode 100644 index 000000000..d873069a4 --- /dev/null +++ b/guidance/library/_multiply.py @@ -0,0 +1,6 @@ +import math + +def multiply(*args): + ''' Multiply the given variables together. + ''' + return math.prod(args) \ No newline at end of file diff --git a/guidance/library/_notequal.py b/guidance/library/_notequal.py new file mode 100644 index 000000000..195046ba2 --- /dev/null +++ b/guidance/library/_notequal.py @@ -0,0 +1,4 @@ +def notequal(arg1, arg2): + ''' Check that the arguments are not equal. + ''' + return arg1 != arg2 \ No newline at end of file diff --git a/guidance/library/_parse.py b/guidance/library/_parse.py index 167546a82..ef4b31ef5 100644 --- a/guidance/library/_parse.py +++ b/guidance/library/_parse.py @@ -21,7 +21,7 @@ async def parse(string, name=None, hidden=False, _parser_context=None): with ContentCapture(variable_stack, hidden) as new_content: # parse and visit the given string - subtree = grammar.parse(string) + subtree = grammar.parse_string(string) new_content += await parser.visit(subtree, variable_stack) # save the content in a variable if needed diff --git a/guidance/library/_role.py b/guidance/library/_role.py index 62fbb9b61..f35b7ee20 100644 --- a/guidance/library/_role.py +++ b/guidance/library/_role.py @@ -15,7 +15,7 @@ async def role(name, hidden=False, _parser_context=None): # visit the block content new_content += await parser.visit( - block_content[0], + block_content, variable_stack, next_node=_parser_context["block_close_node"], prev_node=_parser_context["prev_node"], diff --git a/guidance/library/_select.py b/guidance/library/_select.py index 854a0fae0..5f3f11374 100644 --- a/guidance/library/_select.py +++ b/guidance/library/_select.py @@ -1,6 +1,7 @@ import itertools import pygtrie import numpy as np +from .._utils import ContentCapture async def select(variable_name="selected", options=None, logprobs=None, list_append=False, _parser_context=None): ''' Select a value from a list of choices. @@ -33,10 +34,14 @@ async def select(variable_name="selected", options=None, logprobs=None, list_app assert options is None, "You cannot provide an options list when using the select command in block mode." if options is None: - options = [block_content[0].text] + with ContentCapture(variable_stack) as new_content: + new_content += await parser.visit(block_content[0], variable_stack) + options = [str(new_content)] for i in range(1, len(block_content), 2): - assert block_content[i].text == "{{or}}" - options.append(block_content[i+1].text) + assert block_content[i][0] == "or", "You must provide a {{or}} between each option in a select block." + with ContentCapture(variable_stack) as new_content: + new_content += await parser.visit(block_content[i+1], variable_stack) + options.append(str(new_content))#block_content[i+1].text) # find what text follows the select command and append it to the options. # we do this so we can differentiate between select options where one is a prefix of another @@ -50,10 +55,10 @@ async def select(variable_name="selected", options=None, logprobs=None, list_app options = [option + next_text for option in options] # TODO: this retokenizes the whole prefix many times, perhaps this could become a bottleneck? - options_tokens = [parser.program.llm.encode(variable_stack["prefix"] + option) for option in options] + options_tokens = [parser.program.llm.encode(variable_stack["@prefix"] + option) for option in options] # encoding the prefix and then decoding it might change the length, so we need to account for that - recoded_parser_prefix_length = len(parser.program.llm.decode(parser.program.llm.encode(variable_stack["prefix"]))) + recoded_parser_prefix_length = len(parser.program.llm.decode(parser.program.llm.encode(variable_stack["@prefix"]))) # build a trie of the options token_map = pygtrie.Trie() @@ -172,6 +177,6 @@ async def recursive_select(current_prefix, allow_token_extension=True): if max(option_logprobs.values()) <= -1000: raise ValueError("No valid option generated in #select! Please post a GitHub issue since this should not happen :)") - variable_stack["_prefix"] += selected_option + variable_stack["@raw_prefix"] += selected_option select.is_block = True \ No newline at end of file diff --git a/guidance/library/_set.py b/guidance/library/_set.py index ceaf135cf..d78b11d22 100644 --- a/guidance/library/_set.py +++ b/guidance/library/_set.py @@ -1,4 +1,4 @@ -def set(name, value=None, hidden=None, _parser_context=None): +def set(name, value=None, hidden=True, _parser_context=None): ''' Set the value of a variable or set of variables. Parameters diff --git a/setup.py b/setup.py index 1b16dcba1..a695558a2 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ def find_version(*file_paths): "diskcache", "gptcache", "openai>=0.27", - "parsimonious", + "pyparsing", "pygtrie", "platformdirs", "tiktoken>=0.3", diff --git a/tests/library/test_add.py b/tests/library/test_add.py index 9cba7cf81..20eec65f0 100644 --- a/tests/library/test_add.py +++ b/tests/library/test_add.py @@ -14,4 +14,20 @@ def test_add_multi(): program = guidance("""Write a number: {{set 'user_response' (add 20 5 variable)}}""") assert program(variable=10)["user_response"] == 35 - assert program(variable=20.1)["user_response"] == 45.1 \ No newline at end of file + assert program(variable=20.1)["user_response"] == 45.1 + +def test_add_infix(): + """ Basic infix test of `add`. + """ + + program = guidance("""Write a number: {{set 'user_response' 20 + variable}}""") + assert program(variable=10)["user_response"] == 30 + assert program(variable=20.1)["user_response"] == 40.1 + +if __name__ == "__main__": + # find all the test functions in this file + import sys, inspect + test_functions = [obj for name, obj in inspect.getmembers(sys.modules[__name__]) if (inspect.isfunction(obj) and name.startswith("test_"))] + # run each test function + for test_function in test_functions: + test_function() \ No newline at end of file diff --git a/tests/library/test_await.py b/tests/library/test_await.py index 263768c01..f2e176917 100644 --- a/tests/library/test_await.py +++ b/tests/library/test_await.py @@ -5,8 +5,8 @@ def test_await(): """ prompt = guidance("""Is Everest very tall? -User response: '{{set 'user_response' (await 'user_response')}}'""") +User response: '{{set 'user_response' (await 'user_response') hidden=False}}'""") waiting_prompt = prompt() - assert str(waiting_prompt) == "Is Everest very tall?\nUser response: '{{set 'user_response' (await 'user_response')}}'" + assert str(waiting_prompt) == str(prompt) out = waiting_prompt(user_response="Yes") assert str(out) == "Is Everest very tall?\nUser response: 'Yes'" \ No newline at end of file diff --git a/tests/library/test_block.py b/tests/library/test_block.py index b9b7d2546..d93692989 100644 --- a/tests/library/test_block.py +++ b/tests/library/test_block.py @@ -20,11 +20,15 @@ def test_empty_block(): assert out.text == '' def test_name_capture(): - """ Test the behavior of a completely empty `block`. - """ - prompt = guidance( "This is a block: {{#block 'my_block'}}text inside block{{/block}}", ) out = prompt() - assert out["my_block"] == 'text inside block' \ No newline at end of file + assert out["my_block"] == 'text inside block' + +def test_name_capture_whitespace(): + prompt = guidance( + "This is a block: {{#block 'my_block'}} text inside block {{/block}}", + ) + out = prompt() + assert out["my_block"] == ' text inside block ' \ No newline at end of file diff --git a/tests/library/test_equal.py b/tests/library/test_equal.py index 4dfe4ead2..58eeada18 100644 --- a/tests/library/test_equal.py +++ b/tests/library/test_equal.py @@ -9,11 +9,8 @@ def test_equal(): assert str(program(val=5)) == "are equal" assert str(program(val="5")) == "not equal" -def test_equal_with_symbol(): - """ Test the behavior of `equal` written as `==`. - """ - - program = guidance("""{{#if (== val 5)}}are equal{{else}}not equal{{/if}}""") +def test_equal_infix(): + program = guidance("""{{#if val == 5}}are equal{{else}}not equal{{/if}}""") assert str(program(val=4)) == "not equal" assert str(program(val=5)) == "are equal" assert str(program(val="5")) == "not equal" \ No newline at end of file diff --git a/tests/library/test_geneach.py b/tests/library/test_geneach.py index eda563b1a..ce40e8f2b 100644 --- a/tests/library/test_geneach.py +++ b/tests/library/test_geneach.py @@ -11,6 +11,48 @@ def test_geneach(): }) prompt = guidance('''Generate a list of three names {{#geneach 'names' stop=""}} +{{gen 'this'}}{{/geneach}}''', llm=llm) + out = prompt() + assert len(out["names"]) == 3 + assert out["names"] == ["Bob", "Sue", "Joe"] + assert str(out) == """Generate a list of three names + +Bob +Sue +Joe""" + +def test_geneach_with_join(): + """ Test a geneach loop. + """ + + llm = guidance.llms.Mock({ + 'Joe': {"text": '', "finish_reason": "stop"}, + '': {"text": '\n' : ["Bob", "Sue", "Joe"], + }) + prompt = guidance('''Generate a list of three names +{{#geneach 'names' join="" stop=""}} +{{gen 'this'}}{{/geneach}}''', llm=llm) + out = prompt() + assert len(out["names"]) == 3 + assert out["names"] == ["Bob", "Sue", "Joe"] + assert str(out) == """Generate a list of three names + +Bob +Sue +Joe""" + +def test_geneach_single_call(): + """ Test a geneach loop. + """ + + llm = guidance.llms.Mock(''' +Bob +Sue +Jow +''') + prompt = guidance('''Generate a list of three names +{{#geneach 'names' single_call=True stop=""}} {{gen 'this'}}{{/geneach}}"''', llm=llm) out = prompt() assert len(out["names"]) == 3 diff --git a/tests/library/test_greater.py b/tests/library/test_greater.py index 088382ae1..d984c5657 100644 --- a/tests/library/test_greater.py +++ b/tests/library/test_greater.py @@ -9,11 +9,8 @@ def test_greater(): assert str(program(val=6)) == "greater" assert str(program(val=5.3)) == "greater" -def test_greater_with_symbol(): - """ Test the behavior of `greater` used as `>`. - """ - - program = guidance("""{{#if (> val 5)}}greater{{else}}not greater{{/if}}""") +def test_greater_infix(): + program = guidance("""{{#if val > 5}}greater{{else}}not greater{{/if}}""") assert str(program(val=4)) == "not greater" assert str(program(val=6)) == "greater" assert str(program(val=5.3)) == "greater" \ No newline at end of file diff --git a/tests/library/test_if.py b/tests/library/test_if.py index 573c1c0d6..f129c0b62 100644 --- a/tests/library/test_if.py +++ b/tests/library/test_if.py @@ -14,6 +14,13 @@ def test_if(): out = prompt(flag=flag) assert str(out) == "Answer: " +def test_if_complex_block(): + prompt = guidance("""Answer: {{#if True}}Yes {{my_var}} we{{/if}}""") + + out = prompt(my_var="then") + + assert str(out) == "Answer: Yes then we" + def test_if_else(): """ Test the behavior of `if` with an `else` clause. """ @@ -26,4 +33,22 @@ def test_if_else(): for flag in [False, 0, ""]: out = prompt(flag=flag) - assert str(out) == "Answer 'Yes' or 'No': 'No'" \ No newline at end of file + assert str(out) == "Answer 'Yes' or 'No': 'No'" + +def test_if_complex_blockwith_else(): + prompt = guidance("""Answer: {{#if flag}}Yes {{my_var}} we{{else}}No {{my_var}}{{/if}}""") + + out = prompt(my_var="then", flag=True) + assert str(out) == "Answer: Yes then we" + + out = prompt(my_var="then", flag=False) + assert str(out) == "Answer: No then" + +def test_elif_else(): + """ Test the behavior of `if` with an `else` clause. + """ + + prompt = guidance("""Answer 'Yes' or 'No': '{{#if flag}}Yes{{elif flag2}}maybe{{else}}No{{/if}}'""") + + out = prompt(flag=False, flag2=True) + assert str(out) == "Answer 'Yes' or 'No': 'maybe'" \ No newline at end of file diff --git a/tests/library/test_less.py b/tests/library/test_less.py index d0588e945..4b9b13b39 100644 --- a/tests/library/test_less.py +++ b/tests/library/test_less.py @@ -9,11 +9,11 @@ def test_less(): assert str(program(val=4)) == "less" assert str(program(val=4.3)) == "less" -def test_less_with_symbol(): +def test_less_infix(): """ Test the behavior of `less` used as `<`. """ - program = guidance("""{{#if (< val 5)}}less{{else}}not less{{/if}}""") + program = guidance("""{{#if val < 5}}less{{else}}not less{{/if}}""") assert str(program(val=6)) == "not less" assert str(program(val=4)) == "less" assert str(program(val=4.3)) == "less" \ No newline at end of file diff --git a/tests/library/test_parse.py b/tests/library/test_parse.py index b66007258..91018d017 100644 --- a/tests/library/test_parse.py +++ b/tests/library/test_parse.py @@ -5,4 +5,9 @@ def test_parse(): """ program = guidance("""This is parsed: {{parse template}}""") - assert str(program(template="My name is {{name}}", name="Bob")) == "This is parsed: My name is Bob" \ No newline at end of file + assert str(program(template="My name is {{name}}", name="Bob")) == "This is parsed: My name is Bob" + +def test_parse_with_name(): + program = guidance("""This is parsed: {{parse template name="parsed"}}""") + executed_program = program(template="My name is {{name}}", name="Bob") + assert executed_program["parsed"] == "My name is Bob" \ No newline at end of file diff --git a/tests/library/test_role.py b/tests/library/test_role.py index 062506e7b..2bfa4bba5 100644 --- a/tests/library/test_role.py +++ b/tests/library/test_role.py @@ -15,7 +15,7 @@ def test_role(): {{~/role}} {{#role 'assistant'~}} -{{gen}} +{{gen(max_tokens=23)}} {{~/role}} """, llm=llm) @@ -37,7 +37,7 @@ def test_short_roles(): {{~/user}} {{#assistant~}} -{{gen}} +{{gen()}} {{~/assistant}} """, llm=llm) diff --git a/tests/library/test_set.py b/tests/library/test_set.py index edb9a30e8..3caae7cce 100644 --- a/tests/library/test_set.py +++ b/tests/library/test_set.py @@ -4,11 +4,21 @@ def test_set(): """ Test the behavior of `set`. """ - program = guidance("""{{set 'output' 234}}{{output}}""") + program = guidance("""{{set 'output' 234 hidden=False}}{{output}}""") assert str(program()) == "234234" - program = guidance("""{{set 'output' 234 hidden=True}}{{output}}""") + program = guidance("""{{set 'output' 234}}{{output}}""") assert str(program()) == "234" program = guidance("""{{set 'output' 849203984939}}{{output}}""") - assert str(program()['output']) == "849203984939" \ No newline at end of file + assert str(program()['output']) == "849203984939" + +def test_set_dict(): + + program = guidance("""{{set {'output':234}}}{{output}}""") + assert str(program()) == "234" + +def test_set_array(): + + program = guidance("""{{set 'output' [3, 234]}}{{output}}""") + assert str(program()) == "[3, 234]" \ No newline at end of file diff --git a/tests/library/test_subtract.py b/tests/library/test_subtract.py index 3f4f7b21a..582d09f1d 100644 --- a/tests/library/test_subtract.py +++ b/tests/library/test_subtract.py @@ -6,4 +6,9 @@ def test_subtract(): program = guidance("""Write a number: {{set 'user_response' (subtract 20 variable)}}""") assert program(variable=10)["user_response"] == 10 + assert abs(program(variable=20.1)["user_response"] + 0.1) < 1e-5 + +def test_subtract_infix(): + program = guidance("""Write a number: {{set 'user_response' (20 - variable)}}""") + assert program(variable=10)["user_response"] == 10 assert abs(program(variable=20.1)["user_response"] + 0.1) < 1e-5 \ No newline at end of file diff --git a/tests/llms/test_openai.py b/tests/llms/test_openai.py index 984fb6c05..1145b5c13 100644 --- a/tests/llms/test_openai.py +++ b/tests/llms/test_openai.py @@ -8,24 +8,46 @@ def test_geneach_chat_gpt(): guidance.llm = get_llm("openai:gpt-3.5-turbo") chat_loop = guidance(''' - {{#system~}} - You are a helpful assistant - {{~/system}} +{{#system~}} +You are a helpful assistant +{{~/system}} - {{~#geneach 'conversation' stop=False}} - {{#user~}} - This is great! - {{~/user}} +{{~#geneach 'conversation' stop=False}} +{{#user~}} +This is great! +{{~/user}} - {{#assistant~}} - {{gen 'this.response' temperature=0 max_tokens=3}} - {{~/assistant}} - {{#if (> @index 0)}}{{break}}{{/if}} - {{~/geneach}}''') +{{#assistant~}} +{{gen 'this.response' temperature=0 max_tokens=3}} +{{~/assistant}} +{{#if @index > 0}}{{break}}{{/if}} +{{~/geneach}}''') out = chat_loop() assert len(out["conversation"]) == 2 +def test_syntax_match(): + """ Test a geneach loop with ChatGPT. + """ + + guidance.llm = get_llm("openai:gpt-3.5-turbo") + + chat_loop = guidance(''' +{{~#system~}} +You are a helpful assistant +{{~/system~}} + +{{~#user~}} +This is great! +{{~/user~}} + +{{~#assistant~}} +Indeed +{{~/assistant~}}''') + + out = chat_loop() + assert str(out) == '<|im_start|>system\nYou are a helpful assistant<|im_end|><|im_start|>user\nThis is great!<|im_end|><|im_start|>assistant\nIndeed<|im_end|>' + def test_rest_nostream(): guidance.llm = get_llm('openai:text-davinci-003', endpoint="https://api.openai.com/v1/completions", rest_call=True) a = guidance('''Hello, my name is{{gen 'name' stream=False max_tokens=5}}''', stream=False) diff --git a/tests/test_grammar.py b/tests/test_grammar.py new file mode 100644 index 000000000..ca6e9e9ec --- /dev/null +++ b/tests/test_grammar.py @@ -0,0 +1,52 @@ +import guidance + +def test_variable_interpolation(): + """ Test variable interpolation in prompt + """ + + prompt = guidance("Hello, {{name}}!") + assert str(prompt(name="Guidance")) == "Hello, Guidance!" + +def test_command_call(): + prompt = guidance("Hello, {{add 1 2}}!") + assert str(prompt(name="Guidance")) == "Hello, 3!" + +def test_paren_command_call(): + prompt = guidance("Hello, {{add(1, 2)}}!") + assert str(prompt(name="Guidance")) == "Hello, 3!" + +def test_nested_command_call(): + prompt = guidance("Hello, {{add (add 1 2) 3}}!") + assert str(prompt(name="Guidance")) == "Hello, 6!" + +def test_nested_paren_command_call(): + prompt = guidance("Hello, {{add add(1, 2) 3}}!") + assert str(prompt(name="Guidance")) == "Hello, 6!" + +def test_infix_plus(): + prompt = guidance("Hello, {{1 + 2}}!") + assert str(prompt()) == "Hello, 3!" + +def test_infix_plus_nested(): + prompt = guidance("Hello, {{set 'variable' 1 + 2}}!") + assert prompt()["variable"] == 3 + +def test_comment(): + prompt = guidance("Hello, {{! this is a comment}}Bob!") + assert str(prompt()) == "Hello, Bob!" + +def test_long_comment(): + prompt = guidance("Hello, {{!-- this is a comment --}}Bob!") + assert str(prompt()) == "Hello, Bob!" + +def test_long_comment_ws_strip(): + prompt = guidance("Hello, {{~!-- this is a comment --~}} Bob!") + assert str(prompt()) == "Hello,Bob!" + +def test_comment_ws_strip(): + prompt = guidance("Hello, {{~! this is a comment ~}} Bob!") + assert str(prompt()) == "Hello,Bob!" + +def test_escape_command(): + prompt = guidance("Hello, \{{command}} Bob!") + assert str(prompt()) == "Hello, {{command}} Bob!" \ No newline at end of file diff --git a/tests/test_program.py b/tests/test_program.py index 9aa22e64c..f2f0f3802 100644 --- a/tests/test_program.py +++ b/tests/test_program.py @@ -2,13 +2,6 @@ import pytest from .utils import get_llm -def test_variable_interpolation(): - """ Test variable interpolation in prompt - """ - - prompt = guidance("Hello, {{name}}!") - assert str(prompt(name="Guidance")) == "Hello, Guidance!" - def test_chat_stream(): """ Test the behavior of `stream=True` for an openai chat endpoint. """ @@ -75,7 +68,7 @@ def test_stream_loop(llm): llm = get_llm(llm) program = guidance("""Generate a list of 5 company names: {{#geneach 'companies' num_iterations=5~}} -{{@index}}. "{{gen 'this'}}" +{{@index}}. "{{gen 'this' max_tokens=5}}" {{/geneach}}""", llm=llm) partials = [] @@ -98,7 +91,7 @@ def test_stream_loop_async(llm): async def f(): program = guidance("""Generate a list of 5 company names: {{#geneach 'companies' num_iterations=5~}} -{{@index}}. "{{gen 'this'}}" +{{@index}}. "{{gen 'this' max_tokens=5}}" {{/geneach}}""", llm=llm) partials = []