Skip to content

Commit

Permalink
refactor: SMST#Root(), #Sum(), & #Count()
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanchriswhite committed Jul 12, 2024
1 parent 6c22c94 commit d329ba5
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 37 deletions.
47 changes: 34 additions & 13 deletions root.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ const (
// MustSum returns the uint64 sum of the merkle root, it checks the length of the
// merkle root and if it is no the same as the size of the SMST's expected
// root hash it will panic.
func (r MerkleRoot) MustSum() uint64 {
func (r MerkleSumRoot) MustSum() uint64 {
sum, err := r.Sum()
if err != nil {
panic(err)
Expand All @@ -27,28 +27,49 @@ func (r MerkleRoot) MustSum() uint64 {
// Sum returns the uint64 sum of the merkle root, it checks the length of the
// merkle root and if it is no the same as the size of the SMST's expected
// root hash it will return an error.
func (r MerkleRoot) Sum() (uint64, error) {
if len(r)%SmtRootSizeBytes == 0 {
return 0, fmt.Errorf("root#sum: not a merkle sum trie")
func (r MerkleSumRoot) Sum() (uint64, error) {
if len(r) != SmstRootSizeBytes {
return 0, fmt.Errorf("MerkleSumRoot#Sum: not a merkle sum trie")
}

firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx([]byte(r))
return getSum(r), nil
}

var sumBz [sumSizeBytes]byte
copy(sumBz[:], []byte(r)[firstSumByteIdx:firstCountByteIdx])
return binary.BigEndian.Uint64(sumBz[:]), nil
// MustCount returns the uint64 count of the merkle root, a cryptographically secure
// count of the number of non-empty leafs in the tree.
func (r MerkleSumRoot) MustCount() uint64 {
count, err := r.Count()
if err != nil {
panic(err)
}

return count
}

// Count returns the uint64 count of the merkle root, a cryptographically secure
// count of the number of non-empty leafs in the tree.
func (r MerkleRoot) Count() uint64 {
if len(r)%SmtRootSizeBytes == 0 {
panic("root#sum: not a merkle sum trie")
func (r MerkleSumRoot) Count() (uint64, error) {
if len(r) != SmstRootSizeBytes {
return 0, fmt.Errorf("MerkleSumRoot#Count: not a merkle sum trie")
}

_, firstCountByteIdx := getFirstMetaByteIdx([]byte(r))
return getCount(r), nil
}

// getSum returns the sum of the node stored in the root.
func getSum(root []byte) uint64 {
firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(root)

var sumBz [sumSizeBytes]byte
copy(sumBz[:], root[firstSumByteIdx:firstCountByteIdx])
return binary.BigEndian.Uint64(sumBz[:])
}

// getCount returns the count of the node stored in the root.
func getCount(root []byte) uint64 {
_, firstCountByteIdx := getFirstMetaByteIdx(root)

var countBz [countSizeBytes]byte
copy(countBz[:], []byte(r)[firstCountByteIdx:])
copy(countBz[:], root[firstCountByteIdx:])
return binary.BigEndian.Uint64(countBz[:])
}
48 changes: 28 additions & 20 deletions smst.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package smt
import (
"bytes"
"encoding/binary"
"fmt"
"hash"

"github.com/pokt-network/smt/kvstore"
Expand Down Expand Up @@ -170,39 +171,46 @@ func (smst *SMST) Commit() error {
}

// Root returns the root hash of the trie with the total sum bytes appended
func (smst *SMST) Root() MerkleRoot {
return smst.SMT.Root() // [digest]+[binary sum]
func (smst *SMST) Root() MerkleSumRoot {
return MerkleSumRoot(smst.SMT.Root()) // [digest]+[binary sum]+[binary count]
}

// Sum returns the sum of the entire trie stored in the root.
// MustSum returns the sum of the entire trie stored in the root.
// If the tree is not a sum tree, it will panic.
func (smst *SMST) Sum() uint64 {
rootDigest := []byte(smst.Root())
func (smst *SMST) MustSum() uint64 {
sum, err := smst.Sum()
if err != nil {
panic(err)
}
return sum
}

// Sum returns the sum of the entire trie stored in the root.
// If the tree is not a sum tree, it will panic.
func (smst *SMST) Sum() (uint64, error) {
if !smst.Spec().sumTrie {
panic("SMST: not a merkle sum trie")
return 0, fmt.Errorf("SMST: not a merkle sum trie")
}

firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(rootDigest)
return smst.Root().Sum()
}

var sumBz [sumSizeBytes]byte
copy(sumBz[:], rootDigest[firstSumByteIdx:firstCountByteIdx])
return binary.BigEndian.Uint64(sumBz[:])
// MustCount returns the number of non-empty nodes in the entire trie stored in the root.
func (smst *SMST) MustCount() uint64 {
count, err := smst.Count()
if err != nil {
panic(err)
}
return count
}

// Count returns the number of non-empty nodes in the entire trie stored in the root.
func (smst *SMST) Count() uint64 {
rootDigest := []byte(smst.Root())

func (smst *SMST) Count() (uint64, error) {
if !smst.Spec().sumTrie {
panic("SMST: not a merkle sum trie")
return 0, fmt.Errorf("SMST: not a merkle sum trie")
}

_, firstCountByteIdx := getFirstMetaByteIdx(rootDigest)

var countBz [countSizeBytes]byte
copy(countBz[:], rootDigest[firstCountByteIdx:])
return binary.BigEndian.Uint64(countBz[:])
return smst.Root().Count()
}

// getFirstMetaByteIdx returns the index of the first count byte and the first sum byte
Expand All @@ -211,5 +219,5 @@ func (smst *SMST) Count() uint64 {
func getFirstMetaByteIdx(data []byte) (firstSumByteIdx, firstCountByteIdx int) {
firstCountByteIdx = len(data) - countSizeBytes
firstSumByteIdx = firstCountByteIdx - sumSizeBytes
return
return firstSumByteIdx, firstCountByteIdx
}
13 changes: 9 additions & 4 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ var (
defaultEmptyCount [countSizeBytes]byte
)

// MerkleRoot is a type alias for a byte slice returned from the Root method
// MerkleRoot is a type alias for a byte slice returned from SparseMerkleTrie#Root().
type MerkleRoot []byte

// MerkleSumRoot is a type alias for a byte slice returned from SparseMerkleSumTrie#Root().
type MerkleSumRoot []byte

// A high-level interface that captures the behaviour of all types of nodes
type trieNode interface {
// Persisted returns a boolean to determine whether or not the node
Expand Down Expand Up @@ -68,11 +71,13 @@ type SparseMerkleSumTrie interface {
// Get descends the trie to access a value. Returns nil if key is not present.
Get(key []byte) (data []byte, sum uint64, err error)
// Root computes the Merkle root digest.
Root() MerkleRoot
Root() MerkleSumRoot
// Sum computes the total sum of the Merkle trie
Sum() uint64
Sum() (uint64, error)
MustSum() uint64
// Count returns the total number of non-empty leaves in the trie
Count() uint64
Count() (uint64, error)
MustCount() uint64
// Prove computes a Merkle proof of inclusion or exclusion of a key.
Prove(key []byte) (*SparseMerkleProof, error)
// ProveClosest computes a Merkle proof of inclusion for a key in the trie
Expand Down

0 comments on commit d329ba5

Please sign in to comment.