diff --git a/python/selfie-lib/selfie_lib/EscapeLeadingWhitespace.py b/python/selfie-lib/selfie_lib/EscapeLeadingWhitespace.py index 414eb377..b2e9695d 100644 --- a/python/selfie-lib/selfie_lib/EscapeLeadingWhitespace.py +++ b/python/selfie-lib/selfie_lib/EscapeLeadingWhitespace.py @@ -2,11 +2,56 @@ class EscapeLeadingWhitespace(Enum): + ALWAYS = auto() NEVER = auto() + ONLY_ON_SPACE = auto() + ONLY_ON_TAB = auto() - def escape_line(self, line: str, space: str, tab: str) -> str: # noqa: ARG002 - return line + def escape_line(self, line: str, space: str, tab: str) -> str: + if line.startswith(" "): + if ( + self == EscapeLeadingWhitespace.ALWAYS + or self == EscapeLeadingWhitespace.ONLY_ON_SPACE + ): + return f"{space}{line[1:]}" + else: + return line + elif line.startswith("\t"): + if ( + self == EscapeLeadingWhitespace.ALWAYS + or self == EscapeLeadingWhitespace.ONLY_ON_TAB + ): + return f"{tab}{line[1:]}" + else: + return line + else: + return line - @staticmethod - def appropriate_for(file_content: str) -> "EscapeLeadingWhitespace": # noqa: ARG004 - return EscapeLeadingWhitespace.NEVER + @classmethod + def appropriate_for(cls, file_content: str) -> "EscapeLeadingWhitespace": + MIXED = "m" + common_whitespace = None + + for line in file_content.splitlines(): + whitespace = "".join(c for c in line if c.isspace()) + if not whitespace: + continue + elif all(c == " " for c in whitespace): + whitespace = " " + elif all(c == "\t" for c in whitespace): + whitespace = "\t" + else: + whitespace = MIXED + + if common_whitespace is None: + common_whitespace = whitespace + elif common_whitespace != whitespace: + common_whitespace = MIXED + break + + if common_whitespace == " ": + return cls.ONLY_ON_TAB + elif common_whitespace == "\t": + return cls.ONLY_ON_SPACE + else: + return cls.ALWAYS diff --git a/python/selfie-lib/selfie_lib/Literals.py b/python/selfie-lib/selfie_lib/Literals.py index 0819b2c9..b3e17680 100644 --- a/python/selfie-lib/selfie_lib/Literals.py +++ b/python/selfie-lib/selfie_lib/Literals.py @@ -212,27 +212,10 @@ def _unescape_python(self, source: str) -> str: return value.getvalue() def parseMultiPython(self, source_with_quotes: str) -> str: - assert source_with_quotes.startswith(TRIPLE_QUOTE + "\n") + assert source_with_quotes.startswith(TRIPLE_QUOTE) assert source_with_quotes.endswith(TRIPLE_QUOTE) - - source = source_with_quotes[len(TRIPLE_QUOTE) + 1 : -len(TRIPLE_QUOTE)] - lines = source.split("\n") - - common_prefix = min( - (line[: len(line) - len(line.lstrip())] for line in lines if line.strip()), - default="", - ) - - def remove_common_prefix(line: str) -> str: - return line[len(common_prefix) :] if common_prefix else line - - def handle_escape_sequences(line: str) -> str: - return self._unescape_python(line.rstrip()) - - return "\n".join( - handle_escape_sequences(remove_common_prefix(line)) - for line in lines - if line.strip() + return self._unescape_python( + source_with_quotes[len(TRIPLE_QUOTE) : -len(TRIPLE_QUOTE)] ) diff --git a/python/selfie-lib/selfie_lib/SourceFile.py b/python/selfie-lib/selfie_lib/SourceFile.py index b621b4f0..24962b02 100644 --- a/python/selfie-lib/selfie_lib/SourceFile.py +++ b/python/selfie-lib/selfie_lib/SourceFile.py @@ -208,7 +208,7 @@ def _parse_code( # If all parentheses are closed, return the current index if parenthesis_count == 0: end_paren = i - end_arg = i - 1 + end_arg = i return (end_paren, end_arg) # else ... raise AssertionError( diff --git a/python/selfie-lib/tests/LiteralString_test.py b/python/selfie-lib/tests/LiteralString_test.py index d19c7a63..3fbaf2d2 100644 --- a/python/selfie-lib/tests/LiteralString_test.py +++ b/python/selfie-lib/tests/LiteralString_test.py @@ -50,10 +50,10 @@ def test_parse_single(self, value, expected): @pytest.mark.parametrize( ("value", "expected"), [ - ("\n123\nabc", "123\nabc"), - ("\n 123\n abc", "123\nabc"), - ("\n 123 \n abc\t", "123\nabc"), - ("\n 123 \\s\n abc\t\\s", "123 \nabc\t "), + ("\n123\nabc", "\n123\nabc"), + ("\n 123\n abc", "\n 123\n abc"), + ("\n 123 \n abc\t", "\n 123 \n abc\t"), + (" 123 \n abc\t", " 123 \n abc\t"), ], ) def test_parse_multi(self, value, expected):