diff --git a/app/app.go b/app/app.go index b34d4db68..8dc0ace54 100644 --- a/app/app.go +++ b/app/app.go @@ -575,16 +575,6 @@ func New( ) serviceModule := servicemodule.NewAppModule(appCodec, app.ServiceKeeper, app.AccountKeeper, app.BankKeeper) - app.SupplierKeeper = *suppliermodulekeeper.NewKeeper( - appCodec, - keys[suppliermoduletypes.StoreKey], - keys[suppliermoduletypes.MemStoreKey], - app.GetSubspace(suppliermoduletypes.ModuleName), - - app.BankKeeper, - ) - supplierModule := suppliermodule.NewAppModule(appCodec, app.SupplierKeeper, app.AccountKeeper, app.BankKeeper) - app.GatewayKeeper = *gatewaymodulekeeper.NewKeeper( appCodec, keys[gatewaymoduletypes.StoreKey], @@ -607,6 +597,28 @@ func New( ) applicationModule := applicationmodule.NewAppModule(appCodec, app.ApplicationKeeper, app.AccountKeeper, app.BankKeeper) + // TODO_TECHDEBT: Evaluate if this NB goes away after we upgrade to cosmos 0.5x + // NB: there is a circular dependency between the supplier and session keepers. + // Because the keepers are values (as opposed to pointers), they are copied + // when passed into their respective module constructor functions. For this + // reason, the existing pattern of ignite-generated keeper/module construction + // must be broken for these keepers and modules. + // + // Order of operations: + // 1. Construct supplier keeper + // 2. Construct session keeper + // 3. Provide session keeper to supplier keeper via custom #SupplySessionKeeper method. + // 4. Construct supplier module + // 5. Construct session module + app.SupplierKeeper = *suppliermodulekeeper.NewKeeper( + appCodec, + keys[suppliermoduletypes.StoreKey], + keys[suppliermoduletypes.MemStoreKey], + app.GetSubspace(suppliermoduletypes.ModuleName), + + app.BankKeeper, + ) + app.SessionKeeper = *sessionmodulekeeper.NewKeeper( appCodec, keys[sessionmoduletypes.StoreKey], @@ -616,6 +628,10 @@ func New( app.ApplicationKeeper, app.SupplierKeeper, ) + + app.SupplierKeeper.SupplySessionKeeper(app.SessionKeeper) + + supplierModule := suppliermodule.NewAppModule(appCodec, app.SupplierKeeper, app.AccountKeeper, app.BankKeeper) sessionModule := sessionmodule.NewAppModule(appCodec, app.SessionKeeper, app.AccountKeeper, app.BankKeeper) // this line is used by starport scaffolding # stargate/app/keeperDefinition diff --git a/e2e/tests/init_test.go b/e2e/tests/init_test.go index 8935e1240..81a12795d 100644 --- a/e2e/tests/init_test.go +++ b/e2e/tests/init_test.go @@ -19,6 +19,7 @@ import ( "github.com/stretchr/testify/require" "github.com/pokt-network/poktroll/app" + "github.com/pokt-network/poktroll/testutil/testclient" apptypes "github.com/pokt-network/poktroll/x/application/types" sessiontypes "github.com/pokt-network/poktroll/x/session/types" sharedtypes "github.com/pokt-network/poktroll/x/shared/types" @@ -59,9 +60,11 @@ func TestMain(m *testing.M) { type suite struct { gocuke.TestingT - pocketd *pocketdBin - scenarioState map[string]any // temporary state for each scenario - cdc codec.Codec + // TODO_TECHDEBT: rename to `poktrolld`. + pocketd *pocketdBin + scenarioState map[string]any // temporary state for each scenario + cdc codec.Codec + supplierQueryClient suppliertypes.QueryClient } func (s *suite) Before() { @@ -71,6 +74,10 @@ func (s *suite) Before() { s.buildAddrMap() s.buildAppMap() s.buildSupplierMap() + + flagSet := testclient.NewLocalnetFlagSet(s) + clientCtx := testclient.NewLocalnetClientCtx(s, flagSet) + s.supplierQueryClient = suppliertypes.NewQueryClient(clientCtx) } // TestFeatures runs the e2e tests specified in any .features files in this directory @@ -79,6 +86,7 @@ func TestFeatures(t *testing.T) { gocuke.NewRunner(t, &suite{}).Path(flagFeaturesPath).Run() } +// TODO_TECHDEBT: rename `pocketd` to `poktrolld`. func (s *suite) TheUserHasThePocketdBinaryInstalled() { s.TheUserRunsTheCommand("help") } diff --git a/e2e/tests/session.feature b/e2e/tests/session.feature new file mode 100644 index 000000000..3163eb471 --- /dev/null +++ b/e2e/tests/session.feature @@ -0,0 +1,10 @@ +Feature: Session Namespace + + Scenario: Supplier completes claim/proof lifecycle for a valid session + Given the user has the pocketd binary installed + When the supplier "supplier1" has serviced a session with "5" relays for service "svc1" for application "app1" + And after the supplier creates a claim for the session for service "svc1" for application "app1" + Then the claim created by supplier "supplier1" for service "svc1" for application "app1" should be persisted on-chain +# TODO_IMPROVE: ... +# And an event should be emitted... +# TODO_INCOMPLETE: add step(s) for proof validation. diff --git a/e2e/tests/session_steps_test.go b/e2e/tests/session_steps_test.go new file mode 100644 index 000000000..2c016f706 --- /dev/null +++ b/e2e/tests/session_steps_test.go @@ -0,0 +1,162 @@ +//go:build e2e + +package e2e + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "strings" + "time" + + abci "github.com/cometbft/cometbft/abci/types" + "github.com/stretchr/testify/require" + + eventsquery "github.com/pokt-network/poktroll/pkg/client/events_query" + "github.com/pokt-network/poktroll/pkg/either" + "github.com/pokt-network/poktroll/pkg/observable" + "github.com/pokt-network/poktroll/pkg/observable/channel" + "github.com/pokt-network/poktroll/testutil/testclient" + suppliertypes "github.com/pokt-network/poktroll/x/supplier/types" +) + +const ( + createClaimTimeoutDuration = 10 * time.Second + eitherEventsReplayBufferSize = 100 + msgClaimSenderQueryFmt = "tm.event='Tx' AND message.sender='%s'" + testServiceId = "anvil" + eitherEventsBzReplayObsKey = "eitherEventsBzReplayObsKey" + preExistingClaimsKey = "preExistingClaimsKey" +) + +func (s *suite) AfterTheSupplierCreatesAClaimForTheSessionForServiceForApplication(serviceId, appName string) { + var ctx, done = context.WithCancel(context.Background()) + + // TODO_CONSIDERATION: if this test suite gets more complex, it might make + // sense to refactor this key into a function that takes serviceId and appName + // as arguments and returns the key. + eitherEventsBzReplayObs := s.scenarioState[eitherEventsBzReplayObsKey].(observable.ReplayObservable[either.Bytes]) + + // TODO(#220): refactor to use EventsReplayClient once available. + channel.ForEach[either.Bytes]( + ctx, eitherEventsBzReplayObs, + func(_ context.Context, eitherEventBz either.Bytes) { + eventBz, err := eitherEventBz.ValueOrError() + require.NoError(s, err) + + if strings.Contains(string(eventBz), "jsonrpc") { + return + } + + // Unmarshal event data into a TxEventResponse object. + txEvent := &abci.TxResult{} + err = json.Unmarshal(eventBz, txEvent) + require.NoError(s, err) + + var found bool + for _, event := range txEvent.Result.Events { + for _, attribute := range event.Attributes { + if attribute.Key == "action" { + require.Equal( + s, "/pocket.supplier.MsgCreateClaim", + attribute.Value, + ) + found = true + break + } + } + if found { + break + } + } + require.Truef(s, found, "unable to find event action attribute") + + done() + }, + ) + + select { + case <-ctx.Done(): + case <-time.After(createClaimTimeoutDuration): + s.Fatal("timed out waiting for claim to be created") + } +} + +func (s *suite) TheClaimCreatedBySupplierForServiceForApplicationShouldBePersistedOnchain(supplierName, serviceId, appName string) { + ctx := context.Background() + + claimsRes, err := s.supplierQueryClient.AllClaims(ctx, &suppliertypes.QueryAllClaimsRequest{ + Filter: &suppliertypes.QueryAllClaimsRequest_SupplierAddress{ + SupplierAddress: accNameToAddrMap[supplierName], + }, + }) + require.NoError(s, err) + require.NotNil(s, claimsRes) + + // Assert that the number of claims has increased by one. + preExistingClaims := s.scenarioState[preExistingClaimsKey].([]suppliertypes.Claim) + require.Len(s, claimsRes.Claim, len(preExistingClaims)+1) + + // TODO_IMPROVE: assert that the root hash of the claim contains the correct + // SMST sum. The sum can be retrieved by parsing the last 8 bytes as a + // binary-encoded uint64; e.g. something like: + // `binary.Uvarint(claim.RootHash[len(claim.RootHash-8):])` + + // TODO_IMPROVE: add assertions about serviceId and appName and/or incorporate + // them into the scenarioState key(s). + + claim := claimsRes.Claim[0] + require.Equal(s, accNameToAddrMap[supplierName], claim.SupplierAddress) +} + +func (s *suite) TheSupplierHasServicedASessionWithRelaysForServiceForApplication(supplierName, relayCountStr, serviceId, appName string) { + ctx := context.Background() + + relayCount, err := strconv.Atoi(relayCountStr) + require.NoError(s, err) + + // Query for any existing claims so that we can compensate for them in the + // future assertions about changes in on-chain claims. + claimsRes, err := s.supplierQueryClient.AllClaims(ctx, &suppliertypes.QueryAllClaimsRequest{}) + require.NoError(s, err) + s.scenarioState[preExistingClaimsKey] = claimsRes.Claim + + // Construct an events query client to listen for tx events from the supplier. + msgSenderQuery := fmt.Sprintf(msgClaimSenderQueryFmt, accNameToAddrMap[supplierName]) + + // TODO_TECHDEBT(#220): refactor to use EventsReplayClient once available. + eventsQueryClient := eventsquery.NewEventsQueryClient(testclient.CometLocalWebsocketURL) + eitherEventsBzObs, err := eventsQueryClient.EventsBytes(ctx, msgSenderQuery) + require.NoError(s, err) + + eitherEventsBytesObs := observable.Observable[either.Bytes](eitherEventsBzObs) + eitherEventsBzRelayObs := channel.ToReplayObservable(ctx, eitherEventsReplayBufferSize, eitherEventsBytesObs) + s.scenarioState[eitherEventsBzReplayObsKey] = eitherEventsBzRelayObs + + s.sendRelaysForSession( + appName, + supplierName, + testServiceId, + relayCount, + ) +} + +func (s *suite) sendRelaysForSession( + appName string, + supplierName string, + serviceId string, + relayLimit int, +) { + s.TheApplicationIsStakedForService(appName, serviceId) + s.TheSupplierIsStakedForService(supplierName, serviceId) + s.TheSessionForApplicationAndServiceContainsTheSupplier(appName, serviceId, supplierName) + + // TODO_IMPROVE/TODO_COMMUNITY: hard-code a default set of RPC calls to iterate over for coverage. + data := `{"jsonrpc":"2.0","method":"eth_blockNumber","params":[],"id":1}` + + for i := 0; i < relayLimit; i++ { + s.TheApplicationSendsTheSupplierARequestForServiceWithData(appName, supplierName, serviceId, data) + s.TheApplicationReceivesASuccessfulRelayResponseSignedBy(appName, supplierName) + } +} diff --git a/pkg/client/tx/client.go b/pkg/client/tx/client.go index 39f5208e0..0317e39eb 100644 --- a/pkg/client/tx/client.go +++ b/pkg/client/tx/client.go @@ -9,8 +9,8 @@ import ( "fmt" "sync" + "cosmossdk.io/api/tendermint/abci" "cosmossdk.io/depinject" - abciTypes "github.com/cometbft/cometbft/abci/types" comettypes "github.com/cometbft/cometbft/types" cosmostypes "github.com/cosmos/cosmos-sdk/types" "go.uber.org/multierr" @@ -87,11 +87,10 @@ type ( // TxEvent is used to deserialize incoming websocket messages from // the transactions subscription. -type TxEvent struct { - // Tx is the binary representation of the tx hash. - Tx []byte `json:"tx"` - Events []abciTypes.Event `json:"events"` -} +// +// TODO_CONSIDERATION: either expose this via an interface and unexport this type, +// or remove it altogether. +type TxEvent = abci.TxResult // NewTxClient attempts to construct a new TxClient using the given dependencies // and options. diff --git a/testutil/keeper/supplier.go b/testutil/keeper/supplier.go index d54095fd6..8bd600c27 100644 --- a/testutil/keeper/supplier.go +++ b/testutil/keeper/supplier.go @@ -15,12 +15,16 @@ import ( "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" - mocks "github.com/pokt-network/poktroll/testutil/supplier/mocks" + "github.com/pokt-network/poktroll/testutil/supplier" + "github.com/pokt-network/poktroll/testutil/supplier/mocks" + sessiontypes "github.com/pokt-network/poktroll/x/session/types" "github.com/pokt-network/poktroll/x/supplier/keeper" "github.com/pokt-network/poktroll/x/supplier/types" ) -func SupplierKeeper(t testing.TB) (*keeper.Keeper, sdk.Context) { +func SupplierKeeper(t testing.TB, sessionByAppAddr supplier.SessionsByAppAddress) (*keeper.Keeper, sdk.Context) { + t.Helper() + storeKey := sdk.NewKVStoreKey(types.StoreKey) memStoreKey := storetypes.NewMemoryStoreKey(types.MemStoreKey) @@ -38,6 +42,36 @@ func SupplierKeeper(t testing.TB) (*keeper.Keeper, sdk.Context) { mockBankKeeper.EXPECT().DelegateCoinsFromAccountToModule(gomock.Any(), gomock.Any(), types.ModuleName, gomock.Any()).AnyTimes() mockBankKeeper.EXPECT().UndelegateCoinsFromModuleToAccount(gomock.Any(), types.ModuleName, gomock.Any(), gomock.Any()).AnyTimes() + mockSessionKeeper := mocks.NewMockSessionKeeper(ctrl) + mockSessionKeeper.EXPECT(). + GetSession(gomock.AssignableToTypeOf(sdk.Context{}), gomock.Any()). + DoAndReturn( + func( + ctx sdk.Context, + req *sessiontypes.QueryGetSessionRequest, + ) (*sessiontypes.QueryGetSessionResponse, error) { + session, ok := sessionByAppAddr[req.GetApplicationAddress()] + require.Truef(t, ok, "application address not provided during mock construction: %q", req.ApplicationAddress) + + return &sessiontypes.QueryGetSessionResponse{ + Session: &sessiontypes.Session{ + Header: &sessiontypes.SessionHeader{ + ApplicationAddress: session.GetApplication().GetAddress(), + Service: req.GetService(), + SessionStartBlockHeight: 1, + SessionId: session.GetSessionId(), + SessionEndBlockHeight: 5, + }, + SessionId: session.GetSessionId(), + SessionNumber: 1, + NumBlocksPerSession: session.GetNumBlocksPerSession(), + Application: session.GetApplication(), + Suppliers: session.GetSuppliers(), + }, + }, nil + }, + ).AnyTimes() + paramsSubspace := typesparams.NewSubspace(cdc, types.Amino, storeKey, @@ -52,6 +86,7 @@ func SupplierKeeper(t testing.TB) (*keeper.Keeper, sdk.Context) { mockBankKeeper, ) + k.SupplySessionKeeper(mockSessionKeeper) ctx := sdk.NewContext(stateStore, tmproto.Header{}, false, log.NewNopLogger()) diff --git a/testutil/network/network.go b/testutil/network/network.go index d05c283be..11a5f18ab 100644 --- a/testutil/network/network.go +++ b/testutil/network/network.go @@ -126,6 +126,27 @@ func DefaultApplicationModuleGenesisState(t *testing.T, n int) *apptypes.Genesis return state } +// ApplicationModuleGenesisStateWithAccount generates a GenesisState object with +// a single application for each of the given addresses. +func ApplicationModuleGenesisStateWithAddresses(t *testing.T, addresses []string) *apptypes.GenesisState { + t.Helper() + state := apptypes.DefaultGenesis() + for _, addr := range addresses { + application := apptypes.Application{ + Address: addr, + Stake: &sdk.Coin{Denom: "upokt", Amount: sdk.NewInt(10000)}, + ServiceConfigs: []*sharedtypes.ApplicationServiceConfig{ + { + Service: &sharedtypes.Service{Id: "svc1"}, + }, + }, + } + state.ApplicationList = append(state.ApplicationList, application) + } + + return state +} + // DefaultGatewayModuleGenesisState generates a GenesisState object with a given number of gateways. // It returns the populated GenesisState object. func DefaultGatewayModuleGenesisState(t *testing.T, n int) *gatewaytypes.GenesisState { @@ -173,8 +194,9 @@ func DefaultSupplierModuleGenesisState(t *testing.T, n int) *suppliertypes.Genes return state } -// SupplierModuleGenesisStateWithAccount generates a GenesisState object with a single supplier with the given address. -func SupplierModuleGenesisStateWithAccounts(t *testing.T, addresses []string) *suppliertypes.GenesisState { +// SupplierModuleGenesisStateWithAddresses generates a GenesisState object with +// a single supplier for each of the given addresses. +func SupplierModuleGenesisStateWithAddresses(t *testing.T, addresses []string) *suppliertypes.GenesisState { t.Helper() state := suppliertypes.DefaultGenesis() for _, addr := range addresses { diff --git a/testutil/supplier/fixtures.go b/testutil/supplier/fixtures.go new file mode 100644 index 000000000..b13f1d74d --- /dev/null +++ b/testutil/supplier/fixtures.go @@ -0,0 +1,113 @@ +package supplier + +import ( + "testing" + + sdktypes "github.com/cosmos/cosmos-sdk/types" + + apptypes "github.com/pokt-network/poktroll/x/application/types" + sessiontypes "github.com/pokt-network/poktroll/x/session/types" + sharedtypes "github.com/pokt-network/poktroll/x/shared/types" +) + +const ( + testSessionNumber = 1 + testBlockHeight = 1 + testBlocksPerSession = 4 + testSessionId = "mock_session_id" +) + +// SessionsByAppAddress is a map of session fixtures where the key is the +// application address and the value is the session fixture. +type SessionsByAppAddress map[string]sessiontypes.Session + +// AppSupplierPair is a pairing of an application and a supplier address. +type AppSupplierPair struct { + AppAddr string + SupplierAddr string +} + +// NewSessionFixturesWithPairings creates a map of session fixtures where the key +// is the application address and the value is the session fixture. App/supplier +// addresses are expected to be provided in alternating order (as pairs). The same +// app and/or supplier may be given more than once but only distinct pairs will +// be added to the session fixtures map. +func NewSessionFixturesWithPairings( + t *testing.T, + service *sharedtypes.Service, + appSupplierPairs ...AppSupplierPair, +) SessionsByAppAddress { + t.Helper() + + // Initialize the session fixtures map. + sessionFixturesByAppAddr := make(SessionsByAppAddress) + + // Iterate over the app and supplier address pairs (two indices at a time), + // and create a session fixture for each app address. + for _, appSupplierPair := range appSupplierPairs { + application := newApplication(t, appSupplierPair.AppAddr, service) + supplier := newSupplier(t, appSupplierPair.SupplierAddr, service) + + if session, ok := sessionFixturesByAppAddr[appSupplierPair.AppAddr]; ok { + session.Suppliers = append(session.Suppliers, supplier) + continue + } + + sessionFixturesByAppAddr[appSupplierPair.AppAddr] = sessiontypes.Session{ + Header: &sessiontypes.SessionHeader{ + ApplicationAddress: appSupplierPair.AppAddr, + Service: service, + SessionStartBlockHeight: testBlockHeight, + SessionId: testSessionId, + SessionEndBlockHeight: testBlockHeight + testBlocksPerSession, + }, + SessionId: testSessionId, + SessionNumber: testSessionNumber, + NumBlocksPerSession: testBlocksPerSession, + Application: application, + Suppliers: []*sharedtypes.Supplier{ + newSupplier(t, appSupplierPair.SupplierAddr, service), + }, + } + } + + return sessionFixturesByAppAddr +} + +// newSuppliers configures a supplier for the services provided and nil endpoints. +func newSupplier(t *testing.T, addr string, services ...*sharedtypes.Service) *sharedtypes.Supplier { + t.Helper() + + serviceConfigs := make([]*sharedtypes.SupplierServiceConfig, len(services)) + for i, service := range services { + serviceConfigs[i] = &sharedtypes.SupplierServiceConfig{ + Service: service, + Endpoints: nil, + } + } + + return &sharedtypes.Supplier{ + Address: addr, + Stake: &sdktypes.Coin{}, + Services: serviceConfigs, + } +} + +// newApplication configures an application for the services provided. +func newApplication(t *testing.T, addr string, services ...*sharedtypes.Service) *apptypes.Application { + t.Helper() + + serviceConfigs := make([]*sharedtypes.ApplicationServiceConfig, len(services)) + for i, service := range services { + serviceConfigs[i] = &sharedtypes.ApplicationServiceConfig{ + Service: service, + } + } + + return &apptypes.Application{ + Address: addr, + Stake: &sdktypes.Coin{}, + ServiceConfigs: serviceConfigs, + DelegateeGatewayAddresses: nil, + } +} diff --git a/testutil/testclient/localnet.go b/testutil/testclient/localnet.go index 61d5c0ad8..c0374a7ba 100644 --- a/testutil/testclient/localnet.go +++ b/testutil/testclient/localnet.go @@ -1,11 +1,10 @@ package testclient import ( - "testing" - "github.com/cosmos/cosmos-sdk/client" "github.com/cosmos/cosmos-sdk/client/flags" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" + "github.com/regen-network/gocuke" "github.com/spf13/pflag" "github.com/stretchr/testify/require" @@ -13,8 +12,19 @@ import ( "github.com/pokt-network/poktroll/cmd/pocketd/cmd" ) -// CometLocalWebsocketURL provides a default URL pointing to the localnet websocket endpoint. -const CometLocalWebsocketURL = "ws://localhost:36657/websocket" +const ( + // CometLocalTCPURL provides a default URL pointing to the localnet TCP endpoint. + // + // TODO_IMPROVE: It would be nice if the value could be set correctly based + // on whether the test using it is running in tilt or not. + CometLocalTCPURL = "tcp://sequencer-poktroll-sequencer:36657" + + // CometLocalWebsocketURL provides a default URL pointing to the localnet websocket endpoint. + // + // TODO_IMPROVE: It would be nice if the value could be set correctly based + // on whether the test using it is running in tilt or not. + CometLocalWebsocketURL = "ws://sequencer-poktroll-sequencer:36657/websocket" +) // EncodingConfig encapsulates encoding configurations for the Pocket application. var EncodingConfig = app.MakeEncodingConfig() @@ -35,7 +45,9 @@ func init() { // // Returns: // - A pointer to a populated client.Context instance suitable for localnet usage. -func NewLocalnetClientCtx(t *testing.T, flagSet *pflag.FlagSet) *client.Context { +func NewLocalnetClientCtx(t gocuke.TestingT, flagSet *pflag.FlagSet) *client.Context { + t.Helper() + homedir := app.DefaultNodeHome clientCtx := client.Context{}. WithCodec(EncodingConfig.Marshaler). @@ -57,9 +69,13 @@ func NewLocalnetClientCtx(t *testing.T, flagSet *pflag.FlagSet) *client.Context // // Returns: // - A flag set populated with flags tailored for localnet environments. -func NewLocalnetFlagSet(t *testing.T) *pflag.FlagSet { +func NewLocalnetFlagSet(t gocuke.TestingT) *pflag.FlagSet { + t.Helper() + mockFlagSet := pflag.NewFlagSet("test", pflag.ContinueOnError) - mockFlagSet.String(flags.FlagNode, "tcp://127.0.0.1:36657", "use localnet poktrolld node") + // TODO_IMPROVE: It would be nice if the value could be set correctly based + // on whether the test using it is running in tilt or not. + mockFlagSet.String(flags.FlagNode, CometLocalTCPURL, "use localnet poktrolld node") mockFlagSet.String(flags.FlagHome, "", "use localnet poktrolld node") mockFlagSet.String(flags.FlagKeyringBackend, "test", "use test keyring") err := mockFlagSet.Parse([]string{}) diff --git a/x/session/keeper/session_hydrator.go b/x/session/keeper/session_hydrator.go index 972b5d60c..5035b192f 100644 --- a/x/session/keeper/session_hydrator.go +++ b/x/session/keeper/session_hydrator.go @@ -156,7 +156,11 @@ func (k Keeper) hydrateSessionSuppliers(ctx sdk.Context, sh *sessionHydrator) er suppliers := k.supplierKeeper.GetAllSupplier(ctx) candidateSuppliers := make([]*sharedtypes.Supplier, 0) - for _, supplier := range suppliers { + for _, s := range suppliers { + // NB: Allocate a new heap variable as s is a value and we're appending + // to a slice of pointers; otherwise, we'd be appending new pointers to + // the same memory address containing the last supplier in the loop. + supplier := s // TODO_OPTIMIZE: If `supplier.Services` was a map[string]struct{}, we could eliminate `slices.Contains()`'s loop for _, supplierServiceConfig := range supplier.Services { if supplierServiceConfig.Service.Id == sh.sessionHeader.Service.Id { diff --git a/x/supplier/client/cli/helpers_test.go b/x/supplier/client/cli/helpers_test.go index d6066c28a..9c8089868 100644 --- a/x/supplier/client/cli/helpers_test.go +++ b/x/supplier/client/cli/helpers_test.go @@ -2,17 +2,39 @@ package cli_test import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" "strconv" "testing" + "cosmossdk.io/math" + "github.com/cosmos/cosmos-sdk/client" + "github.com/cosmos/cosmos-sdk/client/flags" + "github.com/cosmos/cosmos-sdk/codec" + codectypes "github.com/cosmos/cosmos-sdk/codec/types" + "github.com/cosmos/cosmos-sdk/crypto/keyring" + testcli "github.com/cosmos/cosmos-sdk/testutil/cli" + sdktypes "github.com/cosmos/cosmos-sdk/types" "github.com/stretchr/testify/require" "github.com/pokt-network/poktroll/cmd/pocketd/cmd" "github.com/pokt-network/poktroll/testutil/network" + "github.com/pokt-network/poktroll/testutil/testkeyring" + apptypes "github.com/pokt-network/poktroll/x/application/types" + sessiontypes "github.com/pokt-network/poktroll/x/session/types" sharedtypes "github.com/pokt-network/poktroll/x/shared/types" + "github.com/pokt-network/poktroll/x/supplier/client/cli" "github.com/pokt-network/poktroll/x/supplier/types" ) +// TODO_TECHDEBT: This should not be hardcoded once the num blocks per session is configurable. +const ( + numBlocksPerSession = 4 + testServiceId = "svc1" +) + // Dummy variable to avoid unused import error. var _ = strconv.IntSize @@ -32,3 +54,202 @@ func networkWithSupplierObjects(t *testing.T, n int) (*network.Network, []shared cfg.GenesisState[types.ModuleName] = buf return network.New(t, cfg), supplierGenesisState.SupplierList } + +// TODO_CONSIDERATION: perhaps this (and/or other similar helpers) can be refactored +// into something more generic and moved into a shared testutil package. +// TODO_TECHDEBT: refactor; this function has more than a single responsibility, +// which should be to configure and start the test network. The genesis state, +// accounts, and claims set up logic can probably be factored out and/or reduced. +func networkWithClaimObjects( + t *testing.T, + sessionCount int, + supplierCount int, + appCount int, +) (net *network.Network, claims []types.Claim) { + t.Helper() + + // Initialize a network config. + cfg := network.DefaultConfig() + + // Construct an in-memory keyring so that it can be populated and used prior + // to network start. + kr := keyring.NewInMemory(cfg.Codec) + // Populate the in-memmory keyring with as many pre-generated accounts as + // we expect to need for the test (i.e. appCount + supplierCount). + testkeyring.CreatePreGeneratedKeyringAccounts(t, kr, supplierCount+appCount) + + // Use the pre-generated accounts iterator to populate the supplier and + // application accounts and addresses lists for use in genesis state construction. + preGeneratedAccts := testkeyring.PreGeneratedAccounts().Clone() + + // Create a supplier for each session in numClaimsSessions and an app for each + // claim in numClaimsPerSession. + supplierAccts := make([]*testkeyring.PreGeneratedAccount, supplierCount) + supplierAddrs := make([]string, supplierCount) + for i := range supplierAccts { + account, ok := preGeneratedAccts.Next() + require.True(t, ok) + supplierAccts[i] = account + supplierAddrs[i] = account.Address.String() + } + appAccts := make([]*testkeyring.PreGeneratedAccount, appCount) + appAddrs := make([]string, appCount) + for i := range appAccts { + account, ok := preGeneratedAccts.Next() + require.True(t, ok) + appAccts[i] = account + appAddrs[i] = account.Address.String() + } + + // Construct supplier and application module genesis states given the account addresses. + supplierGenesisState := network.SupplierModuleGenesisStateWithAddresses(t, supplierAddrs) + supplierGenesisBuffer, err := cfg.Codec.MarshalJSON(supplierGenesisState) + require.NoError(t, err) + appGenesisState := network.ApplicationModuleGenesisStateWithAddresses(t, appAddrs) + appGenesisBuffer, err := cfg.Codec.MarshalJSON(appGenesisState) + require.NoError(t, err) + + // Add supplier and application module genesis states to the network config. + cfg.GenesisState[types.ModuleName] = supplierGenesisBuffer + cfg.GenesisState[apptypes.ModuleName] = appGenesisBuffer + + // Construct the network with the configuration. + net = network.New(t, cfg) + // Only the first validator's client context is populated. + // (see: https://pkg.go.dev/github.com/cosmos/cosmos-sdk/testutil/network#pkg-overview) + ctx := net.Validators[0].ClientCtx + // Overwrite the client context's keyring with the in-memory one that contains + // our pre-generated accounts. + ctx = ctx.WithKeyring(kr) + + // Initialize all the accounts + sequenceIndex := 1 + for _, supplierAcct := range supplierAccts { + network.InitAccountWithSequence(t, net, supplierAcct.Address, sequenceIndex) + sequenceIndex++ + } + for _, appAcct := range appAccts { + network.InitAccountWithSequence(t, net, appAcct.Address, sequenceIndex) + sequenceIndex++ + } + // need to wait for the account to be initialized in the next block + require.NoError(t, net.WaitForNextBlock()) + + // Create sessionCount * numClaimsPerSession claims for the supplier + sessionEndHeight := int64(1) + for sessionIdx := 0; sessionIdx < sessionCount; sessionIdx++ { + sessionEndHeight += numBlocksPerSession + for _, appAcct := range appAccts { + for _, supplierAcct := range supplierAccts { + claim := createClaim( + t, net, ctx, + supplierAcct.Address.String(), + sessionEndHeight, + appAcct.Address.String(), + ) + claims = append(claims, *claim) + // TODO_TECHDEBT(#196): Move this outside of the forloop so that the test iteration is faster + require.NoError(t, net.WaitForNextBlock()) + } + } + } + + return net, claims +} + +// encodeSessionHeader returns a base64 encoded string of a json +// serialized session header. +func encodeSessionHeader( + t *testing.T, + appAddr string, + sessionId string, + sessionStartHeight int64, +) string { + t.Helper() + + argSessionHeader := &sessiontypes.SessionHeader{ + ApplicationAddress: appAddr, + SessionStartBlockHeight: sessionStartHeight, + SessionId: sessionId, + SessionEndBlockHeight: sessionStartHeight + numBlocksPerSession, + Service: &sharedtypes.Service{Id: testServiceId}, + } + cdc := codec.NewProtoCodec(codectypes.NewInterfaceRegistry()) + sessionHeaderBz := cdc.MustMarshalJSON(argSessionHeader) + return base64.StdEncoding.EncodeToString(sessionHeaderBz) +} + +// createClaim sends a tx using the test CLI to create an on-chain claim +func createClaim( + t *testing.T, + net *network.Network, + ctx client.Context, + supplierAddr string, + sessionEndHeight int64, + appAddress string, +) *types.Claim { + t.Helper() + + rootHash := []byte("root_hash") + sessionStartHeight := sessionEndHeight - numBlocksPerSession + sessionId := getSessionId(t, net, appAddress, supplierAddr, sessionStartHeight) + sessionHeaderEncoded := encodeSessionHeader(t, appAddress, sessionId, sessionStartHeight) + rootHashEncoded := base64.StdEncoding.EncodeToString(rootHash) + + args := []string{ + sessionHeaderEncoded, + rootHashEncoded, + fmt.Sprintf("--%s=%s", flags.FlagFrom, supplierAddr), + fmt.Sprintf("--%s=true", flags.FlagSkipConfirmation), + fmt.Sprintf("--%s=%s", flags.FlagBroadcastMode, flags.BroadcastSync), + fmt.Sprintf("--%s=%s", flags.FlagFees, sdktypes.NewCoins(sdktypes.NewCoin(net.Config.BondDenom, math.NewInt(10))).String()), + } + + responseRaw, err := testcli.ExecTestCLICmd(ctx, cli.CmdCreateClaim(), args) + require.NoError(t, err) + var responseJson map[string]interface{} + err = json.Unmarshal(responseRaw.Bytes(), &responseJson) + require.NoError(t, err) + require.Equal(t, float64(0), responseJson["code"], "code is not 0 in the response: %v", responseJson) + + // TODO_TECHDEBT: Forward the actual claim in the response once the response is updated to return it. + return &types.Claim{ + SupplierAddress: supplierAddr, + SessionId: sessionId, + SessionEndBlockHeight: uint64(sessionEndHeight), + RootHash: rootHash, + } +} + +// getSessionId sends a query using the test CLI to get a session for the inputs provided. +// It is assumed that the supplierAddr will be in that session based on the test design, but this +// is insured in this function before it's successfully returned. +func getSessionId( + t *testing.T, + net *network.Network, + appAddr string, + supplierAddr string, + sessionStartHeight int64, +) string { + t.Helper() + ctx := context.TODO() + + sessionQueryClient := sessiontypes.NewQueryClient(net.Validators[0].ClientCtx) + res, err := sessionQueryClient.GetSession(ctx, &sessiontypes.QueryGetSessionRequest{ + ApplicationAddress: appAddr, + Service: &sharedtypes.Service{Id: testServiceId}, + BlockHeight: sessionStartHeight, + }) + require.NoError(t, err) + + var found bool + for _, supplier := range res.GetSession().GetSuppliers() { + if supplier.GetAddress() == supplierAddr { + found = true + break + } + } + require.Truef(t, found, "supplier address %s not found in session", supplierAddr) + + return res.Session.SessionId +} diff --git a/x/supplier/client/cli/query_claim_test.go b/x/supplier/client/cli/query_claim_test.go index c36abf15a..ac18e94b7 100644 --- a/x/supplier/client/cli/query_claim_test.go +++ b/x/supplier/client/cli/query_claim_test.go @@ -1,147 +1,31 @@ package cli_test import ( - "encoding/base64" - "encoding/json" "fmt" "testing" - sdkmath "cosmossdk.io/math" tmcli "github.com/cometbft/cometbft/libs/cli" - "github.com/cosmos/cosmos-sdk/client" "github.com/cosmos/cosmos-sdk/client/flags" - "github.com/cosmos/cosmos-sdk/codec" - cdctypes "github.com/cosmos/cosmos-sdk/codec/types" - "github.com/cosmos/cosmos-sdk/testutil" clitestutil "github.com/cosmos/cosmos-sdk/testutil/cli" - sdk "github.com/cosmos/cosmos-sdk/types" "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "github.com/pokt-network/poktroll/testutil/network" "github.com/pokt-network/poktroll/testutil/nullify" - "github.com/pokt-network/poktroll/testutil/sample" - sessiontypes "github.com/pokt-network/poktroll/x/session/types" - sharedtypes "github.com/pokt-network/poktroll/x/shared/types" "github.com/pokt-network/poktroll/x/supplier/client/cli" "github.com/pokt-network/poktroll/x/supplier/types" ) -// TODO_TECHDEBT: This should not be hardcoded once the num blocks per session is configurable -const numBlocksPerSession = 4 - -func encodeSessionHeader(t *testing.T, sessionId string, sessionEndHeight int64) string { - t.Helper() - - argSessionHeader := &sessiontypes.SessionHeader{ - ApplicationAddress: sample.AccAddress(), - SessionStartBlockHeight: sessionEndHeight - numBlocksPerSession, - SessionId: sessionId, - SessionEndBlockHeight: sessionEndHeight, - Service: &sharedtypes.Service{Id: "anvil"}, // hardcoded for simplicity - } - cdc := codec.NewProtoCodec(cdctypes.NewInterfaceRegistry()) - sessionHeaderBz := cdc.MustMarshalJSON(argSessionHeader) - return base64.StdEncoding.EncodeToString(sessionHeaderBz) -} - -func createClaim( - t *testing.T, - net *network.Network, - ctx client.Context, - supplierAddr string, - sessionId string, - sessionEndHeight int64, -) *types.Claim { - t.Helper() - - rootHash := []byte("root_hash") - sessionHeaderEncoded := encodeSessionHeader(t, sessionId, sessionEndHeight) - rootHashEncoded := base64.StdEncoding.EncodeToString(rootHash) - - args := []string{ - sessionHeaderEncoded, - rootHashEncoded, - fmt.Sprintf("--%s=%s", flags.FlagFrom, supplierAddr), - fmt.Sprintf("--%s=true", flags.FlagSkipConfirmation), - fmt.Sprintf("--%s=%s", flags.FlagBroadcastMode, flags.BroadcastSync), - fmt.Sprintf("--%s=%s", flags.FlagFees, sdk.NewCoins(sdk.NewCoin(net.Config.BondDenom, sdkmath.NewInt(10))).String()), - } - - responseRaw, err := clitestutil.ExecTestCLICmd(ctx, cli.CmdCreateClaim(), args) - require.NoError(t, err) - var responseJson map[string]interface{} - err = json.Unmarshal(responseRaw.Bytes(), &responseJson) - require.NoError(t, err) - require.Equal(t, float64(0), responseJson["code"], "code is not 0 in the response: %v", responseJson) - - return &types.Claim{ - SupplierAddress: supplierAddr, - SessionId: sessionId, - SessionEndBlockHeight: uint64(sessionEndHeight), - RootHash: rootHash, - } -} - -func networkWithClaimObjects( - t *testing.T, - numSessions int, - numClaimsPerSession int, -) (net *network.Network, claims []types.Claim) { - t.Helper() - - // Prepare the network - cfg := network.DefaultConfig() - net = network.New(t, cfg) - ctx := net.Validators[0].ClientCtx - - // Prepare the keyring for the supplier account - kr := ctx.Keyring - accounts := testutil.CreateKeyringAccounts(t, kr, numClaimsPerSession) - ctx = ctx.WithKeyring(kr) - - // Initialize all the accounts - for i, account := range accounts { - signatureSequenceNumber := i + 1 - network.InitAccountWithSequence(t, net, account.Address, signatureSequenceNumber) - } - // need to wait for the account to be initialized in the next block - require.NoError(t, net.WaitForNextBlock()) - - addresses := make([]string, len(accounts)) - for i, account := range accounts { - addresses[i] = account.Address.String() - } - - // Create one supplier - supplierGenesisState := network.SupplierModuleGenesisStateWithAccounts(t, addresses) - buf, err := cfg.Codec.MarshalJSON(supplierGenesisState) - require.NoError(t, err) - cfg.GenesisState[types.ModuleName] = buf - - // Create numSessions * numClaimsPerSession claims for the supplier - sessionEndHeight := int64(1) - for sessionNum := 0; sessionNum < numSessions; sessionNum++ { - sessionEndHeight += numBlocksPerSession - sessionId := fmt.Sprintf("session_id%d", sessionNum) - for claimNum := 0; claimNum < numClaimsPerSession; claimNum++ { - supplierAddr := addresses[claimNum] - claim := createClaim(t, net, ctx, supplierAddr, sessionId, sessionEndHeight) - claims = append(claims, *claim) - // TODO_TECHDEBT(#196): Move this outside of the forloop so that the test iteration is faster - require.NoError(t, net.WaitForNextBlock()) - } - } - - return net, claims -} - func TestClaim_Show(t *testing.T) { - numSessions := 1 - numClaimsPerSession := 2 + sessionCount := 1 + supplierCount := 3 + appCount := 3 - net, claims := networkWithClaimObjects(t, numSessions, numClaimsPerSession) + net, claims := networkWithClaimObjects( + t, sessionCount, + appCount, + supplierCount, + ) ctx := net.Validators[0].ClientCtx common := []string{ @@ -208,11 +92,19 @@ func TestClaim_Show(t *testing.T) { } func TestClaim_List(t *testing.T) { - numSessions := 2 - numClaimsPerSession := 5 - totalClaims := numSessions * numClaimsPerSession - - net, claims := networkWithClaimObjects(t, numSessions, numClaimsPerSession) + sessionCount := 2 + supplierCount := 4 + appCount := 3 + serviceCount := 1 + // Each supplier will submit a claim for each app x service combination (per session). + numClaimsPerSession := supplierCount * appCount * serviceCount + totalClaims := sessionCount * numClaimsPerSession + + net, claims := networkWithClaimObjects( + t, sessionCount, + supplierCount, + appCount, + ) ctx := net.Validators[0].ClientCtx prepareArgs := func(next []byte, offset, limit uint64, total bool) []string { @@ -287,11 +179,11 @@ func TestClaim_List(t *testing.T) { var resp types.QueryAllClaimsResponse require.NoError(t, net.Config.Codec.UnmarshalJSON(out.Bytes(), &resp)) - require.Equal(t, numSessions, int(resp.Pagination.Total)) require.ElementsMatch(t, nullify.Fill(expectedClaims), nullify.Fill(resp.Claim), ) + require.Equal(t, sessionCount*appCount, int(resp.Pagination.Total)) }) t.Run("BySession", func(t *testing.T) { @@ -312,11 +204,11 @@ func TestClaim_List(t *testing.T) { var resp types.QueryAllClaimsResponse require.NoError(t, net.Config.Codec.UnmarshalJSON(out.Bytes(), &resp)) - require.Equal(t, numClaimsPerSession, int(resp.Pagination.Total)) require.ElementsMatch(t, nullify.Fill(expectedClaims), nullify.Fill(resp.Claim), ) + require.Equal(t, supplierCount, int(resp.Pagination.Total)) }) t.Run("ByHeight", func(t *testing.T) { diff --git a/x/supplier/genesis_test.go b/x/supplier/genesis_test.go index ba9b521f8..da9245955 100644 --- a/x/supplier/genesis_test.go +++ b/x/supplier/genesis_test.go @@ -59,7 +59,7 @@ func TestGenesis(t *testing.T) { // this line is used by starport scaffolding # genesis/test/state } - k, ctx := keepertest.SupplierKeeper(t) + k, ctx := keepertest.SupplierKeeper(t, nil) supplier.InitGenesis(ctx, *k, genesisState) got := supplier.ExportGenesis(ctx, *k) require.NotNil(t, got) diff --git a/x/supplier/keeper/claim_test.go b/x/supplier/keeper/claim_test.go index c60403a45..4e38a43d0 100644 --- a/x/supplier/keeper/claim_test.go +++ b/x/supplier/keeper/claim_test.go @@ -31,7 +31,7 @@ func createNClaims(keeper *keeper.Keeper, ctx sdk.Context, n int) []types.Claim } func TestClaim_Get(t *testing.T) { - keeper, ctx := keepertest.SupplierKeeper(t) + keeper, ctx := keepertest.SupplierKeeper(t, nil) claims := createNClaims(keeper, ctx, 10) for _, claim := range claims { foundClaim, isClaimFound := keeper.GetClaim(ctx, @@ -46,7 +46,7 @@ func TestClaim_Get(t *testing.T) { } } func TestClaim_Remove(t *testing.T) { - keeper, ctx := keepertest.SupplierKeeper(t) + keeper, ctx := keepertest.SupplierKeeper(t, nil) claims := createNClaims(keeper, ctx, 10) for _, claim := range claims { keeper.RemoveClaim(ctx, @@ -62,7 +62,7 @@ func TestClaim_Remove(t *testing.T) { } func TestClaim_GetAll(t *testing.T) { - keeper, ctx := keepertest.SupplierKeeper(t) + keeper, ctx := keepertest.SupplierKeeper(t, nil) claims := createNClaims(keeper, ctx, 10) // Get all the claims and check if they match @@ -74,7 +74,7 @@ func TestClaim_GetAll(t *testing.T) { } func TestClaim_GetAll_ByAddress(t *testing.T) { - keeper, ctx := keepertest.SupplierKeeper(t) + keeper, ctx := keepertest.SupplierKeeper(t, nil) claims := createNClaims(keeper, ctx, 10) // Get all claims for a given address @@ -86,7 +86,7 @@ func TestClaim_GetAll_ByAddress(t *testing.T) { } func TestClaim_GetAll_ByHeight(t *testing.T) { - keeper, ctx := keepertest.SupplierKeeper(t) + keeper, ctx := keepertest.SupplierKeeper(t, nil) claims := createNClaims(keeper, ctx, 10) // Get all claims for a given ending session block height @@ -98,7 +98,7 @@ func TestClaim_GetAll_ByHeight(t *testing.T) { } func TestClaim_GetAll_BySession(t *testing.T) { - keeper, ctx := keepertest.SupplierKeeper(t) + keeper, ctx := keepertest.SupplierKeeper(t, nil) claims := createNClaims(keeper, ctx, 10) // Get all claims for a given ending session block height diff --git a/x/supplier/keeper/keeper.go b/x/supplier/keeper/keeper.go index d77218231..709c94c30 100644 --- a/x/supplier/keeper/keeper.go +++ b/x/supplier/keeper/keeper.go @@ -19,7 +19,8 @@ type ( memKey storetypes.StoreKey paramstore paramtypes.Subspace - bankKeeper types.BankKeeper + bankKeeper types.BankKeeper + sessionKeeper types.SessionKeeper } ) @@ -49,3 +50,11 @@ func NewKeeper( func (k Keeper) Logger(ctx sdk.Context) log.Logger { return ctx.Logger().With("module", fmt.Sprintf("x/%s", types.ModuleName)) } + +// TODO_TECHDEBT: Evaluate if this is still necessary after the upgrade to cosmos 0.5x +// SupplySessionKeeper assigns the session keeper dependency of this supplier +// keeper. This MUST be done as a separate step from construction because there +// is a circular dependency between the supplier and session keepers. +func (k *Keeper) SupplySessionKeeper(sessionKeeper types.SessionKeeper) { + k.sessionKeeper = sessionKeeper +} diff --git a/x/supplier/keeper/msg_server_create_claim.go b/x/supplier/keeper/msg_server_create_claim.go index 7a7722294..7d911f95a 100644 --- a/x/supplier/keeper/msg_server_create_claim.go +++ b/x/supplier/keeper/msg_server_create_claim.go @@ -5,10 +5,11 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" - "github.com/pokt-network/poktroll/x/supplier/types" + sessiontypes "github.com/pokt-network/poktroll/x/session/types" + suppliertypes "github.com/pokt-network/poktroll/x/supplier/types" ) -func (k msgServer) CreateClaim(goCtx context.Context, msg *types.MsgCreateClaim) (*types.MsgCreateClaimResponse, error) { +func (k msgServer) CreateClaim(goCtx context.Context, msg *suppliertypes.MsgCreateClaim) (*suppliertypes.MsgCreateClaimResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) logger := k.Logger(ctx).With("method", "CreateClaim") @@ -16,44 +17,85 @@ func (k msgServer) CreateClaim(goCtx context.Context, msg *types.MsgCreateClaim) return nil, err } - claim := types.Claim{ - SupplierAddress: msg.SupplierAddress, - SessionId: msg.SessionHeader.SessionId, - SessionEndBlockHeight: uint64(msg.SessionHeader.SessionEndBlockHeight), - RootHash: msg.RootHash, + sessionReq := &sessiontypes.QueryGetSessionRequest{ + ApplicationAddress: msg.GetSessionHeader().GetApplicationAddress(), + Service: msg.GetSessionHeader().GetService(), + BlockHeight: msg.GetSessionHeader().GetSessionStartBlockHeight(), + } + sessionRes, err := k.Keeper.sessionKeeper.GetSession(goCtx, sessionReq) + if err != nil { + return nil, err } - k.Keeper.InsertClaim(ctx, claim) - logger.Info("created claim for supplier %s at session ending height %d", claim.SupplierAddress, claim.SessionEndBlockHeight) - logger.Info("TODO_INCOMPLETE: Handling actual claim business logic %s", claim.SessionId) + logger. + With( + "session_id", sessionRes.GetSession().GetSessionId(), + "session_end_height", msg.GetSessionHeader().GetSessionEndBlockHeight(), + "supplier", msg.GetSupplierAddress(), + ). + Debug("got sessionId for claim") - /* - TODO_INCOMPLETE: Handling the message + if sessionRes.Session.SessionId != msg.SessionHeader.SessionId { + return nil, suppliertypes.ErrSupplierInvalidSessionId.Wrapf( + "claimed sessionRes ID does not match on-chain sessionRes ID; expected %q, got %q", + sessionRes.Session.SessionId, + msg.SessionHeader.SessionId, + ) + } + + var found bool + for _, supplier := range sessionRes.GetSession().GetSuppliers() { + if supplier.Address == msg.GetSupplierAddress() { + found = true + break + } + } - ## Validation + if !found { + return nil, suppliertypes.ErrSupplierNotFound.Wrapf( + "supplier address %q in session ID %q", + msg.GetSupplierAddress(), + sessionRes.GetSession().GetSessionId(), + ) + } - ### Session validation - 1. [ ] claimed session ID matches on-chain session ID - 2. [ ] this supplier is in the session's suppliers list + logger. + With( + "session_id", sessionRes.GetSession().GetSessionId(), + "session_end_height", msg.GetSessionHeader().GetSessionEndBlockHeight(), + "supplier", msg.GetSupplierAddress(), + ). + Debug("validated claim") - ### Msg distribution validation (depends on session validation) + /* + TODO_INCOMPLETE: + + ### Msg distribution validation (depends on sessionRes validation) 1. [ ] governance-based earliest block offset 2. [ ] pseudo-randomize earliest block offset ### Claim validation - 1. [ ] session validation + 1. [x] sessionRes validation 2. [ ] msg distribution validation - - ## Persistence - 1. [ ] create claim message - - supplier address - - session header - - claim - 2. [ ] last block height commitment; derives: - - last block committed hash, must match proof path - - session ID (?) */ - _ = ctx - return &types.MsgCreateClaimResponse{}, nil + // Construct and insert claim after all validation. + claim := suppliertypes.Claim{ + SupplierAddress: msg.GetSupplierAddress(), + SessionId: msg.GetSessionHeader().GetSessionId(), + SessionEndBlockHeight: uint64(msg.GetSessionHeader().GetSessionEndBlockHeight()), + RootHash: msg.RootHash, + } + k.Keeper.InsertClaim(ctx, claim) + + logger. + With( + "session_id", claim.GetSessionId(), + "session_end_height", claim.GetSessionEndBlockHeight(), + "supplier", claim.GetSupplierAddress(), + ). + Debug("created claim") + + // TODO: return the claim in the response. + return &suppliertypes.MsgCreateClaimResponse{}, nil } diff --git a/x/supplier/keeper/msg_server_create_claim_test.go b/x/supplier/keeper/msg_server_create_claim_test.go new file mode 100644 index 000000000..70ac48c79 --- /dev/null +++ b/x/supplier/keeper/msg_server_create_claim_test.go @@ -0,0 +1,121 @@ +package keeper_test + +import ( + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" + + keepertest "github.com/pokt-network/poktroll/testutil/keeper" + "github.com/pokt-network/poktroll/testutil/sample" + "github.com/pokt-network/poktroll/testutil/supplier" + sessiontypes "github.com/pokt-network/poktroll/x/session/types" + sharedtypes "github.com/pokt-network/poktroll/x/shared/types" + "github.com/pokt-network/poktroll/x/supplier/keeper" + "github.com/pokt-network/poktroll/x/supplier/types" + suppliertypes "github.com/pokt-network/poktroll/x/supplier/types" +) + +const testServiceId = "svc1" + +func TestMsgServer_CreateClaim_Success(t *testing.T) { + appSupplierPair := supplier.AppSupplierPair{ + AppAddr: sample.AccAddress(), + SupplierAddr: sample.AccAddress(), + } + service := &sharedtypes.Service{Id: testServiceId} + sessionFixturesByAddr := supplier.NewSessionFixturesWithPairings(t, service, appSupplierPair) + + supplierKeeper, sdkCtx := keepertest.SupplierKeeper(t, sessionFixturesByAddr) + srv := keeper.NewMsgServerImpl(*supplierKeeper) + ctx := sdk.WrapSDKContext(sdkCtx) + + claimMsg := newTestClaimMsg(t) + claimMsg.SupplierAddress = appSupplierPair.SupplierAddr + claimMsg.SessionHeader.ApplicationAddress = appSupplierPair.AppAddr + + createClaimRes, err := srv.CreateClaim(ctx, claimMsg) + require.NoError(t, err) + require.NotNil(t, createClaimRes) + + claimRes, err := supplierKeeper.AllClaims(sdkCtx, &types.QueryAllClaimsRequest{}) + require.NoError(t, err) + + claims := claimRes.GetClaim() + require.Lenf(t, claims, 1, "expected 1 claim, got %d", len(claims)) + require.Equal(t, claimMsg.SessionHeader.SessionId, claims[0].SessionId) + require.Equal(t, claimMsg.SupplierAddress, claims[0].SupplierAddress) + require.Equal(t, uint64(claimMsg.SessionHeader.GetSessionEndBlockHeight()), claims[0].SessionEndBlockHeight) + require.Equal(t, claimMsg.RootHash, claims[0].RootHash) +} + +func TestMsgServer_CreateClaim_Error(t *testing.T) { + service := &sharedtypes.Service{Id: testServiceId} + appSupplierPair := supplier.AppSupplierPair{ + AppAddr: sample.AccAddress(), + SupplierAddr: sample.AccAddress(), + } + sessionFixturesByAppAddr := supplier.NewSessionFixturesWithPairings(t, service, appSupplierPair) + + supplierKeeper, sdkCtx := keepertest.SupplierKeeper(t, sessionFixturesByAppAddr) + srv := keeper.NewMsgServerImpl(*supplierKeeper) + ctx := sdk.WrapSDKContext(sdkCtx) + + tests := []struct { + desc string + claimMsgFn func(t *testing.T) *types.MsgCreateClaim + expectedErr error + }{ + { + desc: "on-chain session ID must match claim msg session ID", + claimMsgFn: func(t *testing.T) *types.MsgCreateClaim { + msg := newTestClaimMsg(t) + msg.SupplierAddress = appSupplierPair.SupplierAddr + msg.SessionHeader.ApplicationAddress = appSupplierPair.AppAddr + msg.SessionHeader.SessionId = "invalid_session_id" + + return msg + }, + expectedErr: types.ErrSupplierInvalidSessionId, + }, + { + desc: "claim msg supplier address must be in the session", + claimMsgFn: func(t *testing.T) *types.MsgCreateClaim { + msg := newTestClaimMsg(t) + msg.SessionHeader.ApplicationAddress = appSupplierPair.AppAddr + + // Overwrite supplier address to one not included in the session fixtures. + msg.SupplierAddress = sample.AccAddress() + + return msg + }, + expectedErr: types.ErrSupplierNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + createClaimRes, err := srv.CreateClaim(ctx, tt.claimMsgFn(t)) + require.ErrorIs(t, err, tt.expectedErr) + require.Nil(t, createClaimRes) + }) + } +} + +func newTestClaimMsg(t *testing.T) *suppliertypes.MsgCreateClaim { + t.Helper() + + return suppliertypes.NewMsgCreateClaim( + sample.AccAddress(), + &sessiontypes.SessionHeader{ + ApplicationAddress: sample.AccAddress(), + SessionStartBlockHeight: 1, + SessionId: "mock_session_id", + Service: &sharedtypes.Service{ + Id: "svc1", + Name: "svc1", + }, + }, + []byte{0, 0, 0, 0}, + ) +} diff --git a/x/supplier/keeper/msg_server_stake_supplier_test.go b/x/supplier/keeper/msg_server_stake_supplier_test.go index cd6158a81..bbc74cd6e 100644 --- a/x/supplier/keeper/msg_server_stake_supplier_test.go +++ b/x/supplier/keeper/msg_server_stake_supplier_test.go @@ -14,7 +14,7 @@ import ( ) func TestMsgServer_StakeSupplier_SuccessfulCreateAndUpdate(t *testing.T) { - k, ctx := keepertest.SupplierKeeper(t) + k, ctx := keepertest.SupplierKeeper(t, nil) srv := keeper.NewMsgServerImpl(*k) wctx := sdk.WrapSDKContext(ctx) @@ -92,7 +92,7 @@ func TestMsgServer_StakeSupplier_SuccessfulCreateAndUpdate(t *testing.T) { } func TestMsgServer_StakeSupplier_FailRestakingDueToInvalidServices(t *testing.T) { - k, ctx := keepertest.SupplierKeeper(t) + k, ctx := keepertest.SupplierKeeper(t, nil) srv := keeper.NewMsgServerImpl(*k) wctx := sdk.WrapSDKContext(ctx) @@ -173,7 +173,7 @@ func TestMsgServer_StakeSupplier_FailRestakingDueToInvalidServices(t *testing.T) } func TestMsgServer_StakeSupplier_FailLoweringStake(t *testing.T) { - k, ctx := keepertest.SupplierKeeper(t) + k, ctx := keepertest.SupplierKeeper(t, nil) srv := keeper.NewMsgServerImpl(*k) wctx := sdk.WrapSDKContext(ctx) diff --git a/x/supplier/keeper/msg_server_test.go b/x/supplier/keeper/msg_server_test.go index 7e4d01f27..b337090e3 100644 --- a/x/supplier/keeper/msg_server_test.go +++ b/x/supplier/keeper/msg_server_test.go @@ -13,7 +13,7 @@ import ( ) func setupMsgServer(t testing.TB) (types.MsgServer, context.Context) { - k, ctx := keepertest.SupplierKeeper(t) + k, ctx := keepertest.SupplierKeeper(t, nil) return keeper.NewMsgServerImpl(*k), sdk.WrapSDKContext(ctx) } diff --git a/x/supplier/keeper/msg_server_unstake_supplier_test.go b/x/supplier/keeper/msg_server_unstake_supplier_test.go index 993201971..2999a5e3c 100644 --- a/x/supplier/keeper/msg_server_unstake_supplier_test.go +++ b/x/supplier/keeper/msg_server_unstake_supplier_test.go @@ -14,7 +14,7 @@ import ( ) func TestMsgServer_UnstakeSupplier_Success(t *testing.T) { - k, ctx := keepertest.SupplierKeeper(t) + k, ctx := keepertest.SupplierKeeper(t, nil) srv := keeper.NewMsgServerImpl(*k) wctx := sdk.WrapSDKContext(ctx) @@ -68,7 +68,7 @@ func TestMsgServer_UnstakeSupplier_Success(t *testing.T) { } func TestMsgServer_UnstakeSupplier_FailIfNotStaked(t *testing.T) { - k, ctx := keepertest.SupplierKeeper(t) + k, ctx := keepertest.SupplierKeeper(t, nil) srv := keeper.NewMsgServerImpl(*k) wctx := sdk.WrapSDKContext(ctx) diff --git a/x/supplier/keeper/params_test.go b/x/supplier/keeper/params_test.go index 5a7e866fe..38f2e91b9 100644 --- a/x/supplier/keeper/params_test.go +++ b/x/supplier/keeper/params_test.go @@ -10,7 +10,7 @@ import ( ) func TestGetParams(t *testing.T) { - k, ctx := testkeeper.SupplierKeeper(t) + k, ctx := testkeeper.SupplierKeeper(t, nil) params := types.DefaultParams() k.SetParams(ctx, params) diff --git a/x/supplier/keeper/proof_test.go b/x/supplier/keeper/proof_test.go index c3fb4e753..a082a05bf 100644 --- a/x/supplier/keeper/proof_test.go +++ b/x/supplier/keeper/proof_test.go @@ -5,11 +5,12 @@ import ( "testing" sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" + keepertest "github.com/pokt-network/poktroll/testutil/keeper" "github.com/pokt-network/poktroll/testutil/nullify" "github.com/pokt-network/poktroll/x/supplier/keeper" "github.com/pokt-network/poktroll/x/supplier/types" - "github.com/stretchr/testify/require" ) // Prevent strconv unused error @@ -26,7 +27,7 @@ func createNProofs(keeper *keeper.Keeper, ctx sdk.Context, n int) []types.Proof } func TestProofGet(t *testing.T) { - keeper, ctx := keepertest.SupplierKeeper(t) + keeper, ctx := keepertest.SupplierKeeper(t, nil) items := createNProofs(keeper, ctx, 10) for _, item := range items { rst, found := keeper.GetProof(ctx, @@ -40,7 +41,7 @@ func TestProofGet(t *testing.T) { } } func TestProofRemove(t *testing.T) { - keeper, ctx := keepertest.SupplierKeeper(t) + keeper, ctx := keepertest.SupplierKeeper(t, nil) items := createNProofs(keeper, ctx, 10) for _, item := range items { keeper.RemoveProof(ctx, @@ -54,7 +55,7 @@ func TestProofRemove(t *testing.T) { } func TestProofGetAll(t *testing.T) { - keeper, ctx := keepertest.SupplierKeeper(t) + keeper, ctx := keepertest.SupplierKeeper(t, nil) items := createNProofs(keeper, ctx, 10) require.ElementsMatch(t, nullify.Fill(items), diff --git a/x/supplier/keeper/query_claim_test.go b/x/supplier/keeper/query_claim_test.go index dcae244e1..1957362fa 100644 --- a/x/supplier/keeper/query_claim_test.go +++ b/x/supplier/keeper/query_claim_test.go @@ -16,7 +16,7 @@ import ( ) func TestClaim_QuerySingle(t *testing.T) { - keeper, ctx := keepertest.SupplierKeeper(t) + keeper, ctx := keepertest.SupplierKeeper(t, nil) wctx := sdk.WrapSDKContext(ctx) claims := createNClaims(keeper, ctx, 2) tests := []struct { @@ -105,7 +105,7 @@ func TestClaim_QuerySingle(t *testing.T) { } func TestClaim_QueryPaginated(t *testing.T) { - keeper, ctx := keepertest.SupplierKeeper(t) + keeper, ctx := keepertest.SupplierKeeper(t, nil) wctx := sdk.WrapSDKContext(ctx) claims := createNClaims(keeper, ctx, 10) diff --git a/x/supplier/keeper/query_params_test.go b/x/supplier/keeper/query_params_test.go index d9b909305..37df8b949 100644 --- a/x/supplier/keeper/query_params_test.go +++ b/x/supplier/keeper/query_params_test.go @@ -11,7 +11,7 @@ import ( ) func TestParamsQuery(t *testing.T) { - keeper, ctx := testkeeper.SupplierKeeper(t) + keeper, ctx := testkeeper.SupplierKeeper(t, nil) wctx := sdk.WrapSDKContext(ctx) params := types.DefaultParams() keeper.SetParams(ctx, params) diff --git a/x/supplier/keeper/query_proof_test.go b/x/supplier/keeper/query_proof_test.go index d932b43cb..a082d6e10 100644 --- a/x/supplier/keeper/query_proof_test.go +++ b/x/supplier/keeper/query_proof_test.go @@ -6,19 +6,20 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/types/query" - keepertest "github.com/pokt-network/poktroll/testutil/keeper" - "github.com/pokt-network/poktroll/testutil/nullify" - "github.com/pokt-network/poktroll/x/supplier/types" "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + + keepertest "github.com/pokt-network/poktroll/testutil/keeper" + "github.com/pokt-network/poktroll/testutil/nullify" + "github.com/pokt-network/poktroll/x/supplier/types" ) // Prevent strconv unused error var _ = strconv.IntSize func TestProofQuerySingle(t *testing.T) { - keeper, ctx := keepertest.SupplierKeeper(t) + keeper, ctx := keepertest.SupplierKeeper(t, nil) wctx := sdk.WrapSDKContext(ctx) msgs := createNProofs(keeper, ctx, 2) tests := []struct { @@ -70,7 +71,7 @@ func TestProofQuerySingle(t *testing.T) { } func TestProofQueryPaginated(t *testing.T) { - keeper, ctx := keepertest.SupplierKeeper(t) + keeper, ctx := keepertest.SupplierKeeper(t, nil) wctx := sdk.WrapSDKContext(ctx) msgs := createNProofs(keeper, ctx, 5) diff --git a/x/supplier/keeper/query_supplier_test.go b/x/supplier/keeper/query_supplier_test.go index ad28a377b..bfd4f8019 100644 --- a/x/supplier/keeper/query_supplier_test.go +++ b/x/supplier/keeper/query_supplier_test.go @@ -19,7 +19,7 @@ import ( var _ = strconv.IntSize func TestSupplierQuerySingle(t *testing.T) { - keeper, ctx := keepertest.SupplierKeeper(t) + keeper, ctx := keepertest.SupplierKeeper(t, nil) wctx := sdk.WrapSDKContext(ctx) msgs := createNSupplier(keeper, ctx, 2) tests := []struct { @@ -71,7 +71,7 @@ func TestSupplierQuerySingle(t *testing.T) { } func TestSupplierQueryPaginated(t *testing.T) { - keeper, ctx := keepertest.SupplierKeeper(t) + keeper, ctx := keepertest.SupplierKeeper(t, nil) wctx := sdk.WrapSDKContext(ctx) msgs := createNSupplier(keeper, ctx, 5) diff --git a/x/supplier/keeper/supplier_test.go b/x/supplier/keeper/supplier_test.go index e7b901489..03b1e470a 100644 --- a/x/supplier/keeper/supplier_test.go +++ b/x/supplier/keeper/supplier_test.go @@ -50,7 +50,7 @@ func createNSupplier(keeper *keeper.Keeper, ctx sdk.Context, n int) []sharedtype } func TestSupplierGet(t *testing.T) { - keeper, ctx := keepertest.SupplierKeeper(t) + keeper, ctx := keepertest.SupplierKeeper(t, nil) suppliers := createNSupplier(keeper, ctx, 10) for _, supplier := range suppliers { supplierFound, isSupplierFound := keeper.GetSupplier(ctx, @@ -64,7 +64,7 @@ func TestSupplierGet(t *testing.T) { } } func TestSupplierRemove(t *testing.T) { - keeper, ctx := keepertest.SupplierKeeper(t) + keeper, ctx := keepertest.SupplierKeeper(t, nil) suppliers := createNSupplier(keeper, ctx, 10) for _, supplier := range suppliers { keeper.RemoveSupplier(ctx, @@ -78,7 +78,7 @@ func TestSupplierRemove(t *testing.T) { } func TestSupplierGetAll(t *testing.T) { - keeper, ctx := keepertest.SupplierKeeper(t) + keeper, ctx := keepertest.SupplierKeeper(t, nil) suppliers := createNSupplier(keeper, ctx, 10) require.ElementsMatch(t, nullify.Fill(suppliers), diff --git a/x/supplier/types/expected_keepers.go b/x/supplier/types/expected_keepers.go index e92e7d60e..f6f8b626c 100644 --- a/x/supplier/types/expected_keepers.go +++ b/x/supplier/types/expected_keepers.go @@ -1,10 +1,14 @@ package types -//go:generate mockgen -destination ../../../testutil/supplier/mocks/expected_keepers_mock.go -package mocks . AccountKeeper,BankKeeper +//go:generate mockgen -destination ../../../testutil/supplier/mocks/expected_keepers_mock.go -package mocks . AccountKeeper,BankKeeper,SessionKeeper import ( + "context" + sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/x/auth/types" + + sessiontypes "github.com/pokt-network/poktroll/x/session/types" ) // AccountKeeper defines the expected account keeper used for simulations (noalias) @@ -18,3 +22,7 @@ type BankKeeper interface { DelegateCoinsFromAccountToModule(ctx sdk.Context, senderAddr sdk.AccAddress, recipientModule string, amt sdk.Coins) error UndelegateCoinsFromModuleToAccount(ctx sdk.Context, senderModule string, recipientAddr sdk.AccAddress, amt sdk.Coins) error } + +type SessionKeeper interface { + GetSession(context.Context, *sessiontypes.QueryGetSessionRequest) (*sessiontypes.QueryGetSessionResponse, error) +}