diff --git a/src/syrupy/assertion.py b/src/syrupy/assertion.py index 3c2b89fb..1ff7b353 100644 --- a/src/syrupy/assertion.py +++ b/src/syrupy/assertion.py @@ -377,11 +377,7 @@ def _recall_data( ) -> Tuple[Optional["SerializableData"], bool]: try: return ( - self.extension.read_snapshot( - test_location=self.test_location, - index=index, - session_id=str(id(self.session)), - ), + self.session.recall_snapshot(self.extension, self.test_location, index), False, ) except SnapshotDoesNotExist: diff --git a/src/syrupy/location.py b/src/syrupy/location.py index 0f955bb8..8a85d0b9 100644 --- a/src/syrupy/location.py +++ b/src/syrupy/location.py @@ -13,7 +13,7 @@ from syrupy.constants import PYTEST_NODE_SEP -@dataclass +@dataclass(frozen=True) class PyTestLocation: item: "pytest.Item" nodename: Optional[str] = field(init=False) @@ -23,27 +23,42 @@ class PyTestLocation: filepath: str = field(init=False) def __post_init__(self) -> None: + # NB. we're in a frozen dataclass, but need to transform the values that the caller + # supplied... we do so by (ab)using object.__setattr__ to forcibly set the attributes. (See + # rejected PEP-0712 for an example of a better way to handle this.) + # + # This is safe because this all happens during initialization: `self` hasn't been hashed + # (or, e.g., stored in a dict), so the mutation won't be noticed. if self.is_doctest: return self.__attrs_post_init_doc__() self.__attrs_post_init_def__() def __attrs_post_init_def__(self) -> None: node_path: Path = getattr(self.item, "path") # noqa: B009 - self.filepath = str(node_path.absolute()) + # See __post_init__ for discussion of object.__setattr__ + object.__setattr__(self, "filepath", str(node_path.absolute())) obj = getattr(self.item, "obj") # noqa: B009 - self.modulename = obj.__module__ - self.methodname = obj.__name__ - self.nodename = getattr(self.item, "name", None) - self.testname = self.nodename or self.methodname + object.__setattr__(self, "modulename", obj.__module__) + object.__setattr__(self, "methodname", obj.__name__) + object.__setattr__(self, "nodename", getattr(self.item, "name", None)) + object.__setattr__(self, "testname", self.nodename or self.methodname) def __attrs_post_init_doc__(self) -> None: doctest = getattr(self.item, "dtest") # noqa: B009 - self.filepath = doctest.filename + # See __post_init__ for discussion of object.__setattr__ + object.__setattr__(self, "filepath", doctest.filename) test_relfile, test_node = self.nodeid.split(PYTEST_NODE_SEP) test_relpath = Path(test_relfile) - self.modulename = ".".join([*test_relpath.parent.parts, test_relpath.stem]) - self.nodename = test_node.replace(f"{self.modulename}.", "") - self.testname = self.nodename or self.methodname + object.__setattr__( + self, + "modulename", + ".".join([*test_relpath.parent.parts, test_relpath.stem]), + ) + object.__setattr__(self, "methodname", None) + object.__setattr__( + self, "nodename", test_node.replace(f"{self.modulename}.", "") + ) + object.__setattr__(self, "testname", self.nodename or self.methodname) @property def classname(self) -> Optional[str]: diff --git a/src/syrupy/session.py b/src/syrupy/session.py index 9770948a..cba70a65 100644 --- a/src/syrupy/session.py +++ b/src/syrupy/session.py @@ -46,6 +46,10 @@ class ItemStatus(Enum): SKIPPED = "skipped" +_QueuedWriteExtensionKey = Tuple[Type["AbstractSyrupyExtension"], str] +_QueuedWriteTestLocationKey = Tuple["PyTestLocation", "SnapshotIndex"] + + @dataclass class SnapshotSession: pytest_session: "pytest.Session" @@ -62,10 +66,28 @@ class SnapshotSession: default_factory=lambda: defaultdict(set) ) - _queued_snapshot_writes: Dict[ - Tuple[Type["AbstractSyrupyExtension"], str], - List[Tuple["SerializedData", "PyTestLocation", "SnapshotIndex"]], - ] = field(default_factory=dict) + # For performance, we buffer snapshot writes in memory before flushing them to disk. In + # particular, we want to be able to write to a file on disk only once, rather than having to + # repeatedly rewrite it. + # + # That batching leads to using two layers of dicts here: the outer layer represents the + # extension/file-location pair that will be written, and the inner layer represents the + # snapshots within that, "indexed" to allow efficient recall. + _queued_snapshot_writes: DefaultDict[ + _QueuedWriteExtensionKey, + Dict[_QueuedWriteTestLocationKey, "SerializedData"], + ] = field(default_factory=lambda: defaultdict(dict)) + + def _snapshot_write_queue_keys( + self, + extension: "AbstractSyrupyExtension", + test_location: "PyTestLocation", + index: "SnapshotIndex", + ) -> Tuple[_QueuedWriteExtensionKey, _QueuedWriteTestLocationKey]: + snapshot_location = extension.get_location( + test_location=test_location, index=index + ) + return (extension.__class__, snapshot_location), (test_location, index) def queue_snapshot_write( self, @@ -74,13 +96,10 @@ def queue_snapshot_write( data: "SerializedData", index: "SnapshotIndex", ) -> None: - snapshot_location = extension.get_location( - test_location=test_location, index=index + ext_key, loc_key = self._snapshot_write_queue_keys( + extension, test_location, index ) - key = (extension.__class__, snapshot_location) - queue = self._queued_snapshot_writes.get(key, []) - queue.append((data, test_location, index)) - self._queued_snapshot_writes[key] = queue + self._queued_snapshot_writes[ext_key][loc_key] = data def flush_snapshot_write_queue(self) -> None: for ( @@ -89,9 +108,33 @@ def flush_snapshot_write_queue(self) -> None: ), queued_write in self._queued_snapshot_writes.items(): if queued_write: extension_class.write_snapshot( - snapshot_location=snapshot_location, snapshots=queued_write + snapshot_location=snapshot_location, + snapshots=[ + (data, loc, index) + for (loc, index), data in queued_write.items() + ], ) - self._queued_snapshot_writes = {} + self._queued_snapshot_writes.clear() + + def recall_snapshot( + self, + extension: "AbstractSyrupyExtension", + test_location: "PyTestLocation", + index: "SnapshotIndex", + ) -> Optional["SerializedData"]: + """Find the current value of the snapshot, for this session, either a pending write or the actual snapshot.""" + + ext_key, loc_key = self._snapshot_write_queue_keys( + extension, test_location, index + ) + data = self._queued_snapshot_writes[ext_key].get(loc_key) + if data is not None: + return data + + # No matching write queued, so just read the snapshot directly: + return extension.read_snapshot( + test_location=test_location, index=index, session_id=str(id(self)) + ) @property def update_snapshots(self) -> bool: diff --git a/tests/integration/test_snapshot_diff.py b/tests/integration/test_snapshot_diff.py new file mode 100644 index 00000000..302b4d94 --- /dev/null +++ b/tests/integration/test_snapshot_diff.py @@ -0,0 +1,56 @@ +import pytest + +_TEST = """ +def test_foo(snapshot): + assert {**base} == snapshot(name="a") + assert {**base, **extra} == snapshot(name="b", diff="a") +""" + + +def _make_file(testdir, base, extra): + testdir.makepyfile( + test_file="\n\n".join([f"base = {base!r}", f"extra = {extra!r}", _TEST]) + ) + + +def _run_test(testdir, base, extra, expected_update_lines): + _make_file(testdir, base=base, extra=extra) + + # Run with --snapshot-update, to generate/update snapshots: + result = testdir.runpytest( + "-v", + "--snapshot-update", + ) + result.stdout.re_match_lines((expected_update_lines,)) + assert result.ret == 0 + + # Run without --snapshot-update, to validate the snapshots are actually up-to-date + result = testdir.runpytest("-v") + result.stdout.re_match_lines((r"2 snapshots passed\.",)) + assert result.ret == 0 + + +def test_diff_lifecycle(testdir) -> pytest.Testdir: + # first: create both snapshots completely from scratch + _run_test( + testdir, + base={"A": 1}, + extra={"X": 10}, + expected_update_lines=r"2 snapshots generated\.", + ) + + # second: edit the base data, to change the data for both snapshots (only changes the serialized output for the base snapshot `a`). + _run_test( + testdir, + base={"A": 1, "B": 2}, + extra={"X": 10}, + expected_update_lines=r"1 snapshot passed. 1 snapshot updated\.", + ) + + # third: edit just the extra data (only changes the serialized output for the diff snapshot `b`) + _run_test( + testdir, + base={"A": 1, "B": 2}, + extra={"X": 10, "Y": 20}, + expected_update_lines=r"1 snapshot passed. 1 snapshot updated\.", + )