Skip to content

Commit

Permalink
fix: Use fix smt verification concurrency
Browse files Browse the repository at this point in the history
  • Loading branch information
red-0ne committed Jan 29, 2025
1 parent b552bf9 commit 307ff89
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 21 deletions.
28 changes: 14 additions & 14 deletions pkg/crypto/protocol/proof_path.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,8 @@ import (
"github.com/pokt-network/smt"
)

// SMT specification used for the proof verification.
var (
newHasher = sha256.New
SmtSpec smt.TrieSpec
)

func init() {
// Use a spec that does not prehash values in the smst. This returns a nil value
// hasher for the proof verification in order to avoid hashing the value twice.
SmtSpec = smt.NewTrieSpec(
newHasher(), true,
smt.WithValueHasher(nil),
)
}
// newHasher is the hash function used by the SMT specification.
var newHasher = sha256.New

// GetPathForProof computes the path to be used for proof validation by hashing
// the block hash and session id.
Expand All @@ -31,3 +19,15 @@ func GetPathForProof(blockHash []byte, sessionId string) []byte {

return hasher.Sum(nil)
}

// NewSMTSpec returns the SMT specification used for the proof verification.
// It uses a new hasher at every call to avoid concurrency issues that could be
// caused by a shared hasher.
func NewSMTSpec() *smt.TrieSpec {
trieSpec := smt.NewTrieSpec(
newHasher(), true,
smt.WithValueHasher(nil),
)

return &trieSpec
}
10 changes: 5 additions & 5 deletions x/proof/keeper/proof_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,14 @@ func (k Keeper) EnsureWellFormedProof(ctx context.Context, proof *types.Proof) e
}

// SparseCompactMerkeClosestProof does not implement GetValueHash, so we need to decompact it.
sparseMerkleClosestProof, err := smt.DecompactClosestProof(sparseCompactMerkleClosestProof, &protocol.SmtSpec)
sparseMerkleClosestProof, err := smt.DecompactClosestProof(sparseCompactMerkleClosestProof, protocol.NewSMTSpec())
if err != nil {
logger.Error(fmt.Sprintf("failed to decompact sparse merkle closest proof due to error: %v", err))
return types.ErrProofInvalidProof.Wrapf("failed to decompact sparse erkle closest proof: %s", err)
}

// Get the relay request and response from the proof.GetClosestMerkleProof.
relayBz := sparseMerkleClosestProof.GetValueHash(&protocol.SmtSpec)
relayBz := sparseMerkleClosestProof.GetValueHash(protocol.NewSMTSpec())
relay := &servicetypes.Relay{}
if err = k.cdc.Unmarshal(relayBz, relay); err != nil {
logger.Error(fmt.Sprintf("failed to unmarshal relay due to error: %v", err))
Expand Down Expand Up @@ -231,14 +231,14 @@ func (k Keeper) EnsureValidProofSignaturesAndClosestPath(

// SparseCompactMerkeClosestProof was intentionally compacted to reduce its onchain state size
// so it must be decompacted rather than just retrieving the value via GetValueHash (not implemented).
sparseMerkleClosestProof, err := smt.DecompactClosestProof(sparseCompactMerkleClosestProof, &protocol.SmtSpec)
sparseMerkleClosestProof, err := smt.DecompactClosestProof(sparseCompactMerkleClosestProof, protocol.NewSMTSpec())
if err != nil {
logger.Error(fmt.Sprintf("failed to decompact sparse merkle closest proof due to error: %v", err))
return types.ErrProofInvalidProof.Wrapf("failed to decompact sparse merkle closest proof: %s", err)
}

// Get the relay request and response from the proof.GetClosestMerkleProof.
relayBz := sparseMerkleClosestProof.GetValueHash(&protocol.SmtSpec)
relayBz := sparseMerkleClosestProof.GetValueHash(protocol.NewSMTSpec())
relay := &servicetypes.Relay{}
if err = k.cdc.Unmarshal(relayBz, relay); err != nil {
logger.Error(fmt.Sprintf("failed to unmarshal relay due to error: %v", err))
Expand Down Expand Up @@ -449,7 +449,7 @@ func verifyClosestProof(
proof *smt.SparseMerkleClosestProof,
claimRootHash []byte,
) error {
valid, err := smt.VerifyClosestProof(proof, claimRootHash, &protocol.SmtSpec)
valid, err := smt.VerifyClosestProof(proof, claimRootHash, protocol.NewSMTSpec())
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions x/proof/keeper/proof_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -611,10 +611,10 @@ func TestEnsureValidProof_Error(t *testing.T) {
err = sparseCompactMerkleClosestProof.Unmarshal(proof.ClosestMerkleProof)
require.NoError(t, err)
var sparseMerkleClosestProof *smt.SparseMerkleClosestProof
sparseMerkleClosestProof, err = smt.DecompactClosestProof(sparseCompactMerkleClosestProof, &protocol.SmtSpec)
sparseMerkleClosestProof, err = smt.DecompactClosestProof(sparseCompactMerkleClosestProof, protocol.NewSMTSpec())
require.NoError(t, err)

relayBz := sparseMerkleClosestProof.GetValueHash(&protocol.SmtSpec)
relayBz := sparseMerkleClosestProof.GetValueHash(protocol.NewSMTSpec())
relayHashArr := protocol.GetRelayHashFromBytes(relayBz)
relayHash := relayHashArr[:]

Expand Down

0 comments on commit 307ff89

Please sign in to comment.