Skip to content

Commit

Permalink
refactor(wip): implement update correct
Browse files Browse the repository at this point in the history
  • Loading branch information
15r10nk committed Jul 19, 2023
1 parent 5df89af commit 050433a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 23 deletions.
32 changes: 14 additions & 18 deletions inline_snapshot/_inline_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,7 @@ def _needs_fix(self):
return not self.cmp(self._old_value, self._new_value)

def get_result(self, flags):
if (
flags.fix
and flags.trim
or flags.create
and self._old_value == undefined
or flags.update
):
if flags.create and self._needs_create():
return self._new_value

if flags.fix and self._needs_fix():
Expand Down Expand Up @@ -289,10 +283,8 @@ def _needs_fix(self):
return any(item not in self._old_value for item in self._new_value)

def get_result(self, flags):
if (
(flags.fix and flags.trim)
or (flags.create and self._old_value is undefined)
or flags.update
if (flags.fix and self._needs_fix() and flags.trim and self._needs_trim()) or (
flags.create and self._needs_create()
):
return self._new_value

Expand Down Expand Up @@ -538,7 +530,6 @@ def _change(self):

if (
_update_flags.update
and not (needs_fix or needs_create or needs_trim)
or _update_flags.fix
and needs_fix
or _update_flags.create
Expand All @@ -559,11 +550,16 @@ def _change(self):
)

def _current_tokens(self):
return [
(t.type, t.string)
for t in self._expr.source.asttokens().get_tokens(self._expr.node.args[0])
if t.type not in ignore_tokens
]
if not self._expr.node.args:
return []
else:
return [
(t.type, t.string)
for t in self._expr.source.asttokens().get_tokens(
self._expr.node.args[0]
)
if t.type not in ignore_tokens
]

@property
def _flags(self):
Expand All @@ -576,7 +572,7 @@ def _flags(self):
s.add("create")

if (
not s
"create" not in s
and self._expr is not None
and self._current_tokens() != self._value_to_token(self._value._old_value)
):
Expand Down
7 changes: 2 additions & 5 deletions tests/test_inline_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,9 @@ def gen_code(ops, fixed):

return code

all_flags = {op.flag for op in ops}

s = source(gen_code(ops, {}))

reported_flags = all_flags
if "update" in reported_flags and len(reported_flags) > 1:
reported_flags -= {"update"}
reported_flags = {op.flag for op in ops}

assert s.flags == reported_flags

Expand All @@ -241,6 +237,7 @@ def gen_code(ops, fixed):
fixed_flags = set()
for flag in flags:
if flag in {"create", "fix", "trim"}:
# sub-snapshots are always updated if the snapshot is changed
fixed_flags.add("update")

s2 = s2.run(flag)
Expand Down

0 comments on commit 050433a

Please sign in to comment.