diff --git a/merkle/merkle.go b/merkle/merkle.go index 833b2a0..06ed634 100644 --- a/merkle/merkle.go +++ b/merkle/merkle.go @@ -85,7 +85,7 @@ func (t *Tree) Put(label []byte, mapVal []byte) ([]byte, []byte, bool) { } dig := t.ctx.getHash(t.root) - proof := t.ctx.getProof(interiors, label) + proof := t.ctx.getProof(t.root, label) return dig, proof, false } @@ -95,11 +95,9 @@ func (t *Tree) Get(label []byte) ([]byte, []byte, bool, []byte, bool) { if uint64(len(label)) != cryptoffi.HashLen { return nil, nil, false, nil, true } - nodePath := getPath(t.root, label) - lastIdx := uint64(len(nodePath)) - 1 - lastNode := nodePath[lastIdx] + lastNode := getPath(t.root, label) dig := t.ctx.getHash(t.root) - proof := t.ctx.getProof(nodePath[:lastIdx], label) + proof := t.ctx.getProof(t.root, label) if lastNode == nil { return nil, dig, NonmembProofTy, proof, false } else { @@ -192,37 +190,28 @@ func (ctx *context) updInteriorHash(b []byte, n *node) []byte { return b0 } -// getPath fetches the maximal path to label, including the leaf node. +// getPath fetches including the leaf node. // if the path doesn't exist, it terminates in an empty node. -func getPath(root *node, label []byte) []*node { - var nodePath []*node - nodePath = append(nodePath, root) - if root == nil { - return nodePath - } - var isEmpty = false - for depth := uint64(0); depth < cryptoffi.HashLen && !isEmpty; depth++ { - currNode := nodePath[depth] - pos := label[depth] - nextNode := currNode.children[pos] - nodePath = append(nodePath, nextNode) - if nextNode == nil { - isEmpty = true - } +func getPath(root *node, label []byte) *node { + var currNode = root + for depth := uint64(0); depth < cryptoffi.HashLen && currNode != nil; depth++ { + pos := uint64(label[depth]) + currNode = currNode.children[pos] } - return nodePath + return currNode } -func (ctx *context) getProof(interiors []*node, label []byte) []byte { - interiorsLen := uint64(len(interiors)) - var proof = make([]byte, 0, interiorsLen*hashesPerProofDepth) - for depth := uint64(0); depth < interiorsLen; depth++ { - children := interiors[depth].children +func (ctx *context) getProof(root *node, label []byte) []byte { + var proof = make([]byte, 0, cryptoffi.HashLen*hashesPerProofDepth) + var currNode = root + for depth := uint64(0); depth < cryptoffi.HashLen; depth++ { + children := currNode.children // convert to uint64 bc otherwise pos+1 might overflow. pos := uint64(label[depth]) for _, n := range children[:pos] { proof = marshal.WriteBytes(proof, ctx.getHash(n)) } + currNode = currNode.children[pos] for _, n := range children[pos+1:] { proof = marshal.WriteBytes(proof, ctx.getHash(n)) }