From b1de9b82529d500c0f030b3c17f71241f401689a Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Tue, 29 Nov 2022 15:25:38 +0200 Subject: [PATCH] sqlite: add basic implementation --- .github/workflows/tests.yml | 4 +- .gitignore | 3 + .pre-commit-config.yaml | 6 + MANIFEST.in | 1 + README.rst | 2 +- pyproject.toml | 28 +++ setup.cfg | 16 +- src/sqltrie/.trie.py.swp | Bin 12288 -> 0 bytes src/sqltrie/__init__.py | 22 ++- src/sqltrie/serialized.py | 89 +++++++++ src/sqltrie/sqlite.py | 159 ---------------- src/sqltrie/sqlite/__init__.py | 1 + src/sqltrie/sqlite/diff.sql | 117 ++++++++++++ src/sqltrie/sqlite/init.sql | 13 ++ src/sqltrie/sqlite/items.sql | 38 ++++ src/sqltrie/sqlite/sqlite.py | 305 +++++++++++++++++++++++++++++++ src/sqltrie/sqlite/steps.sql | 78 ++++++++ src/sqltrie/trie.py | 149 +++++++-------- tests/benchmarks/test_sqltrie.py | 38 ++++ tests/test_sqltrie.py | 76 ++++++-- 20 files changed, 882 insertions(+), 263 deletions(-) create mode 100644 MANIFEST.in delete mode 100644 src/sqltrie/.trie.py.swp create mode 100644 src/sqltrie/serialized.py delete mode 100644 src/sqltrie/sqlite.py create mode 100644 src/sqltrie/sqlite/__init__.py create mode 100644 src/sqltrie/sqlite/diff.sql create mode 100644 src/sqltrie/sqlite/init.sql create mode 100644 src/sqltrie/sqlite/items.sql create mode 100644 src/sqltrie/sqlite/sqlite.py create mode 100644 src/sqltrie/sqlite/steps.sql create mode 100644 tests/benchmarks/test_sqltrie.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5ab3450..58c56b9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -20,10 +20,8 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-20.04, windows-latest, macos-latest] + os: [ubuntu-22.04, windows-latest, macos-latest] pyv: ['3.8', '3.9', '3.10', '3.11'] - include: - - {os: ubuntu-latest, pyv: 'pypy3.8'} steps: - name: Check out the repository diff --git a/.gitignore b/.gitignore index a81c8ee..7dd8d78 100644 --- a/.gitignore +++ b/.gitignore @@ -136,3 +136,6 @@ dmypy.json # Cython debug symbols cython_debug/ + +# vim +*.swp diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 000fa4a..0f6390c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -53,3 +53,9 @@ repos: - id: bandit args: [-c, pyproject.toml] additional_dependencies: ["toml"] +# NOTE: temporarily skipped +# - repo: https://github.com/sqlfluff/sqlfluff +# rev: 1.4.2 +# hooks: +# - id: sqlfluff-fix +# args: [--FIX-EVEN-UNPARSABLE, --force] diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..06be1e3 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +global-include *.sql diff --git a/README.rst b/README.rst index a0df42c..f34fd42 100644 --- a/README.rst +++ b/README.rst @@ -1,5 +1,5 @@ SQLTrie -======= +======== |PyPI| |Status| |Python Version| |License| diff --git a/pyproject.toml b/pyproject.toml index 2cd2184..11852d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,8 +63,16 @@ warn_redundant_casts = true warn_unreachable = true files = ["src", "tests"] +[tool.pylint.master] +load-plugins = ["pylint_pytest"] + [tool.pylint.message_control] enable = ["c-extension-no-member", "no-else-return"] +disable = [ + "fixme", + "missing-function-docstring", "missing-module-docstring", + "missing-class-docstring", +] [tool.pylint.variables] dummy-variables-rgx = "_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_" @@ -76,3 +84,23 @@ ignore-words-list = " " [tool.bandit] exclude_dirs = ["tests"] skips = ["B101"] + +[tool.sqlfluff.core] +dialect = "sqlite" +exclude_rules = "L031" + +[tool.sqlfluff.rules] +tab_space_size = 4 +max_line_length = 80 +indent_unit = "space" +allow_scalar = true +single_table_references = "consistent" +unquoted_identifiers_policy = "all" + +[tool.sqlfluff.rules.L010] +capitalisation_policy = "upper" + +[tool.sqlfluff.rules.L029] +# these are not reserved in sqlite, +# see https://www.sqlite.org/lang_keywords.html +ignore_words = ["name", "value", "depth"] diff --git a/setup.cfg b/setup.cfg index 81f9983..1b445a8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,10 +5,18 @@ long_description = file: README.rst long_description_content_type = text/x-rst license = Apache-2.0 license_file = LICENSE -url = https://github.com/efiop/sqltrie +url = https://github.com/iterative/sqltrie platforms=any -authors = Ruslan Kuprieiev -maintainer_email = ruslan@iterative.ai +authors = DVC team +maintainer_email = support@dvc.org +keywords = + sqlite + sqlite3 + sql + trie + prefix tree + data-science + diskcache classifiers = Programming Language :: Python :: 3 Programming Language :: Python :: 3.8 @@ -23,11 +31,13 @@ zip_safe = False package_dir= =src packages = find: +include_package_data = True install_requires= [options.extras_require] tests = pytest==7.2.0 + pytest-benchmark pytest-sugar==0.9.5 pytest-cov==3.0.0 pytest-mock==3.8.2 diff --git a/src/sqltrie/.trie.py.swp b/src/sqltrie/.trie.py.swp deleted file mode 100644 index 316b9bc23dcce402133c666c51b92921949e6cea..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12288 zcmeI2OKclO7{{lCM;{al9B40enjn-!u2b3z7+E5;L5*m7NZJ4rVi@m^?ZxArWoF#O zNNt4>oZ-NM#|4RU0s-Q2K!_6u#DxRmDUje$D3@N4fCEVUXLlXjp^dTIOU118%dU53 zzWL_=eeJSIIFJATcfS9Bc_SeofcL?B;03SM3>sh_90t3=-`5iI7x)~! z3ElwDf)E7Y2$%*_U^Dpn8bZ#3Pr>Wp40sAK@F=(k>;gN%S;X!e;9~fdi=)510{?Rb zXnrN!7~-BoQKeS%6G9s?X@sm6XsklPsywhtw+l4U_)*eo`KXG;R0)eserX;VL60gs z-%O-vir9*9PHUyp8qTUzaZMP@G6T8ohG-9G#nhFymu@C=LYPtASmCn4tXknZ9N9*q zlE{i?>cVg$pk>$bf{Ph%NjpMJjIep*%VUqRP)Ii@#<~WX(#*&x7&2cNgLfCTRiG!h zYz=L)wt>mM7zo?aF|8;i(}5K0+%Pv0CSAtTJE)r)#ly2pZ0Y!Y#}?=3rq=S@)ikY& z`+;6pgVAg+&mK9x&N8~ch}D4;aBT?LwCx1Q_TlV^YL<4)wLn!XhJMv#=mPbeafeGGCaoTY;K7mr|{i^ZQgcKzaR4 zg}jv;cDK#_T$=vWx=_CmaEM;h+Lzp?n-4>DnpGdF*MYI=k^J%-9$uyYb=j&*a0G57L989LCPs#eX^S|Ck#W~yN4+YUu=X#t9RMMW5 zyjcl3T`f`OJ{LnFlC)s8K%1heda~Xv%~|2v+ML$t-NVRwAV^PZaW62xvn7|)6)%X?KLx1 zqI7~LZ5hS3M8nANiWFs(wA=~D>cHn?eT1kzLD$_4P_5%|JR^F*pte_L$qw7FyPMvA~lAPPnk8<|-{q zLnRDOik8-KcBDMpohfoy1rg*8=}KhlG_t90w QioixHHpQN*{X-@H0AXYFtN;K2 diff --git a/src/sqltrie/__init__.py b/src/sqltrie/__init__.py index 907ba6c..4246759 100644 --- a/src/sqltrie/__init__.py +++ b/src/sqltrie/__init__.py @@ -1,5 +1,17 @@ -"""SQLTrie.""" - -from .trie import AbstractTrie, ShortKeyError -from .sqlite import SQLiteTrie - +from .serialized import ( # noqa: F401, pylint: disable=unused-import + JSONTrie, + SerializedTrie, +) +from .sqlite import SQLiteTrie # noqa: F401, pylint: disable=unused-import +from .trie import ( # noqa: F401, pylint: disable=unused-import + ADD, + DELETE, + MODIFY, + RENAME, + UNCHANGED, + AbstractTrie, + Change, + ShortKeyError, + TrieKey, + TrieNode, +) diff --git a/src/sqltrie/serialized.py b/src/sqltrie/serialized.py new file mode 100644 index 0000000..5cc5ad8 --- /dev/null +++ b/src/sqltrie/serialized.py @@ -0,0 +1,89 @@ +import json +from abc import abstractmethod +from typing import Any, Optional + +from .trie import AbstractTrie, Iterator, TrieKey + + +class SerializedTrie(AbstractTrie): + @property + @abstractmethod + def _trie(self): + pass + + @abstractmethod + def _load(self, key: TrieKey, value: Optional[bytes]) -> Optional[Any]: + pass + + @abstractmethod + def _dump(self, key: TrieKey, value: Optional[Any]) -> Optional[bytes]: + pass + + def __setitem__(self, key, value): + self._trie[key] = self._dump(key, value) + + def __getitem__(self, key): + raw = self._trie[key] + return self._load(key, raw) + + def __delitem__(self, key): + del self._trie[key] + + def __len__(self): + return len(self._trie) + + def view(self, key: Optional[TrieKey] = None) -> "SerializedTrie": + if not key: + return self + + raw_trie = self._trie.view(key) + trie = type(self)() + # pylint: disable-next=protected-access + trie._trie = raw_trie # type: ignore + return trie + + def items(self, *args, **kwargs): + yield from ( + (key, self._load(key, raw)) + for key, raw in self._trie.items(*args, **kwargs) + ) + + def ls(self, key): + yield from self._trie.ls(key) + + def traverse(self, node_factory, prefix=None): + def _node_factory_wrapper(path_conv, path, children, value): + return node_factory( + path_conv, path, children, self._load(path, value) + ) + + return self._trie.traverse(_node_factory_wrapper, prefix=prefix) + + def diff(self, *args, **kwargs): + yield from self._trie.diff(*args, **kwargs) + + def has_node(self, key): + return self._trie.has_node(key) + + def shortest_prefix(self, key): + skey, raw = self._trie.shortest_prefix(key) + return key, self._load(skey, raw) + + def longest_prefix(self, key): + lkey, raw = self._trie.longest_prefix(key) + return lkey, self._load(lkey, raw) + + def __iter__(self) -> Iterator[TrieKey]: + yield from self._trie + + +class JSONTrie(SerializedTrie): # pylint: disable=abstract-method + def _load(self, key: TrieKey, value: Optional[bytes]) -> Optional[Any]: + if value is None: + return None + return json.loads(value.decode("utf-8")) + + def _dump(self, key: TrieKey, value: Optional[Any]) -> Optional[bytes]: + if value is None: + return None + return json.dumps(value).encode("utf-8") diff --git a/src/sqltrie/sqlite.py b/src/sqltrie/sqlite.py deleted file mode 100644 index 47a1a87..0000000 --- a/src/sqltrie/sqlite.py +++ /dev/null @@ -1,159 +0,0 @@ -import sqlite3 -from functools import cached_property -from .trie import AbstractTrie, ShortKeyError - -ROOT_ID = 1 - -class SQLiteTrie(AbstractTrie): - def __init__(self, *args, root_id=None, **kwargs): - self._root_id = root_id or ROOT_ID -# super().__init__(*args, **kwargs) - - @cached_property - def _conn(self): - conn = sqlite3.connect(":memory:") - conn.row_factory = sqlite3.Row - conn.executescript( - """ - CREATE TABLE nodes ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - pid INTEGER, - name TEXT, - value TEXT, - UNIQUE(pid, name) - ); - CREATE INDEX nodes_pid_idx ON nodes (pid); - INSERT INTO nodes (id, pid, name, value) VALUES (1, NULL, 'root', NULL) - """ - ) - return conn - - def _create_node(self, key): - pid = self._root_id - for name in key: - ret = self._conn.execute( - """ - SELECT id FROM nodes WHERE nodes.pid = ? AND nodes.name = ? - """, - (pid, name), - ).fetchone() - if ret is None: - self._conn.execute( - """ - INSERT OR IGNORE INTO nodes (pid, name) VALUES (?, ?) - """, - (pid, name), - ) - # FIXME this might not work on IGNORE - ret = self._conn.execute( - "SELECT last_insert_rowid() AS id" - ).fetchone() - pid = ret["id"] - return pid - - def _get_node(self, key): - if not key: - return { - "id": self._root_id, - "pid": None, - "name": None, - "value": None, - } - - pid = self._root_id - row = None - for name in key: - row = self._conn.execute( - """ - SELECT id, pid, name, value FROM nodes WHERE nodes.pid = ? AND nodes.name = ? - """, - (pid, name), - ).fetchone() - if row is None: - raise KeyError - pid = row["id"] - return row - - def _get_children(self, key, limit=None): - node = self._get_node(key) - - limit_sql = "" - if limit: - limit_sql = f"LIMIT {limit}" - - return self._conn.execute( - f""" - SELECT * FROM nodes WHERE nodes.pid == ? {limit_sql} - """, - (node["id"],) - ).fetchall() - - def _delete_node(self, key): - node = self._get_node(key) - self._conn.execute( - """ - DELETE FROM nodes WHERE id = ? - """, - (node["id"],), - ) - - def __setitem__(self, key, value): - nid = self._create_node(key) - self._conn.execute( - """ - UPDATE nodes SET value = ? WHERE id = ? - """, - ( - value, - nid, - ) - ) - - def __getitem__(self, key): - return self._get_node(key)["value"] - - def __delitem__(self, key): - node = self._get_node(key) - self._conn.execute( - f""" - UPDATE nodes SET value = NULL WHERE id == ? - """, - (node["id"],), - ) - - def __len__(self): - return self._conn.execute( - """ - SELECT COUNT(*) AS count FROM nodes WHERE nodes.value is not NULL - """ - ).fetchone()["count"] - - def iteritems(self, prefix=None, shallow=False): - assert not shallow - if prefix: - pid = self._get_node(prefix)["id"] - else: - pid = self._root_id - - cursor = self._conn.execute( - """ - WITH RECURSIVE - myfunc (id, pid, name, value) AS ( - SELECT * FROM nodes WHERE nodes.pid == ? - - UNION ALL - - SELECT nodes.id, nodes.pid, nodes.name, nodes.value - FROM nodes, myfunc WHERE myfunc.id == nodes.pid - ) - SELECT * FROM myfunc - """, - (pid,) - ) - - for row in cursor: - # FIXME can join name in the CTE - yield row - - def clear(self): - self._conn.execute("DELETE FROM nodes") diff --git a/src/sqltrie/sqlite/__init__.py b/src/sqltrie/sqlite/__init__.py new file mode 100644 index 0000000..162eebc --- /dev/null +++ b/src/sqltrie/sqlite/__init__.py @@ -0,0 +1 @@ +from .sqlite import SQLiteTrie # noqa: F401, pylint: disable=unused-import diff --git a/src/sqltrie/sqlite/diff.sql b/src/sqltrie/sqlite/diff.sql new file mode 100644 index 0000000..3999797 --- /dev/null +++ b/src/sqltrie/sqlite/diff.sql @@ -0,0 +1,117 @@ +DROP TABLE IF EXISTS temp_old_items; +DROP TABLE IF EXISTS temp_new_items; +DROP TABLE IF EXISTS temp_diff; + +CREATE TEMP TABLE temp_old_items AS +WITH RECURSIVE old_items (id, pid, name, path, value) AS ( + SELECT + nodes.id, + nodes.pid, + nodes.name, + nodes.name, + nodes.value + FROM nodes WHERE nodes.pid == {old_root} + + UNION ALL + + SELECT + nodes.id, + nodes.pid, + nodes.name, + old_items.path || '/' || nodes.name, + nodes.value + FROM nodes, old_items WHERE old_items.id == nodes.pid +) + +SELECT * FROM old_items; + +CREATE TEMP TABLE temp_new_items AS +WITH RECURSIVE new_items (id, pid, name, path, value) AS ( + SELECT + nodes.id, + nodes.pid, + nodes.name, + nodes.name, + nodes.value + FROM nodes WHERE nodes.pid == {new_root} + + UNION ALL + + SELECT + nodes.id, + nodes.pid, + nodes.name, + new_items.path || '/' || nodes.name, + nodes.value + FROM nodes, new_items WHERE new_items.id == nodes.pid +) + +SELECT * FROM new_items; + +CREATE TEMP TABLE temp_diff AS +WITH RECURSIVE diff ( + old_id, + old_pid, + old_name, + old_path, + old_value, + new_id, + new_pid, + new_name, + new_path, + new_value +) AS ( + /* FULL OUTER JOIN is not supported, so we have to use two LEFT JOINs :( */ + SELECT + old.id, + old.pid, + old.name, + old.path, + old.value, + new.id, + new.pid, + new.name, + new.path, + new.value + FROM + temp_old_items AS old + LEFT JOIN + temp_new_items AS new + ON old.path == new.path + + UNION + + SELECT + old.id, + old.pid, + old.name, + old.path, + old.value, + new.id, + new.pid, + new.name, + new.path, + new.value + FROM + temp_new_items AS new + LEFT JOIN + temp_old_items AS old + ON old.path == new.path +) + +SELECT + ( + CASE WHEN old_id IS NULL THEN 'add' ELSE ( + CASE WHEN new_id IS NULL THEN 'delete' ELSE ( + CASE + WHEN old_value != new_value THEN 'modify' ELSE 'unchanged' + END + ) END + ) END + ) AS type, + * +FROM diff +WHERE ( + {with_unchanged} + OR type != 'unchanged' +); diff --git a/src/sqltrie/sqlite/init.sql b/src/sqltrie/sqlite/init.sql new file mode 100644 index 0000000..8701cbd --- /dev/null +++ b/src/sqltrie/sqlite/init.sql @@ -0,0 +1,13 @@ +CREATE TABLE IF NOT EXISTS nodes ( + id integer PRIMARY KEY AUTOINCREMENT, + pid integer, + name text, + has_value boolean, + value blob, + UNIQUE(pid, name), + UNIQUE(id, pid), + CHECK(id != pid) +); +CREATE INDEX IF NOT EXISTS nodes_pid_idx ON nodes (pid); +INSERT OR IGNORE INTO nodes (id, pid, name, has_value, value) +VALUES (1, NULL, "", FALSE, NULL); diff --git a/src/sqltrie/sqlite/items.sql b/src/sqltrie/sqlite/items.sql new file mode 100644 index 0000000..2764f46 --- /dev/null +++ b/src/sqltrie/sqlite/items.sql @@ -0,0 +1,38 @@ +DROP TABLE IF EXISTS temp_items; + +CREATE TEMP TABLE temp_items AS +WITH RECURSIVE children ( + id, pid, name, path, has_value, value, found_value +) AS ( + SELECT + nodes.id, + nodes.pid, + nodes.name, + nodes.name, + nodes.has_value, + nodes.value, + nodes.has_value + FROM nodes WHERE nodes.pid == {root} + + UNION ALL + + SELECT + nodes.id, + nodes.pid, + nodes.name, + children.path || '/' || nodes.name, + nodes.has_value, + nodes.value, + children.found_value OR nodes.has_value + FROM nodes, children + WHERE children.id == nodes.pid AND (NOT {shallow} OR NOT children.found_value OR nodes.has_value) +) + +SELECT + id, + pid, + name, + path, + has_value, + value +FROM children WHERE has_value; diff --git a/src/sqltrie/sqlite/sqlite.py b/src/sqltrie/sqlite/sqlite.py new file mode 100644 index 0000000..031afdb --- /dev/null +++ b/src/sqltrie/sqlite/sqlite.py @@ -0,0 +1,305 @@ +import sqlite3 +import threading +from pathlib import Path +from typing import Iterator, Optional, Tuple +from uuid import uuid4 + +from ..trie import AbstractTrie, Change, ShortKeyError, TrieKey, TrieNode + +# NOTE: seems like "named" doesn't work without changing this global var, +# so unfortunately we have to stick with qmark. +assert sqlite3.paramstyle == "qmark" + +scripts = Path(__file__).parent + +ROOT_ID = 1 +ROOT_NAME = "/" + +INIT_SQL = (scripts / "init.sql").read_text() + +STEPS_SQL = (scripts / "steps.sql").read_text() +STEPS_TABLE = "temp_steps" + +ITEMS_SQL = (scripts / "items.sql").read_text() +ITEMS_TABLE = "temp_items" + +DIFF_SQL = (scripts / "diff.sql").read_text() +DIFF_TABLE = "temp_diff" + +DEFAULT_DB_FMT = "file:sqlitetrie_{id}?mode=memory&cache=shared" + + +class SQLiteTrie(AbstractTrie): + def __init__(self, *args, **kwargs): + self._root_id = ROOT_ID + self._path = DEFAULT_DB_FMT.format(id=uuid4()) + self._local = threading.local() + self._ids = {} + super().__init__(*args, **kwargs) + + @classmethod + def open(cls, path): + trie = cls() + trie._path = path + return trie + + def close(self): + conn = getattr(self._local, "conn", None) + if conn is None: + return + + conn.close() + + try: + delattr(self._local, "conn") + except AttributeError: + pass + + @property + def _conn(self): # pylint: disable=method-hidden + conn = getattr(self._local, "conn", None) + if conn is None: + conn = self._local.conn = sqlite3.connect( + self._path, isolation_level=None + ) + conn.row_factory = sqlite3.Row + conn.executescript(INIT_SQL) + + return conn + + def _create_node(self, key): + try: + return self._ids[key] + except KeyError: + pass + + rows = self._traverse(key) + if rows: + longest_prefix = tuple(rows[-1]["path"].split("/")) + pid = rows[-1]["id"] + else: + longest_prefix = () + pid = self._root_id + self._ids[longest_prefix] = pid + + node_key = longest_prefix + for name in key[len(longest_prefix) :]: + node_key = (*node_key, name) + row = self._conn.execute( + """ + INSERT OR IGNORE + INTO nodes (pid, name) + VALUES (?, ?) + RETURNING id + """, + (pid, name), + ).fetchone() + nid = row["id"] + self._ids[node_key] = nid + pid = nid + + return pid + + def _traverse(self, key): + self._conn.executescript( + STEPS_SQL.format(path="/".join(key), root=self._root_id) + ) + + return self._conn.execute( # nosec + f"SELECT * FROM {STEPS_TABLE}" + ).fetchall() + + def _get_node(self, key): + if not key: + return { + "id": self._root_id, + "pid": None, + "name": None, + "value": None, + } + + rows = list(self._traverse(key)) + if len(rows) != len(key): + raise KeyError(key) + + return rows[-1] + + def _get_children(self, key, limit=None): + node = self._get_node(key) + + limit_sql = "" + if limit: + limit_sql = f"LIMIT {limit}" + + return self._conn.execute( # nosec + f""" + SELECT * FROM nodes WHERE nodes.pid == ? {limit_sql} + """, + (node["id"],), + ).fetchall() + + def _delete_node(self, key): + node = self._get_node(key) + del self._ids[key] + self._conn.execute( + """ + DELETE FROM nodes WHERE id = ? + """, + (node["id"],), + ) + + def __setitem__(self, key, value): + pid = self._create_node(key[:-1]) + self._conn.execute( + """ + INSERT INTO + nodes (pid, name, has_value, value) + VALUES (?1, ?2, True, ?3) + ON CONFLICT (pid, name) DO UPDATE SET value=?3 + """, + ( + pid, + key[-1], + value, + ), + ) + + def __iter__(self): + yield from (key for key, _ in self.items()) + + def __getitem__(self, key): + value = self._get_node(key)["value"] + if not value: + raise ShortKeyError(key) + return value + + def __delitem__(self, key): + node = self._get_node(key) + self._conn.execute( + """ + UPDATE nodes SET has_value = False, value = NULL WHERE id == ? + """, + (node["id"],), + ) + + def __len__(self): + self._conn.executescript( + ITEMS_SQL.format(root=self._root_id, shallow=False) + ) + return self._conn.execute( # nosec + f""" + SELECT COUNT(*) AS count FROM {ITEMS_TABLE} + """ + ).fetchone()["count"] + + def shortest_prefix( + self, key: TrieKey + ) -> Tuple[Optional[TrieKey], Optional[bytes]]: + skey: TrieKey = () + value = None + for row in self._traverse(key): + skey = (*skey, row["name"]) # type: ignore + value = row["value"] + if value is not None: + break + + return skey, value + + def longest_prefix(self, key) -> Tuple[Optional[TrieKey], Optional[bytes]]: + rows = self._traverse(key) + lkey: TrieKey = () + value = None + for idx, row in enumerate(reversed(rows)): + if row["value"] is None: + continue + lkey = tuple( # type: ignore + row["name"] for row in rows[: len(rows) - idx] + ) + value = row["value"] + break + return lkey, value + + def view( # type: ignore + self, + key: Optional[TrieKey] = None, + ) -> "SQLiteTrie": + if not key: + return self + + node = self._get_node(key) + + trie = SQLiteTrie() + trie._path = self._path # pylint: disable=protected-access + trie._root_id = node["id"] # pylint: disable=protected-access + return trie + + def items(self, prefix=None, shallow=False): + if prefix: + pid = self._get_node(prefix)["id"] + else: + prefix = () + pid = self._root_id + + self._conn.executescript(ITEMS_SQL.format(root=pid, shallow=shallow)) + rows = self._conn.execute(f"SELECT * FROM {ITEMS_TABLE}") # nosec + + yield from ( + ((*prefix, *row["path"].split("/")), row["value"]) for row in rows + ) + + def clear(self): + self._conn.execute("DELETE FROM nodes") + + def has_node(self, key: TrieKey) -> bool: + try: + value = self[key] + return value is not None + except KeyError: + return False + + def ls(self, key: TrieKey) -> Iterator[TrieKey]: + yield from ( # type: ignore + (*key, row["name"]) for row in self._get_children(key) + ) + + def traverse(self, node_factory, prefix=None): + key = prefix or () + row = self._get_node(prefix) + value = row["value"] + + children_keys = ( + (*key, row["name"]) for row in self._get_children(key) + ) + children = ( + self.traverse(node_factory, child) for child in children_keys + ) + + return node_factory(None, key, children, value) + + def diff(self, old, new, with_unchanged=False): + old_id = self._get_node(old)["id"] + new_id = self._get_node(new)["id"] + + self._conn.executescript( + DIFF_SQL.format( + old_root=old_id, + new_root=new_id, + with_unchanged=with_unchanged, + ) + ) + + rows = self._conn.execute(f"SELECT * FROM {DIFF_TABLE}") # nosec + yield from ( + Change( + row["type"], + TrieNode( + tuple(row["old_path"].split("/")), + row["old_value"], + ), + TrieNode( + tuple(row["new_path"].split("/")), + row["new_value"], + ), + ) + for row in rows + ) diff --git a/src/sqltrie/sqlite/steps.sql b/src/sqltrie/sqlite/steps.sql new file mode 100644 index 0000000..66fd99e --- /dev/null +++ b/src/sqltrie/sqlite/steps.sql @@ -0,0 +1,78 @@ +DROP TABLE IF EXISTS temp_split; +DROP TABLE IF EXISTS temp_steps; + +CREATE TEMP TABLE temp_split AS +WITH RECURSIVE +path (path) AS ( + VALUES('{path}') +), + +split (depth, name, rpath) AS ( + SELECT + 1, + ( + CASE WHEN instr(path, '/') == 0 THEN path ELSE substr(path, 0, instr(path, '/')) END + ), + ( + CASE WHEN instr(path, '/') == 0 THEN '' ELSE substr(path, instr(path, '/') + 1) END + ) + FROM path + + UNION ALL + + SELECT + split.depth + 1, + ( + CASE WHEN instr(split.rpath, '/') == 0 THEN split.rpath ELSE substr(split.rpath, 0, instr(split.rpath, '/')) END + ), + ( + CASE WHEN instr(split.rpath, '/') == 0 THEN '' ELSE substr(split.rpath, instr(split.rpath, '/') + 1) END + ) + FROM split WHERE split.rpath != '' +) + +SELECT + depth, + name +FROM split; + +CREATE TEMP TABLE temp_steps AS +WITH RECURSIVE +steps (id, pid, name, path, has_value, value, depth) AS ( + SELECT + nodes.id, + nodes.pid, + nodes.name, + nodes.name, + nodes.has_value, + nodes.value, + temp_split.depth + FROM nodes, temp_split + WHERE + temp_split.depth == 1 AND nodes.pid == {root} AND nodes.name == temp_split.name + + UNION ALL + + SELECT + nodes.id, + nodes.pid, + nodes.name, + steps.path || '/' || nodes.name, + nodes.has_value, + nodes.value, + steps.depth + 1 + FROM nodes, steps, temp_split + WHERE + nodes.pid == steps.id + AND temp_split.depth == steps.depth + 1 + AND temp_split.name == nodes.name +) + +SELECT + id, + pid, + name, + path, + has_value, + value +FROM steps; diff --git a/src/sqltrie/trie.py b/src/sqltrie/trie.py index 647ed8e..0c3922a 100644 --- a/src/sqltrie/trie.py +++ b/src/sqltrie/trie.py @@ -1,5 +1,8 @@ -from collections.abc import MutableMapping from abc import abstractmethod +from collections.abc import MutableMapping +from typing import Iterator, NamedTuple, Optional, Tuple, Union + +from attrs import define class ShortKeyError(KeyError): @@ -7,104 +10,88 @@ class ShortKeyError(KeyError): but does not have a value associated with itself.""" -class AbstractTrie(MutableMapping): - def __init__(self, *args, **kwargs): - self.update(*args, **kwargs) - - def enable_sorting(self, enable=True): - raise NotImplementedError - - def clear(self): - raise NotImplementedError - - def update(self, *args, **kwargs): # pylint: disable=arguments-differ - raise NotImplementedError - - def merge(self, other, overwrite=False): - raise NotImplementedError - - def copy(self, __make_copy=lambda x: x): - raise NotImplementedError - - def __copy__(self): - return self.copy() - - def __deepcopy__(self, memo): - return self.copy(lambda x: _copy.deepcopy(x, memo)) - - @classmethod - def fromkeys(cls, keys, value=None): - raise NotImplementedError - - def __iter__(self): - return self.iterkeys() +TrieKey = Union[Tuple[()], Tuple[str]] - def iteritems(self, prefix=None, shallow=False): - raise NotImplementedError - def iterkeys(self, prefix=None, shallow=False): - raise NotImplementedError +class TrieNode(NamedTuple): + key: TrieKey + value: Optional[bytes] - def itervalues(self, prefix=None, shallow=False): - raise NotImplementedError - def items(self, prefix=None, shallow=False): - return list(self.iteritems(prefix=prefix, shallow=shallow)) +ADD = "add" +MODIFY = "modify" +RENAME = "rename" +DELETE = "delete" +UNCHANGED = "unchanged" - def keys(self, prefix=None, shallow=False): - return list(self.iterkeys(prefix=prefix, shallow=shallow)) - def values(self, prefix=None, shallow=False): - return list(self.itervalues(prefix=prefix, shallow=shallow)) +@define(frozen=True, hash=True, order=True) +class Change: + typ: str + old: Optional[TrieNode] + new: Optional[TrieNode] - def __len__(self): - raise NotImplementedError + @property + def key(self) -> TrieKey: + if self.typ == RENAME: + raise ValueError - def __bool__(self): - raise NotImplementedError + if self.typ == ADD: + entry = self.new + else: + entry = self.old - __nonzero__ = __bool__ - __hash__ = None + assert entry + assert entry.key + return entry.key - def has_node(self, key): - raise NotImplementedError + def __bool__(self) -> bool: + return self.typ != UNCHANGED - def has_key(self, key): - return bool(self.has_node(key) & self.HAS_VALUE) - def has_subtrie(self, key): - return bool(self.has_node(key) & self.HAS_SUBTRIE) - - def __getitem__(self, key_or_slice): - raise NotImplementedError - - def __setitem__(self, key_or_slice, value): - raise NotImplementedError - - def __delitem__(self, key_or_slice): - raise NotImplementedError - - def setdefault(self, key, default=None): - raise NotImplementedError +class AbstractTrie(MutableMapping): + def __init__(self, *args, **kwargs): + self.update(*args, **kwargs) - def pop(self, key, default=None): - raise NotImplementedError + @abstractmethod + def items( # type: ignore + self, prefix: Optional[TrieKey] = None, shallow: Optional[bool] = False + ) -> Iterator[Tuple[TrieKey, bytes]]: + pass - def popitem(self): - raise NotImplementedError + @abstractmethod + def view(self, key: Optional[TrieKey] = None) -> "AbstractTrie": + pass - def walk_towards(self, key): - raise NotImplementedError + @abstractmethod + def has_node(self, key: TrieKey) -> bool: + pass - def prefixes(self, key): - raise NotImplementedError + @abstractmethod + def shortest_prefix( + self, key: TrieKey + ) -> Tuple[Optional[TrieKey], Optional[bytes]]: + pass - def shortest_prefix(self, key): - raise NotImplementedError + @abstractmethod + def longest_prefix( + self, key: TrieKey + ) -> Tuple[Optional[TrieKey], Optional[bytes]]: + pass - def longest_prefix(self, key): - raise NotImplementedError + @abstractmethod + # pylint: disable-next=invalid-name + def ls(self, key: TrieKey) -> Iterator[TrieKey]: + pass - def traverse(self, node_factory, prefix=None): + @abstractmethod + def traverse( + self, node_factory, prefix: Optional[TrieKey] + ) -> Iterator[Tuple[TrieKey, bytes]]: pass + @abstractmethod + def diff( + self, old: TrieKey, new: TrieKey, with_unchanged: bool = False + ) -> Iterator[Change]: + pass diff --git a/tests/benchmarks/test_sqltrie.py b/tests/benchmarks/test_sqltrie.py new file mode 100644 index 0000000..c26505c --- /dev/null +++ b/tests/benchmarks/test_sqltrie.py @@ -0,0 +1,38 @@ +import pytest + +from sqltrie import SQLiteTrie + + +@pytest.fixture(scope="session") +def items(): + ret = {} + + files = {str(idx): bytes(idx) for idx in range(10000)} + for subdir in ["foo", "bar", "baz"]: + ret[subdir] = files.copy() + + return ret + + +def test_set(benchmark, items): + def _set(): + trie = SQLiteTrie() + + for subdir in ["foo", "bar", "baz"]: + for idx in range(10000): + trie[(subdir, str(idx))] = bytes(idx) + + benchmark(_set) + + +def test_items(benchmark, items): + trie = SQLiteTrie() + + for subdir in ["foo", "bar", "baz"]: + for idx in range(10000): + trie[(subdir, str(idx))] = bytes(idx) + + def _items(): + list(trie.items()) + + benchmark(_items) diff --git a/tests/test_sqltrie.py b/tests/test_sqltrie.py index a8302da..5be4f01 100644 --- a/tests/test_sqltrie.py +++ b/tests/test_sqltrie.py @@ -1,24 +1,78 @@ """Tests for `sqltrie` package.""" import pytest -from sqltrie import SQLiteTrie, ShortKeyError +from sqltrie import UNCHANGED, Change, ShortKeyError, SQLiteTrie, TrieNode + def test_trie(): trie = SQLiteTrie() - trie[("foo",)] = "foo-value" - trie[("foo", "bar", "baz")] = "baz-value" + trie[("foo",)] = b"foo-value" + trie[("foo", "bar", "baz")] = b"baz-value" assert len(trie) == 2 - assert trie[("foo",)] == "foo-value" - assert trie[("foo", "bar")] == None - assert trie[("foo", "bar", "baz")] == "baz-value" + assert trie[("foo",)] == b"foo-value" + with pytest.raises(ShortKeyError): + trie[("foo", "bar")] # pylint: disable=pointless-statement + assert trie[("foo", "bar", "baz")] == b"baz-value" del trie[("foo",)] assert len(trie) == 1 - # FIXME the next two should raise ShortKeyError - assert trie[("foo",)] == None - assert trie[("foo", "bar")] == None - assert trie[("foo", "bar", "baz")] == "baz-value" + assert trie[("foo", "bar", "baz")] == b"baz-value" + + with pytest.raises(ShortKeyError): + trie[("foo",)] # pylint: disable=pointless-statement + + with pytest.raises(ShortKeyError): + trie[("foo", "bar")] # pylint: disable=pointless-statement + + with pytest.raises(KeyError): + trie[("non-existent",)] # pylint: disable=pointless-statement + + with pytest.raises(KeyError): + trie[("foo", "non-existent")] # pylint: disable=pointless-statement + + assert trie.longest_prefix(("non-existent",)) == ((), None) + assert trie.longest_prefix(("foo",)) == ((), None) + assert trie.longest_prefix(("foo", "non-existent")) == ((), None) + assert trie.longest_prefix(("foo", "bar", "baz", "qux")) == ( + ("foo", "bar", "baz"), + b"baz-value", + ) + + assert set(trie.items()) == { + (("foo", "bar", "baz"), b"baz-value"), + } + assert set(trie.items(shallow=True)) == { + (("foo", "bar", "baz"), b"baz-value"), + } + assert set(trie.items(("foo",))) == { + (("foo", "bar", "baz"), b"baz-value"), + } + assert set(trie.items(("foo", "bar"))) == { + (("foo", "bar", "baz"), b"baz-value"), + } + assert set(trie.items(("foo", "bar", "baz"))) == set() + + assert set(trie.view(("foo",)).items()) == { + (("bar", "baz"), b"baz-value"), + } + assert set(trie.view(("foo", "bar", "baz")).items()) == set() + + assert list(trie.ls(())) == [("foo",)] + assert list(trie.ls(("foo",))) == [("foo", "bar")] + assert list(trie.ls(("foo", "bar"))) == [("foo", "bar", "baz")] - assert set(trie.iteritems()) == set() + assert not list(trie.diff(("foo",), ("foo",))) + assert list(trie.diff(("foo",), ("foo",), with_unchanged=True)) == [ + Change( + typ=UNCHANGED, + old=TrieNode(key=("bar",), value=None), + new=TrieNode(key=("bar",), value=None), + ), + Change( + typ=UNCHANGED, + old=TrieNode(key=("bar", "baz"), value=b"baz-value"), + new=TrieNode(key=("bar", "baz"), value=b"baz-value"), + ), + ]