diff --git a/.gitignore b/.gitignore index 55b4fed..36195c1 100644 --- a/.gitignore +++ b/.gitignore @@ -7,5 +7,5 @@ # Ignore Goland and JetBrains IDE files .idea/ -# Visual Studio Code +# Ignore vscode files .vscode diff --git a/Makefile b/Makefile index 6d5ab99..c9d4774 100644 --- a/Makefile +++ b/Makefile @@ -34,11 +34,11 @@ check_godoc: .PHONY: test_all test_all: ## runs the test suite - go test -v -p 1 ./... -mod=readonly -race + go test -v -p 1 -count=1 ./... -mod=readonly -race .PHONY: test_badger test_badger: ## runs the badger KVStore submodule's test suite - go test -v -p 1 ./kvstore/badger/... -mod=readonly -race + go test -v -p 1 -count=1 ./kvstore/badger/... -mod=readonly -race ##################### diff --git a/bulk_test.go b/bulk_test.go index 0f378d1..a0ab126 100644 --- a/bulk_test.go +++ b/bulk_test.go @@ -85,7 +85,7 @@ func bulkOperations(t *testing.T, operations int, insert int, update int, delete if err != nil && err != ErrKeyNotFound { t.Fatalf("error: %v", err) } - kv[ki].val = defaultValue + kv[ki].val = defaultEmptyValue } } diff --git a/docs/faq.md b/docs/faq.md new file mode 100644 index 0000000..7dcd4c4 --- /dev/null +++ b/docs/faq.md @@ -0,0 +1,25 @@ +# FAQ + +- [History](#history) + - [Fork](#fork) +- [Implementation](#implementation) + - [What's the story behind Extension Node Implementation?](#whats-the-story-behind-extension-node-implementation) + +This documentation is meant to capture common questions that come up and act +as a supplement or secondary reference to the primary documentation. + +## History + +### Fork + +This library was originally forked off of [celestiaorg/smt](https://github.com/celestiaorg/smt) +which was archived on Feb 27th, 2023. + +## Implementation + +### What's the story behind Extension Node Implementation? + +The [SMT extension node](./smt.md#extension-nodes) is very similar to that of +Ethereum's [Modified Merkle Patricia Trie](https://ethereum.org/developers/docs/data-structures-and-encoding/patricia-merkle-trie). + +A quick primer on it can be found in this [5P;1R post](https://olshansky.substack.com/p/5p1r-ethereums-modified-merkle-patricia). diff --git a/docs/merkle-sum-trie.md b/docs/merkle-sum-trie.md index 299c0e6..ee881db 100644 --- a/docs/merkle-sum-trie.md +++ b/docs/merkle-sum-trie.md @@ -2,16 +2,17 @@ -- [Overview](#overview) -- [Implementation](#implementation) - * [Sum Encoding](#sum-encoding) - * [Digests](#digests) - * [Visualisations](#visualisations) - + [General Trie Structure](#general-trie-structure) - + [Binary Sum Digests](#binary-sum-digests) -- [Sum](#sum) -- [Roots](#roots) -- [Nil Values](#nil-values) +- [Sparse Merkle Sum Trie (smst)](#sparse-merkle-sum-trie-smst) + - [Overview](#overview) + - [Implementation](#implementation) + - [Sum Encoding](#sum-encoding) + - [Digests](#digests) + - [Visualizations](#visualizations) + - [General Trie Structure](#general-trie-structure) + - [Binary Sum Digests](#binary-sum-digests) + - [Sum](#sum) + - [Roots](#roots) + - [Nil Values](#nil-values) @@ -64,34 +65,34 @@ The golang `encoding/binary` package is used to encode the sum with `binary.BigEndian.PutUint64(sumBz[:], sum)` into a byte array `sumBz`. In order for the SMST to include the sum into a leaf node the SMT the SMST -initialises the SMT with the `WithValueHasher(nil)` option so that the SMT does +initializes the SMT with the `WithValueHasher(nil)` option so that the SMT does **not** hash any values. The SMST will then hash the value and append the sum bytes to the end of the hashed value, using whatever `ValueHasher` was given to -the SMST on initialisation. +the SMST on initialization. ```mermaid graph TD - subgraph KVS[Key-Value-Sum] - K1["Key: foo"] - K2["Value: bar"] - K3["Sum: 10"] - end - subgraph SMST[SMST] - SS1[ValueHasher: SHA256] - subgraph SUM["SMST.Update()"] - SU1["valueHash = ValueHasher(Value)"] - SU2["sumBytes = binary(Sum)"] - SU3["valueHash = append(valueHash, sumBytes...)"] - end - end - subgraph SMT[SMT] - SM1[ValueHasher: nil] - subgraph UPD["SMT.Update()"] - U2["SMT.nodeStore.Set(Key, valueHash)"] - end - end - KVS --"Key + Value + Sum"--> SMST - SMST --"Key + valueHash"--> SMT + subgraph KVS[Key-Value-Sum] + K1["Key: foo"] + K2["Value: bar"] + K3["Sum: 10"] + end + subgraph SMST[SMST] + SS1[ValueHasher: SHA256] + subgraph SUM["SMST.Update()"] + SU1["valueHash = ValueHasher(Value)"] + SU2["sumBytes = binary(Sum)"] + SU3["valueHash = append(valueHash, sumBytes...)"] + end + end + subgraph SMT[SMT] + SM1[ValueHasher: nil] + subgraph UPD["SMT.Update()"] + U2["SMT.nodeStore.Set(Key, valueHash)"] + end + end + KVS --"Key + Value + Sum"--> SMST + SMST --"Key + valueHash"--> SMT ``` ### Digests @@ -128,10 +129,10 @@ Therefore for the following node types, the digests are computed as follows: This means that with a hasher such as `sha256.New()` whose hash size is `32 bytes`, the digest of any node will be `40 bytes` in length. -### Visualisations +### Visualizations The following diagrams are representations of how the trie and its components -can be visualised. +can be visualized. #### General Trie Structure @@ -142,45 +143,45 @@ nodes as an extra field. ```mermaid graph TB - subgraph Root - A1["Digest: Hash(Hash(Path+H1)+Hash(H2+(Hash(H3+H4)))+Binary(20))+Binary(20)"] + subgraph Root + A1["Digest: Hash(Hash(Path+H1)+Hash(H2+(Hash(H3+H4)))+Binary(20))+Binary(20)"] A2[Sum: 20] - end - subgraph BI[Inner Node] - B1["Digest: Hash(H2+(Hash(H3+H4))+Binary(12))+Binary(12)"] + end + subgraph BI[Inner Node] + B1["Digest: Hash(H2+(Hash(H3+H4))+Binary(12))+Binary(12)"] B2[Sum: 12] - end - subgraph BE[Extension Node] - B3["Digest: Hash(Path+H1+Binary(8))+Binary(8)"] + end + subgraph BE[Extension Node] + B3["Digest: Hash(Path+H1+Binary(8))+Binary(8)"] B4[Sum: 8] - end - subgraph CI[Inner Node] - C1["Digest: Hash(H3+H4+Binary(7))+Binary(7)"] + end + subgraph CI[Inner Node] + C1["Digest: Hash(H3+H4+Binary(7))+Binary(7)"] C2[Sum: 7] - end - subgraph CL[Leaf Node] - C3[Digest: H2] + end + subgraph CL[Leaf Node] + C3[Digest: H2] C4[Sum: 5] - end - subgraph DL1[Leaf Node] - D1[Digest: H3] + end + subgraph DL1[Leaf Node] + D1[Digest: H3] D2[Sum: 4] - end - subgraph DL2[Leaf Node] - D3[Digest: H4] + end + subgraph DL2[Leaf Node] + D3[Digest: H4] D4[Sum: 3] - end - subgraph EL[Leaf Node] - E1[Digest: H1] + end + subgraph EL[Leaf Node] + E1[Digest: H1] E2[Sum: 8] - end - Root-->|0| BE - Root-->|1| BI - BI-->|0| CL - BI-->|1| CI - CI-->|0| DL1 - CI-->|1| DL2 - BE-->EL + end + Root-->|0| BE + Root-->|1| BI + BI-->|0| CL + BI-->|1| CI + CI-->|0| DL1 + CI-->|1| DL2 + BE-->EL ``` #### Binary Sum Digests @@ -192,56 +193,56 @@ exception of the leaf nodes where the sum is shown as part of its value. ```mermaid graph TB - subgraph RI[Inner Node] - RIA["Root Hash: Hash(D6+D7+Binary(18))+Binary(18)"] + subgraph RI[Inner Node] + RIA["Root Hash: Hash(D6+D7+Binary(18))+Binary(18)"] RIB[Sum: 15] - end - subgraph I1[Inner Node] - I1A["D7: Hash(D1+D5+Binary(11))+Binary(11)"] + end + subgraph I1[Inner Node] + I1A["D7: Hash(D1+D5+Binary(11))+Binary(11)"] I1B[Sum: 11] - end - subgraph I2[Inner Node] - I2A["D6: Hash(D3+D4+Binary(7))+Binary(7)"] + end + subgraph I2[Inner Node] + I2A["D6: Hash(D3+D4+Binary(7))+Binary(7)"] I2B[Sum: 7] - end - subgraph L1[Leaf Node] - L1A[Path: 0b0010000] - L1B["Value: 0x01+Binary(6)"] + end + subgraph L1[Leaf Node] + L1A[Path: 0b0010000] + L1B["Value: 0x01+Binary(6)"] L1C["H1: Hash(Path+Value+Binary(6))"] L1D["D1: H1+Binary(6)"] - end - subgraph L3[Leaf Node] - L3A[Path: 0b1010000] - L3B["Value: 0x03+Binary(3)"] + end + subgraph L3[Leaf Node] + L3A[Path: 0b1010000] + L3B["Value: 0x03+Binary(3)"] L3C["H3: Hash(Path+Value+Binary(3))"] L3D["D3: H3+Binary(3)"] - end - subgraph L4[Leaf Node] - L4A[Path: 0b1100000] - L4B["Value: 0x04+Binary(4)"] + end + subgraph L4[Leaf Node] + L4A[Path: 0b1100000] + L4B["Value: 0x04+Binary(4)"] L4C["H4: Hash(Path+Value+Binary(4))"] L4D["D4: H4+Binary(4)"] - end - subgraph E1[Extension Node] - E1A[Path: 0b01100101] - E1B["Path Bounds: [2, 6)"] + end + subgraph E1[Extension Node] + E1A[Path: 0b01100101] + E1B["Path Bounds: [2, 6)"] E1C[Sum: 5] E1D["H5: Hash(Path+PathBounds+D2+Binary(5))"] E1E["D5: H5+Binary(5)"] - end - subgraph L2[Leaf Node] - L2A[Path: 0b01100101] - L2B["Value: 0x02+Binary(5)"] + end + subgraph L2[Leaf Node] + L2A[Path: 0b01100101] + L2B["Value: 0x02+Binary(5)"] L2C["H2: Hash(Path+Value+Hex(5))+Binary(5)"] L2D["D2: H2+Binary(5)"] - end - RI -->|0| I1 - RI -->|1| I2 - I1 -->|0| L1 - I1 -->|1| E1 - E1 --> L2 - I2 -->|0| L3 - I2 -->|1| L4 + end + RI -->|0| I1 + RI -->|1| I2 + I1 -->|0| L1 + I1 -->|1| E1 + E1 --> L2 + I2 -->|0| L3 + I2 -->|1| L4 ``` ## Sum diff --git a/docs/smt.md b/docs/smt.md index b7bdfad..af421fd 100644 --- a/docs/smt.md +++ b/docs/smt.md @@ -2,16 +2,16 @@ - [Overview](#overview) - [Implementation](#implementation) + - [Leaf Nodes](#leaf-nodes) - [Inner Nodes](#inner-nodes) - [Extension Nodes](#extension-nodes) - - [Leaf Nodes](#leaf-nodes) - [Lazy Nodes](#lazy-nodes) - [Lazy Loading](#lazy-loading) - - [Visualisations](#visualisations) + - [Visualizations](#visualizations) - [General Trie Structure](#general-trie-structure) - [Lazy Nodes](#lazy-nodes-1) - [Paths](#paths) - - [Visualisation](#visualisation) + - [Visualization](#visualization) - [Values](#values) - [Nil values](#nil-values) - [Hashers \& Digests](#hashers--digests) @@ -47,44 +47,54 @@ See [smt.go](../smt.go) for more details on the implementation. The SMT has 4 node types that are used to construct the trie: -- Inner Nodes - - Prefixed `[]byte{1}` - - `digest = hash([]byte{1} + leftChild.digest + rightChild.digest)` -- Extension Nodes - - Prefixed `[]byte{2}` - - `digest = hash([]byte{2} + pathBounds + path + child.digest)` -- Leaf Nodes - - Prefixed `[]byte{0}` - - `digest = hash([]byte{0} + path + value)` -- Lazy Nodes - - Prefix of the actual node type is stored in the persisted digest as +- [Inner Nodes](#inner-nodes) +- [Extension Nodes](#extension-nodes) +- [Leaf Nodes](#leaf-nodes) +- [Lazy Nodes](#lazy-nodes) + - Prefix of the actual node type is stored in the persisted preimage as determined above - `digest = persistedDigest` +### Leaf Nodes + +Leaf nodes store the full path associated with the `key`. A leaf node also +store the hash of the `value` stored. + +The `digest` of a leaf node is the hash of concatenation of the leaf node's +prefix, path and value. + +By default, the SMT only stores the hashes of the values in the trie, and not the +raw values themselves. In order to store the raw values in the underlying database, +the option `WithValueHasher(nil)` must be passed into the `NewSparseMerkleTrie` +constructor. + +- _Prefix_: `[]byte{0}` +- _Digest_: `hash([]byte{0} + path + value)` + ### Inner Nodes Inner nodes represent a branch in the trie with two **non-nil** child nodes. The inner node has an internal `digest` which represents the hash of the child nodes concatenated hashes. +- _Prefix_: `[]byte{1}` +- _Digest_: `hash([]byte{1} + leftChild.digest + rightChild.digest)` + ### Extension Nodes Extension nodes represent a singly linked chain of inner nodes, with a single -child. They are used to represent a common path in the trie and as such contain -the path and bounds of the path they represent. The `digest` of an extension -node is the hash of its path bounds, the path itself and the child nodes digest -concatenated. +child. In other words, they are an optimization to avoid having a long chain of +inner nodes where each inner node only has one child. -### Leaf Nodes +They are used to represent a common path in the trie and as such contain the path +and bounds of the path they represent. -Leaf nodes store the full path which they represent and also the hash of the -value they store. The `digest` of a leaf node is the hash of the leaf nodes path -and value concatenated. +The `digest` of an extension node is the hash of its path bounds, the path itself +and the child node digest. Note that an extension node can only have exactly one +child node. -The SMT stores only the hashes of the values in the trie, not the raw values -themselves. In order to store the raw values in the underlying database the -option `WithValueHasher(nil)` must be passed into the `NewSparseMerkleTrie` -constructor. +- _Prefix_: `[]byte{2}` +- _Digest_: `hash([]byte{2} + pathBounds + path + child.digest)` ### Lazy Nodes @@ -111,7 +121,7 @@ Once the `Commit()` function is called the trie will delete any orphaned nodes from the database and write the key-value pairs of all the unpersisted leaf nodes' hashes and their values to the database. -### Visualisations +### Visualizations The following diagrams are representations of how the trie and its components can be visualised. @@ -197,18 +207,20 @@ Where `Hash(Hash1 + Hash2)` is the same root hash as the previous example. ## Paths -Paths are **only** stored in two types of nodes: Leaf nodes and Extension nodes. +Paths are **only** stored in two types of nodes: `Leaf` nodes and `Extension` nodes. -- Extension nodes contain not only the path they represent but also the path - bounds (ie. the start and end of the path they cover). -- Leaf nodes contain the full path which they represent, as well as the value - stored at that path. +- `Leaf` nodes contain: + - The full path which it represent + - The (hashed) value stored at that path +- `Extension` nodes contain: + - not only the path they represent but also the path + bounds (ie. the start and end of the path that they cover). Inner nodes do **not** contain a path, as they represent a branch in the trie and not a path. As such their children, _if they are extension nodes or leaf nodes_, will hold a path value. -### Visualisation +### Visualization The following diagram shows how paths are stored in the different nodes of the trie. In the actual SMT paths are not 8 bit binary strings but are instead the diff --git a/extension_node.go b/extension_node.go new file mode 100644 index 0000000..fd2dc03 --- /dev/null +++ b/extension_node.go @@ -0,0 +1,150 @@ +package smt + +// Ensure extensionNode satisfies the trieNode interface +var _ trieNode = (*extensionNode)(nil) + +// A compressed chain of singly-linked inner nodes. +// +// Extension nodes are used to captures a series of inner nodes that only +// have one child in a succinct `pathBounds` for optimization purposes. +// +// TODO_TECHDEBT(@Olshansk): Does this assumption still hold? +// +// Assumption: the path is <=256 bits +type extensionNode struct { + // The path (starting at the root) to this extension node. + path []byte + // The path (starting at pathBounds[0] and ending at pathBounds[1]) of + // inner nodes that this single extension node replaces. + pathBounds [2]byte + // A child node from this extension node. + // It will always be an innerNode, leafNode or lazyNode. + child trieNode + // Bool whether or not the node has been flushed to disk + persisted bool + // The cached digest of the node trie + digest []byte +} + +// Persisted satisfied the trieNode#Persisted interface +func (node *extensionNode) Persisted() bool { + return node.persisted +} + +// Persisted satisfied the trieNode#CachedDigest interface +func (node *extensionNode) CachedDigest() []byte { + return node.digest +} + +// Length returns the length of the path segment represented by this single +// extensionNode. Since the SMT is a binary trie, the length represents both +// the depth and the number of nodes replaced by a single extension node. If +// this SMT were to have k-ary support, the depth would be strictly less than +// the number of nodes replaced. +func (ext *extensionNode) length() int { + return ext.pathEnd() - ext.pathStart() +} + +func (ext *extensionNode) pathStart() int { + return int(ext.pathBounds[0]) +} + +func (ext *extensionNode) pathEnd() int { + return int(ext.pathBounds[1]) +} + +// setDirty marks the node as dirty (i.e. not flushed to disk) and clears +// its digest +func (ext *extensionNode) setDirty() { + ext.persisted = false + ext.digest = nil +} + +// boundsMatch returns the length of the matching prefix between `ext.pathBounds` +// and `path` starting at index `depth`, along with a bool if a full match is found. +func (extNode *extensionNode) boundsMatch(path []byte, depth int) (int, bool) { + if depth != extNode.pathStart() { + panic("depth != extNode.pathStart") + } + for pathIdx := extNode.pathStart(); pathIdx < extNode.pathEnd(); pathIdx++ { + if getPathBit(extNode.path, pathIdx) != getPathBit(path, pathIdx) { + return pathIdx - extNode.pathStart(), false + } + } + return extNode.length(), true +} + +// split splits the node in-place by returning a new extensionNode and a child +// node at the split and split depth. +func (extNode *extensionNode) split(path []byte) (trieNode, *trieNode, int) { + // Start path to extNode.pathBounds until there is no match + var extNodeBit, pathBit int + pathIdx := extNode.pathStart() + for ; pathIdx < extNode.pathEnd(); pathIdx++ { + extNodeBit = getPathBit(extNode.path, pathIdx) + pathBit = getPathBit(path, pathIdx) + if extNodeBit != pathBit { + break + } + } + // Return the extension node's child if path fully matches extNode.pathBounds + if pathIdx == extNode.pathEnd() { + return extNode, &extNode.child, pathIdx + } + + child := extNode.child + var branch innerNode + var head trieNode + var tail *trieNode + if extNodeBit == leftChildBit { + tail = &branch.leftChild + } else { + tail = &branch.rightChild + } + + // Split at first bit: chain starts with new node + if pathIdx == extNode.pathStart() { + head = &branch + extNode.pathBounds[0]++ // Shrink the extension from front + if extNode.length() == 0 { + *tail = child + } else { + *tail = extNode + } + } else { + // Split inside: chain ends at index + head = extNode + extNode.child = &branch + if pathIdx == extNode.pathEnd()-1 { + *tail = child + } else { + *tail = &extensionNode{ + path: extNode.path, + pathBounds: [2]byte{ + byte(pathIdx + 1), + extNode.pathBounds[1], + }, + child: child, + } + } + extNode.pathBounds[1] = byte(pathIdx) + } + var b trieNode = &branch + return head, &b, pathIdx +} + +// expand returns the inner node that represents the start of the singly +// linked list that this extension node represents +func (extNode *extensionNode) expand() trieNode { + last := extNode.child + for i := extNode.pathEnd() - 1; i >= extNode.pathStart(); i-- { + var next innerNode + if getPathBit(extNode.path, i) == leftChildBit { + next.leftChild = last + } else { + next.rightChild = last + } + last = &next + } + return last +} diff --git a/godoc.go b/godoc.go index 7354c89..67fc078 100644 --- a/godoc.go +++ b/godoc.go @@ -1,11 +1,13 @@ // Package smt provides an implementation of a Sparse Merkle Trie for a -// key-value map. +// key-value map or engine. // // The trie implements the same optimizations specified in the JMT -// whitepaper to account for empty and single-node subtrees. Unlike the -// JMT, it only supports binary trees and does not optimise for RockDB -// on-disk storage. +// whitepaper to account for empty and single-node subtrees. + +// Unlike the JMT, it only supports binary trees and does not implemented the +// same RocksDB optimizations as specified in the original JMT library when +// optimizing for disk iops // -// This package implements novel features that include native in-node -// weight sums, as well as support for ClosestProof mechanics. +// This package implements additional SMT specific functionality related to +// tree sums and closest proof mechanics. package smt diff --git a/hasher.go b/hasher.go index c6f452c..3676a49 100644 --- a/hasher.go +++ b/hasher.go @@ -1,19 +1,17 @@ package smt import ( - "bytes" "encoding/binary" "hash" ) -var ( - leafPrefix = []byte{0} - innerPrefix = []byte{1} - extPrefix = []byte{2} -) +// TODO_IMPROVE:: Improve how the `hasher` file is consolidated with +// `node_encoders.go` since the two are very similar. +// Ensure the hasher interfaces are satisfied var ( _ PathHasher = (*pathHasher)(nil) + _ PathHasher = (*nilPathHasher)(nil) _ ValueHasher = (*valueHasher)(nil) ) @@ -33,26 +31,43 @@ type ValueHasher interface { ValueHashSize() int } +// trieHasher is a common hasher for all trie hashers (paths & values). type trieHasher struct { hasher hash.Hash zeroValue []byte } + +// pathHasher is a hasher for trie paths. type pathHasher struct { trieHasher } + +// valueHasher is a hasher for leaf values. type valueHasher struct { trieHasher } -func newTrieHasher(hasher hash.Hash) *trieHasher { +// nilPathHasher is a dummy hasher that returns its input - it should not be used outside of the closest proof verification logic +type nilPathHasher struct { + hashSize int +} + +// NewTrieHasher returns a new trie hasher with the given hash function. +func NewTrieHasher(hasher hash.Hash) *trieHasher { th := trieHasher{hasher: hasher} th.zeroValue = make([]byte, th.hashSize()) return &th } +// newNilPathHasher returns a new nil path hasher with the given hash size. +// It is not exported as the validation logic for the ClosestProof automatically handles this case. +func newNilPathHasher(hasherSize int) PathHasher { + return &nilPathHasher{hashSize: hasherSize} +} + // Path returns the digest of a key produced by the path hasher func (ph *pathHasher) Path(key []byte) []byte { - return ph.digest(key)[:ph.PathSize()] + return ph.digestData(key)[:ph.PathSize()] } // PathSize returns the length (in bytes) of digests produced by the path hasher @@ -63,7 +78,7 @@ func (ph *pathHasher) PathSize() int { // HashValue hashes the produces a digest of the data provided by the value hasher func (vh *valueHasher) HashValue(data []byte) []byte { - return vh.digest(data) + return vh.digestData(data) } // ValueHashSize returns the length (in bytes) of digests produced by the value hasher @@ -74,44 +89,73 @@ func (vh *valueHasher) ValueHashSize() int { return vh.hasher.Size() } -func (th *trieHasher) digest(data []byte) []byte { +// Path satisfies the PathHasher#Path interface +func (n *nilPathHasher) Path(key []byte) []byte { + return key[:n.hashSize] +} + +// PathSize satisfies the PathHasher#PathSize interface +func (n *nilPathHasher) PathSize() int { + return n.hashSize +} + +// digestData returns the hash of the data provided using the trie hasher. +func (th *trieHasher) digestData(data []byte) []byte { th.hasher.Write(data) - sum := th.hasher.Sum(nil) + digest := th.hasher.Sum(nil) th.hasher.Reset() - return sum + return digest } -func (th *trieHasher) digestLeaf(path []byte, leafData []byte) ([]byte, []byte) { - value := encodeLeaf(path, leafData) - return th.digest(value), value +// digestLeafNode returns the encoded leaf data as well as its hash (i.e. digest) +func (th *trieHasher) digestLeafNode(path, data []byte) (digest, value []byte) { + value = encodeLeafNode(path, data) + digest = th.digestData(value) + return } -func (th *trieHasher) digestSumLeaf(path []byte, leafData []byte) ([]byte, []byte) { - value := encodeLeaf(path, leafData) - digest := th.digest(value) - digest = append(digest, value[len(value)-sumSize:]...) - return digest, value +// digestInnerNode returns the encoded inner node data as well as its hash (i.e. digest) +func (th *trieHasher) digestInnerNode(leftData, rightData []byte) (digest, value []byte) { + value = encodeInnerNode(leftData, rightData) + digest = th.digestData(value) + return } -func (th *trieHasher) digestNode(leftData []byte, rightData []byte) ([]byte, []byte) { - value := encodeInner(leftData, rightData) - return th.digest(value), value +// digestSumNode returns the encoded leaf node data as well as its hash (i.e. digest) +func (th *trieHasher) digestSumLeafNode(path, data []byte) (digest, value []byte) { + value = encodeLeafNode(path, data) + digest = th.digestData(value) + digest = append(digest, value[len(value)-sumSizeBytes:]...) + return } -func (th *trieHasher) digestSumNode(leftData []byte, rightData []byte) ([]byte, []byte) { - value := encodeSumInner(leftData, rightData) - digest := th.digest(value) - digest = append(digest, value[len(value)-sumSize:]...) - return digest, value +// digestSumInnerNode returns the encoded inner node data as well as its hash (i.e. digest) +func (th *trieHasher) digestSumInnerNode(leftData, rightData []byte) (digest, value []byte) { + value = encodeSumInnerNode(leftData, rightData) + digest = th.digestData(value) + digest = append(digest, value[len(value)-sumSizeBytes:]...) + return } -func (th *trieHasher) parseNode(data []byte) ([]byte, []byte) { - return data[len(innerPrefix) : th.hashSize()+len(innerPrefix)], data[len(innerPrefix)+th.hashSize():] +// parseInnerNode returns the encoded left and right nodes +func (th *trieHasher) parseInnerNode(data []byte) (leftData, rightData []byte) { + leftData = data[len(innerNodePrefix) : th.hashSize()+len(innerNodePrefix)] + rightData = data[len(innerNodePrefix)+th.hashSize():] + return } -func (th *trieHasher) parseSumNode(data []byte) ([]byte, []byte) { - sumless := data[:len(data)-sumSize] - return sumless[len(innerPrefix) : th.hashSize()+sumSize+len(innerPrefix)], sumless[len(innerPrefix)+th.hashSize()+sumSize:] +// parseSumInnerNode returns the encoded left and right nodes as well as the sum of the current node +func (th *trieHasher) parseSumInnerNode(data []byte) (leftData, rightData []byte, sum uint64) { + // Extract the sum from the encoded node data + var sumBz [sumSizeBytes]byte + copy(sumBz[:], data[len(data)-sumSizeBytes:]) + binary.BigEndian.PutUint64(sumBz[:], sum) + + // Extract the left and right children + dataWithoutSum := data[:len(data)-sumSizeBytes] + leftData = dataWithoutSum[len(innerNodePrefix) : len(innerNodePrefix)+th.hashSize()+sumSizeBytes] + rightData = dataWithoutSum[len(innerNodePrefix)+th.hashSize()+sumSizeBytes:] + return } func (th *trieHasher) hashSize() int { @@ -121,90 +165,3 @@ func (th *trieHasher) hashSize() int { func (th *trieHasher) placeholder() []byte { return th.zeroValue } - -func isLeaf(data []byte) bool { - return bytes.Equal(data[:len(leafPrefix)], leafPrefix) -} - -func isExtension(data []byte) bool { - return bytes.Equal(data[:len(extPrefix)], extPrefix) -} - -func parseLeaf(data []byte, ph PathHasher) ([]byte, []byte) { - return data[len(leafPrefix) : ph.PathSize()+len(leafPrefix)], data[len(leafPrefix)+ph.PathSize():] -} - -func parseExtension(data []byte, ph PathHasher) (pathBounds, path, childData []byte) { - return data[len(extPrefix) : len(extPrefix)+2], // +2 represents the length of the pathBounds - data[len(extPrefix)+2 : len(extPrefix)+2+ph.PathSize()], - data[len(extPrefix)+2+ph.PathSize():] -} - -func parseSumExtension(data []byte, ph PathHasher) (pathBounds, path, childData []byte, sum [sumSize]byte) { - var sumBz [sumSize]byte - copy(sumBz[:], data[len(data)-sumSize:]) - return data[len(extPrefix) : len(extPrefix)+2], // +2 represents the length of the pathBounds - data[len(extPrefix)+2 : len(extPrefix)+2+ph.PathSize()], - data[len(extPrefix)+2+ph.PathSize() : len(data)-sumSize], - sumBz -} - -// encodeLeaf encodes both normal and sum leaves as in the sum leaf the -// sum is appended to the end of the valueHash -func encodeLeaf(path []byte, leafData []byte) []byte { - value := make([]byte, 0, len(leafPrefix)+len(path)+len(leafData)) - value = append(value, leafPrefix...) - value = append(value, path...) - value = append(value, leafData...) - return value -} - -func encodeInner(leftData []byte, rightData []byte) []byte { - value := make([]byte, 0, len(innerPrefix)+len(leftData)+len(rightData)) - value = append(value, innerPrefix...) - value = append(value, leftData...) - value = append(value, rightData...) - return value -} - -func encodeSumInner(leftData []byte, rightData []byte) []byte { - value := make([]byte, 0, len(innerPrefix)+len(leftData)+len(rightData)) - value = append(value, innerPrefix...) - value = append(value, leftData...) - value = append(value, rightData...) - var sum [sumSize]byte - leftSum := uint64(0) - rightSum := uint64(0) - leftSumBz := leftData[len(leftData)-sumSize:] - rightSumBz := rightData[len(rightData)-sumSize:] - if !bytes.Equal(leftSumBz, defaultSum[:]) { - leftSum = binary.BigEndian.Uint64(leftSumBz) - } - if !bytes.Equal(rightSumBz, defaultSum[:]) { - rightSum = binary.BigEndian.Uint64(rightSumBz) - } - binary.BigEndian.PutUint64(sum[:], leftSum+rightSum) - value = append(value, sum[:]...) - return value -} - -func encodeExtension(pathBounds [2]byte, path []byte, childData []byte) []byte { - value := make([]byte, 0, len(extPrefix)+len(path)+2+len(childData)) - value = append(value, extPrefix...) - value = append(value, pathBounds[:]...) - value = append(value, path...) - value = append(value, childData...) - return value -} - -func encodeSumExtension(pathBounds [2]byte, path []byte, childData []byte) []byte { - value := make([]byte, 0, len(extPrefix)+len(path)+2+len(childData)) - value = append(value, extPrefix...) - value = append(value, pathBounds[:]...) - value = append(value, path...) - value = append(value, childData...) - var sum [sumSize]byte - copy(sum[:], childData[len(childData)-sumSize:]) - value = append(value, sum[:]...) - return value -} diff --git a/inner_node.go b/inner_node.go new file mode 100644 index 0000000..96ed2f1 --- /dev/null +++ b/inner_node.go @@ -0,0 +1,25 @@ +package smt + +// Ensure innerNode satisfies the trieNode interface +var _ trieNode = (*innerNode)(nil) + +// A branch within the binary trie pointing to a left & right child. +type innerNode struct { + // Left and right child nodes. + // Both child nodes are always expected to be non-nil. + leftChild, rightChild trieNode + persisted bool + digest []byte +} + +// Persisted satisfied the trieNode#Persisted interface +func (node *innerNode) Persisted() bool { return node.persisted } + +// Persisted satisfied the trieNode#CachedDigest interface +func (node *innerNode) CachedDigest() []byte { return node.digest } + +// setDirty marks the node as dirty (i.e. not flushed to disk) and clears the cached digest +func (node *innerNode) setDirty() { + node.persisted = false + node.digest = nil +} diff --git a/kvstore/simplemap/simplemap.go b/kvstore/simplemap/simplemap.go index 9b548ab..c20e4bf 100644 --- a/kvstore/simplemap/simplemap.go +++ b/kvstore/simplemap/simplemap.go @@ -19,6 +19,14 @@ func NewSimpleMap() kvstore.MapStore { } } +// NewSimpleMap creates a new SimpleMap instance using the map provided. +// This is useful for testing & debugging purposes. +func NewSimpleMapWithMap(m map[string][]byte) kvstore.MapStore { + return &simpleMap{ + m: m, + } +} + // Get gets the value for a key. func (sm *simpleMap) Get(key []byte) ([]byte, error) { if len(key) == 0 { diff --git a/lazy_node.go b/lazy_node.go new file mode 100644 index 0000000..fa60fef --- /dev/null +++ b/lazy_node.go @@ -0,0 +1,15 @@ +package smt + +// Ensure lazyNode satisfies the trieNode interface +var _ trieNode = (*lazyNode)(nil) + +// lazyNode represents an uncached persisted node +type lazyNode struct { + digest []byte +} + +// Persisted satisfied the trieNode#Persisted interface +func (node *lazyNode) Persisted() bool { return true } + +// Persisted satisfied the trieNode#CachedDigest interface +func (node *lazyNode) CachedDigest() []byte { return node.digest } diff --git a/leaf_node.go b/leaf_node.go new file mode 100644 index 0000000..2c84c2f --- /dev/null +++ b/leaf_node.go @@ -0,0 +1,18 @@ +package smt + +// Ensure leafNode satisfies the trieNode interface +var _ trieNode = (*leafNode)(nil) + +// leafNode stores a full key-value pair in the trie +type leafNode struct { + path []byte + valueHash []byte + persisted bool + digest []byte +} + +// Persisted satisfied the trieNode#Persisted interface +func (node *leafNode) Persisted() bool { return node.persisted } + +// Persisted satisfied the trieNode#CachedDigest interface +func (node *leafNode) CachedDigest() []byte { return node.digest } diff --git a/node_encoders.go b/node_encoders.go new file mode 100644 index 0000000..fdc6a06 --- /dev/null +++ b/node_encoders.go @@ -0,0 +1,124 @@ +package smt + +import ( + "bytes" + "encoding/binary" +) + +// TODO_TECHDEBT: All of the parsing, encoding and checking functions in this file +// can be abstracted out into the `trieNode` interface. + +// TODO_IMPROVE: We should create well-defined structs for every type of node +// to streamline the process of encoding & encoding and to improve readability. +// If decoding needs to be language agnostic (to implement POKT clients), in other +// languages, protobufs should be considered. If decoding does not need to be +// language agnostic, we can use Go's gob package for more efficient serialization. + +// NB: In this file, all references to the variable `data` should be treated as `encodedNodeData`. +// It was abbreviated to `data` for brevity. + +// TODO_TECHDEBT: We can easily use `iota` and ENUMS to create a wait to have +// more expressive code, and leverage switches statements throughout. +var ( + leafNodePrefix = []byte{0} + innerNodePrefix = []byte{1} + extNodePrefix = []byte{2} + prefixLen = 1 +) + +// NB: We use `prefixLen` a lot through this file, so to make the code more readable, we +// define it as a constant but need to assert on its length just in case the code evolves +// in the future. +func init() { + if len(leafNodePrefix) != prefixLen || + len(innerNodePrefix) != prefixLen || + len(extNodePrefix) != prefixLen { + panic("invalid prefix length") + } +} + +// isLeafNode returns true if the encoded node data is a leaf node +func isLeafNode(data []byte) bool { + return bytes.Equal(data[:prefixLen], leafNodePrefix) +} + +// isExtNode returns true if the encoded node data is an extension node +func isExtNode(data []byte) bool { + return bytes.Equal(data[:prefixLen], extNodePrefix) +} + +// isInnerNode returns true if the encoded node data is an inner node +func isInnerNode(data []byte) bool { + return bytes.Equal(data[:prefixLen], innerNodePrefix) +} + +// encodeLeafNode encodes leaf nodes. This function applies to both the SMT and +// SMST since the weight of the node is appended to the end of the valueHash. +func encodeLeafNode(path, leafData []byte) (data []byte) { + data = append(data, leafNodePrefix...) + data = append(data, path...) + data = append(data, leafData...) + return +} + +// encodeInnerNode encodes inner node given the data for both children +func encodeInnerNode(leftData, rightData []byte) (data []byte) { + data = append(data, innerNodePrefix...) + data = append(data, leftData...) + data = append(data, rightData...) + return +} + +// encodeExtensionNode encodes the data of an extension nodes +func encodeExtensionNode(pathBounds [2]byte, path, childData []byte) (data []byte) { + data = append(data, extNodePrefix...) + data = append(data, pathBounds[:]...) + data = append(data, path...) + data = append(data, childData...) + return +} + +// encodeSumInnerNode encodes an inner node for an smst given the data for both children +func encodeSumInnerNode(leftData, rightData []byte) (data []byte) { + // Compute the sum of the current node + var sum [sumSizeBytes]byte + leftSum := parseSum(leftData) + rightSum := parseSum(rightData) + // TODO_CONSIDERATION: ` I chose BigEndian for readability but most computers + // now are optimized for LittleEndian encoding could be a micro optimization one day.` + binary.BigEndian.PutUint64(sum[:], leftSum+rightSum) + + // Prepare and return the encoded inner node data + data = encodeInnerNode(leftData, rightData) + data = append(data, sum[:]...) + return +} + +// encodeSumExtensionNode encodes the data of a sum extension nodes +func encodeSumExtensionNode(pathBounds [2]byte, path, childData []byte) (data []byte) { + // Compute the sum of the current node + var sum [sumSizeBytes]byte + copy(sum[:], childData[len(childData)-sumSizeBytes:]) + + // Prepare and return the encoded inner node data + data = encodeExtensionNode(pathBounds, path, childData) + data = append(data, sum[:]...) + return +} + +// checkPrefix panics if the prefix of the data does not match the expected prefix +func checkPrefix(data, prefix []byte) { + if !bytes.Equal(data[:prefixLen], prefix) { + panic("invalid prefix") + } +} + +// parseSum parses the sum from the encoded node data +func parseSum(data []byte) uint64 { + sum := uint64(0) + sumBz := data[len(data)-sumSizeBytes:] + if !bytes.Equal(sumBz, defaultEmptySum[:]) { + sum = binary.BigEndian.Uint64(sumBz) + } + return sum +} diff --git a/options.go b/options.go index 884d559..92ec8fb 100644 --- a/options.go +++ b/options.go @@ -1,14 +1,15 @@ package smt -// Option is a function that configures SparseMerkleTrie. -type Option func(*TrieSpec) +// TrieSpecOption is a function that configures SparseMerkleTrie. +type TrieSpecOption func(*TrieSpec) // WithPathHasher returns an Option that sets the PathHasher to the one provided -func WithPathHasher(ph PathHasher) Option { +// this MUST not be nil or unknown behaviour will occur. +func WithPathHasher(ph PathHasher) TrieSpecOption { return func(ts *TrieSpec) { ts.ph = ph } } // WithValueHasher returns an Option that sets the ValueHasher to the one provided -func WithValueHasher(vh ValueHasher) Option { +func WithValueHasher(vh ValueHasher) TrieSpecOption { return func(ts *TrieSpec) { ts.vh = vh } } diff --git a/proofs.go b/proofs.go index be33a8f..b834468 100644 --- a/proofs.go +++ b/proofs.go @@ -49,17 +49,17 @@ func (proof *SparseMerkleProof) Unmarshal(bz []byte) error { return dec.Decode(proof) } +// validateBasic performs basic sanity check on the proof so that a malicious +// proof cannot cause the verifier to fatally exit (e.g. due to an index +// out-of-range error) or cause a CPU DoS attack. func (proof *SparseMerkleProof) validateBasic(spec *TrieSpec) error { - // Do a basic sanity check on the proof, so that a malicious proof cannot - // cause the verifier to fatally exit (e.g. due to an index out-of-range - // error) or cause a CPU DoS attack. - - // Check that the number of supplied sidenodes does not exceed the maximum possible. + // Verify the number of supplied sideNodes does not exceed the possible maximum. if len(proof.SideNodes) > spec.ph.PathSize()*8 { return fmt.Errorf("too many side nodes: got %d but max is %d", len(proof.SideNodes), spec.ph.PathSize()*8) } + // Check that leaf data for non-membership proofs is a valid size. - lps := len(leafPrefix) + spec.ph.PathSize() + lps := len(leafNodePrefix) + spec.ph.PathSize() if proof.NonMembershipLeafData != nil && len(proof.NonMembershipLeafData) < lps { return fmt.Errorf( "invalid non-membership leaf data size: got %d but min is %d", @@ -68,18 +68,25 @@ func (proof *SparseMerkleProof) validateBasic(spec *TrieSpec) error { ) } - // Check that all supplied sidenodes are the correct size. - for _, v := range proof.SideNodes { - if len(v) != hashSize(spec) { - return fmt.Errorf("invalid side node size: got %d but want %d", len(v), hashSize(spec)) - } + // Verify that the non-membership leaf data is of the correct size. + leafPathSize := len(leafNodePrefix) + spec.ph.PathSize() + if proof.NonMembershipLeafData != nil && len(proof.NonMembershipLeafData) < leafPathSize { + return fmt.Errorf("invalid non-membership leaf data size: got %d but min is %d", len(proof.NonMembershipLeafData), leafPathSize) } // Check that the sibling data hashes to the first side node if not nil if proof.SiblingData == nil || len(proof.SideNodes) == 0 { return nil } - siblingHash := hashPreimage(spec, proof.SiblingData) + + // Check that all supplied sideNodes are the correct size. + for _, sideNodeValue := range proof.SideNodes { + if len(sideNodeValue) != spec.hashSize() { + return fmt.Errorf("invalid side node size: got %d but want %d", len(sideNodeValue), spec.hashSize()) + } + } + + siblingHash := spec.hashPreimage(proof.SiblingData) if eq := bytes.Equal(proof.SideNodes[0], siblingHash); !eq { return fmt.Errorf("invalid sibling data hash: got %x but want %x", siblingHash, proof.SideNodes[0]) } @@ -199,7 +206,7 @@ func (proof *SparseMerkleClosestProof) GetValueHash(spec *TrieSpec) []byte { return nil } if spec.sumTrie { - return proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSize] + return proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSizeBytes] } return proof.ClosestValueHash } @@ -207,8 +214,8 @@ func (proof *SparseMerkleClosestProof) GetValueHash(spec *TrieSpec) []byte { func (proof *SparseMerkleClosestProof) validateBasic(spec *TrieSpec) error { // ensure the proof length is the same size (in bytes) as the path // hasher of the spec provided - if len(proof.Path) != spec.PathHasherSize() { - return fmt.Errorf("invalid path length: got %d, want %d", len(proof.Path), spec.PathHasherSize()) + if len(proof.Path) != spec.ph.PathSize() { + return fmt.Errorf("invalid path length: got %d, want %d", len(proof.Path), spec.ph.PathSize()) } // ensure the depth of the leaf node being proven is within the path size @@ -258,8 +265,8 @@ type SparseCompactMerkleClosestProof struct { func (proof *SparseCompactMerkleClosestProof) validateBasic(spec *TrieSpec) error { // Ensure the proof length is the same size (in bytes) as the path // hasher of the spec provided - if len(proof.Path) != spec.PathHasherSize() { - return fmt.Errorf("invalid path length: got %d, want %d", len(proof.Path), spec.PathHasherSize()) + if len(proof.Path) != spec.ph.PathSize() { + return fmt.Errorf("invalid path length: got %d, want %d", len(proof.Path), spec.ph.PathSize()) } // Do a basic sanity check on the proof on the fields of the proof specific to @@ -317,12 +324,12 @@ func VerifyProof(proof *SparseMerkleProof, root, key, value []byte, spec *TrieSp // VerifySumProof verifies a Merkle proof for a sum trie. func VerifySumProof(proof *SparseMerkleProof, root, key, value []byte, sum uint64, spec *TrieSpec) (bool, error) { - var sumBz [sumSize]byte + var sumBz [sumSizeBytes]byte binary.BigEndian.PutUint64(sumBz[:], sum) - valueHash := spec.digestValue(value) + valueHash := spec.valueHash(value) valueHash = append(valueHash, sumBz[:]...) - if bytes.Equal(value, defaultValue) && sum == 0 { - valueHash = defaultValue + if bytes.Equal(value, defaultEmptyValue) && sum == 0 { + valueHash = defaultEmptyValue } smtSpec := &TrieSpec{ th: spec.th, @@ -346,7 +353,7 @@ func VerifyClosestProof(proof *SparseMerkleClosestProof, root []byte, spec *Trie // will invalidate the proof. nilSpec := &TrieSpec{ th: spec.th, - ph: NewNilPathHasher(spec.ph.PathSize()), + ph: newNilPathHasher(spec.ph.PathSize()), vh: spec.vh, sumTrie: spec.sumTrie, } @@ -356,19 +363,20 @@ func VerifyClosestProof(proof *SparseMerkleClosestProof, root []byte, spec *Trie if proof.ClosestValueHash == nil { return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, nil, 0, nilSpec) } - sumBz := proof.ClosestValueHash[len(proof.ClosestValueHash)-sumSize:] + sumBz := proof.ClosestValueHash[len(proof.ClosestValueHash)-sumSizeBytes:] sum := binary.BigEndian.Uint64(sumBz) - valueHash := proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSize] + + valueHash := proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSizeBytes] return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, valueHash, sum, nilSpec) } +// verifyProofWithUpdates func verifyProofWithUpdates( proof *SparseMerkleProof, - root []byte, - key []byte, - value []byte, + root, key, value []byte, spec *TrieSpec, ) (bool, [][][]byte, error) { + // Retrieve the trie path for the key being proven path := spec.ph.Path(key) if err := proof.validateBasic(spec); err != nil { @@ -379,39 +387,40 @@ func verifyProofWithUpdates( // Determine what the leaf hash should be. var currentHash, currentData []byte - if bytes.Equal(value, defaultValue) { // Non-membership proof. - if proof.NonMembershipLeafData == nil { // Leaf is a placeholder value. - currentHash = placeholder(spec) - } else { // Leaf is an unrelated leaf. + if bytes.Equal(value, defaultEmptyValue) { + // Non-membership proof if `value` is empty. + if proof.NonMembershipLeafData == nil { + // Leaf is a placeholder value. + currentHash = spec.placeholder() + } else { + // Leaf is an unrelated leaf. var actualPath, valueHash []byte - actualPath, valueHash = parseLeaf(proof.NonMembershipLeafData, spec.ph) + actualPath, valueHash = spec.parseLeafNode(proof.NonMembershipLeafData) if bytes.Equal(actualPath, path) { // This is not an unrelated leaf; non-membership proof failed. return false, nil, errors.Join(ErrBadProof, errors.New("non-membership proof on related leaf")) } - currentHash, currentData = digestLeaf(spec, actualPath, valueHash) - - update := make([][]byte, 2) - update[0], update[1] = currentHash, currentData - updates = append(updates, update) + currentHash, currentData = spec.digestLeaf(actualPath, valueHash) } - } else { // Membership proof. - valueHash := spec.digestValue(value) - currentHash, currentData = digestLeaf(spec, path, valueHash) - update := make([][]byte, 2) - update[0], update[1] = currentHash, currentData - updates = append(updates, update) + } else { + // Membership proof if `value` is non-empty. + valueHash := spec.valueHash(value) + currentHash, currentData = spec.digestLeaf(path, valueHash) } + update := make([][]byte, 2) + update[0], update[1] = currentHash, currentData + updates = append(updates, update) + // Recompute root. for i := 0; i < len(proof.SideNodes); i++ { - node := make([]byte, hashSize(spec)) + node := make([]byte, spec.hashSize()) copy(node, proof.SideNodes[i]) - if getPathBit(path, len(proof.SideNodes)-1-i) == left { - currentHash, currentData = digestNode(spec, currentHash, node) + if getPathBit(path, len(proof.SideNodes)-1-i) == leftChildBit { + currentHash, currentData = spec.digestInnerNode(currentHash, node) } else { - currentHash, currentData = digestNode(spec, node, currentHash) + currentHash, currentData = spec.digestInnerNode(node, currentHash) } update := make([][]byte, 2) @@ -464,9 +473,9 @@ func CompactProof(proof *SparseMerkleProof, spec *TrieSpec) (*SparseCompactMerkl bitMask := make([]byte, int(math.Ceil(float64(len(proof.SideNodes))/float64(8)))) var compactedSideNodes [][]byte for i := 0; i < len(proof.SideNodes); i++ { - node := make([]byte, hashSize(spec)) + node := make([]byte, spec.hashSize()) copy(node, proof.SideNodes[i]) - if bytes.Equal(node, placeholder(spec)) { + if bytes.Equal(node, spec.placeholder()) { setPathBit(bitMask, i) } else { compactedSideNodes = append(compactedSideNodes, node) @@ -492,7 +501,7 @@ func DecompactProof(proof *SparseCompactMerkleProof, spec *TrieSpec) (*SparseMer position := 0 for i := 0; i < proof.NumSideNodes; i++ { if getPathBit(proof.BitMask, i) == 1 { - decompactedSideNodes[i] = placeholder(spec) + decompactedSideNodes[i] = spec.placeholder() } else { decompactedSideNodes[i] = proof.SideNodes[position] position++ diff --git a/proofs_test.go b/proofs_test.go index 6248e5c..b673b2f 100644 --- a/proofs_test.go +++ b/proofs_test.go @@ -28,7 +28,7 @@ func TestSparseMerkleProof_Marshal(t *testing.T) { require.Greater(t, len(bz2), 0) require.NotEqual(t, bz, bz2) - proof3 := randomiseProof(proof) + proof3 := randomizeProof(proof) bz3, err := proof3.Marshal() require.NoError(t, err) require.NotNil(t, bz3) @@ -59,7 +59,7 @@ func TestSparseMerkleProof_Unmarshal(t *testing.T) { require.NoError(t, uproof2.Unmarshal(bz2)) require.Equal(t, proof2, uproof2) - proof3 := randomiseProof(proof) + proof3 := randomizeProof(proof) bz3, err := proof3.Marshal() require.NoError(t, err) require.NotNil(t, bz3) @@ -91,7 +91,7 @@ func TestSparseCompactMerkleProof_Marshal(t *testing.T) { require.Greater(t, len(bz2), 0) require.NotEqual(t, bz, bz2) - proof3 := randomiseProof(proof) + proof3 := randomizeProof(proof) compactProof3, err := CompactProof(proof3, trie.Spec()) require.NoError(t, err) bz3, err := compactProof3.Marshal() @@ -134,7 +134,7 @@ func TestSparseCompactMerkleProof_Unmarshal(t *testing.T) { require.NoError(t, err) require.Equal(t, proof2, uproof2) - proof3 := randomiseProof(proof) + proof3 := randomizeProof(proof) compactProof3, err := CompactProof(proof3, trie.Spec()) require.NoError(t, err) bz3, err := compactProof3.Marshal() @@ -162,7 +162,7 @@ func setupTrie(t *testing.T) *SMT { return trie } -func randomiseProof(proof *SparseMerkleProof) *SparseMerkleProof { +func randomizeProof(proof *SparseMerkleProof) *SparseMerkleProof { sideNodes := make([][]byte, len(proof.SideNodes)) for i := range sideNodes { sideNodes[i] = make([]byte, len(proof.SideNodes[i])) @@ -174,12 +174,12 @@ func randomiseProof(proof *SparseMerkleProof) *SparseMerkleProof { } } -func randomiseSumProof(proof *SparseMerkleProof) *SparseMerkleProof { +func randomizeSumProof(proof *SparseMerkleProof) *SparseMerkleProof { sideNodes := make([][]byte, len(proof.SideNodes)) for i := range sideNodes { - sideNodes[i] = make([]byte, len(proof.SideNodes[i])-sumSize) + sideNodes[i] = make([]byte, len(proof.SideNodes[i])-sumSizeBytes) rand.Read(sideNodes[i]) // nolint: errcheck - sideNodes[i] = append(sideNodes[i], proof.SideNodes[i][len(proof.SideNodes[i])-sumSize:]...) + sideNodes[i] = append(sideNodes[i], proof.SideNodes[i][len(proof.SideNodes[i])-sumSizeBytes:]...) } return &SparseMerkleProof{ SideNodes: sideNodes, diff --git a/root_test.go b/root_test.go index da6293c..08f8b6d 100644 --- a/root_test.go +++ b/root_test.go @@ -3,7 +3,6 @@ package smt_test import ( "crypto/sha256" "crypto/sha512" - "encoding/binary" "fmt" "hash" "testing" @@ -59,8 +58,9 @@ func TestMerkleRoot_TrieTypes(t *testing.T) { for i := uint64(0); i < 10; i++ { require.NoError(t, trie.Update([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)), i)) } - root := trie.Root() - require.Equal(t, root.Sum(), getSumBzHelper(t, root)) + require.NotNil(t, trie.Sum()) + require.EqualValues(t, 45, trie.Sum()) + return } trie := smt.NewSparseMerkleTrie(nodeStore, tt.hasher) @@ -73,10 +73,3 @@ func TestMerkleRoot_TrieTypes(t *testing.T) { }) } } - -func getSumBzHelper(t *testing.T, r []byte) uint64 { - sumSize := len(r) % 32 - sumBz := make([]byte, sumSize) - copy(sumBz[:], []byte(r)[len([]byte(r))-sumSize:]) - return binary.BigEndian.Uint64(sumBz[:]) -} diff --git a/smst.go b/smst.go index 6198ec8..a2bc49c 100644 --- a/smst.go +++ b/smst.go @@ -8,6 +8,11 @@ import ( "github.com/pokt-network/smt/kvstore" ) +const ( + // The number of bits used to represent the sum of a node + sumSizeBytes = 8 +) + var _ SparseMerkleSumTrie = (*SMST)(nil) // SMST is an object wrapping a Sparse Merkle Trie for custom encoding @@ -20,25 +25,32 @@ type SMST struct { func NewSparseMerkleSumTrie( nodes kvstore.MapStore, hasher hash.Hash, - options ...Option, + options ...TrieSpecOption, ) *SMST { + trieSpec := newTrieSpec(hasher, true) + for _, option := range options { + option(&trieSpec) + } + + // Initialize a non-sum SMT and modify it to have a nil value hasher. + // NB: We are using a nil value hasher because the SMST pre-hashes its paths. + // This results in double path hashing because the SMST is a wrapper + // around the SMT. The reason the SMST uses its own path hashing logic is + // to account for the additional sum in the encoding/decoding process. + // Therefore, the underlying SMT underneath needs a nil path hasher, while + // the outer SMST does all the (non nil) path hashing itself. + // TODO_TECHDEBT(@Olshansk): Look for ways to simplify / cleanup the above. smt := &SMT{ - TrieSpec: NewTrieSpec(hasher, true), + TrieSpec: trieSpec, nodes: nodes, } - for _, option := range options { - option(&smt.TrieSpec) - } - nvh := WithValueHasher(nil) - nvh(&smt.TrieSpec) - smst := &SMST{ - TrieSpec: NewTrieSpec(hasher, true), + nilValueHasher := WithValueHasher(nil) + nilValueHasher(&smt.TrieSpec) + + return &SMST{ + TrieSpec: trieSpec, SMT: smt, } - for _, option := range options { - option(&smst.TrieSpec) - } - return smst } // ImportSparseMerkleSumTrie returns a pointer to an SMST struct with the root hash provided @@ -46,11 +58,11 @@ func ImportSparseMerkleSumTrie( nodes kvstore.MapStore, hasher hash.Hash, root []byte, - options ...Option, + options ...TrieSpecOption, ) *SMST { smst := NewSparseMerkleSumTrie(nodes, hasher, options...) - smst.trie = &lazyNode{root} - smst.savedRoot = root + smst.root = &lazyNode{root} + smst.rootHash = root return smst } @@ -59,31 +71,50 @@ func (smst *SMST) Spec() *TrieSpec { return &smst.TrieSpec } -// Get returns the digest of the value stored at the given key and the weight -// of the leaf node -func (smst *SMST) Get(key []byte) ([]byte, uint64, error) { - valueHash, err := smst.SMT.Get(key) +// Get retrieves the value digest for the given key and the digest of the value +// along with its weight provided a leaf node exists. +func (smst *SMST) Get(key []byte) (valueDigest []byte, weight uint64, err error) { + // Retrieve the value digest from the trie for the given key + valueDigest, err = smst.SMT.Get(key) if err != nil { return nil, 0, err } - if bytes.Equal(valueHash, defaultValue) { - return defaultValue, 0, nil + + // Check if it ias an empty branch + if bytes.Equal(valueDigest, defaultEmptyValue) { + return defaultEmptyValue, 0, nil } - var weightBz [sumSize]byte - copy(weightBz[:], valueHash[len(valueHash)-sumSize:]) - weight := binary.BigEndian.Uint64(weightBz[:]) - return valueHash[:len(valueHash)-sumSize], weight, nil + + // Retrieve the node weight + var weightBz [sumSizeBytes]byte + copy(weightBz[:], valueDigest[len(valueDigest)-sumSizeBytes:]) + weight = binary.BigEndian.Uint64(weightBz[:]) + + // Remove the weight from the value digest + valueDigest = valueDigest[:len(valueDigest)-sumSizeBytes] + + // Return the value digest and weight + return valueDigest, weight, nil } -// Update sets the value for the given key, to the digest of the provided value -// appended with the binary representation of the weight provided. The weight -// is used to compute the interim and total sum of the trie. +// Update inserts the value and weight into the trie for the given key. +// +// The a digest (i.e. hash) of the value is computed and appended with the byte +// representation of the weight integer provided. + +// The weight is used to compute the interim sum of the node which then percolates +// up to the total sum of the trie. func (smst *SMST) Update(key, value []byte, weight uint64) error { - valueHash := smst.digestValue(value) - var weightBz [sumSize]byte + // Convert the node weight to a byte slice + var weightBz [sumSizeBytes]byte binary.BigEndian.PutUint64(weightBz[:], weight) - valueHash = append(valueHash, weightBz[:]...) - return smst.SMT.Update(key, valueHash) + + // Compute the digest of the value and append the weight to it + valueDigest := smst.valueHash(value) + valueDigest = append(valueDigest, weightBz[:]...) + + // Return the result of the trie update + return smst.SMT.Update(key, valueDigest) } // Delete removes the node at the path corresponding to the given key @@ -116,8 +147,14 @@ func (smst *SMST) Root() MerkleRoot { return smst.SMT.Root() // [digest]+[binary sum] } -// Sum returns the uint64 sum of the entire trie +// Sum returns the sum of the entire trie stored in the root. +// If the tree is not a sum tree, it will panic. func (smst *SMST) Sum() uint64 { - digest := smst.Root() - return digest.Sum() + rootDigest := smst.Root() + if !smst.Spec().sumTrie { + panic("SMST: not a merkle sum trie") + } + var sumBz [sumSizeBytes]byte + copy(sumBz[:], []byte(rootDigest)[len([]byte(rootDigest))-sumSizeBytes:]) + return binary.BigEndian.Uint64(sumBz[:]) } diff --git a/smst_proofs_test.go b/smst_proofs_test.go index c909d23..d0d8c9d 100644 --- a/smst_proofs_test.go +++ b/smst_proofs_test.go @@ -33,7 +33,7 @@ func TestSMST_Proof_Operations(t *testing.T) { proof, err = smst.Prove([]byte("testKey3")) require.NoError(t, err) checkCompactEquivalence(t, proof, base) - result, err = VerifySumProof(proof, placeholder(base), []byte("testKey3"), defaultValue, 0, base) + result, err = VerifySumProof(proof, base.placeholder(), []byte("testKey3"), defaultEmptyValue, 0, base) require.NoError(t, err) require.True(t, result) result, err = VerifySumProof(proof, root, []byte("testKey3"), []byte("badValue"), 5, base) @@ -80,7 +80,7 @@ func TestSMST_Proof_Operations(t *testing.T) { require.NoError(t, err) require.False(t, result) result, err = VerifySumProof( - randomiseSumProof(proof), + randomizeSumProof(proof), root, []byte("testKey"), []byte("testValue"), @@ -106,7 +106,7 @@ func TestSMST_Proof_Operations(t *testing.T) { require.NoError(t, err) require.False(t, result) result, err = VerifySumProof( - randomiseSumProof(proof), + randomizeSumProof(proof), root, []byte("testKey2"), []byte("testValue"), @@ -117,16 +117,16 @@ func TestSMST_Proof_Operations(t *testing.T) { require.False(t, result) // Try proving a default value for a non-default leaf. - var sum [sumSize]byte + var sum [sumSizeBytes]byte binary.BigEndian.PutUint64(sum[:], 5) - tval := base.digestValue([]byte("testValue")) + tval := base.valueHash([]byte("testValue")) tval = append(tval, sum[:]...) - _, leafData := base.th.digestSumLeaf(base.ph.Path([]byte("testKey2")), tval) + _, leafData := base.th.digestSumLeafNode(base.ph.Path([]byte("testKey2")), tval) proof = &SparseMerkleProof{ SideNodes: proof.SideNodes, NonMembershipLeafData: leafData, } - result, err = VerifySumProof(proof, root, []byte("testKey2"), defaultValue, 0, base) + result, err = VerifySumProof(proof, root, []byte("testKey2"), defaultEmptyValue, 0, base) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) @@ -134,20 +134,20 @@ func TestSMST_Proof_Operations(t *testing.T) { proof, err = smst.Prove([]byte("testKey3")) require.NoError(t, err) checkCompactEquivalence(t, proof, base) - result, err = VerifySumProof(proof, root, []byte("testKey3"), defaultValue, 0, base) // valid + result, err = VerifySumProof(proof, root, []byte("testKey3"), defaultEmptyValue, 0, base) // valid require.NoError(t, err) require.True(t, result) result, err = VerifySumProof(proof, root, []byte("testKey3"), []byte("badValue"), 0, base) // wrong value require.NoError(t, err) require.False(t, result) - result, err = VerifySumProof(proof, root, []byte("testKey3"), defaultValue, 5, base) // wrong sum + result, err = VerifySumProof(proof, root, []byte("testKey3"), defaultEmptyValue, 5, base) // wrong sum require.NoError(t, err) require.False(t, result) result, err = VerifySumProof( - randomiseSumProof(proof), + randomizeSumProof(proof), root, []byte("testKey3"), - defaultValue, + defaultEmptyValue, 0, base, ) // invalid proof @@ -208,7 +208,7 @@ func TestSMST_Proof_ValidateBasic(t *testing.T) { // Case: incorrect non-nil sibling data proof, _ = smst.Prove([]byte("testKey1")) - proof.SiblingData = base.th.digest(proof.SiblingData) + proof.SiblingData = base.th.digestData(proof.SiblingData) require.EqualError( t, proof.validateBasic(base), @@ -301,7 +301,7 @@ func TestSMST_ProveClosest(t *testing.T) { var result bool var root []byte var err error - var sumBz [sumSize]byte + var sumBz [sumSizeBytes]byte smn = simplemap.NewSimpleMap() require.NoError(t, err) @@ -397,7 +397,7 @@ func TestSMST_ProveClosest_Empty(t *testing.T) { Path: path[:], FlippedBits: []int{0}, Depth: 0, - ClosestPath: placeholder(smst.Spec()), + ClosestPath: smst.placeholder(), ClosestProof: &SparseMerkleProof{}, }) @@ -427,7 +427,7 @@ func TestSMST_ProveClosest_OneNode(t *testing.T) { closestPath := sha256.Sum256([]byte("foo")) closestValueHash := []byte("bar") - var sumBz [sumSize]byte + var sumBz [sumSizeBytes]byte binary.BigEndian.PutUint64(sumBz[:], 5) closestValueHash = append(closestValueHash, sumBz[:]...) require.Equal(t, proof, &SparseMerkleClosestProof{ diff --git a/smst_test.go b/smst_test.go index d8e0b3c..e331994 100644 --- a/smst_test.go +++ b/smst_test.go @@ -17,7 +17,7 @@ import ( func NewSMSTWithStorage( nodes, preimages kvstore.MapStore, hasher hash.Hash, - options ...Option, + options ...TrieSpecOption, ) *SMSTWithStorage { return &SMSTWithStorage{ SMST: NewSparseMerkleSumTrie(nodes, hasher, options...), @@ -38,7 +38,7 @@ func TestSMST_TrieUpdateBasic(t *testing.T) { // Test getting an empty key. value, sum, err = smst.GetValueSum([]byte("testKey")) require.NoError(t, err) - require.Equal(t, defaultValue, value) + require.Equal(t, defaultEmptyValue, value) require.Equal(t, uint64(0), sum) has, err = smst.Has([]byte("testKey")) @@ -132,7 +132,7 @@ func TestSMST_TrieDeleteBasic(t *testing.T) { value, sum, err := smst.GetValueSum([]byte("testKey")) require.NoError(t, err) - require.Equal(t, defaultValue, value, "getting deleted key") + require.Equal(t, defaultEmptyValue, value, "getting deleted key") require.Equal(t, uint64(0), sum, "getting deleted key") has, err := smst.Has([]byte("testKey")) @@ -157,7 +157,7 @@ func TestSMST_TrieDeleteBasic(t *testing.T) { value, sum, err = smst.GetValueSum([]byte("testKey2")) require.NoError(t, err) - require.Equal(t, defaultValue, value, "getting deleted key") + require.Equal(t, defaultEmptyValue, value, "getting deleted key") require.Equal(t, uint64(0), sum, "getting deleted key") value, sum, err = smst.GetValueSum([]byte("testKey")) @@ -179,7 +179,7 @@ func TestSMST_TrieDeleteBasic(t *testing.T) { value, sum, err = smst.GetValueSum([]byte("foo")) require.NoError(t, err) - require.Equal(t, defaultValue, value, "getting deleted key") + require.Equal(t, defaultEmptyValue, value, "getting deleted key") require.Equal(t, uint64(0), sum, "getting deleted key") value, sum, err = smst.GetValueSum([]byte("testKey")) @@ -202,7 +202,7 @@ func TestSMST_TrieDeleteBasic(t *testing.T) { value, sum, err = smst.GetValueSum([]byte("testKey")) require.NoError(t, err) - require.Equal(t, defaultValue, value, "getting deleted key") + require.Equal(t, defaultEmptyValue, value, "getting deleted key") require.Equal(t, uint64(0), sum, "getting deleted key") has, err = smst.Has([]byte("testKey")) @@ -441,7 +441,7 @@ func TestSMST_TotalSum(t *testing.T) { // Check root hash contains the correct hex sum root1 := smst.Root() - sumBz := root1[len(root1)-sumSize:] + sumBz := root1[len(root1)-sumSizeBytes:] rootSum := binary.BigEndian.Uint64(sumBz) require.NoError(t, err) diff --git a/smst_utils_test.go b/smst_utils_test.go index 21a9ac7..db2acd0 100644 --- a/smst_utils_test.go +++ b/smst_utils_test.go @@ -26,8 +26,8 @@ func (smst *SMSTWithStorage) Update(key, value []byte, sum uint64) error { if err := smst.SMST.Update(key, value, sum); err != nil { return err } - valueHash := smst.digestValue(value) - var sumBz [sumSize]byte + valueHash := smst.valueHash(value) + var sumBz [sumSizeBytes]byte binary.BigEndian.PutUint64(sumBz[:], sum) value = append(value, sumBz[:]...) return smst.preimages.Set(valueHash, value) @@ -52,22 +52,22 @@ func (smst *SMSTWithStorage) GetValueSum(key []byte) ([]byte, uint64, error) { if err != nil { if errors.Is(err, ErrKeyNotFound) { // If key isn't found, return default value and sum - return defaultValue, 0, nil + return defaultEmptyValue, 0, nil } // Otherwise percolate up any other error return nil, 0, err } - var sumBz [sumSize]byte - copy(sumBz[:], value[len(value)-sumSize:]) + var sumBz [sumSizeBytes]byte + copy(sumBz[:], value[len(value)-sumSizeBytes:]) storedSum := binary.BigEndian.Uint64(sumBz[:]) if storedSum != sum { return nil, 0, fmt.Errorf("sum mismatch for %s: got %d, expected %d", string(key), storedSum, sum) } - return value[:len(value)-sumSize], storedSum, nil + return value[:len(value)-sumSizeBytes], storedSum, nil } // Has returns true if the value at the given key is non-default, false otherwise. func (smst *SMSTWithStorage) Has(key []byte) (bool, error) { val, sum, err := smst.GetValueSum(key) - return !bytes.Equal(defaultValue, val) || sum != 0, err + return !bytes.Equal(defaultEmptyValue, val) || sum != 0, err } diff --git a/smt.go b/smt.go index 02d088b..120186f 100644 --- a/smt.go +++ b/smt.go @@ -7,59 +7,18 @@ import ( "github.com/pokt-network/smt/kvstore" ) -var ( - _ trieNode = (*innerNode)(nil) - _ trieNode = (*leafNode)(nil) - _ SparseMerkleTrie = (*SMT)(nil) -) - -type trieNode interface { - // when committing a node to disk, skip if already persisted - Persisted() bool - CachedDigest() []byte -} - -// A branch within the trie -type innerNode struct { - // Both child nodes are always non-nil - leftChild, rightChild trieNode - persisted bool - digest []byte -} - -// Stores data and full path -type leafNode struct { - path []byte - valueHash []byte - persisted bool - digest []byte -} - -// A compressed chain of singly-linked inner nodes -type extensionNode struct { - path []byte - // Offsets into path slice of bounds defining actual path segment. - // Note: assumes path is <=256 bits - pathBounds [2]byte - // Child is always an inner node, or lazy. - child trieNode - persisted bool - digest []byte -} - -// Represents an uncached, persisted node -type lazyNode struct { - digest []byte -} +// Ensure the `SMT` struct implements the `SparseMerkleTrie` interface +var _ SparseMerkleTrie = (*SMT)(nil) // SMT is a Sparse Merkle Trie object that implements the SparseMerkleTrie interface type SMT struct { TrieSpec + // Backing key-value store for the node nodes kvstore.MapStore // Last persisted root hash - savedRoot []byte - // Current state of trie - trie trieNode + rootHash []byte + // The current view of the SMT + root trieNode // Lists of per-operation orphan sets orphans []orphanNodes } @@ -72,10 +31,10 @@ type orphanNodes = [][]byte func NewSparseMerkleTrie( nodes kvstore.MapStore, hasher hash.Hash, - options ...Option, + options ...TrieSpecOption, ) *SMT { smt := SMT{ - TrieSpec: NewTrieSpec(hasher, false), + TrieSpec: newTrieSpec(hasher, false), nodes: nodes, } for _, option := range options { @@ -90,75 +49,96 @@ func ImportSparseMerkleTrie( nodes kvstore.MapStore, hasher hash.Hash, root []byte, - options ...Option, + options ...TrieSpecOption, ) *SMT { smt := NewSparseMerkleTrie(nodes, hasher, options...) - smt.trie = &lazyNode{root} - smt.savedRoot = root + smt.root = &lazyNode{root} + smt.rootHash = root return smt } -// Get returns the digest of the value stored at the given key +// Root returns the root hash of the trie +func (smt *SMT) Root() MerkleRoot { + return smt.digest(smt.root) +} + +// Get returns the hash (i.e. digest) of the leaf value stored at the given key func (smt *SMT) Get(key []byte) ([]byte, error) { path := smt.ph.Path(key) + // The leaf node whose value will be returned var leaf *leafNode var err error - for node, depth := &smt.trie, 0; ; depth++ { - *node, err = smt.resolveLazy(*node) + + // Loop throughout the entire trie to find the corresponding leaf for the + // given key. + for currNode, depth := &smt.root, 0; ; depth++ { + *currNode, err = smt.resolveLazy(*currNode) if err != nil { return nil, err } - if *node == nil { + if *currNode == nil { break } - if n, ok := (*node).(*leafNode); ok { + if n, ok := (*currNode).(*leafNode); ok { if bytes.Equal(path, n.path) { leaf = n } break } - if ext, ok := (*node).(*extensionNode); ok { - if _, match := ext.match(path, depth); !match { + if extNode, ok := (*currNode).(*extensionNode); ok { + if _, fullMatch := extNode.boundsMatch(path, depth); !fullMatch { break } - depth += ext.length() - node = &ext.child - *node, err = smt.resolveLazy(*node) + depth += extNode.length() + currNode = &extNode.child + *currNode, err = smt.resolveLazy(*currNode) if err != nil { return nil, err } } - inner := (*node).(*innerNode) - if getPathBit(path, depth) == left { - node = &inner.leftChild + inner := (*currNode).(*innerNode) + if getPathBit(path, depth) == leftChildBit { + currNode = &inner.leftChild } else { - node = &inner.rightChild + currNode = &inner.rightChild } } if leaf == nil { - return defaultValue, nil + return defaultEmptyValue, nil } return leaf.valueHash, nil } -// Update sets the value for the given key, to the digest of the provided value -func (smt *SMT) Update(key []byte, value []byte) error { +// Update inserts the `value` for the given `key` into the SMT +func (smt *SMT) Update(key, value []byte) error { + // Convert the key into a path by computing its digest path := smt.ph.Path(key) - valueHash := smt.digestValue(value) + + // Convert the value into a hash by computing its digest + valueHash := smt.valueHash(value) + + // Update the trie with the new key-value pair var orphans orphanNodes - trie, err := smt.update(smt.trie, 0, path, valueHash, &orphans) + + // Compute the new root by inserting (path, valueHash) starting from the + // root of the tree in order to find the correct position of the new leaf. + newRoot, err := smt.update(smt.root, 0, path, valueHash, &orphans) if err != nil { return err } - smt.trie = trie + smt.root = newRoot if len(orphans) > 0 { smt.orphans = append(smt.orphans, orphans) } return nil } +// Internal helper to the `Update` method func (smt *SMT) update( - node trieNode, depth int, path, value []byte, orphans *orphanNodes, + node trieNode, + depth int, + path, value []byte, + orphans *orphanNodes, ) (trieNode, error) { node, err := smt.resolveLazy(node) if err != nil { @@ -171,14 +151,16 @@ func (smt *SMT) update( return newLeaf, nil } if leaf, ok := node.(*leafNode); ok { - prefixlen := countCommonPrefixBits(path, leaf.path, depth) - if prefixlen == smt.depth() { // replace leaf if paths are equal + prefixLen := countCommonPrefixBits(path, leaf.path, depth) + // replace leaf if paths are equal + if prefixLen == smt.depth() { smt.addOrphan(orphans, node) return newLeaf, nil } - // We insert an "extension" representing multiple single-branch inner nodes + // Create a new innerNode where a previous leafNode was, branching + // based on the path bit at the current depth in the path. var newInner *innerNode - if getPathBit(path, prefixlen) == left { + if getPathBit(path, prefixLen) == leftChildBit { newInner = &innerNode{ leftChild: newLeaf, rightChild: leaf, @@ -189,9 +171,11 @@ func (smt *SMT) update( rightChild: newLeaf, } } - // Determine if we need to insert an extension or a branch + // Determine if we need to insert the new innerNode as the child + // of an extensionNode or a insert a the new innerNode in place of + // a pre-existing leafNode with a common prefix. last := &node - if depth < prefixlen { + if depth < prefixLen { // note: this keeps path slice alive - GC inefficiency? if depth > 0xff { panic("invalid depth") @@ -200,7 +184,7 @@ func (smt *SMT) update( child: newInner, path: path, pathBounds: [2]byte{ - byte(depth), byte(prefixlen), + byte(depth), byte(prefixLen), }, } // Dereference the last node to replace it with the extension node @@ -214,20 +198,24 @@ func (smt *SMT) update( smt.addOrphan(orphans, node) - if ext, ok := node.(*extensionNode); ok { + // If the node is an extensionNode split it by the path provided, we + // call update() on the results to place the newLeaf correctly. + if extNode, ok := node.(*extensionNode); ok { var branch *trieNode - node, branch, depth = ext.split(path, depth) + node, branch, depth = extNode.split(path) *branch, err = smt.update(*branch, depth, path, value, orphans) if err != nil { return node, err } - ext.setDirty() + extNode.setDirty() return node, nil } + // The node must be an innerNode. Depending on which side of the branch inner + // node the newLeaf should be added to, call update() accordingly. inner := node.(*innerNode) var child *trieNode - if getPathBit(path, depth) == left { + if getPathBit(path, depth) == leftChildBit { child = &inner.leftChild } else { child = &inner.rightChild @@ -244,11 +232,11 @@ func (smt *SMT) update( func (smt *SMT) Delete(key []byte) error { path := smt.ph.Path(key) var orphans orphanNodes - trie, err := smt.delete(smt.trie, 0, path, &orphans) + trie, err := smt.delete(smt.root, 0, path, &orphans) if err != nil { return err } - smt.trie = trie + smt.root = trie if len(orphans) > 0 { smt.orphans = append(smt.orphans, orphans) } @@ -275,30 +263,30 @@ func (smt *SMT) delete(node trieNode, depth int, path []byte, orphans *orphanNod smt.addOrphan(orphans, node) - if ext, ok := node.(*extensionNode); ok { - if _, match := ext.match(path, depth); !match { + if extNode, ok := node.(*extensionNode); ok { + if _, fullMatch := extNode.boundsMatch(path, depth); !fullMatch { return node, ErrKeyNotFound } - ext.child, err = smt.delete(ext.child, depth+ext.length(), path, orphans) + extNode.child, err = smt.delete(extNode.child, depth+extNode.length(), path, orphans) if err != nil { return node, err } - switch n := ext.child.(type) { + switch n := extNode.child.(type) { case *leafNode: return n, nil case *extensionNode: // Join this extension with the child smt.addOrphan(orphans, n) - n.pathBounds[0] = ext.pathBounds[0] + n.pathBounds[0] = extNode.pathBounds[0] node = n } - ext.setDirty() + extNode.setDirty() return node, nil } inner := node.(*innerNode) var child, sib *trieNode - if getPathBit(path, depth) == left { + if getPathBit(path, depth) == leftChildBit { child, sib = &inner.leftChild, &inner.rightChild } else { child, sib = &inner.rightChild, &inner.leftChild @@ -338,7 +326,7 @@ func (smt *SMT) Prove(key []byte) (proof *SparseMerkleProof, err error) { var siblings []trieNode var sib trieNode - node := smt.trie + node := smt.root for depth := 0; depth < smt.depth(); depth++ { node, err = smt.resolveLazy(node) if err != nil { @@ -350,24 +338,24 @@ func (smt *SMT) Prove(key []byte) (proof *SparseMerkleProof, err error) { if _, ok := node.(*leafNode); ok { break } - if ext, ok := node.(*extensionNode); ok { - length, match := ext.match(path, depth) - if match { - for i := 0; i < length; i++ { + if extNode, ok := node.(*extensionNode); ok { + matchLen, fullMatch := extNode.boundsMatch(path, depth) + if fullMatch { + for i := 0; i < matchLen; i++ { siblings = append(siblings, nil) } - depth += length - node = ext.child + depth += matchLen + node = extNode.child node, err = smt.resolveLazy(node) if err != nil { return nil, err } } else { - node = ext.expand() + node = extNode.expand() } } inner := node.(*innerNode) - if getPathBit(path, depth) == left { + if getPathBit(path, depth) == leftChildBit { node, sib = inner.leftChild, inner.rightChild } else { node, sib = inner.rightChild, inner.leftChild @@ -383,7 +371,7 @@ func (smt *SMT) Prove(key []byte) (proof *SparseMerkleProof, err error) { if !bytes.Equal(leaf.path, path) { // This is a non-membership proof that involves showing a different leaf. // Add the leaf data to the proof. - leafData = encodeLeaf(leaf.path, leaf.valueHash) + leafData = encodeLeafNode(leaf.path, leaf.valueHash) } } // Hash siblings from bottom up. @@ -391,7 +379,7 @@ func (smt *SMT) Prove(key []byte) (proof *SparseMerkleProof, err error) { for i := range siblings { var sideNode []byte sibling := siblings[len(siblings)-i-1] - sideNode = hashNode(smt.Spec(), sibling) + sideNode = smt.digest(sibling) sideNodes = append(sideNodes, sideNode) } @@ -404,7 +392,7 @@ func (smt *SMT) Prove(key []byte) (proof *SparseMerkleProof, err error) { if err != nil { return nil, err } - proof.SiblingData = serialize(smt.Spec(), sib) + proof.SiblingData = smt.encode(sib) } return proof, nil } @@ -425,7 +413,7 @@ func (smt *SMT) ProveClosest(path []byte) ( err error, // the error value encountered ) { // Ensure the path provided is the correct length for the path hasher. - if len(path) != smt.Spec().PathHasherSize() { + if len(path) != smt.Spec().ph.PathSize() { return nil, ErrInvalidClosestPath } @@ -443,7 +431,7 @@ func (smt *SMT) ProveClosest(path []byte) ( FlippedBits: make([]int, 0), } - node := smt.trie + node := smt.root depth := 0 // continuously traverse the trie until we hit a leaf node for depth < smt.depth() { @@ -483,21 +471,21 @@ func (smt *SMT) ProveClosest(path []byte) ( proof.Depth = depth break } - if ext, ok := node.(*extensionNode); ok { - length, match := ext.match(workingPath, depth) + if extNode, ok := node.(*extensionNode); ok { + matchLen, fullMatch := extNode.boundsMatch(workingPath, depth) // workingPath from depth to end of extension node's path bounds // is a perfect match - if !match { - node = ext.expand() + if !fullMatch { + node = extNode.expand() } else { // extension nodes represent a singly linked list of inner nodes // add nil siblings to represent the empty neighbours - for i := 0; i < length; i++ { + for i := 0; i < matchLen; i++ { siblings = append(siblings, nil) } - depth += length - depthDelta += length - node = ext.child + depth += matchLen + depthDelta += matchLen + node = extNode.child node, err = smt.resolveLazy(node) if err != nil { return nil, err @@ -509,7 +497,7 @@ func (smt *SMT) ProveClosest(path []byte) ( proof.Depth = depth break } - if getPathBit(workingPath, depth) == left { + if getPathBit(workingPath, depth) == leftChildBit { node, sib = inner.leftChild, inner.rightChild } else { node, sib = inner.rightChild, inner.leftChild @@ -521,7 +509,7 @@ func (smt *SMT) ProveClosest(path []byte) ( // Retrieve the closest path and value hash if found if node == nil { // trie was empty - proof.ClosestPath, proof.ClosestValueHash = placeholder(smt.Spec()), nil + proof.ClosestPath, proof.ClosestValueHash = smt.placeholder(), nil proof.ClosestProof = &SparseMerkleProof{} return proof, nil } @@ -536,7 +524,7 @@ func (smt *SMT) ProveClosest(path []byte) ( for i := range siblings { var sideNode []byte sibling := siblings[len(siblings)-i-1] - sideNode = hashNode(smt.Spec(), sibling) + sideNode = smt.digest(sibling) sideNodes = append(sideNodes, sideNode) } proof.ClosestProof = &SparseMerkleProof{ @@ -547,107 +535,122 @@ func (smt *SMT) ProveClosest(path []byte) ( if err != nil { return nil, err } - proof.ClosestProof.SiblingData = serialize(smt.Spec(), sib) + proof.ClosestProof.SiblingData = smt.encode(sib) } return proof, nil } -//nolint:unused -func (smt *SMT) recursiveLoad(hash []byte) (trieNode, error) { - return smt.resolve(hash, smt.recursiveLoad) -} - -// resolves a stub into a cached node +// resolveLazy resolves a lazy note into a cached node depending on the tree type func (smt *SMT) resolveLazy(node trieNode) (trieNode, error) { stub, ok := node.(*lazyNode) if !ok { return node, nil } - resolver := func(hash []byte) (trieNode, error) { - return &lazyNode{hash}, nil - } - ret, err := resolve(smt, stub.digest, resolver) - if err != nil { - return node, err + if smt.sumTrie { + return smt.resolveSumNode(stub.digest) } - return ret, nil + return smt.resolveNode(stub.digest) } -func (smt *SMT) resolve(hash []byte, resolver func([]byte) (trieNode, error), -) (ret trieNode, err error) { - if bytes.Equal(smt.th.placeholder(), hash) { - return - } - data, err := smt.nodes.Get(hash) - if err != nil { - return - } - if isLeaf(data) { - leaf := leafNode{persisted: true, digest: hash} - leaf.path, leaf.valueHash = parseLeaf(data, smt.ph) - return &leaf, nil - } - if isExtension(data) { - ext := extensionNode{persisted: true, digest: hash} - pathBounds, path, childHash := parseExtension(data, smt.ph) - ext.path = path - copy(ext.pathBounds[:], pathBounds) - ext.child, err = resolver(childHash) - if err != nil { - return - } - return &ext, nil +// resolveNode returns a trieNode (inner, leaf, or extension) based on what they +// keyHash points to. +func (smt *SMT) resolveNode(digest []byte) (trieNode, error) { + // Check if the keyHash is the empty zero value of an empty subtree + if bytes.Equal(smt.placeholder(), digest) { + return nil, nil } - leftHash, rightHash := smt.th.parseNode(data) - inner := innerNode{persisted: true, digest: hash} - inner.leftChild, err = resolver(leftHash) + + // Retrieve the encoded noe data + data, err := smt.nodes.Get(digest) if err != nil { - return + return nil, err } - inner.rightChild, err = resolver(rightHash) - if err != nil { - return + + return smt.parseTrieNode(data, digest) +} + +// parseTrieNode returns a trieNode (inner, leaf, or extension) based on the +// first byte of the data. +func (smt *SMT) parseTrieNode(data, digest []byte) (trieNode, error) { + if isLeafNode(data) { + path, valueHash := smt.parseLeafNode(data) + return &leafNode{ + path: path, + valueHash: valueHash, + persisted: true, + digest: digest, + }, nil + } else if isExtNode(data) { + pathBounds, path, childData := smt.parseExtNode(data) + return &extensionNode{ + path: path, + pathBounds: [2]byte(pathBounds), + child: &lazyNode{childData}, + persisted: true, + digest: digest, + }, nil + } else if isInnerNode(data) { + leftData, rightData := smt.th.parseInnerNode(data) + return &innerNode{ + leftChild: &lazyNode{leftData}, + rightChild: &lazyNode{rightData}, + persisted: true, + digest: digest, + }, nil + } else { + panic("invalid node type") } - return &inner, nil } -func (smt *SMT) resolveSum(hash []byte, resolver func([]byte) (trieNode, error), -) (ret trieNode, err error) { - if bytes.Equal(placeholder(smt.Spec()), hash) { - return +// resolveNode returns a trieNode (inner, leaf, or extension) based on what they +// keyHash points to. +func (smt *SMT) resolveSumNode(digest []byte) (trieNode, error) { + // Check if the keyHash is the empty zero value of an empty subtree + if bytes.Equal(smt.placeholder(), digest) { + return nil, nil } - data, err := smt.nodes.Get(hash) + + // Retrieve the encoded noe data + data, err := smt.nodes.Get(digest) if err != nil { return nil, err } - if isLeaf(data) { - leaf := leafNode{persisted: true, digest: hash} - leaf.path, leaf.valueHash = parseLeaf(data, smt.ph) - return &leaf, nil - } - if isExtension(data) { - ext := extensionNode{persisted: true, digest: hash} - pathBounds, path, childHash, _ := parseSumExtension(data, smt.ph) - ext.path = path - copy(ext.pathBounds[:], pathBounds) - ext.child, err = resolver(childHash) - if err != nil { - return - } - return &ext, nil - } - leftHash, rightHash := smt.th.parseSumNode(data) - inner := innerNode{persisted: true, digest: hash} - inner.leftChild, err = resolver(leftHash) - if err != nil { - return - } - inner.rightChild, err = resolver(rightHash) - if err != nil { - return + + return smt.parseSumTrieNode(data, digest) +} + +// parseTrieNode returns a trieNode (inner, leaf, or extension) based on the +// first byte of the data. +func (smt *SMT) parseSumTrieNode(data, digest []byte) (trieNode, error) { + if isLeafNode(data) { + path, valueHash := smt.parseLeafNode(data) + return &leafNode{ + path: path, + valueHash: valueHash, + persisted: true, + digest: digest, + }, nil + } else if isExtNode(data) { + pathBounds, path, childData, _ := smt.parseSumExtNode(data) + return &extensionNode{ + path: path, + pathBounds: [2]byte(pathBounds), + child: &lazyNode{childData}, + persisted: true, + digest: digest, + }, nil + } else if isInnerNode(data) { + leftData, rightData, _ := smt.th.parseSumInnerNode(data) + return &innerNode{ + leftChild: &lazyNode{leftData}, + rightChild: &lazyNode{rightData}, + persisted: true, + digest: digest, + }, nil + } else { + panic("invalid node type") } - return &inner, nil } // Commit persists all dirty nodes in the trie, deletes all orphaned @@ -662,10 +665,10 @@ func (smt *SMT) Commit() (err error) { } } smt.orphans = nil - if err = smt.commit(smt.trie); err != nil { + if err = smt.commit(smt.root); err != nil { return } - smt.savedRoot = smt.Root() + smt.rootHash = smt.Root() return } @@ -692,13 +695,8 @@ func (smt *SMT) commit(node trieNode) error { default: return nil } - preimage := serialize(smt.Spec(), node) - return smt.nodes.Set(hashNode(smt.Spec(), node), preimage) -} - -// Root returns the root hash of the trie -func (smt *SMT) Root() MerkleRoot { - return hashNode(smt.Spec(), smt.trie) + preimage := smt.encode(node) + return smt.nodes.Set(smt.digest(node), preimage) } func (smt *SMT) addOrphan(orphans *[][]byte, node trieNode) { @@ -706,125 +704,3 @@ func (smt *SMT) addOrphan(orphans *[][]byte, node trieNode) { *orphans = append(*orphans, node.CachedDigest()) } } - -func (node *leafNode) Persisted() bool { return node.persisted } -func (node *innerNode) Persisted() bool { return node.persisted } -func (node *lazyNode) Persisted() bool { return true } -func (node *extensionNode) Persisted() bool { return node.persisted } - -func (node *leafNode) CachedDigest() []byte { return node.digest } -func (node *innerNode) CachedDigest() []byte { return node.digest } -func (node *lazyNode) CachedDigest() []byte { return node.digest } -func (node *extensionNode) CachedDigest() []byte { return node.digest } - -func (inner *innerNode) setDirty() { - inner.persisted = false - inner.digest = nil -} - -func (ext *extensionNode) length() int { return int(ext.pathBounds[1] - ext.pathBounds[0]) } - -func (ext *extensionNode) setDirty() { - ext.persisted = false - ext.digest = nil -} - -// Returns length of matching prefix, and whether it's a full match -func (ext *extensionNode) match(path []byte, depth int) (int, bool) { - if depth != ext.pathStart() { - panic("depth != path_begin") - } - for i := ext.pathStart(); i < ext.pathEnd(); i++ { - if getPathBit(ext.path, i) != getPathBit(path, i) { - return i - ext.pathStart(), false - } - } - return ext.length(), true -} - -//nolint:unused -func (ext *extensionNode) commonPrefix(path []byte) int { - count := 0 - for i := ext.pathStart(); i < ext.pathEnd(); i++ { - if getPathBit(ext.path, i) != getPathBit(path, i) { - break - } - count++ - } - return count -} - -func (ext *extensionNode) pathStart() int { return int(ext.pathBounds[0]) } -func (ext *extensionNode) pathEnd() int { return int(ext.pathBounds[1]) } - -// Splits the node in-place; returns replacement node, child node at the split, and split depth -func (ext *extensionNode) split(path []byte, depth int) (trieNode, *trieNode, int) { - if depth != ext.pathStart() { - panic("depth != path_begin") - } - index := ext.pathStart() - var myBit, branchBit int - for ; index < ext.pathEnd(); index++ { - myBit = getPathBit(ext.path, index) - branchBit = getPathBit(path, index) - if myBit != branchBit { - break - } - } - if index == ext.pathEnd() { - return ext, &ext.child, index - } - - child := ext.child - var branch innerNode - var head trieNode - var tail *trieNode - if myBit == left { - tail = &branch.leftChild - } else { - tail = &branch.rightChild - } - - // Split at first bit: chain starts with new node - if index == ext.pathStart() { - head = &branch - ext.pathBounds[0]++ // Shrink the extension from front - if ext.length() == 0 { - *tail = child - } else { - *tail = ext - } - } else { - // Split inside: chain ends at index - head = ext - ext.child = &branch - if index == ext.pathEnd()-1 { - *tail = child - } else { - *tail = &extensionNode{ - path: ext.path, - pathBounds: [2]byte{byte(index + 1), ext.pathBounds[1]}, - child: child, - } - } - ext.pathBounds[1] = byte(index) - } - var b trieNode = &branch - return head, &b, index -} - -// expand returns the inner node that represents the start of the singly -// linked list that this extension node represents -func (ext *extensionNode) expand() trieNode { - last := ext.child - for i := ext.pathEnd() - 1; i >= ext.pathStart(); i-- { - var next innerNode - if getPathBit(ext.path, i) == left { - next.leftChild = last - } else { - next.rightChild = last - } - last = &next - } - return last -} diff --git a/smt_example_test.go b/smt_example_test.go index 6d74980..2f7af1b 100644 --- a/smt_example_test.go +++ b/smt_example_test.go @@ -2,18 +2,19 @@ package smt_test import ( "crypto/sha256" - "fmt" + "testing" "github.com/pokt-network/smt" "github.com/pokt-network/smt/kvstore/simplemap" ) -func ExampleSMT() { - // Initialise a new in-memory key-value store to store the nodes of the trie +// TestExampleSMT is a test that aims to act as an example of how to use the SMST. +func TestExampleSMT(t *testing.T) { + // Initialize a new in-memory key-value store to store the nodes of the trie // (Note: the trie only stores hashed values, not raw value data) nodeStore := simplemap.NewSimpleMap() - // Initialise the trie + // Initialize the trie trie := smt.NewSparseMerkleTrie(nodeStore, sha256.New()) // Update the key "foo" with the value "bar" @@ -30,6 +31,7 @@ func ExampleSMT() { valid, _ := smt.VerifyProof(proof, root, []byte("foo"), []byte("bar"), trie.Spec()) // Attempt to verify the Merkle proof for "foo"="baz" invalid, _ := smt.VerifyProof(proof, root, []byte("foo"), []byte("baz"), trie.Spec()) - fmt.Println(valid, invalid) + // Output: true false + t.Log(valid, invalid) } diff --git a/smt_proofs_test.go b/smt_proofs_test.go index 1c353df..cf2bf89 100644 --- a/smt_proofs_test.go +++ b/smt_proofs_test.go @@ -32,7 +32,7 @@ func TestSMT_Proof_Operations(t *testing.T) { proof, err = smt.Prove([]byte("testKey3")) require.NoError(t, err) checkCompactEquivalence(t, proof, base) - result, err = VerifyProof(proof, base.th.placeholder(), []byte("testKey3"), defaultValue, base) + result, err = VerifyProof(proof, base.th.placeholder(), []byte("testKey3"), defaultEmptyValue, base) require.NoError(t, err) require.True(t, result) result, err = VerifyProof(proof, root, []byte("testKey3"), []byte("badValue"), base) @@ -66,7 +66,7 @@ func TestSMT_Proof_Operations(t *testing.T) { result, err = VerifyProof(proof, root, []byte("testKey"), []byte("badValue"), base) require.NoError(t, err) require.False(t, result) - result, err = VerifyProof(randomiseProof(proof), root, []byte("testKey"), []byte("testValue"), base) + result, err = VerifyProof(randomizeProof(proof), root, []byte("testKey"), []byte("testValue"), base) require.NoError(t, err) require.False(t, result) @@ -79,17 +79,17 @@ func TestSMT_Proof_Operations(t *testing.T) { result, err = VerifyProof(proof, root, []byte("testKey2"), []byte("badValue"), base) require.NoError(t, err) require.False(t, result) - result, err = VerifyProof(randomiseProof(proof), root, []byte("testKey2"), []byte("testValue"), base) + result, err = VerifyProof(randomizeProof(proof), root, []byte("testKey2"), []byte("testValue"), base) require.NoError(t, err) require.False(t, result) // Try proving a default value for a non-default leaf. - _, leafData := base.th.digestLeaf(base.ph.Path([]byte("testKey2")), base.digestValue([]byte("testValue"))) + _, leafData := base.th.digestLeafNode(base.ph.Path([]byte("testKey2")), base.valueHash([]byte("testValue"))) proof = &SparseMerkleProof{ SideNodes: proof.SideNodes, NonMembershipLeafData: leafData, } - result, err = VerifyProof(proof, root, []byte("testKey2"), defaultValue, base) + result, err = VerifyProof(proof, root, []byte("testKey2"), defaultEmptyValue, base) require.ErrorIs(t, err, ErrBadProof) require.False(t, result) @@ -97,13 +97,13 @@ func TestSMT_Proof_Operations(t *testing.T) { proof, err = smt.Prove([]byte("testKey3")) require.NoError(t, err) checkCompactEquivalence(t, proof, base) - result, err = VerifyProof(proof, root, []byte("testKey3"), defaultValue, base) + result, err = VerifyProof(proof, root, []byte("testKey3"), defaultEmptyValue, base) require.NoError(t, err) require.True(t, result) result, err = VerifyProof(proof, root, []byte("testKey3"), []byte("badValue"), base) require.NoError(t, err) require.False(t, result) - result, err = VerifyProof(randomiseProof(proof), root, []byte("testKey3"), defaultValue, base) + result, err = VerifyProof(randomizeProof(proof), root, []byte("testKey3"), defaultEmptyValue, base) require.NoError(t, err) require.False(t, result) } @@ -161,7 +161,7 @@ func TestSMT_Proof_ValidateBasic(t *testing.T) { // Case: incorrect non-nil sibling data proof, _ = smt.Prove([]byte("testKey1")) - proof.SiblingData = base.th.digest(proof.SiblingData) + proof.SiblingData = base.th.digestData(proof.SiblingData) require.EqualError( t, proof.validateBasic(base), @@ -331,7 +331,7 @@ func TestSMT_ProveClosest_Empty(t *testing.T) { Path: path[:], FlippedBits: []int{0}, Depth: 0, - ClosestPath: placeholder(smt.Spec()), + ClosestPath: smt.placeholder(), ClosestProof: &SparseMerkleProof{}, }) diff --git a/smt_test.go b/smt_test.go index fd1adac..cf0e48a 100644 --- a/smt_test.go +++ b/smt_test.go @@ -15,7 +15,7 @@ import ( func NewSMTWithStorage( nodes, preimages kvstore.MapStore, hasher hash.Hash, - options ...Option, + options ...TrieSpecOption, ) *SMTWithStorage { return &SMTWithStorage{ SMT: NewSparseMerkleTrie(nodes, hasher, options...), @@ -34,7 +34,7 @@ func TestSMT_TrieUpdateBasic(t *testing.T) { // Test getting an empty key. value, err := smt.GetValue([]byte("testKey")) require.NoError(t, err) - require.Equal(t, defaultValue, value) + require.Equal(t, defaultEmptyValue, value) has, err = smt.Has([]byte("testKey")) require.NoError(t, err) @@ -119,7 +119,7 @@ func TestSMT_TrieDeleteBasic(t *testing.T) { value, err := smt.GetValue([]byte("testKey")) require.NoError(t, err) - require.Equal(t, defaultValue, value, "getting deleted key") + require.Equal(t, defaultEmptyValue, value, "getting deleted key") has, err := smt.Has([]byte("testKey")) require.NoError(t, err) @@ -142,7 +142,7 @@ func TestSMT_TrieDeleteBasic(t *testing.T) { value, err = smt.GetValue([]byte("testKey2")) require.NoError(t, err) - require.Equal(t, defaultValue, value, "getting deleted key") + require.Equal(t, defaultEmptyValue, value, "getting deleted key") value, err = smt.GetValue([]byte("testKey")) require.NoError(t, err) @@ -162,7 +162,7 @@ func TestSMT_TrieDeleteBasic(t *testing.T) { value, err = smt.GetValue([]byte("foo")) require.NoError(t, err) - require.Equal(t, defaultValue, value, "getting deleted key") + require.Equal(t, defaultEmptyValue, value, "getting deleted key") value, err = smt.GetValue([]byte("testKey")) require.NoError(t, err) @@ -183,7 +183,7 @@ func TestSMT_TrieDeleteBasic(t *testing.T) { value, err = smt.GetValue([]byte("testKey")) require.NoError(t, err) - require.Equal(t, defaultValue, value, "getting deleted key") + require.Equal(t, defaultEmptyValue, value, "getting deleted key") has, err = smt.Has([]byte("testKey")) require.NoError(t, err) diff --git a/smt_utils_test.go b/smt_utils_test.go index 2912e26..dc65a95 100644 --- a/smt_utils_test.go +++ b/smt_utils_test.go @@ -24,7 +24,7 @@ func (smt *SMTWithStorage) Update(key, value []byte) error { if err := smt.SMT.Update(key, value); err != nil { return err } - valueHash := smt.digestValue(value) + valueHash := smt.valueHash(value) return smt.preimages.Set(valueHash, value) } @@ -46,7 +46,7 @@ func (smt *SMTWithStorage) GetValue(key []byte) ([]byte, error) { if err != nil { if errors.Is(err, ErrKeyNotFound) { // If key isn't found, return default value - value = defaultValue + value = defaultEmptyValue } else { // Otherwise percolate up any other error return nil, err @@ -59,7 +59,7 @@ func (smt *SMTWithStorage) GetValue(key []byte) ([]byte, error) { // otherwise. func (smt *SMTWithStorage) Has(key []byte) (bool, error) { val, err := smt.GetValue(key) - return !bytes.Equal(defaultValue, val), err + return !bytes.Equal(defaultEmptyValue, val), err } // ProveCompact generates a compacted Merkle proof for a key against the diff --git a/trie_spec.go b/trie_spec.go new file mode 100644 index 0000000..a9f047e --- /dev/null +++ b/trie_spec.go @@ -0,0 +1,272 @@ +package smt + +import ( + "encoding/binary" + "hash" +) + +// TrieSpec specifies the hashing functions used by a trie instance to encode +// leaf paths and stored values, and the corresponding maximum trie depth. +type TrieSpec struct { + th trieHasher + ph PathHasher + vh ValueHasher + sumTrie bool +} + +// newTrieSpec returns a new TrieSpec with the given hasher and sumTrie flag +func newTrieSpec(hasher hash.Hash, sumTrie bool) TrieSpec { + spec := TrieSpec{th: *NewTrieHasher(hasher)} + spec.ph = &pathHasher{spec.th} + spec.vh = &valueHasher{spec.th} + spec.sumTrie = sumTrie + return spec +} + +// Spec returns the TrieSpec associated with the given trie +func (spec *TrieSpec) Spec() *TrieSpec { + return spec +} + +// placeholder returns the default placeholder value depending on the trie type +func (spec *TrieSpec) placeholder() []byte { + if spec.sumTrie { + placeholder := spec.th.placeholder() + placeholder = append(placeholder, defaultEmptySum[:]...) + return placeholder + } + return spec.th.placeholder() +} + +// hashSize returns the hash size depending on the trie type +func (spec *TrieSpec) hashSize() int { + if spec.sumTrie { + return spec.th.hashSize() + sumSizeBytes + } + return spec.th.hashSize() +} + +// digestLeaf returns the hash and preimage of a leaf node depending on the trie type +func (spec *TrieSpec) digestLeaf(path, value []byte) ([]byte, []byte) { + if spec.sumTrie { + return spec.th.digestSumLeafNode(path, value) + } + return spec.th.digestLeafNode(path, value) +} + +// digestNode returns the hash and preimage of a node depending on the trie type +func (spec *TrieSpec) digestInnerNode(left, right []byte) ([]byte, []byte) { + if spec.sumTrie { + return spec.th.digestSumInnerNode(left, right) + } + return spec.th.digestInnerNode(left, right) +} + +// digest hashes a node depending on the trie type +func (spec *TrieSpec) digest(node trieNode) []byte { + if spec.sumTrie { + return spec.digestSumNode(node) + } + return spec.digestNode(node) +} + +// encode serializes a node depending on the trie type +func (spec *TrieSpec) encode(node trieNode) []byte { + if spec.sumTrie { + return spec.encodeSumNode(node) + } + return spec.encodeNode(node) +} + +// hashPreimage hashes the serialised data provided depending on the trie type +func (spec *TrieSpec) hashPreimage(data []byte) []byte { + if spec.sumTrie { + return spec.hashSumSerialization(data) + } + return spec.hashSerialization(data) +} + +// Used for verification of serialized proof data +func (spec *TrieSpec) hashSerialization(data []byte) []byte { + if isExtNode(data) { + pathBounds, path, childHash := spec.parseExtNode(data) + ext := extensionNode{path: path, child: &lazyNode{childHash}} + copy(ext.pathBounds[:], pathBounds) + return spec.digestNode(&ext) + } + return spec.th.digestData(data) +} + +// Used for verification of serialized proof data for sum trie nodes +func (spec *TrieSpec) hashSumSerialization(data []byte) []byte { + if isExtNode(data) { + pathBounds, path, childHash, _ := spec.parseSumExtNode(data) + ext := extensionNode{path: path, child: &lazyNode{childHash}} + copy(ext.pathBounds[:], pathBounds) + return spec.digestSumNode(&ext) + } + digest := spec.th.digestData(data) + digest = append(digest, data[len(data)-sumSizeBytes:]...) + return digest +} + +// depth returns the maximum depth of the trie. +// Since this tree is a binary tree, the depth is the number of bits in the path +// TODO_UPNEXT(@Olshansk):: Try to understand why we're not taking the log of the output +func (spec *TrieSpec) depth() int { + return spec.ph.PathSize() * 8 // path size is in bytes so multiply by 8 to get num bits +} + +// valueHash returns the hash of a value, or the value itself if no value hasher is specified. +func (spec *TrieSpec) valueHash(value []byte) []byte { + if spec.vh == nil { + return value + } + return spec.vh.HashValue(value) +} + +// encodeNode serializes a node into a byte slice +func (spec *TrieSpec) encodeNode(node trieNode) []byte { + switch n := node.(type) { + case *lazyNode: + panic("Encoding a lazyNode is not supported") + case *leafNode: + return encodeLeafNode(n.path, n.valueHash) + case *innerNode: + leftChild := spec.digestNode(n.leftChild) + rightChild := spec.digestNode(n.rightChild) + return encodeInnerNode(leftChild, rightChild) + case *extensionNode: + child := spec.digestNode(n.child) + return encodeExtensionNode(n.pathBounds, n.path, child) + default: + panic("Unknown node type") + } +} + +// digestNode hashes a node and returns its digest +func (spec *TrieSpec) digestNode(node trieNode) []byte { + if node == nil { + return spec.th.placeholder() + } + + var cachedDigest *[]byte + switch n := node.(type) { + case *lazyNode: + return n.digest + case *leafNode: + cachedDigest = &n.digest + case *innerNode: + cachedDigest = &n.digest + case *extensionNode: + if n.digest == nil { + n.digest = spec.digestNode(n.expand()) + } + return n.digest + } + if *cachedDigest == nil { + *cachedDigest = spec.th.digestData(spec.encodeNode(node)) + } + return *cachedDigest +} + +// encodeSumNode serializes a sum node and returns the preImage hash. +func (spec *TrieSpec) encodeSumNode(node trieNode) (preImage []byte) { + switch n := node.(type) { + case *lazyNode: + panic("encodeSumNode(lazyNode)") + case *leafNode: + return encodeLeafNode(n.path, n.valueHash) + case *innerNode: + leftChild := spec.digestSumNode(n.leftChild) + rightChild := spec.digestSumNode(n.rightChild) + return encodeSumInnerNode(leftChild, rightChild) + case *extensionNode: + child := spec.digestSumNode(n.child) + return encodeSumExtensionNode(n.pathBounds, n.path, child) + } + return nil +} + +// digestSumNode hashes a sum node returning its digest in the following form: [node hash]+[8 byte sum] +func (spec *TrieSpec) digestSumNode(node trieNode) []byte { + if node == nil { + return spec.placeholder() + } + var cache *[]byte + switch n := node.(type) { + case *lazyNode: + return n.digest + case *leafNode: + cache = &n.digest + case *innerNode: + cache = &n.digest + case *extensionNode: + if n.digest == nil { + n.digest = spec.digestSumNode(n.expand()) + } + return n.digest + } + if *cache == nil { + preImage := spec.encodeSumNode(node) + *cache = spec.th.digestData(preImage) + *cache = append(*cache, preImage[len(preImage)-sumSizeBytes:]...) + } + return *cache +} + +// parseLeafNode parses a leafNode into its components +func (spec *TrieSpec) parseLeafNode(data []byte) (path, value []byte) { + // panics if not a leaf node + checkPrefix(data, leafNodePrefix) + + path = data[prefixLen : prefixLen+spec.ph.PathSize()] + value = data[prefixLen+spec.ph.PathSize():] + return +} + +// parseExtNode parses an extNode into its components +func (spec *TrieSpec) parseExtNode(data []byte) (pathBounds, path, childData []byte) { + // panics if not an extension node + checkPrefix(data, extNodePrefix) + + // +2 represents the length of the pathBounds + pathBounds = data[prefixLen : prefixLen+2] + path = data[prefixLen+2 : prefixLen+2+spec.ph.PathSize()] + childData = data[prefixLen+2+spec.ph.PathSize():] + return +} + +// parseSumLeafNode parses a leafNode and returns its weight as well +// // nolint: unused +func (spec *TrieSpec) parseSumLeafNode(data []byte) (path, value []byte, weight uint64) { + // panics if not a leaf node + checkPrefix(data, leafNodePrefix) + + path = data[prefixLen : prefixLen+spec.ph.PathSize()] + value = data[prefixLen+spec.ph.PathSize():] + + // Extract the sum from the encoded node data + var weightBz [sumSizeBytes]byte + copy(weightBz[:], value[len(value)-sumSizeBytes:]) + binary.BigEndian.PutUint64(weightBz[:], weight) + + return +} + +// parseSumExtNode parses the pathBounds, path, child data and sum from the encoded extension node data +func (spec *TrieSpec) parseSumExtNode(data []byte) (pathBounds, path, childData []byte, sum uint64) { + // panics if not an extension node + checkPrefix(data, extNodePrefix) + + // Extract the sum from the encoded node data + var sumBz [sumSizeBytes]byte + copy(sumBz[:], data[len(data)-sumSizeBytes:]) + binary.BigEndian.PutUint64(sumBz[:], sum) + + // +2 represents the length of the pathBounds + pathBounds = data[prefixLen : prefixLen+2] + path = data[prefixLen+2 : prefixLen+2+spec.ph.PathSize()] + childData = data[prefixLen+2+spec.ph.PathSize() : len(data)-sumSizeBytes] + return +} diff --git a/types.go b/types.go index 61b0638..fd6dacf 100644 --- a/types.go +++ b/types.go @@ -1,33 +1,37 @@ package smt -import ( - "encoding/binary" - "hash" -) +// TODO_DISCUSS_CONSIDERIN_THE_FUTURE: +// 1. Should we rename all instances of digest to hash? +// > digest is the correct term for the output of a hashing function IIRC +// 2. Should we introduce a shared interface between SparseMerkleTrie and SparseMerkleSumTrie? +// > Sum() would have to be no-op but could be done +// 3. Should we rename Commit to FlushToDisk? +// > No because what if this is an in memory trie? const ( - left = 0 - sumSize = 8 + // The bit value use to distinguish an inner nodes left child and right child + leftChildBit = 0 ) var ( - defaultValue []byte - defaultSum [sumSize]byte + // defaultEmptyValue is the default value for a leaf node + defaultEmptyValue []byte + // defaultEmptySum is the default sum value for a leaf node + defaultEmptySum [sumSizeBytes]byte ) // MerkleRoot is a type alias for a byte slice returned from the Root method type MerkleRoot []byte -// Sum returns the uint64 sum of the merkle root, it checks the length of the -// merkle root and if it is no the same as the size of the SMST's expected -// root hash it will panic. -func (r MerkleRoot) Sum() uint64 { - if len(r)%32 == 0 { - panic("roo#sum: not a merkle sum trie") - } - var sumbz [sumSize]byte - copy(sumbz[:], []byte(r)[len([]byte(r))-sumSize:]) - return binary.BigEndian.Uint64(sumbz[:]) +// A high-level interface that captures the behaviour of all types of nodes +type trieNode interface { + // Persisted returns a boolean to determine whether or not the node + // has been persisted to disk or only held in memory. + // It can be used skip unnecessary iops if already persisted + Persisted() bool + + // The digest of the node, returning a cached value if available. + CachedDigest() []byte } // SparseMerkleTrie represents a Sparse Merkle Trie. @@ -77,137 +81,3 @@ type SparseMerkleSumTrie interface { // Spec returns the TrieSpec for the trie Spec() *TrieSpec } - -// TrieSpec specifies the hashing functions used by a trie instance to encode -// leaf paths and stored values, and the corresponding maximum trie depth. -type TrieSpec struct { - th trieHasher - ph PathHasher - vh ValueHasher - sumTrie bool -} - -func NewTrieSpec(hasher hash.Hash, sumTrie bool, opts ...Option) TrieSpec { - spec := TrieSpec{th: *newTrieHasher(hasher)} - spec.ph = &pathHasher{spec.th} - spec.vh = &valueHasher{spec.th} - spec.sumTrie = sumTrie - - for _, opt := range opts { - opt(&spec) - } - - return spec -} - -// Spec returns the TrieSpec associated with the given trie -func (spec *TrieSpec) Spec() *TrieSpec { return spec } - -// PathHasherSize returns the length (in bytes) of digests produced by the -// path hasher -func (spec *TrieSpec) PathHasherSize() int { return spec.ph.PathSize() } - -// ValueHasherSize returns the length (in bytes) of digests produced by the -// value hasher -func (spec *TrieSpec) ValueHasherSize() int { return spec.vh.ValueHashSize() } - -// TrieHasherSize returns the length (in bytes) of digests produced by the -// trie hasher -func (spec *TrieSpec) TrieHasherSize() int { return spec.th.hashSize() } - -func (spec *TrieSpec) depth() int { return spec.ph.PathSize() * 8 } -func (spec *TrieSpec) digestValue(data []byte) []byte { - if spec.vh == nil { - return data - } - return spec.vh.HashValue(data) -} - -func (spec *TrieSpec) serialize(node trieNode) (data []byte) { - switch n := node.(type) { - case *lazyNode: - panic("serialize(lazyNode)") - case *leafNode: - return encodeLeaf(n.path, n.valueHash) - case *innerNode: - lchild := spec.hashNode(n.leftChild) - rchild := spec.hashNode(n.rightChild) - return encodeInner(lchild, rchild) - case *extensionNode: - child := spec.hashNode(n.child) - return encodeExtension(n.pathBounds, n.path, child) - } - return nil -} - -func (spec *TrieSpec) hashNode(node trieNode) []byte { - if node == nil { - return spec.th.placeholder() - } - var cache *[]byte - switch n := node.(type) { - case *lazyNode: - return n.digest - case *leafNode: - cache = &n.digest - case *innerNode: - cache = &n.digest - case *extensionNode: - if n.digest == nil { - n.digest = spec.hashNode(n.expand()) - } - return n.digest - } - if *cache == nil { - *cache = spec.th.digest(spec.serialize(node)) - } - return *cache -} - -// sumSerialize serializes a node returning the preimage hash, its sum and any -// errors encountered -func (spec *TrieSpec) sumSerialize(node trieNode) (preimage []byte) { - switch n := node.(type) { - case *lazyNode: - panic("serialize(lazyNode)") - case *leafNode: - return encodeLeaf(n.path, n.valueHash) - case *innerNode: - lchild := spec.hashSumNode(n.leftChild) - rchild := spec.hashSumNode(n.rightChild) - preimage = encodeSumInner(lchild, rchild) - return preimage - case *extensionNode: - child := spec.hashSumNode(n.child) - return encodeSumExtension(n.pathBounds, n.path, child) - } - return nil -} - -// hashSumNode hashes a node returning its digest in the following form -// digest = [node hash]+[8 byte sum] -func (spec *TrieSpec) hashSumNode(node trieNode) []byte { - if node == nil { - return placeholder(spec) - } - var cache *[]byte - switch n := node.(type) { - case *lazyNode: - return n.digest - case *leafNode: - cache = &n.digest - case *innerNode: - cache = &n.digest - case *extensionNode: - if n.digest == nil { - n.digest = spec.hashSumNode(n.expand()) - } - return n.digest - } - if *cache == nil { - preimage := spec.sumSerialize(node) - *cache = spec.th.digest(preimage) - *cache = append(*cache, preimage[len(preimage)-sumSize:]...) - } - return *cache -} diff --git a/utils.go b/utils.go index 64e98cb..710e117 100644 --- a/utils.go +++ b/utils.go @@ -4,17 +4,6 @@ import ( "encoding/binary" ) -type nilPathHasher struct { - hashSize int -} - -func (n *nilPathHasher) Path(key []byte) []byte { return key[:n.hashSize] } -func (n *nilPathHasher) PathSize() int { return n.hashSize } - -func NewNilPathHasher(hashSize int) PathHasher { - return &nilPathHasher{hashSize: hashSize} -} - // getPathBit gets the bit at an offset (see position) in the data // provided relative to the most significant bit func getPathBit(data []byte, position int) int { @@ -116,94 +105,3 @@ func bytesToInt(bz []byte) int { u := binary.BigEndian.Uint64(b) return int(u) } - -// placeholder returns the default placeholder value depending on the trie type -func placeholder(spec *TrieSpec) []byte { - if spec.sumTrie { - placeholder := spec.th.placeholder() - placeholder = append(placeholder, defaultSum[:]...) - return placeholder - } - return spec.th.placeholder() -} - -// hashSize returns the hash size depending on the trie type -func hashSize(spec *TrieSpec) int { - if spec.sumTrie { - return spec.th.hashSize() + sumSize - } - return spec.th.hashSize() -} - -// digestLeaf returns the hash and preimage of a leaf node depending on the trie type -func digestLeaf(spec *TrieSpec, path, value []byte) ([]byte, []byte) { - if spec.sumTrie { - return spec.th.digestSumLeaf(path, value) - } - return spec.th.digestLeaf(path, value) -} - -// digestNode returns the hash and preimage of a node depending on the trie type -func digestNode(spec *TrieSpec, left, right []byte) ([]byte, []byte) { - if spec.sumTrie { - return spec.th.digestSumNode(left, right) - } - return spec.th.digestNode(left, right) -} - -// hashNode hashes a node depending on the trie type -func hashNode(spec *TrieSpec, node trieNode) []byte { - if spec.sumTrie { - return spec.hashSumNode(node) - } - return spec.hashNode(node) -} - -// serialize serializes a node depending on the trie type -func serialize(spec *TrieSpec, node trieNode) []byte { - if spec.sumTrie { - return spec.sumSerialize(node) - } - return spec.serialize(node) -} - -// hashPreimage hashes the serialised data provided depending on the trie type -func hashPreimage(spec *TrieSpec, data []byte) []byte { - if spec.sumTrie { - return hashSumSerialization(spec, data) - } - return hashSerialization(spec, data) -} - -// Used for verification of serialized proof data -func hashSerialization(smt *TrieSpec, data []byte) []byte { - if isExtension(data) { - pathBounds, path, childHash := parseExtension(data, smt.ph) - ext := extensionNode{path: path, child: &lazyNode{childHash}} - copy(ext.pathBounds[:], pathBounds) - return smt.hashNode(&ext) - } - return smt.th.digest(data) -} - -// Used for verification of serialized proof data for sum trie nodes -func hashSumSerialization(smt *TrieSpec, data []byte) []byte { - if isExtension(data) { - pathBounds, path, childHash, _ := parseSumExtension(data, smt.ph) - ext := extensionNode{path: path, child: &lazyNode{childHash}} - copy(ext.pathBounds[:], pathBounds) - return smt.hashSumNode(&ext) - } - digest := smt.th.digest(data) - digest = append(digest, data[len(data)-sumSize:]...) - return digest -} - -// resolve resolves a lazy node depending on the trie type -func resolve(smt *SMT, hash []byte, resolver func([]byte) (trieNode, error), -) (trieNode, error) { - if smt.sumTrie { - return smt.resolveSum(hash, resolver) - } - return smt.resolve(hash, resolver) -}