Skip to content

Commit

Permalink
[Session Module] refactor: session module fetches on-chain params (#557)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel Olshansky <[email protected]>
  • Loading branch information
bryanchriswhite and Olshansk authored May 30, 2024
1 parent 4003ab3 commit 72d3e47
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 17 deletions.
1 change: 1 addition & 0 deletions testutil/keeper/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ func NewProofModuleKeepers(t testing.TB, opts ...ProofKeepersOpt) (_ *ProofModul
sessionmocks.NewMockBankKeeper(ctrl),
appKeeper,
supplierKeeper,
sharedKeeper,
)
require.NoError(t, sessionKeeper.SetParams(ctx, sessiontypes.DefaultParams()))

Expand Down
13 changes: 13 additions & 0 deletions testutil/keeper/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ func SessionKeeper(t testing.TB) (keeper.Keeper, context.Context) {

mockAppKeeper := defaultAppKeeperMock(t)
mockSupplierKeeper := defaultSupplierKeeperMock(t)
mockSharedKeeper := defaultSharedKeeperMock(t)

k := keeper.NewKeeper(
cdc,
Expand All @@ -144,6 +145,7 @@ func SessionKeeper(t testing.TB) (keeper.Keeper, context.Context) {
mockBankKeeper,
mockAppKeeper,
mockSupplierKeeper,
mockSharedKeeper,
)

// TODO_TECHDEBT: See the comment at the bottom of this file explaining
Expand Down Expand Up @@ -216,6 +218,17 @@ func defaultSupplierKeeperMock(t testing.TB) types.SupplierKeeper {
return mockSupplierKeeper
}

func defaultSharedKeeperMock(t testing.TB) types.SharedKeeper {
t.Helper()
ctrl := gomock.NewController(t)

mockSharedKeeper := mocks.NewMockSharedKeeper(ctrl)
mockSharedKeeper.EXPECT().GetParams(gomock.Any()).
Return(sharedtypes.DefaultParams()).
AnyTimes()
return mockSharedKeeper
}

// TODO_TECHDEBT: Figure out how to vary the supplierKeep on a per test basis with exposing `SupplierKeeper publically`

// type option[V any] func(k *keeper.Keeper)
Expand Down
1 change: 1 addition & 0 deletions testutil/keeper/tokenomics.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ func NewTokenomicsModuleKeepers(
bankKeeper,
appKeeper,
supplierKeeper,
sharedKeeper,
)
require.NoError(t, sessionKeeper.SetParams(ctx, sessiontypes.DefaultParams()))

Expand Down
4 changes: 2 additions & 2 deletions testutil/testclient/testqueryclients/sessionquerier.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func NewTestSessionQueryClient(
serviceId string,
blockHeight int64,
) (session *sessiontypes.Session, err error) {
sessionId, _ := sessionkeeper.GetSessionId(address, serviceId, blockHashBz, blockHeight)
sessionId, _ := sessionkeeper.GetSessionIdWithDefaultParams(address, serviceId, blockHashBz, blockHeight)

session, ok := sessionsMap[sessionId]
if !ok {
Expand All @@ -73,7 +73,7 @@ func AddToExistingSessions(
) {
t.Helper()

sessionId, _ := sessionkeeper.GetSessionId(appAddress, serviceId, blockHashBz, blockHeight)
sessionId, _ := sessionkeeper.GetSessionIdWithDefaultParams(appAddress, serviceId, blockHashBz, blockHeight)

session := sessiontypes.Session{
Header: &sessiontypes.SessionHeader{
Expand Down
2 changes: 1 addition & 1 deletion testutil/testproxy/relayerproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ func GenerateRelayRequest(
payload []byte,
) *servicetypes.RelayRequest {
appAddress := GetAddressFromPrivateKey(test, privKey)
sessionId, _ := sessionkeeper.GetSessionId(appAddress, serviceId, blockHashBz, blockHeight)
sessionId, _ := sessionkeeper.GetSessionIdWithDefaultParams(appAddress, serviceId, blockHashBz, blockHeight)

return &servicetypes.RelayRequest{
Meta: servicetypes.RelayRequestMetadata{
Expand Down
3 changes: 3 additions & 0 deletions x/session/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type (
bankKeeper types.BankKeeper
applicationKeeper types.ApplicationKeeper
supplierKeeper types.SupplierKeeper
sharedKeeper types.SharedKeeper
}
)

Expand All @@ -41,6 +42,7 @@ func NewKeeper(
bankKeeper types.BankKeeper,
applicationKeeper types.ApplicationKeeper,
supplierKeeper types.SupplierKeeper,
sharedKeeper types.SharedKeeper,
) Keeper {
if _, err := sdk.AccAddressFromBech32(authority); err != nil {
panic(fmt.Sprintf("invalid authority address: %s", authority))
Expand All @@ -56,6 +58,7 @@ func NewKeeper(
bankKeeper: bankKeeper,
applicationKeeper: applicationKeeper,
supplierKeeper: supplierKeeper,
sharedKeeper: sharedKeeper,
}
}

Expand Down
58 changes: 45 additions & 13 deletions x/session/keeper/session_hydrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ func (k Keeper) hydrateSessionMetadata(ctx context.Context, sh *sessionHydrator)
)
}

// TODO_UPNEXT(#517): Refactor session module to use current on-chain shared
// parameters instead of their corresponding constant stand-ins.

sh.session.NumBlocksPerSession = shared.NumBlocksPerSession
sh.session.SessionNumber = shared.GetSessionNumberWithDefaultParams(sh.blockHeight)

sh.sessionHeader.SessionStartBlockHeight = shared.GetSessionStartHeightWithDefaultParams(sh.blockHeight)
sh.sessionHeader.SessionEndBlockHeight = shared.GetSessionEndHeightWithDefaultParams(sh.blockHeight)
// TODO_BLOCKER(#543): If the num_blocks_per_session param has ever been changed,
// this function may cause unexpected behavior for historical sessions.
sharedParams := k.sharedKeeper.GetParams(ctx)
sh.session.NumBlocksPerSession = int64(sharedParams.NumBlocksPerSession)
sh.session.SessionNumber = shared.GetSessionNumber(&sharedParams, sh.blockHeight)

sh.sessionHeader.SessionStartBlockHeight = shared.GetSessionStartHeight(&sharedParams, sh.blockHeight)
sh.sessionHeader.SessionEndBlockHeight = shared.GetSessionEndHeight(&sharedParams, sh.blockHeight)
return nil
}

Expand All @@ -121,7 +121,8 @@ func (k Keeper) hydrateSessionID(ctx context.Context, sh *sessionHydrator) error
return types.ErrSessionHydration.Wrapf("invalid service: %v", sh.sessionHeader.Service)
}

sh.sessionHeader.SessionId, sh.sessionIDBz = GetSessionId(
sh.sessionHeader.SessionId, sh.sessionIDBz = k.GetSessionId(
ctx,
sh.sessionHeader.ApplicationAddress,
sh.sessionHeader.Service.Id,
prevHashBz,
Expand Down Expand Up @@ -269,7 +270,37 @@ func sha3Hash(bz []byte) []byte {
// GetSessionId returns the string and bytes representation of the sessionId
// given the application public key, service ID, block hash, and block height
// that is used to get the session start block height.
func (k Keeper) GetSessionId(
ctx context.Context,
appPubKey,
serviceId string,
blockHashBz []byte,
blockHeight int64,
) (sessionId string, sessionIdBz []byte) {
sharedParams := k.sharedKeeper.GetParams(ctx)
return GetSessionId(&sharedParams, appPubKey, serviceId, blockHashBz, blockHeight)
}

// GetSessionIdWithDefaultParams returns the string and bytes representation of the
// sessionId for the session containing blockHeight, given the default shared on-chain
// parameters, application public key, service ID, and block hash.
//
// TODO_TECHDEBT(#517): Move this to a shared testutil.
func GetSessionIdWithDefaultParams(
appPubKey,
serviceId string,
blockHashBz []byte,
blockHeight int64,
) (sessionId string, sessionIdBz []byte) {
sharedParams := sharedtypes.DefaultParams()
return GetSessionId(&sharedParams, appPubKey, serviceId, blockHashBz, blockHeight)
}

// GetSessionId returns the string and bytes representation of the sessionId for the
// session containing blockHeight, given the shared on-chain parameters, application
// public key, service ID, and block hash.
func GetSessionId(
sharedParams *sharedtypes.Params,
appPubKey,
serviceId string,
blockHashBz []byte,
Expand All @@ -278,7 +309,7 @@ func GetSessionId(
appPubKeyBz := []byte(appPubKey)
serviceIdBz := []byte(serviceId)

blockHeightBz := getSessionStartBlockHeightBz(blockHeight)
blockHeightBz := getSessionStartBlockHeightBz(sharedParams, blockHeight)
sessionIdBz = concatWithDelimiter(
SessionIDComponentDelimiter,
blockHashBz,
Expand All @@ -292,9 +323,10 @@ func GetSessionId(
}

// getSessionStartBlockHeightBz returns the bytes representation of the session
// start block height given the block height.
func getSessionStartBlockHeightBz(blockHeight int64) []byte {
sessionStartBlockHeight := shared.GetSessionStartHeightWithDefaultParams(blockHeight)
// start height for the session containing blockHeight, given the shared on-chain
// parameters.
func getSessionStartBlockHeightBz(sharedParams *sharedtypes.Params, blockHeight int64) []byte {
sessionStartBlockHeight := shared.GetSessionStartHeight(sharedParams, blockHeight)
sessionStartBlockHeightBz := make([]byte, 8)
binary.LittleEndian.PutUint64(sessionStartBlockHeightBz, uint64(sessionStartBlockHeight))
return sessionStartBlockHeightBz
Expand Down
2 changes: 2 additions & 0 deletions x/session/module/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ type ModuleInputs struct {
BankKeeper types.BankKeeper
ApplicationKeeper types.ApplicationKeeper
SupplierKeeper types.SupplierKeeper
SharedKeeper types.SharedKeeper
}

type ModuleOutputs struct {
Expand All @@ -214,6 +215,7 @@ func ProvideModule(in ModuleInputs) ModuleOutputs {
in.BankKeeper,
in.ApplicationKeeper,
in.SupplierKeeper,
in.SharedKeeper,
)
m := NewAppModule(
in.Cdc,
Expand Down
8 changes: 7 additions & 1 deletion x/session/types/expected_keepers.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//go:generate mockgen -destination ../../../testutil/session/mocks/expected_keepers_mock.go -package mocks . AccountKeeper,BankKeeper,ApplicationKeeper,SupplierKeeper
//go:generate mockgen -destination ../../../testutil/session/mocks/expected_keepers_mock.go -package mocks . AccountKeeper,BankKeeper,ApplicationKeeper,SupplierKeeper,SharedKeeper

package types

Expand Down Expand Up @@ -26,6 +26,12 @@ type ApplicationKeeper interface {
GetApplication(ctx context.Context, address string) (app apptypes.Application, found bool)
}

// SupplierKeeper defines the expected interface needed to retrieve suppliers
type SupplierKeeper interface {
GetAllSuppliers(ctx context.Context) (suppliers []sharedtypes.Supplier)
}

// SharedKeeper defines the expected interface needed to retrieve shared parameters
type SharedKeeper interface {
GetParams(ctx context.Context) (params sharedtypes.Params)
}

0 comments on commit 72d3e47

Please sign in to comment.