diff --git a/pkg/crypto/rings/client.go b/pkg/crypto/rings/client.go index 8373be876..4bb96413b 100644 --- a/pkg/crypto/rings/client.go +++ b/pkg/crypto/rings/client.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "slices" + "sync" "cosmossdk.io/depinject" ring_secp256k1 "github.com/athanorlabs/go-dleq/secp256k1" @@ -21,9 +22,7 @@ import ( var _ crypto.RingClient = (*ringClient)(nil) -// ringClient is an implementation of the RingClient interface that uses the -// client.ApplicationQueryClient to get application's delegation information -// needed to construct the ring for signing relay requests. +// ringClient implements the RingClient interface. type ringClient struct { logger polylog.Logger @@ -36,6 +35,9 @@ type ringClient struct { // sharedQuerier is used to fetch the shared module's parameters. sharedQuerier client.SharedQueryClient + + // Mutex to protect concurrent access to shared resources + mu sync.RWMutex } // NewRingClient returns a new ring client constructed from the given dependencies. @@ -213,6 +215,10 @@ func (rc *ringClient) addressesToPubKeys( ctx context.Context, addresses []string, ) ([]cryptotypes.PubKey, error) { + // Lock for the entire operation since we're doing multiple queries + rc.mu.Lock() + defer rc.mu.Unlock() + pubKeys := make([]cryptotypes.PubKey, len(addresses)) for i, addr := range addresses { acc, err := rc.accountQuerier.GetPubKeyFromAddress(ctx, addr) diff --git a/x/proof/keeper/validate_proofs.go b/x/proof/keeper/validate_proofs.go index e98b5a17c..2f707b9a0 100644 --- a/x/proof/keeper/validate_proofs.go +++ b/x/proof/keeper/validate_proofs.go @@ -3,6 +3,7 @@ package keeper import ( "context" "fmt" + "runtime" "sync" sdk "github.com/cosmos/cosmos-sdk/types" @@ -17,7 +18,7 @@ var numCPU int func init() { // Initialize the number of CPU cores available on the machine. - numCPU = 1 //runtime.NumCPU() + numCPU = runtime.NumCPU() } // ValidateSubmittedProofs concurrently validates block proofs. @@ -31,16 +32,7 @@ func (k Keeper) ValidateSubmittedProofs(ctx sdk.Context) (numValidProofs, numInv // Iterate over proofs using an proofIterator to prevent memory issues from bulk fetching. proofIterator := k.GetAllProofsIterator(ctx) - coordinator := &proofValidationTaskCoordinator{ - // Parallelize proof validation across CPU cores since they are independent from one another. - // Use semaphores to limit concurrent goroutines and prevent memory issues. - sem: make(chan struct{}, numCPU), - // Use a wait group to wait for all goroutines to finish before returning. - wg: &sync.WaitGroup{}, - - processedProofs: make(map[string][]string), - coordinatorMu: &sync.Mutex{}, - } + coordinator := newProofValidationTaskCoordinator(numCPU) for ; proofIterator.Valid(); proofIterator.Next() { proofBz := proofIterator.Value() @@ -63,6 +55,7 @@ func (k Keeper) ValidateSubmittedProofs(ctx sdk.Context) (numValidProofs, numInv // Delete all the processed proofs from the store since they are no longer needed. logger.Info("removing processed proofs from the store") + coordinator.mu.Lock() for supplierOperatorAddr, processedProofs := range coordinator.processedProofs { for _, sessionId := range processedProofs { k.RemoveProof(ctx, sessionId, supplierOperatorAddr) @@ -73,8 +66,11 @@ func (k Keeper) ValidateSubmittedProofs(ctx sdk.Context) (numValidProofs, numInv )) } } + numValidProofs = coordinator.numValidProofs + numInvalidProofs = coordinator.numInvalidProofs + coordinator.mu.Unlock() - return coordinator.numValidProofs, coordinator.numInvalidProofs, nil + return numValidProofs, numInvalidProofs, nil } // validateProof validates a proof before removing it from the store. @@ -155,10 +151,8 @@ func (k Keeper) validateProof( return } - // Protect the subsequent operations from concurrent access. - coordinator.coordinatorMu.Lock() - defer coordinator.coordinatorMu.Unlock() - + // Update all shared state under a single lock + coordinator.mu.Lock() // Update the claim to reflect its corresponding the proof validation result. // // It will be used later by the SettlePendingClaims routine to determine whether: @@ -167,20 +161,19 @@ func (k Keeper) validateProof( claim.ProofStatus = proofStatus k.UpsertClaim(ctx, claim) - // Collect the processed proofs info to delete them after the proofIterator is closed - // to prevent iterator invalidation. - coordinator.processedProofs[supplierOperatorAddr] = append( - coordinator.processedProofs[supplierOperatorAddr], - sessionHeader.GetSessionId(), - ) - + // Update the counters if proofStatus == types.ClaimProofStatus_INVALID { - // Increment the number of invalid proofs. coordinator.numInvalidProofs++ } else { - // Increment the number of valid proofs. coordinator.numValidProofs++ } + + // Update processed proofs + coordinator.processedProofs[supplierOperatorAddr] = append( + coordinator.processedProofs[supplierOperatorAddr], + sessionHeader.GetSessionId(), + ) + coordinator.mu.Unlock() } // proofValidationTaskCoordinator is a helper struct to coordinate parallel proof @@ -197,9 +190,19 @@ type proofValidationTaskCoordinator struct { processedProofs map[string][]string // numValidProofs and numInvalidProofs are counters for the number of valid and invalid proofs. - numValidProofs, + numValidProofs uint64 numInvalidProofs uint64 - // coordinatorMu protects the coordinator fields. - coordinatorMu *sync.Mutex + // mu protects all shared state (processedProofs and counters) + mu *sync.Mutex +} + +// newProofValidationTaskCoordinator creates a new proofValidationTaskCoordinator +func newProofValidationTaskCoordinator(numWorkers int) *proofValidationTaskCoordinator { + return &proofValidationTaskCoordinator{ + sem: make(chan struct{}, numWorkers), + wg: &sync.WaitGroup{}, + processedProofs: make(map[string][]string), + mu: &sync.Mutex{}, + } }