Skip to content

Commit

Permalink
chore: Address reivew change requests
Browse files Browse the repository at this point in the history
  • Loading branch information
red-0ne committed Jan 29, 2025
1 parent 832e985 commit 43a9a3d
Show file tree
Hide file tree
Showing 12 changed files with 240 additions and 229 deletions.
9 changes: 6 additions & 3 deletions api/poktroll/proof/event.pulsar.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

146 changes: 74 additions & 72 deletions api/poktroll/proof/types.pulsar.go

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion proto/poktroll/proof/event.proto
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,13 @@ message EventProofUpdated {
cosmos.base.v1beta1.Coin claimed_upokt = 6 [(gogoproto.jsontag) = "claimed_upokt"];
}

// Event emitted after a proof has been checked for validity.
// Event emitted after a proof has been checked for validity in the proof module's
// EndBlocker.
message EventProofValidityChecked {
poktroll.proof.Proof proof = 1 [(gogoproto.jsontag) = "proof"];
uint64 block_height = 2 [(gogoproto.jsontag) = "block_height"];
poktroll.proof.ClaimProofStatus proof_status = 3 [(gogoproto.jsontag) = "proof_status"];
// reason is the string representation of the error that led to the proof being
// marked as invalid (e.g. "invalid closest merkle proof", "invalid relay request signature")
string reason = 4 [(gogoproto.jsontag) = "reason"];
}
21 changes: 12 additions & 9 deletions proto/poktroll/proof/types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,17 @@ message Proof {

// Claim is the serialized object stored onchain for claims pending to be proven
message Claim {
// Address of the supplier's operator that submitted this claim.
string supplier_operator_address = 1 [(cosmos_proto.scalar) = "cosmos.AddressString"]; // the address of the supplier's operator that submitted this claim
// The session header of the session that this claim is for.

// Session header this claim is for.
poktroll.session.SessionHeader session_header = 2;
// Root hash returned from smt.SMST#Root().

// Root hash from smt.SMST#Root().
bytes root_hash = 3;
// Claim proof status captures the status of the proof for this claim.
// WARNING: This field MUST only be set by proofKeeper#EnsureValidProofSignaturesAndClosestPath
ClaimProofStatus proof_status = 4;

// Important: This field MUST only be set by proofKeeper#EnsureValidProofSignaturesAndClosestPath
ClaimProofStatus proof_validation_status = 4;
}

enum ProofRequirementReason {
Expand All @@ -47,10 +50,10 @@ enum ClaimProofStage {
EXPIRED = 3;
}

// ClaimProofStatus defines the status of the proof for a claim.
// The default value is NOT_FOUND, whether the proof is required or not.
// Status of proof validation for a claim
// Default is PENDING_VALIDATION regardless of proof requirement
enum ClaimProofStatus {
NOT_FOUND = 0;
VALID = 1;
PENDING_VALIDATION = 0;
VALIDATED = 1;
INVALID = 2;
}
2 changes: 1 addition & 1 deletion testutil/testtree/tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,6 @@ func NewClaim(
SupplierOperatorAddress: supplierOperatorAddr,
SessionHeader: sessionHeader,
RootHash: rootHash,
ProofStatus: prooftypes.ClaimProofStatus_NOT_FOUND,
ProofValidationStatus: prooftypes.ClaimProofStatus_PENDING_VALIDATION,
}
}
23 changes: 10 additions & 13 deletions x/proof/keeper/msg_server_submit_proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,18 @@ import (
sharedtypes "github.com/pokt-network/poktroll/x/shared/types"
)

// SubmitProof is the server handler to submit and store a proof onchain.
// A proof that's stored onchain is what leads to rewards (i.e. inflation)
// downstream, making this a critical part of the protocol.
// SubmitProof is the server message handler that stores a valid
// proof onchain, enabling downstream reward distribution.
//
// Note that the validation of the proof is done in `EnsureValidProofSignaturesAndClosestPath`.
// However, preliminary checks are done in the handler to prevent sybil or DoS attacks on
// full nodes by submitting malformed proofs.
// IMPORTANT: Full proof validation occurs in EnsureValidProofSignaturesAndClosestPath.
// This handler performs preliminary validation to prevent sybil/DoS attacks.
//
// We are playing a balance of security and efficiency here, where enough validation
// is done on proof submission, and exhaustive validation is done during the endblocker.
// There is a security & performance balance and tradeoff between the handler and end blocker:
// - Basic validation on submission (here)
// - Exhaustive validation in endblocker (EnsureValidProofSignaturesAndClosestPath)
//
// The entity sending the SubmitProof messages does not necessarily need
// to correspond to the supplier signing the proof. For example, a single entity
// could (theoretically) batch multiple proofs (signed by the corresponding supplier)
// into one transaction to save on transaction fees.
// Note: Proof submitter may differ from supplier signer, allowing batched submissions
// to optimize transaction fees.
func (k msgServer) SubmitProof(
ctx context.Context,
msg *types.MsgSubmitProof,
Expand Down Expand Up @@ -85,7 +82,7 @@ func (k msgServer) SubmitProof(
logger.Error(fmt.Sprintf("failed to ensure well-formed proof: %v", err))
return nil, status.Error(codes.FailedPrecondition, err.Error())
}
logger.Info("checked the proof is well-formed")
logger.Info("ensured the proof is well-formed")

// Retrieve the claim associated with the proof.
// The claim should ALWAYS exist since the proof validation in EnsureWellFormedProof
Expand Down
30 changes: 14 additions & 16 deletions x/proof/keeper/proof_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func (k Keeper) EnsureWellFormedProof(ctx context.Context, proof *types.Proof) e
logger.Debug("successfully validated relay mining difficulty")

// Retrieve the corresponding claim for the proof submitted
if err := k.validateClaimForProof(ctx, sessionHeader, supplierOperatorAddr); err != nil {
if err := k.validateSessionClaim(ctx, sessionHeader, supplierOperatorAddr); err != nil {
return err
}
logger.Debug("successfully retrieved and validated claim")
Expand Down Expand Up @@ -331,17 +331,16 @@ func (k Keeper) validateClosestPath(
return nil
}

// validateClaimForProof ensures that a claim corresponding to the given proof's
// session exists & has a matching supplier operator address and session header.
func (k Keeper) validateClaimForProof(
// validateSessionClaim ensures that the given session header and supplierOperatorAddress
// have a corresponding claim.
func (k Keeper) validateSessionClaim(
ctx context.Context,
sessionHeader *sessiontypes.SessionHeader,
supplierOperatorAddr string,
) error {
sessionId := sessionHeader.SessionId
// NB: no need to assert the testSessionId or supplier operator address as it is retrieved
// by respective values of the given proof. I.e., if the claim exists, then these
// values are guaranteed to match.

// Retrieve the claim corresponding to the session ID and supplier operator address.
foundClaim, found := k.GetClaim(ctx, sessionId, supplierOperatorAddr)
if !found {
return types.ErrProofClaimNotFound.Wrapf(
Expand All @@ -352,41 +351,40 @@ func (k Keeper) validateClaimForProof(
}

claimSessionHeader := foundClaim.GetSessionHeader()
proofSessionHeader := sessionHeader

// Ensure session start heights match.
if claimSessionHeader.GetSessionStartBlockHeight() != proofSessionHeader.GetSessionStartBlockHeight() {
if claimSessionHeader.GetSessionStartBlockHeight() != sessionHeader.GetSessionStartBlockHeight() {
return types.ErrProofInvalidSessionStartHeight.Wrapf(
"claim session start height %d does not match proof session start height %d",
claimSessionHeader.GetSessionStartBlockHeight(),
proofSessionHeader.GetSessionStartBlockHeight(),
sessionHeader.GetSessionStartBlockHeight(),
)
}

// Ensure session end heights match.
if claimSessionHeader.GetSessionEndBlockHeight() != proofSessionHeader.GetSessionEndBlockHeight() {
if claimSessionHeader.GetSessionEndBlockHeight() != sessionHeader.GetSessionEndBlockHeight() {
return types.ErrProofInvalidSessionEndHeight.Wrapf(
"claim session end height %d does not match proof session end height %d",
claimSessionHeader.GetSessionEndBlockHeight(),
proofSessionHeader.GetSessionEndBlockHeight(),
sessionHeader.GetSessionEndBlockHeight(),
)
}

// Ensure application addresses match.
if claimSessionHeader.GetApplicationAddress() != proofSessionHeader.GetApplicationAddress() {
if claimSessionHeader.GetApplicationAddress() != sessionHeader.GetApplicationAddress() {
return types.ErrProofInvalidAddress.Wrapf(
"claim application address %q does not match proof application address %q",
claimSessionHeader.GetApplicationAddress(),
proofSessionHeader.GetApplicationAddress(),
sessionHeader.GetApplicationAddress(),
)
}

// Ensure service IDs match.
if claimSessionHeader.GetServiceId() != proofSessionHeader.GetServiceId() {
if claimSessionHeader.GetServiceId() != sessionHeader.GetServiceId() {
return types.ErrProofInvalidService.Wrapf(
"claim service ID %q does not match proof service ID %q",
claimSessionHeader.GetServiceId(),
proofSessionHeader.GetServiceId(),
sessionHeader.GetServiceId(),
)
}

Expand Down
82 changes: 41 additions & 41 deletions x/proof/keeper/validate_proofs.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,46 @@ import (
"github.com/pokt-network/poktroll/x/proof/types"
)

// numCPU is the number of CPU cores available on the machine.
// It is initialized in the init function to prevent runtime.NumCPU from being called
// multiple times in the ValidateSubmittedProofs function.
// proofValidationTaskCoordinator is a helper struct to coordinate parallel proof
// validation tasks.
type proofValidationTaskCoordinator struct {
// sem is a semaphore to limit the number of concurrent goroutines.
sem chan struct{}

// wg is a wait group to wait for all goroutines to finish before returning.
wg *sync.WaitGroup

// processedProofs is a map of supplier operator addresses to the session IDs
// whose proofs that have been processed.
processedProofs map[string][]string

// numValidProofs and numInvalidProofs are counters to keep track of proof validation results.
numValidProofs,
numInvalidProofs uint64

// coordinatorMu protects the coordinator fields.
coordinatorMu *sync.Mutex
}

// numCPU caches runtime.NumCPU() to avoid being retrieved on every ValidateSubmittedProofs call.
var numCPU int

func init() {
// Initialize the number of CPU cores available on the machine.
numCPU = runtime.NumCPU()
}

// ValidateSubmittedProofs concurrently validates block proofs.
// It marks their corresponding claims as valid or invalid based on the proof validation.
// It removes them from the store once they are processed.
// ValidateSubmittedProofs performs concurrent proof validation, updating claims'
// proof validation states and removing processed proofs from storage.
func (k Keeper) ValidateSubmittedProofs(ctx sdk.Context) (numValidProofs, numInvalidProofs uint64, err error) {
logger := k.Logger().With("method", "ValidateSubmittedProofs")

logger.Info(fmt.Sprintf("Number of CPU cores used for parallel proof validation: %d\n", numCPU))

// Iterate over proofs using an proofIterator to prevent memory issues from bulk fetching.
// Iterate over proofs using an iterator to prevent OOM issues caused by bulk fetching.
proofIterator := k.GetAllProofsIterator(ctx)

coordinator := &proofValidationTaskCoordinator{
proofValidationCoordinator := &proofValidationTaskCoordinator{
// Parallelize proof validation across CPU cores since they are independent from one another.
// Use semaphores to limit concurrent goroutines and prevent memory issues.
sem: make(chan struct{}, numCPU),
Expand All @@ -48,23 +66,23 @@ func (k Keeper) ValidateSubmittedProofs(ctx sdk.Context) (numValidProofs, numInv

// Acquire a semaphore to limit the number of goroutines.
// This will block if the sem channel is full.
coordinator.sem <- struct{}{}
proofValidationCoordinator.sem <- struct{}{}

// Increment the wait group to wait for proof validation to finish.
coordinator.wg.Add(1)
proofValidationCoordinator.wg.Add(1)

go k.validateProof(ctx, proofBz, coordinator)
go k.validateProof(ctx, proofBz, proofValidationCoordinator)
}

// Wait for all goroutines to finish before returning.
coordinator.wg.Wait()
proofValidationCoordinator.wg.Wait()

// Close the proof iterator before deleting the processed proofs.
proofIterator.Close()

// Delete all the processed proofs from the store since they are no longer needed.
logger.Info("removing processed proofs from the store")
for supplierOperatorAddr, processedProofs := range coordinator.processedProofs {
for supplierOperatorAddr, processedProofs := range proofValidationCoordinator.processedProofs {
for _, sessionId := range processedProofs {
k.RemoveProof(ctx, sessionId, supplierOperatorAddr)
logger.Info(fmt.Sprintf(
Expand All @@ -75,10 +93,10 @@ func (k Keeper) ValidateSubmittedProofs(ctx sdk.Context) (numValidProofs, numInv
}
}

return coordinator.numValidProofs, coordinator.numInvalidProofs, nil
return proofValidationCoordinator.numValidProofs, proofValidationCoordinator.numInvalidProofs, nil
}

// validateProof validates a proof before removing it from the store.
// validateProof validates a proof submitted by a supplier.
// It marks the corresponding claim as valid or invalid based on the proof validation.
// It is meant to be called concurrently by multiple goroutines to parallelize
// proof validation.
Expand All @@ -101,6 +119,9 @@ func (k Keeper) validateProof(
// proofBz is not expected to fail unmarshalling since it is should have
// passed EnsureWellFormedProof validation in MsgSubmitProof handler.
// Panic if it fails unmarshalling.
// If a failure occurs, it indicates either a bug in the code or data corruption.
// In either case, panicking is an appropriate response since both panics and
// returning an error would halt block production.
k.cdc.MustUnmarshal(proofBz, &proof)

sessionHeader := proof.GetSessionHeader()
Expand All @@ -116,8 +137,8 @@ func (k Keeper) validateProof(

// Retrieve the corresponding claim for the proof submitted so it can be
// used in the proof validation below.
// EnsureWellFormedProof has already validated that the claim referenced by the
// proof exists and has a matching session header.
// EnsureWellFormedProof which is called in MsgSubmitProof handler has already validated
// that the claim referenced by the proof exists and has a matching session header.
claim, claimFound := k.GetClaim(ctx, sessionHeader.GetSessionId(), supplierOperatorAddr)
if !claimFound {
// DEV_NOTE: This should never happen since EnsureWellFormedProof has already checked
Expand All @@ -128,7 +149,7 @@ func (k Keeper) validateProof(
logger.Debug("successfully retrieved claim")

// Set the proof status to valid by default.
proofStatus := types.ClaimProofStatus_VALID
proofStatus := types.ClaimProofStatus_VALIDATED
// Set the invalidity reason to an empty string by default.
invalidProofCause := ""

Expand Down Expand Up @@ -160,12 +181,12 @@ func (k Keeper) validateProof(
coordinator.coordinatorMu.Lock()
defer coordinator.coordinatorMu.Unlock()

// Update the claim to reflect its corresponding the proof validation result.
// Update the claim to reflect the validation result of the associated proof.
//
// It will be used later by the SettlePendingClaims routine to determine whether:
// 1. The claim should be settled or not
// 2. The corresponding supplier should be slashed or not
claim.ProofStatus = proofStatus
claim.ProofValidationStatus = proofStatus
k.UpsertClaim(ctx, claim)

// Collect the processed proofs info to delete them after the proofIterator is closed
Expand All @@ -183,24 +204,3 @@ func (k Keeper) validateProof(
coordinator.numValidProofs++
}
}

// proofValidationTaskCoordinator is a helper struct to coordinate parallel proof
// validation tasks.
type proofValidationTaskCoordinator struct {
// sem is a semaphore to limit the number of concurrent goroutines.
sem chan struct{}

// wg is a wait group to wait for all goroutines to finish before returning.
wg *sync.WaitGroup

// processedProofs is a map of supplier operator addresses to the session IDs
// of proofs that have been processed.
processedProofs map[string][]string

// numValidProofs and numInvalidProofs are counters for the number of valid and invalid proofs.
numValidProofs,
numInvalidProofs uint64

// coordinatorMu protects the coordinator fields.
coordinatorMu *sync.Mutex
}
Loading

0 comments on commit 43a9a3d

Please sign in to comment.