diff --git a/.githooks.ini b/.githooks.ini index 523df812..e4ecf4ea 100644 --- a/.githooks.ini +++ b/.githooks.ini @@ -1,2 +1,3 @@ [pre-commit] command = inv lint + diff --git a/setup.py b/setup.py index d355dc1d..0b46636e 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,6 @@ def readme() -> str: use_scm_version={"write_to": "version.txt"}, package_dir={"": "src"}, packages=["syrupy"], - py_modules=["syrupy"], zip_safe=False, install_requires=[], setup_requires=["setuptools_scm"], diff --git a/src/syrupy/__init__.py b/src/syrupy/__init__.py index befe8f61..e98e4ca0 100644 --- a/src/syrupy/__init__.py +++ b/src/syrupy/__init__.py @@ -52,8 +52,11 @@ def snapshot(request): classname=request.cls.__name__ if request.cls else None, methodname=request.function.__name__ if request.function else None, nodename=getattr(request.node, "name", ""), - testname=getattr(request.node, "name", "") - or (request.function.__name__ if request.function else None), + testname=getattr( + request.node, + "name", + request.function.__name__ if request.function else None, + ), ) return SnapshotAssertion( update_snapshots=request.config.option.update_snapshots, diff --git a/src/syrupy/assertion.py b/src/syrupy/assertion.py index 7175a38e..f21a2448 100644 --- a/src/syrupy/assertion.py +++ b/src/syrupy/assertion.py @@ -1,7 +1,7 @@ import traceback import pytest import os -from typing import List, Optional, Any +from typing import Any, Callable, List, Optional, Type from .exceptions import SnapshotDoesNotExist @@ -15,8 +15,8 @@ def __init__( self, *, update_snapshots: bool, - io_class: SnapshotIO, - serializer_class: SnapshotSerializer, + io_class: Type[SnapshotIO], + serializer_class: Type[SnapshotSerializer], test_location: TestLocation, session, ): @@ -25,10 +25,14 @@ def __init__( self._serializer_class = serializer_class self._test_location = test_location self._executions = 0 - self._session = session + + from .session import SnapshotSession + + self._session: SnapshotSession = session + self._session.register_request(self) @property - def io(self): + def io(self) -> SnapshotIO: if not getattr(self, "_io", None): self._io = self._io_class( test_location=self._test_location, file_hook=self._file_hook @@ -36,13 +40,19 @@ def io(self): return self._io @property - def serializer(self): + def serializer(self) -> SnapshotSerializer: if not getattr(self, "_serializer", None): self._serializer = self._serializer_class() return self._serializer + @property + def num_executions(self) -> int: + return int(self._executions) + def with_class( - self, io_class: SnapshotIO = None, serializer_class: SnapshotSerializer = None + self, + io_class: Type[SnapshotIO] = None, + serializer_class: Type[SnapshotSerializer] = None, ): return self.__class__( update_snapshots=self._update_snapshots, @@ -56,7 +66,7 @@ def assert_match(self, data) -> bool: return self._assert(data) def get_assert_diff(self, data) -> List[str]: - deserialized = self._recall_data(index=self._executions - 1) + deserialized = self._recall_data(index=self.num_executions - 1) if deserialized is None: return ["Snapshot does not exist!"] @@ -65,11 +75,11 @@ def get_assert_diff(self, data) -> List[str]: return [] - def _file_hook(self, filepath): - self._session.add_visited_file(filepath) + def _file_hook(self, filepath, snapshot_name): + self._session.add_visited_snapshots({filepath: {snapshot_name}}) def __repr__(self) -> str: - return f"" + return f"" def __call__(self, data) -> bool: return self._assert(data) @@ -78,26 +88,25 @@ def __eq__(self, other) -> bool: return self._assert(other) def _assert(self, data) -> bool: - executions = self._executions - self._executions += 1 - - if self._update_snapshots: - serialized_data = self.serializer.encode(data) - self.io.pre_write(serialized_data, index=executions) - filepath = self.io.write(serialized_data, index=executions) - self.io.post_write(serialized_data, index=executions) + self._session.register_assertion(self) + try: + if self._update_snapshots: + serialized_data = self.serializer.encode(data) + self.io.create_or_update_snapshot( + serialized_data, index=self.num_executions + ) + return True + + deserialized = self._recall_data(index=self.num_executions) + if deserialized is None or data != deserialized: + return False return True - - deserialized = self._recall_data(index=executions) - if deserialized is None or data != deserialized: - return False - return True + finally: + self._executions += 1 def _recall_data(self, index: int) -> Optional[Any]: try: - self.io.pre_read(index=index) - saved_data = self.io.read(index=index) - self.io.post_read(index=index) + saved_data = self.io.read_snapshot(index=index) return self.serializer.decode(saved_data) except SnapshotDoesNotExist: return None diff --git a/src/syrupy/io.py b/src/syrupy/io.py index a6755563..7db16842 100644 --- a/src/syrupy/io.py +++ b/src/syrupy/io.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Callable, Optional, Set import os import yaml @@ -6,6 +6,7 @@ from .constants import SNAPSHOT_DIRNAME from .exceptions import SnapshotDoesNotExist from .location import TestLocation +from .types import SnapshotFiles class SnapshotIO: @@ -13,33 +14,79 @@ def __init__(self, test_location: TestLocation, file_hook): self._test_location = test_location self._file_hook = file_hook - def pre_write(self, data: Any, index: int = 0): - self.ensure_snapshot_dir(index) + @property + def test_location(self): + return self._test_location - def write(self, data: Any, index: int = 0): - snapshot_name = self.get_snapshot_name(index) - snapshots = self._load_documents(index) - snapshots[snapshot_name] = snapshots.get(snapshot_name, {}) - snapshots[snapshot_name]["data"] = data - with open(self.get_filepath(index), "w") as f: - yaml.safe_dump(snapshots, f) + @property + def dirname(self) -> str: + test_dirname = os.path.dirname(self.test_location.filename) + snapshot_dir = self._get_snapshot_dirname() + if snapshot_dir is not None: + return os.path.join(test_dirname, SNAPSHOT_DIRNAME, snapshot_dir) + return os.path.join(test_dirname, SNAPSHOT_DIRNAME) - def post_write(self, data: Any, index: int = 0): - self._file_hook(self.get_filepath(index)) + def discover_snapshots(self, filepath: str) -> Set[str]: + """ + Utility method for getting all the snapshots from a file. + Returns an empty set if the file cannot be read. + """ + try: + return set(self._read_file(filepath).keys()) + except: + return set() + + def read_snapshot(self, index: int) -> Any: + """ + Utility method for reading the contents of a snapshot assertion. + Will call `pre_read`, then `read` and finally `post_read`, + returning the contents parsed from the `read` method. + """ + try: + self.pre_read(index=index) + return self.read(index=index) + finally: + self.post_read(index=index) + + def create_or_update_snapshot(self, serialized_data: Any, index: int): + """ + Utility method for reading the contents of a snapshot assertion. + Will call `pre_write`, then `write` and finally `post_write`. + """ + self.pre_write(serialized_data, index=index) + self.write(serialized_data, index=index) + self.post_write(serialized_data, index=index) + + def delete_snapshot(self, snapshot_file: str, snapshot_name: str): + """ + Utility method for removing a snapshot from a snapshot file. + """ + self._write_snapshot_or_remove_file(snapshot_file, snapshot_name, None) def pre_read(self, index: int = 0): pass def read(self, index: int = 0) -> Any: + snapshot_file = self.get_filepath(index) snapshot_name = self.get_snapshot_name(index) - snapshots = self._load_documents(index) - snapshot = snapshots.get(snapshot_name, None) + snapshot = self._read_snapshot_from_file(snapshot_file, snapshot_name) if snapshot is None: raise SnapshotDoesNotExist() - return snapshot["data"] + return snapshot def post_read(self, index: int = 0): - self._file_hook(self.get_filepath(index)) + self._snap_file_hook(index) + + def pre_write(self, data: Any, index: int = 0): + self._ensure_snapshot_dir(index) + + def write(self, data: Any, index: int = 0): + snapshot_file = self.get_filepath(index) + snapshot_name = self.get_snapshot_name(index) + self._write_snapshot_or_remove_file(snapshot_file, snapshot_name, data) + + def post_write(self, data: Any, index: int = 0): + self._snap_file_hook(index) def get_snapshot_name(self, index: int = 0) -> str: index_suffix = f".{index}" if index > 0 else "" @@ -49,37 +96,74 @@ def get_snapshot_name(self, index: int = 0) -> str: return f"{self._test_location.classname}.{methodname}{index_suffix}" return f"{methodname}{index_suffix}" - def get_snapshot_dirname(self) -> Optional[str]: - return None - def get_filepath(self, index: int) -> str: basename = self.get_file_basename(index=index) - return os.path.join(self._get_dirname(), basename) + return os.path.join(self.dirname, basename) def get_file_basename(self, index: int) -> str: - return f"{os.path.basename(self._test_location.filename)[: -len('.py')]}.yaml" + return f"{os.path.splitext(os.path.basename(self._test_location.filename))[0]}.yaml" + + def _get_snapshot_dirname(self) -> Optional[str]: + return None - def ensure_snapshot_dir(self, index: int): + def _ensure_snapshot_dir(self, index: int): + """ + Ensures the folder path for the snapshot file exists. + """ try: os.makedirs(os.path.dirname(self.get_filepath(index))) except FileExistsError: pass - @property - def test_location(self): - return self._test_location - - def _load_documents(self, index: int) -> dict: + def _read_snapshot_from_file(self, snapshot_file: str, snapshot_name: str) -> Any: + """ + Read the snapshot file and get only the snapshot data for assertion + """ + snapshots = self._read_file(snapshot_file) + return snapshots.get(snapshot_name, {}).get("data", None) + + def _read_file(self, filepath: str) -> Any: + """ + Read the snapshot data from the snapshot file into a python instance. + """ try: - with open(self.get_filepath(index), "r") as f: + with open(filepath, "r") as f: return yaml.safe_load(f) or {} except FileNotFoundError: pass return {} - def _get_dirname(self) -> str: - test_dirname = os.path.dirname(self._test_location.filename) - snapshot_dir = self.get_snapshot_dirname() - if snapshot_dir is not None: - return os.path.join(test_dirname, SNAPSHOT_DIRNAME, snapshot_dir) - return os.path.join(test_dirname, SNAPSHOT_DIRNAME) + def _write_snapshot_or_remove_file( + self, snapshot_file: str, snapshot_name: str, data: Any + ): + """ + Adds the snapshot data to the snapshots read from the file + or removes the snapshot entry if data is `None`. + If the snapshot file will be empty remove the entire file. + """ + snapshots = self._read_file(snapshot_file) + if data is None and snapshot_name in snapshots: + del snapshots[snapshot_name] + else: + snapshots[snapshot_name] = snapshots.get(snapshot_name, {}) + snapshots[snapshot_name]["data"] = data + + if snapshots: + self._write_file(snapshot_file, snapshots) + else: + os.remove(snapshot_file) + + def _write_file(self, filepath: str, data: Any): + """ + Writes the snapshot data into the snapshot file that be read later. + """ + with open(filepath, "w") as f: + yaml.safe_dump(data, f) + + def _snap_file_hook(self, index: int): + """ + Notify the assertion of an access to a snapshot in a file + """ + snapshot_file = self.get_filepath(index) + snapshot_name = self.get_snapshot_name(index) + self._file_hook(snapshot_file, snapshot_name) diff --git a/src/syrupy/plugins/image/__init__.py b/src/syrupy/plugins/image/__init__.py index 6bfdf9c9..3c9b358d 100644 --- a/src/syrupy/plugins/image/__init__.py +++ b/src/syrupy/plugins/image/__init__.py @@ -4,11 +4,11 @@ class PNGImageSnapshotIO(AbstractImageSnapshotIO): @property - def extension(self): + def extension(self) -> str: return "png" class SVGImageSnapshotIO(AbstractImageSnapshotIO): @property - def extension(self): + def extension(self) -> str: return "svg" diff --git a/src/syrupy/plugins/image/io.py b/src/syrupy/plugins/image/io.py index 775a49c2..3987f16b 100644 --- a/src/syrupy/plugins/image/io.py +++ b/src/syrupy/plugins/image/io.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Set, Optional from abc import ABC, abstractmethod import re @@ -6,35 +6,47 @@ from syrupy.io import SnapshotIO from syrupy.exceptions import SnapshotDoesNotExist +from syrupy.types import SnapshotFiles class AbstractImageSnapshotIO(ABC, SnapshotIO): @property @abstractmethod - def extension(self): - return None + def extension(self) -> str: + pass - def write(self, data, index: int = 0): - with open(self.get_filepath(index), "wb") as f: - f.write(data) + def discover_snapshots(self, filepath: str) -> Set[str]: + return {os.path.splitext(os.path.basename(filepath))[0]} + + def get_file_basename(self, index: int) -> str: + maybe_extension = f".{self.extension}" if self.extension else "" + sanitized_name = self._clean_filename(self.get_snapshot_name(index=index)) + return f"{sanitized_name}{maybe_extension}" + + def _get_snapshot_dirname(self): + return os.path.splitext(os.path.basename(str(self.test_location.filename)))[0] - def read(self, index: int = 0) -> Any: + def _read_snapshot_from_file(self, snapshot_file: str, _): + return self._read_file(snapshot_file) + + def _read_file(self, filepath: str) -> Any: try: - with open(self.get_filepath(index), "rb") as f: + with open(filepath, "rb") as f: return f.read() except FileNotFoundError: - raise SnapshotDoesNotExist() + return None - def get_snapshot_dirname(self) -> str: - return os.path.basename(str(self.test_location.filename)[: -len(".py")]) + def _write_snapshot_or_remove_file(self, snapshot_file: str, _: str, data: Any): + if data: + self._write_file(snapshot_file, data) + else: + os.remove(snapshot_file) - def get_file_basename(self, index: int) -> str: - ext = f".{self.extension}" - sanitized_name = self.get_valid_filename(self.get_snapshot_name(index=index))[ - : 255 - len(ext) - ] - return f"{sanitized_name}{ext}" + def _write_file(self, filepath: str, data: Any): + with open(filepath, "wb") as f: + f.write(data) - def get_valid_filename(self, filename: str) -> str: + def _clean_filename(self, filename: str) -> str: filename = str(filename).strip().replace(" ", "_") - return re.sub(r"(?u)[^-\w.]", "", filename) + max_filename_length = 255 - len(self.extension or "") + return re.sub(r"(?u)[^-\w.]", "", filename)[:max_filename_length] diff --git a/src/syrupy/serializer.py b/src/syrupy/serializer.py index 1a790e20..ad975dee 100644 --- a/src/syrupy/serializer.py +++ b/src/syrupy/serializer.py @@ -1,3 +1,6 @@ +from typing import Callable + + class SnapshotSerializer: def __init__(self): pass diff --git a/src/syrupy/session.py b/src/syrupy/session.py index e079e394..c7a4c2cf 100644 --- a/src/syrupy/session.py +++ b/src/syrupy/session.py @@ -1,44 +1,52 @@ import os +from collections import defaultdict +from functools import lru_cache from gettext import ngettext, gettext -from typing import List, Set +from typing import Dict, List, Set, Tuple +from .assertion import SnapshotAssertion from .constants import SNAPSHOT_DIRNAME from .terminal import yellow, bold +from .types import SnapshotFiles class SnapshotSession: def __init__(self, *, update_snapshots: bool, base_dir: str): self.update_snapshots = update_snapshots self.base_dir = base_dir - self.discovered_snapshots: Set[str] = set() - self.visited_snapshots: Set[str] = set() + self.discovered_snapshots: SnapshotFiles = dict() + self.visited_snapshots: SnapshotFiles = dict() self.report: List[str] = [] + self._assertions: Dict[str, Dict[str, SnapshotAssertion]] = dict() - def start(self): - self.report = [] - self.visited_snapshots = set() - self.discovered_snapshots = set( - filepath for filepath in self._walk_dir(self.base_dir) + @property + def unused_snapshots(self) -> SnapshotFiles: + return self._diff_snapshot_files( + self.discovered_snapshots, self.visited_snapshots ) @property - def unused_snapshots(self): - return self.discovered_snapshots - self.visited_snapshots + def written_snapshots(self) -> SnapshotFiles: + return self._diff_snapshot_files( + self.visited_snapshots, self.discovered_snapshots + ) @property - def written_snapshots(self): - return self.visited_snapshots - self.discovered_snapshots + def num_unused_snapshots(self): + return self._count_snapshots(self.unused_snapshots) - def add_visited_file(self, filepath: str): - if self._in_snapshot_dir(filepath): - self.visited_snapshots.add(filepath) + @property + def num_written_snapshots(self): + return self._count_snapshots(self.written_snapshots) - def add_report_line(self, line: str = ""): - self.report += [line] + def start(self): + self.report = [] + self.visited_snapshots = dict() + self.discovered_snapshots = dict() def finish(self): - n_unused = len(self.unused_snapshots) - n_written = len(self.written_snapshots) + n_unused = self.num_unused_snapshots + n_written = self.num_written_snapshots self.add_report_line() @@ -46,15 +54,13 @@ def finish(self): if self.update_snapshots and n_written: summary_lines += [ ngettext( - "{} snapshot file generated.", - "{} snapshot files generated.", - n_written, + "{} snapshot generated.", "{} snapshots generated.", n_written, ).format(bold(n_written)) ] summary_lines += [ - ngettext( - "{} snapshot file unused.", "{} snapshot files unused.", n_unused - ).format(bold(n_unused)) + ngettext("{} snapshot unused.", "{} snapshots unused.", n_unused).format( + bold(n_unused) + ) ] summary_line = " ".join(summary_lines) self.add_report_line( @@ -63,8 +69,18 @@ def finish(self): else summary_line ) - for filepath in self.unused_snapshots: - self.add_report_line(f" {os.path.relpath(filepath, self.base_dir)}") + for filepath, snapshots in self.unused_snapshots.items(): + count = self._count_snapshots({filepath: snapshots}) + if not count: + continue + path_to_file = os.path.relpath(filepath, self.base_dir) + self.add_report_line( + ngettext( + f"{{}} at {path_to_file}", + f"{{}} in {path_to_file} → {', '.join(snapshots)}", + count, + ).format(bold(count)) + ) if n_unused: self.add_report_line() @@ -72,28 +88,83 @@ def finish(self): self.remove_unused_snapshots() self.add_report_line( ngettext( - "This file has been deleted.", - "These files have been deleted.", + "This snapshot has been deleted.", + "These snapshots have been deleted.", n_unused, ) ) else: self.add_report_line( gettext( - "Re-run pytest with --update-snapshots to delete these files." + "Re-run pytest with --update-snapshots to delete the snapshots." ) ) + def add_report_line(self, line: str = ""): + self.report += [line] + + def register_request(self, assertion: SnapshotAssertion): + discovered = { + filepath: assertion.io.discover_snapshots(filepath) + for filepath in self._walk_dir(assertion.io.dirname) + } + self.add_discovered_snapshots(discovered) + + def register_assertion(self, assertion: SnapshotAssertion): + filepath = assertion.io.get_filepath(assertion.num_executions) + snapshot = assertion.io.get_snapshot_name(assertion.num_executions) + self.add_visited_snapshots({filepath: {snapshot}}) + + if filepath not in self._assertions: + self._assertions[filepath] = dict() + self._assertions[filepath][snapshot] = assertion + + def add_discovered_snapshots(self, snapshots: SnapshotFiles): + self._merge_snapshot_files_into(self.discovered_snapshots, snapshots) + + def add_visited_snapshots(self, snapshots: SnapshotFiles): + self._merge_snapshot_files_into(self.visited_snapshots, snapshots) + def remove_unused_snapshots(self): - for snapshot_file in self.unused_snapshots: - os.remove(snapshot_file) + for snapshot_file, unused_snapshots in self.unused_snapshots.items(): + if self.discovered_snapshots[snapshot_file] == unused_snapshots: + os.remove(snapshot_file) + continue + snapshot_assertion, *_ = self._assertions[snapshot_file].values() + for snapshot_name in unused_snapshots: + snapshot_assertion.io.delete_snapshot(snapshot_file, snapshot_name) + + def _merge_snapshot_files_into( + self, snapshot_files: SnapshotFiles, *snapshot_files_to_merge: SnapshotFiles, + ): + """ + Add snapshots from other files into the first one + """ + for snapshot_file in snapshot_files_to_merge: + for filepath, snapshots in snapshot_file.items(): + if self._in_snapshot_dir(filepath): + if filepath not in snapshot_files: + snapshot_files[filepath] = set() + snapshot_files[filepath].update(snapshots) + + def _diff_snapshot_files( + self, snapshot_files1: SnapshotFiles, snapshot_files2: SnapshotFiles, + ) -> SnapshotFiles: + return { + filename: snapshots1 - snapshot_files2.get(filename, set()) + for filename, snapshots1 in snapshot_files1.items() + } + + def _count_snapshots(self, snapshot_files: SnapshotFiles) -> int: + return sum(len(snaps) for snaps in snapshot_files.values()) def _in_snapshot_dir(self, path: str) -> bool: parts = path.split(os.path.sep) return SNAPSHOT_DIRNAME in parts + @lru_cache(maxsize=32) def _walk_dir(self, root: str): - for (dirpath, dirnames, filenames) in os.walk(root): + for (dirpath, _, filenames) in os.walk(root): if not self._in_snapshot_dir(dirpath): continue for filename in filenames: diff --git a/src/syrupy/types.py b/src/syrupy/types.py new file mode 100644 index 00000000..b0fcc99a --- /dev/null +++ b/src/syrupy/types.py @@ -0,0 +1,3 @@ +from typing import Dict, Set + +SnapshotFiles = Dict[str, Set[str]] diff --git a/tasks.py b/tasks.py index cd1428a9..7cab86da 100644 --- a/tasks.py +++ b/tasks.py @@ -29,9 +29,12 @@ def lint(ctx, fix=False): @task -def test(ctx, update_snapshots=False): +def test(ctx, update_snapshots=False, verbose=False): ctx.run( - f"python -m pytest . {'--update-snapshots' if update_snapshots else ''}", + "python -m pytest ." + f"{' -s' if verbose else ''}" + f"{' --update-snapshots' if update_snapshots else ''}", + env={"PYTHONPATH": "./src"}, pty=True, ) diff --git a/tests/__snapshots__/test_snapshots.yaml b/tests/__snapshots__/test_snapshots.yaml index d6787149..428a0bd9 100644 --- a/tests/__snapshots__/test_snapshots.yaml +++ b/tests/__snapshots__/test_snapshots.yaml @@ -1,4 +1,4 @@ -test_dict: +test_dict[actual0]: data: a: e: false @@ -7,7 +7,7 @@ test_dict: d: - '1' - 2 -test_dict_1: +test_dict[actual1]: data: a: e: false diff --git a/tests/test_injection.py b/tests/test_injection.py index c59ba2f2..d8f1f01f 100644 --- a/tests/test_injection.py +++ b/tests/test_injection.py @@ -1,10 +1,9 @@ def test_fixture(testdir): - testdir.makepyfile( - """ - def test_sth(snapshot): - assert snapshot is not None - """ - ) + pyfile_content = """ +def test_sth(snapshot): + assert snapshot is not None +""" + testdir.makepyfile(pyfile_content) result = testdir.runpytest("-v") result.stdout.fnmatch_lines(["*::test_sth PASSED*"]) diff --git a/tests/test_snapshots.py b/tests/test_snapshots.py index 10dda08f..500614eb 100644 --- a/tests/test_snapshots.py +++ b/tests/test_snapshots.py @@ -20,11 +20,12 @@ def test_parametrized_with_special_char(snapshot, expected): assert expected == snapshot -def test_dict(snapshot): - actual = {"b": True, "c": "Some text.", "d": ["1", 2], "a": {"e": False}} - assert actual == snapshot - - -def test_dict_1(snapshot): - actual = {"b": True, "c": "Some ttext.", "d": ["1", 2], "a": {"e": False}} +@pytest.mark.parametrize( + "actual", + [ + {"b": True, "c": "Some text.", "d": ["1", 2], "a": {"e": False}}, + {"b": True, "c": "Some ttext.", "d": ["1", 2], "a": {"e": False}}, + ], +) +def test_dict(snapshot, actual): assert actual == snapshot