From 050433a359948f3000e0b7bedaa00dc199d03029 Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Wed, 19 Jul 2023 19:24:31 +0200 Subject: [PATCH] refactor(wip): implement update correct --- inline_snapshot/_inline_snapshot.py | 32 +++++++++++++---------------- tests/test_inline_snapshot.py | 7 ++----- 2 files changed, 16 insertions(+), 23 deletions(-) diff --git a/inline_snapshot/_inline_snapshot.py b/inline_snapshot/_inline_snapshot.py index 40043d16..65f4b2b0 100644 --- a/inline_snapshot/_inline_snapshot.py +++ b/inline_snapshot/_inline_snapshot.py @@ -199,13 +199,7 @@ def _needs_fix(self): return not self.cmp(self._old_value, self._new_value) def get_result(self, flags): - if ( - flags.fix - and flags.trim - or flags.create - and self._old_value == undefined - or flags.update - ): + if flags.create and self._needs_create(): return self._new_value if flags.fix and self._needs_fix(): @@ -289,10 +283,8 @@ def _needs_fix(self): return any(item not in self._old_value for item in self._new_value) def get_result(self, flags): - if ( - (flags.fix and flags.trim) - or (flags.create and self._old_value is undefined) - or flags.update + if (flags.fix and self._needs_fix() and flags.trim and self._needs_trim()) or ( + flags.create and self._needs_create() ): return self._new_value @@ -538,7 +530,6 @@ def _change(self): if ( _update_flags.update - and not (needs_fix or needs_create or needs_trim) or _update_flags.fix and needs_fix or _update_flags.create @@ -559,11 +550,16 @@ def _change(self): ) def _current_tokens(self): - return [ - (t.type, t.string) - for t in self._expr.source.asttokens().get_tokens(self._expr.node.args[0]) - if t.type not in ignore_tokens - ] + if not self._expr.node.args: + return [] + else: + return [ + (t.type, t.string) + for t in self._expr.source.asttokens().get_tokens( + self._expr.node.args[0] + ) + if t.type not in ignore_tokens + ] @property def _flags(self): @@ -576,7 +572,7 @@ def _flags(self): s.add("create") if ( - not s + "create" not in s and self._expr is not None and self._current_tokens() != self._value_to_token(self._value._old_value) ): diff --git a/tests/test_inline_snapshot.py b/tests/test_inline_snapshot.py index 98009abe..f01732ec 100644 --- a/tests/test_inline_snapshot.py +++ b/tests/test_inline_snapshot.py @@ -225,13 +225,9 @@ def gen_code(ops, fixed): return code - all_flags = {op.flag for op in ops} - s = source(gen_code(ops, {})) - reported_flags = all_flags - if "update" in reported_flags and len(reported_flags) > 1: - reported_flags -= {"update"} + reported_flags = {op.flag for op in ops} assert s.flags == reported_flags @@ -241,6 +237,7 @@ def gen_code(ops, fixed): fixed_flags = set() for flag in flags: if flag in {"create", "fix", "trim"}: + # sub-snapshots are always updated if the snapshot is changed fixed_flags.add("update") s2 = s2.run(flag)