From e2606d5b37e9a98709147f198de59050869436b4 Mon Sep 17 00:00:00 2001 From: Hai Zhu <35182391+cocolato@users.noreply.github.com> Date: Tue, 6 Feb 2024 08:14:07 -0500 Subject: [PATCH] Support comprehensions inside functions when use strict_undefined flag. Fixes: https://github.com/sqlalchemy/mako/issues/320 Now the test code works as expected if strict_undefined is set to true: ```python from mako.template import Template text = """ <% mydict = { 'foo': 1 } ## Uncomment the following line to workaround the error ##k = None def getkeys(x): return [ k for k in x.keys() ] %> ${ ','.join( getkeys(mydict) ) } """ tmpl = Template(text=text, strict_undefined=True) out = tmpl.render() print(out) ``` output: ``` foo ``` Closes: #386 Pull-request: https://github.com/sqlalchemy/mako/pull/386 Pull-request-sha: cc6a3e0694fb5615db2c3fec2cd23bc9e8a70066 Change-Id: I0591873a83837f8f35b0963c0536df1e2675012f --- doc/build/unreleased/320.rst | 9 ++++++ mako/pyparser.py | 20 ++++++++++++ test/test_ast.py | 36 ++++++++++++++++++++++ test/test_template.py | 59 ++++++++++++++++++++++++++++++++++++ 4 files changed, 124 insertions(+) create mode 100644 doc/build/unreleased/320.rst diff --git a/doc/build/unreleased/320.rst b/doc/build/unreleased/320.rst new file mode 100644 index 00000000..20deeb9b --- /dev/null +++ b/doc/build/unreleased/320.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, parser + :tickets: 320 + + Fixed unexpected syntax error in strict_undefined mode that occurred + when using comprehensions within a function in a Mako Python code block. + Now, the local variable in comprehensions won't be added to the checklist + when using strict_undefined mode. + Pull request courtesy Hai Zhu. \ No newline at end of file diff --git a/mako/pyparser.py b/mako/pyparser.py index 9b63dc33..b25ef6e4 100644 --- a/mako/pyparser.py +++ b/mako/pyparser.py @@ -90,6 +90,26 @@ def visit_FunctionDef(self, node): self._add_declared(node.name) self._visit_function(node, False) + def visit_ListComp(self, node): + if self.in_function: + if not isinstance(node.elt, _ast.Name): + self.visit(node.elt) + for comp in node.generators: + self.visit(comp.iter) + else: + self.generic_visit(node) + + visit_SetComp = visit_GeneratorExp = visit_ListComp + + def visit_DictComp(self, node): + if self.in_function: + if not isinstance(node.key, _ast.Name): + self.visit(node.elt) + for comp in node.generators: + self.visit(comp.iter) + else: + self.generic_visit(node) + def _expand_tuples(self, args): for arg in args: if isinstance(arg, _ast.Tuple): diff --git a/test/test_ast.py b/test/test_ast.py index 6b3a3e2f..84e23380 100644 --- a/test/test_ast.py +++ b/test/test_ast.py @@ -222,6 +222,42 @@ def test_locate_identifiers_17(self): parsed = ast.PythonCode(code, **exception_kwargs) eq_(parsed.undeclared_identifiers, {"x", "y", "Foo", "Bar"}) + def test_locate_identifiers_18(self): + code = """ +def func(): + return [i for i in range(10)] +""" + parsed = ast.PythonCode(code, **exception_kwargs) + eq_(parsed.declared_identifiers, {"func"}) + eq_(parsed.undeclared_identifiers, {"range"}) + + def test_locate_identifiers_19(self): + code = """ +def func(): + return (i for i in range(10)) +""" + parsed = ast.PythonCode(code, **exception_kwargs) + eq_(parsed.declared_identifiers, {"func"}) + eq_(parsed.undeclared_identifiers, {"range"}) + + def test_locate_identifiers_20(self): + code = """ +def func(): + return {i for i in range(10)} +""" + parsed = ast.PythonCode(code, **exception_kwargs) + eq_(parsed.declared_identifiers, {"func"}) + eq_(parsed.undeclared_identifiers, {"range"}) + + def test_locate_identifiers_21(self): + code = """ +def func(): + return {i: i**2 for i in range(10)} +""" + parsed = ast.PythonCode(code, **exception_kwargs) + eq_(parsed.declared_identifiers, {"func"}) + eq_(parsed.undeclared_identifiers, {"range"}) + def test_no_global_imports(self): code = """ from foo import * diff --git a/test/test_template.py b/test/test_template.py index e03415e2..94b255f3 100644 --- a/test/test_template.py +++ b/test/test_template.py @@ -1717,3 +1717,62 @@ def test_inline_percent(self): "% foo", "bar %% baz", ] + + def test_listcomp_in_func_strict(self): + t = Template( + """ +<% + mydict = { 'foo': 1 } + def getkeys(x): + return [ k for k in x.keys() ] +%> + +${ ','.join( getkeys(mydict) ) } +""", + strict_undefined=True, + ) + assert result_raw_lines(t.render()) == ["foo"] + + def test_setcomp_in_func_strict(self): + t = Template( + """ +<% + mydict = { 'foo': 1 } + def getkeys(x): + return { k for k in x.keys() } +%> + +${ ','.join( getkeys(mydict) ) } +""", + strict_undefined=True, + ) + assert result_raw_lines(t.render()) == ["foo"] + + def test_generator_in_func_strict(self): + t = Template( + """ +<% + mydict = { 'foo': 1 } + def getkeys(x): + return ( k for k in x.keys()) +%> + +${ ','.join( getkeys(mydict) ) } +""", + strict_undefined=True, + ) + assert result_raw_lines(t.render()) == ["foo"] + + def test_dictcomp_in_func_strict(self): + t = Template( + """ +<% + def square(): + return {i: i**2 for i in range(10)} +%> + +${ square()[3] } +""", + strict_undefined=True, + ) + assert result_raw_lines(t.render()) == ["9"]