diff --git a/pkg/crypto/protocol/proof_path.go b/pkg/crypto/protocol/proof_path.go index 61f7e23ce..1a48b40ee 100644 --- a/pkg/crypto/protocol/proof_path.go +++ b/pkg/crypto/protocol/proof_path.go @@ -6,20 +6,8 @@ import ( "github.com/pokt-network/smt" ) -// SMT specification used for the proof verification. -var ( - newHasher = sha256.New - SmtSpec smt.TrieSpec -) - -func init() { - // Use a spec that does not prehash values in the smst. This returns a nil value - // hasher for the proof verification in order to avoid hashing the value twice. - SmtSpec = smt.NewTrieSpec( - newHasher(), true, - smt.WithValueHasher(nil), - ) -} +// newHasher is the hash function used by the SMT specification. +var newHasher = sha256.New // GetPathForProof computes the path to be used for proof validation by hashing // the block hash and session id. @@ -31,3 +19,15 @@ func GetPathForProof(blockHash []byte, sessionId string) []byte { return hasher.Sum(nil) } + +// NewSMTSpec returns the SMT specification used for the proof verification. +// It uses a new hasher at every call to avoid concurrency issues that could be +// caused by a shared hasher. +func NewSMTSpec() *smt.TrieSpec { + trieSpec := smt.NewTrieSpec( + newHasher(), true, + smt.WithValueHasher(nil), + ) + + return &trieSpec +} diff --git a/x/proof/keeper/proof_validation.go b/x/proof/keeper/proof_validation.go index e4d53be2e..bf55d2e3c 100644 --- a/x/proof/keeper/proof_validation.go +++ b/x/proof/keeper/proof_validation.go @@ -106,14 +106,14 @@ func (k Keeper) EnsureWellFormedProof(ctx context.Context, proof *types.Proof) e } // SparseCompactMerkeClosestProof does not implement GetValueHash, so we need to decompact it. - sparseMerkleClosestProof, err := smt.DecompactClosestProof(sparseCompactMerkleClosestProof, &protocol.SmtSpec) + sparseMerkleClosestProof, err := smt.DecompactClosestProof(sparseCompactMerkleClosestProof, protocol.NewSMTSpec()) if err != nil { logger.Error(fmt.Sprintf("failed to decompact sparse merkle closest proof due to error: %v", err)) return types.ErrProofInvalidProof.Wrapf("failed to decompact sparse erkle closest proof: %s", err) } // Get the relay request and response from the proof.GetClosestMerkleProof. - relayBz := sparseMerkleClosestProof.GetValueHash(&protocol.SmtSpec) + relayBz := sparseMerkleClosestProof.GetValueHash(protocol.NewSMTSpec()) relay := &servicetypes.Relay{} if err = k.cdc.Unmarshal(relayBz, relay); err != nil { logger.Error(fmt.Sprintf("failed to unmarshal relay due to error: %v", err)) @@ -231,14 +231,14 @@ func (k Keeper) EnsureValidProofSignaturesAndClosestPath( // SparseCompactMerkeClosestProof was intentionally compacted to reduce its onchain state size // so it must be decompacted rather than just retrieving the value via GetValueHash (not implemented). - sparseMerkleClosestProof, err := smt.DecompactClosestProof(sparseCompactMerkleClosestProof, &protocol.SmtSpec) + sparseMerkleClosestProof, err := smt.DecompactClosestProof(sparseCompactMerkleClosestProof, protocol.NewSMTSpec()) if err != nil { logger.Error(fmt.Sprintf("failed to decompact sparse merkle closest proof due to error: %v", err)) return types.ErrProofInvalidProof.Wrapf("failed to decompact sparse merkle closest proof: %s", err) } // Get the relay request and response from the proof.GetClosestMerkleProof. - relayBz := sparseMerkleClosestProof.GetValueHash(&protocol.SmtSpec) + relayBz := sparseMerkleClosestProof.GetValueHash(protocol.NewSMTSpec()) relay := &servicetypes.Relay{} if err = k.cdc.Unmarshal(relayBz, relay); err != nil { logger.Error(fmt.Sprintf("failed to unmarshal relay due to error: %v", err)) @@ -449,7 +449,7 @@ func verifyClosestProof( proof *smt.SparseMerkleClosestProof, claimRootHash []byte, ) error { - valid, err := smt.VerifyClosestProof(proof, claimRootHash, &protocol.SmtSpec) + valid, err := smt.VerifyClosestProof(proof, claimRootHash, protocol.NewSMTSpec()) if err != nil { return err } diff --git a/x/proof/keeper/proof_validation_test.go b/x/proof/keeper/proof_validation_test.go index 349dcd59c..f2701db84 100644 --- a/x/proof/keeper/proof_validation_test.go +++ b/x/proof/keeper/proof_validation_test.go @@ -611,10 +611,10 @@ func TestEnsureValidProof_Error(t *testing.T) { err = sparseCompactMerkleClosestProof.Unmarshal(proof.ClosestMerkleProof) require.NoError(t, err) var sparseMerkleClosestProof *smt.SparseMerkleClosestProof - sparseMerkleClosestProof, err = smt.DecompactClosestProof(sparseCompactMerkleClosestProof, &protocol.SmtSpec) + sparseMerkleClosestProof, err = smt.DecompactClosestProof(sparseCompactMerkleClosestProof, protocol.NewSMTSpec()) require.NoError(t, err) - relayBz := sparseMerkleClosestProof.GetValueHash(&protocol.SmtSpec) + relayBz := sparseMerkleClosestProof.GetValueHash(protocol.NewSMTSpec()) relayHashArr := protocol.GetRelayHashFromBytes(relayBz) relayHash := relayHashArr[:]