Skip to content

Commit

Permalink
[Session test fixes] Fix window bug and make E2E session test determi…
Browse files Browse the repository at this point in the history
…nistic (#607)

Make `make test_e2e_session` by:
- Fixing a bug in `GetClaimWindowCloseHeight` where we were not adding things correctly leading to incorrect windows
- Fix a flaky unit test in `session.feature` because the `Proof` would be deleted on-chain before the e2e testing client queried it for validation; switching to replayable events.
- Add an example of how we can query and listen for any on-chain event to improve future tests
  • Loading branch information
Olshansk authored Jun 14, 2024
1 parent d586eaa commit a656145
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 59 deletions.
3 changes: 1 addition & 2 deletions e2e/tests/session.feature
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ Feature: Session Namespace
Then the claim created by supplier "supplier1" for service "svc1" for application "app1" should be persisted on-chain
# TODO_BLOCKER(@bryanchriswhite): And a cosmos-sdk event (e.g. EventClaimCreated) should be emitted.
And the user should wait for the "proof" module "SubmitProof" Message to be submitted
# TODO_BLOCKER(@bryanchriswhite): And a cosmos-sdk event (e.g. EventClaimCreated) should be emitted.
Then the proof submitted by supplier "supplier1" for service "svc1" for application "app1" should be persisted on-chain
Then the claim created by supplier "supplier1" for service "anvil" for application "app1" should be successfully settled

# TODO_BLOCKER(@red-0ne): Make sure to implement and validate this test
# One way to exercise this behavior is to close the `RelayMiner` port to prevent
Expand Down
114 changes: 77 additions & 37 deletions e2e/tests/session_steps_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/pokt-network/poktroll/pkg/observable/channel"
"github.com/pokt-network/poktroll/testutil/testclient"
prooftypes "github.com/pokt-network/poktroll/x/proof/types"
tokenomicstypes "github.com/pokt-network/poktroll/x/tokenomics/types"
)

const (
Expand All @@ -34,7 +35,7 @@ const (
// This is used by an events replay client to subscribe to tx events from the supplier.
// See: https://docs.cosmos.network/v0.47/learn/advanced/events#subscribing-to-events
txSenderEventSubscriptionQueryFmt = "tm.event='Tx' AND message.sender='%s'"
// newBlockEventSubscriptionQuery is the format string which yields a
// newBlockEventSubscriptionQuery is the query string which yields a
// subscription query to listen for on-chain new block events.
newBlockEventSubscriptionQuery = "tm.event='NewBlock'"
// eventsReplayClientBufferSize is the buffer size for the events replay client
Expand All @@ -57,11 +58,14 @@ const (
)

func (s *suite) TheUserShouldWaitForTheModuleMessageToBeSubmitted(module, message string) {
s.waitForTxResultEvent(fmt.Sprintf("/poktroll.%s.Msg%s", module, message))
msgType := fmt.Sprintf("/poktroll.%s.Msg%s", module, message)
s.waitForTxResultEvent(msgType)
}

func (s *suite) TheUserShouldWaitForTheModuleEventToBeBroadcast(module, message string) {
s.waitForNewBlockEvent(fmt.Sprintf("poktroll.%s.Event%s", module, message))
eventType := fmt.Sprintf("poktroll.%s.Event%s", module, message)
isExpectedEventFn := func(event *abci.Event) bool { return event.Type == eventType }
s.waitForNewBlockEvent(isExpectedEventFn)
}

func (s *suite) TheClaimCreatedBySupplierForServiceForApplicationShouldBePersistedOnchain(supplierName, serviceId, appName string) {
Expand Down Expand Up @@ -147,36 +151,28 @@ func (s *suite) TheSupplierHasServicedASessionWithRelaysForServiceForApplication
)
}

func (s *suite) TheProofSubmittedBySupplierForServiceForApplicationShouldBePersistedOnchain(supplierName, serviceId, appName string) {
ctx := context.Background()

// Retrieve all on-chain proofs for supplierName
allProofsRes, err := s.proofQueryClient.AllProofs(ctx, &prooftypes.QueryAllProofsRequest{
Filter: &prooftypes.QueryAllProofsRequest_SupplierAddress{
SupplierAddress: accNameToAddrMap[supplierName],
},
})
require.NoError(s, err)
require.NotNil(s, allProofsRes)

// Assert that the number of proofs has increased by one.
preExistingProofs, ok := s.scenarioState[preExistingProofsKey].([]prooftypes.Proof)
require.True(s, ok, "preExistingProofsKey not found in scenarioState")
// NB: We are avoiding the use of require.Len here because it provides unreadable output
// TODO_TECHDEBT: Due to the speed of the blocks of the LocalNet validator, along with the small number
// of blocks per session, multiple proofs may be created throughout the duration of the test. Until
// these values are appropriately adjusted, we assert on an increase in proofs rather than +1.
require.Greater(s, len(allProofsRes.Proofs), len(preExistingProofs), "number of proofs must have increased")

// TODO_UPNEXT(@bryanchriswhite): assert that the root hash of the proof contains the correct
// SMST sum. The sum can be retrieved via the `GetSum` function exposed
// by the SMT.

// TODO_IMPROVE: add assertions about serviceId and appName and/or incorporate
// them into the scenarioState key(s).
func (s *suite) TheClaimCreatedBySupplierForServiceForApplicationShouldBeSuccessfullySettled(supplierName, serviceId, appName string) {
app, ok := accNameToAppMap[appName]
require.True(s, ok, "application %s not found", appName)

supplier, ok := accNameToSupplierMap[supplierName]
require.True(s, ok, "supplier %s not found", supplierName)

isValidClaimSettledEvent := func(event *abci.Event) bool {
if event.Type != "poktroll.tokenomics.EventClaimSettled" {
return false
}
claimSettledEvent := s.abciToClaimSettledEvent(event)
claim := claimSettledEvent.Claim
require.Equal(s, app.Address, claim.SessionHeader.ApplicationAddress)
require.Equal(s, supplier.Address, claim.SupplierAddress)
require.Equal(s, serviceId, claim.SessionHeader.Service.Id)
require.Greater(s, claimSettledEvent.ComputeUnits, uint64(0), "compute units should be greater than 0")
s.Logf("Claim settled for %d compute units w/ proof requirement: %t\n", claimSettledEvent.ComputeUnits, claimSettledEvent.ProofRequired)
return true
}

proof := allProofsRes.Proofs[0]
require.Equal(s, accNameToAddrMap[supplierName], proof.SupplierAddress)
s.waitForNewBlockEvent(isValidClaimSettledEvent)
}

func (s *suite) sendRelaysForSession(
Expand Down Expand Up @@ -241,7 +237,13 @@ func (s *suite) waitForTxResultEvent(targetAction string) {
}
}

func (s *suite) waitForNewBlockEvent(targetEvent string) {
// waitForNewBlockEvent waits for an event to be observed whose type and data
// match the conditions specified by isEventMatchFn.
// isEventMatchFn is a function that receives an abci.Event and returns a boolean
// indicating whether the event matches the desired conditions.
func (s *suite) waitForNewBlockEvent(
isEventMatchFn func(*abci.Event) bool,
) {
ctx, done := context.WithCancel(context.Background())

newBlockEventsReplayClientState, ok := s.scenarioState[newBlockEventReplayClientKey]
Expand All @@ -262,10 +264,9 @@ func (s *suite) waitForNewBlockEvent(targetEvent string) {
// Range over each event's attributes to find the "action" attribute
// and compare its value to that of the action provided.
for _, event := range newBlockEvent.Data.Value.ResultFinalizeBlock.Events {
// TODO_IMPROVE: We can pass in a function to do even more granular
// checks on the event. For example, for a Claim Settlement event,
// Checks on the event. For example, for a Claim Settlement event,
// we can parse the claim and verify the compute units.
if event.Type == targetEvent {
if isEventMatchFn(&event) {
done()
return
}
Expand All @@ -275,8 +276,47 @@ func (s *suite) waitForNewBlockEvent(targetEvent string) {

select {
case <-time.After(eventTimeout):
s.Fatalf("timed out waiting for NewBlock event %q", targetEvent)
s.Fatalf("timed out waiting for NewBlock event")
case <-ctx.Done():
s.Log("Success; message detected before timeout.")
}
}

// abciToClaimSettledEvent converts an abci.Event to a tokenomics.EventClaimSettled
//

func (s *suite) abciToClaimSettledEvent(event *abci.Event) *tokenomicstypes.EventClaimSettled {
var claimSettledEvent tokenomicstypes.EventClaimSettled

// TODO_TECHDEBT: Investigate why `cosmostypes.ParseTypedEvent(*event)` throws
// an error where cosmostypes is imported from "github.com/cosmos/cosmos-sdk/types"
// resulting in the following error:
// 'json: error calling MarshalJSON for type json.RawMessage: invalid character 'E' looking for beginning of value'
// typedEvent, err := cosmostypes.ParseTypedEvent(*event)

for _, attr := range event.Attributes {
switch string(attr.Key) {
case "claim":
var claim prooftypes.Claim
if err := s.cdc.UnmarshalJSON([]byte(attr.Value), &claim); err != nil {
s.Fatalf("Failed to unmarshal claim: %v", err)
}
claimSettledEvent.Claim = &claim
case "compute_units":
value := string(attr.Value)
value = value[1 : len(value)-1] // Remove surrounding quotes
computeUnits, err := strconv.ParseUint(value, 10, 64)
if err != nil {
s.Fatalf("Failed to parse compute_units: %v", err)
}
claimSettledEvent.ComputeUnits = computeUnits
case "proof_required":
proofRequired, err := strconv.ParseBool(string(attr.Value))
if err != nil {
s.Fatalf("Failed to parse proof_required: %v", err)
}
claimSettledEvent.ProofRequired = proofRequired
}
}
return &claimSettledEvent
}
3 changes: 1 addition & 2 deletions pkg/client/tx/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,7 @@ func UnmarshalTxResult(txResultBz []byte) (*abci.TxResult, error) {
}

var cometTxEvent CometTxEvent

// Try to deserialize the provided bytes into a TxResult.
// Try to deserialize the provided bytes into a CometTxEvent.
if err := json.Unmarshal(rpcResponse.Result, &cometTxEvent); err != nil {
return nil, events.ErrEventsUnmarshalEvent.Wrap(err.Error())
}
Expand Down
24 changes: 12 additions & 12 deletions x/proof/keeper/msg_server_submit_proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func (k msgServer) SubmitProof(ctx context.Context, msg *types.MsgSubmitProof) (
logger.Info("queried and validated the session header")

// Re-hydrate message session header with the on-chain session header.
// This corrects for discrepencies between unvalidated fields in the session header
// This corrects for discrepancies between unvalidated fields in the session header
// which can be derived from known values (e.g. session end height).
msg.SessionHeader = onChainSession.GetHeader()

Expand Down Expand Up @@ -165,39 +165,39 @@ func (k msgServer) SubmitProof(ctx context.Context, msg *types.MsgSubmitProof) (
if err := relayReq.ValidateBasic(); err != nil {
return nil, status.Error(codes.FailedPrecondition, err.Error())
}
logger.Info("successfully validated relay request")
logger.Debug("successfully validated relay request")

// Basic validation of the relay response.
relayRes := relay.GetRes()
if err := relayRes.ValidateBasic(); err != nil {
return nil, status.Error(codes.FailedPrecondition, err.Error())
}
logger.Info("successfully validated relay response")
logger.Debug("successfully validated relay response")

// Verify that the relay request session header matches the proof session header.
if err := compareSessionHeaders(msg.GetSessionHeader(), relayReq.Meta.GetSessionHeader()); err != nil {
return nil, status.Error(codes.FailedPrecondition, err.Error())
}
logger.Info("successfully compared relay request session header")
logger.Debug("successfully compared relay request session header")

// Verify that the relay response session header matches the proof session header.
if err := compareSessionHeaders(msg.GetSessionHeader(), relayRes.Meta.GetSessionHeader()); err != nil {
return nil, status.Error(codes.FailedPrecondition, err.Error())
}
logger.Info("successfully compared relay response session header")
logger.Debug("successfully compared relay response session header")

// Verify the relay request's signature.
// TODO_BLOCKER(@red-0ne): Fetch the correct ring for the session this relay is from.
if err := k.ringClient.VerifyRelayRequestSignature(ctx, relayReq); err != nil {
return nil, status.Error(codes.FailedPrecondition, err.Error())
}
logger.Info("successfully verified relay request signature")
logger.Debug("successfully verified relay request signature")

// Verify the relay response's signature.
if err := relayRes.VerifySupplierSignature(supplierPubKey); err != nil {
return nil, status.Error(codes.FailedPrecondition, err.Error())
}
logger.Info("successfully verified relay response signature")
logger.Debug("successfully verified relay response signature")

// Get the proof module's governance parameters.
params := k.GetParams(ctx)
Expand All @@ -206,14 +206,14 @@ func (k msgServer) SubmitProof(ctx context.Context, msg *types.MsgSubmitProof) (
if err := validateMiningDifficulty(relayBz, params.MinRelayDifficultyBits); err != nil {
return nil, status.Error(codes.FailedPrecondition, err.Error())
}
logger.Info("successfully validated relay mining difficulty")
logger.Debug("successfully validated relay mining difficulty")

// Validate that path the proof is submitted for matches the expected one
// based on the pseudo-random on-chain data associated with the header.
if err := k.validateClosestPath(ctx, sparseMerkleClosestProof, msg.GetSessionHeader()); err != nil {
return nil, status.Error(codes.FailedPrecondition, err.Error())
}
logger.Info("successfully validated proof path")
logger.Debug("successfully validated proof path")

// Verify the relay's difficulty.
if err := validateMiningDifficulty(relayBz, params.MinRelayDifficultyBits); err != nil {
Expand All @@ -226,21 +226,21 @@ func (k msgServer) SubmitProof(ctx context.Context, msg *types.MsgSubmitProof) (
if err != nil {
return nil, status.Error(codes.FailedPrecondition, err.Error())
}
logger.Info("successfully retrieved and validated claim")
logger.Debug("successfully retrieved and validated claim")

// Verify the proof's closest merkle proof.
if err := verifyClosestProof(sparseMerkleClosestProof, claim.GetRootHash()); err != nil {
return nil, status.Error(codes.FailedPrecondition, err.Error())
}
logger.Info("successfully verified closest merkle proof")
logger.Debug("successfully verified closest merkle proof")

// Construct and insert proof after all validation.
proof := types.Proof{
SupplierAddress: supplierAddr,
SessionHeader: msg.GetSessionHeader(),
ClosestMerkleProof: msg.GetProof(),
}
logger.Info(fmt.Sprintf("queried and validated the claim for session ID %q", sessionHeader.SessionId))
logger.Debug(fmt.Sprintf("queried and validated the claim for session ID %q", sessionHeader.SessionId))

// TODO_BLOCKER(@Olshansk): check if this proof already exists and return an
// appropriate error in any case where the supplier should no longer be able
Expand Down
13 changes: 7 additions & 6 deletions x/shared/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
sharedtypes "github.com/pokt-network/poktroll/x/shared/types"
)

// TODO_DOCUMENT(@bryanchriswhite): Move this into the documentation: https://github.com/pokt-network/poktroll/pull/571#discussion_r1630923625

// SessionGracePeriodBlocks is the number of blocks after the session ends before the
// "session grace period" is considered to have elapsed.
//
Expand Down Expand Up @@ -37,8 +39,9 @@ func GetSessionEndHeight(sharedParams *sharedtypes.Params, queryHeight int64) in
}

numBlocksPerSession := int64(sharedParams.GetNumBlocksPerSession())
sessionStartHeight := GetSessionStartHeight(sharedParams, queryHeight)

return GetSessionStartHeight(sharedParams, queryHeight) + numBlocksPerSession - 1
return sessionStartHeight + numBlocksPerSession - 1
}

// GetSessionNumber returns the session number of the session containing queryHeight,
Expand Down Expand Up @@ -85,11 +88,9 @@ func GetClaimWindowOpenHeight(sharedParams *sharedtypes.Params, queryHeight int6
// GetClaimWindowCloseHeight returns the block height at which the claim window of
// the session that includes queryHeight closes, for the provided sharedParams.
func GetClaimWindowCloseHeight(sharedParams *sharedtypes.Params, queryHeight int64) int64 {
sessionEndHeight := GetSessionEndHeight(sharedParams, queryHeight)
sessionGracePeriodEndHeight := GetSessionGracePeriodEndHeight(sharedParams, sessionEndHeight)
return GetClaimWindowOpenHeight(sharedParams, queryHeight) +
sessionGracePeriodEndHeight +
int64(sharedParams.GetClaimWindowCloseOffsetBlocks())
claimWindowOpenHeight := GetClaimWindowOpenHeight(sharedParams, queryHeight)
claimWindowCloseOffsetBlocks := int64(sharedParams.GetClaimWindowCloseOffsetBlocks())
return claimWindowOpenHeight + claimWindowCloseOffsetBlocks
}

// GetProofWindowOpenHeight returns the block height at which the claim window of
Expand Down

0 comments on commit a656145

Please sign in to comment.