diff --git a/pkg/client/query/accquerier.go b/pkg/client/query/accquerier.go index 932db5836..6d987ca6d 100644 --- a/pkg/client/query/accquerier.go +++ b/pkg/client/query/accquerier.go @@ -2,7 +2,6 @@ package query import ( "context" - "sync" "cosmossdk.io/depinject" cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" @@ -11,6 +10,7 @@ import ( grpc "github.com/cosmos/gogoproto/grpc" "github.com/pokt-network/poktroll/pkg/client" + "github.com/pokt-network/poktroll/pkg/polylog" ) var _ client.AccountQueryClient = (*accQuerier)(nil) @@ -21,11 +21,10 @@ var _ client.AccountQueryClient = (*accQuerier)(nil) type accQuerier struct { clientConn grpc.ClientConn accountQuerier accounttypes.QueryClient + logger polylog.Logger - // accountCache is a cache of accounts that have already been queried. - // TODO_TECHDEBT: Add a size limit to the cache and consider an LRU cache. - accountCache map[string]types.AccountI - accountCacheMu sync.Mutex + // accountsCache caches accountQueryClient.Account requests + accountsCache KeyValueCache[types.AccountI] } // NewAccountQuerier returns a new instance of a client.AccountQueryClient by @@ -34,11 +33,13 @@ type accQuerier struct { // Required dependencies: // - clientCtx func NewAccountQuerier(deps depinject.Config) (client.AccountQueryClient, error) { - aq := &accQuerier{accountCache: make(map[string]types.AccountI)} + aq := &accQuerier{} if err := depinject.Inject( deps, &aq.clientConn, + &aq.logger, + &aq.accountsCache, ); err != nil { return nil, err } @@ -53,13 +54,16 @@ func (aq *accQuerier) GetAccount( ctx context.Context, address string, ) (types.AccountI, error) { - aq.accountCacheMu.Lock() - defer aq.accountCacheMu.Unlock() + logger := aq.logger.With("query_client", "account", "method", "GetAccount") - if foundAccount, isAccountFound := aq.accountCache[address]; isAccountFound { - return foundAccount, nil + // Check if the account is present in the cache. + if account, found := aq.accountsCache.Get(address); found { + logger.Debug().Msgf("cache hit for key: %s", address) + return account, nil } + logger.Debug().Msgf("cache miss for key: %s", address) + // Query the blockchain for the account record req := &accounttypes.QueryAccountRequest{Address: address} res, err := aq.accountQuerier.Account(ctx, req) @@ -81,8 +85,8 @@ func (aq *accQuerier) GetAccount( return nil, ErrQueryPubKeyNotFound } - aq.accountCache[address] = fetchedAccount - + // Cache the fetched account for future queries. + aq.accountsCache.Set(address, fetchedAccount) return fetchedAccount, nil } diff --git a/pkg/client/query/appquerier.go b/pkg/client/query/appquerier.go index 356ce674c..1ec9b2b66 100644 --- a/pkg/client/query/appquerier.go +++ b/pkg/client/query/appquerier.go @@ -7,6 +7,7 @@ import ( grpc "github.com/cosmos/gogoproto/grpc" "github.com/pokt-network/poktroll/pkg/client" + "github.com/pokt-network/poktroll/pkg/polylog" apptypes "github.com/pokt-network/poktroll/x/application/types" ) @@ -18,6 +19,12 @@ var _ client.ApplicationQueryClient = (*appQuerier)(nil) type appQuerier struct { clientConn grpc.ClientConn applicationQuerier apptypes.QueryClient + logger polylog.Logger + + // applicationsCache caches applicationQueryClient.Application requests + applicationsCache KeyValueCache[apptypes.Application] + // paramsCache caches applicationQueryClient.Params requests + paramsCache ParamsCache[apptypes.Params] } // NewApplicationQuerier returns a new instance of a client.ApplicationQueryClient @@ -31,6 +38,9 @@ func NewApplicationQuerier(deps depinject.Config) (client.ApplicationQueryClient if err := depinject.Inject( deps, &aq.clientConn, + &aq.logger, + &aq.applicationsCache, + &aq.paramsCache, ); err != nil { return nil, err } @@ -45,17 +55,33 @@ func (aq *appQuerier) GetApplication( ctx context.Context, appAddress string, ) (apptypes.Application, error) { + logger := aq.logger.With("query_client", "application", "method", "GetApplication") + + // Check if the application is present in the cache. + if app, found := aq.applicationsCache.Get(appAddress); found { + logger.Debug().Msgf("cache hit for key: %s", appAddress) + return app, nil + } + + logger.Debug().Msgf("cache miss for key: %s", appAddress) + req := apptypes.QueryGetApplicationRequest{Address: appAddress} res, err := aq.applicationQuerier.Application(ctx, &req) if err != nil { return apptypes.Application{}, apptypes.ErrAppNotFound.Wrapf("app address: %s [%v]", appAddress, err) } + + // Cache the application. + aq.applicationsCache.Set(appAddress, res.Application) return res.Application, nil } // GetAllApplications returns all staked applications func (aq *appQuerier) GetAllApplications(ctx context.Context) ([]apptypes.Application, error) { req := apptypes.QueryAllApplicationsRequest{} + // TODO_OPTIMIZE: Fill the cache with all applications and mark it as + // having been filled, such that subsequent calls to this function will + // return the cached value. res, err := aq.applicationQuerier.AllApplications(ctx, &req) if err != nil { return []apptypes.Application{}, err @@ -65,10 +91,23 @@ func (aq *appQuerier) GetAllApplications(ctx context.Context) ([]apptypes.Applic // GetParams returns the application module parameters func (aq *appQuerier) GetParams(ctx context.Context) (*apptypes.Params, error) { + logger := aq.logger.With("query_client", "application", "method", "GetParams") + + // Check if the application module parameters are present in the cache. + if params, found := aq.paramsCache.Get(); found { + logger.Debug().Msg("cache hit") + return ¶ms, nil + } + + logger.Debug().Msg("cache miss") + req := apptypes.QueryParamsRequest{} res, err := aq.applicationQuerier.Params(ctx, &req) if err != nil { return nil, err } + + // Update the cache with the newly retrieved application module parameters. + aq.paramsCache.Set(res.Params) return &res.Params, nil } diff --git a/pkg/client/query/bankquerier.go b/pkg/client/query/bankquerier.go index ca28a4998..728addc26 100644 --- a/pkg/client/query/bankquerier.go +++ b/pkg/client/query/bankquerier.go @@ -10,6 +10,8 @@ import ( "github.com/pokt-network/poktroll/app/volatile" "github.com/pokt-network/poktroll/pkg/client" + querytypes "github.com/pokt-network/poktroll/pkg/client/query/types" + "github.com/pokt-network/poktroll/pkg/polylog" ) var _ client.BankQueryClient = (*bankQuerier)(nil) @@ -19,6 +21,10 @@ var _ client.BankQueryClient = (*bankQuerier)(nil) type bankQuerier struct { clientConn grpc.ClientConn bankQuerier banktypes.QueryClient + logger polylog.Logger + + // balancesCache caches bankQueryClient.GetBalance requests + balancesCache KeyValueCache[querytypes.Balance] } // NewBankQuerier returns a new instance of a client.BankQueryClient by @@ -32,6 +38,8 @@ func NewBankQuerier(deps depinject.Config) (client.BankQueryClient, error) { if err := depinject.Inject( deps, &bq.clientConn, + &bq.logger, + &bq.balancesCache, ); err != nil { return nil, err } @@ -46,6 +54,16 @@ func (bq *bankQuerier) GetBalance( ctx context.Context, address string, ) (*sdk.Coin, error) { + logger := bq.logger.With("query_client", "bank", "method", "GetBalance") + + // Check if the account balance is present in the cache. + if balance, found := bq.balancesCache.Get(address); found { + logger.Debug().Msgf("cache hit for key: %s", address) + return balance, nil + } + + logger.Debug().Msgf("cache miss for key: %s", address) + // Query the blockchain for the balance record req := &banktypes.QueryBalanceRequest{Address: address, Denom: volatile.DenomuPOKT} res, err := bq.bankQuerier.Balance(ctx, req) @@ -53,5 +71,7 @@ func (bq *bankQuerier) GetBalance( return nil, ErrQueryBalanceNotFound.Wrapf("address: %s [%s]", address, err) } + // Cache the balance for future queries + bq.balancesCache.Set(address, res.Balance) return res.Balance, nil } diff --git a/pkg/client/query/cache/cache_test.go b/pkg/client/query/cache/cache_test.go new file mode 100644 index 000000000..19c9945a0 --- /dev/null +++ b/pkg/client/query/cache/cache_test.go @@ -0,0 +1,83 @@ +package cache_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/pokt-network/poktroll/pkg/client/query/cache" +) + +func TestKeyValueCache(t *testing.T) { + kvCache := cache.NewKeyValueCache[any]() + + // Test Get on an empty cache + _, found := kvCache.Get("key") + require.False(t, found) + + // Set a value in the cache + kvCache.Set("key", "value") + + // Test Get on a non-empty cache + value, found := kvCache.Get("key") + require.True(t, found) + require.Equal(t, "value", value) + + // Test Delete on a non-empty cache + kvCache.Delete("key") + + // Test Get on a deleted key + _, found = kvCache.Get("key") + require.False(t, found) + + // Set multiple values in the cache + kvCache.Set("key1", "value1") + kvCache.Set("key2", "value2") + + // Test Clear on a non-empty cache + kvCache.Clear() + + // Test Get on an empty cache + _, found = kvCache.Get("key1") + require.False(t, found) + + _, found = kvCache.Get("key2") + require.False(t, found) + + // Delete a non-existing key + kvCache.Delete("key1") + + // Test Get on a deleted key + _, found = kvCache.Get("key1") + require.False(t, found) + + // Test Clear on an empty cache + kvCache.Clear() + + // Test Get on an empty cache + _, found = kvCache.Get("key2") + require.False(t, found) +} + +func TestParamsCache(t *testing.T) { + paramsCache := cache.NewParamsCache[any]() + + // Test Get on an empty cache + _, found := paramsCache.Get() + require.False(t, found) + + // Set a value in the cache + paramsCache.Set("value") + + // Test Get on a non-empty cache + value, found := paramsCache.Get() + require.True(t, found) + require.Equal(t, "value", value) + + // Test Clear on a non-empty cache + paramsCache.Clear() + + // Test Get on an empty cache + _, found = paramsCache.Get() + require.False(t, found) +} diff --git a/pkg/client/query/cache/kvcache.go b/pkg/client/query/cache/kvcache.go new file mode 100644 index 000000000..0ce19c808 --- /dev/null +++ b/pkg/client/query/cache/kvcache.go @@ -0,0 +1,61 @@ +package cache + +import ( + "sync" + + "github.com/pokt-network/poktroll/pkg/client/query" +) + +var _ query.KeyValueCache[any] = (*keyValueCache[any])(nil) + +// keyValueCache is a simple in-memory key-value cache implementation. +// It is safe for concurrent use. +type keyValueCache[V any] struct { + cacheMu sync.RWMutex + valuesMap map[string]V +} + +// NewKeyValueCache returns a new instance of a KeyValueCache. +func NewKeyValueCache[T any]() query.KeyValueCache[T] { + return &keyValueCache[T]{ + valuesMap: make(map[string]T), + } +} + +// Get returns the value for the given key. +// A boolean is returned as the second value to indicate if the key was found in the cache. +func (c *keyValueCache[V]) Get(key string) (value V, found bool) { + c.cacheMu.RLock() + defer c.cacheMu.RUnlock() + + value, found = c.valuesMap[key] + return value, found +} + +// Set sets the value for the given key. +// TODO_CONSIDERATION: Add a method to set many values and indicate whether it +// is the result of a GetAll operation. This would allow us to know whether the +// cache is populated with all the possible values, so any other GetAll operation +// could be returned from the cache. +func (c *keyValueCache[V]) Set(key string, value V) { + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + + c.valuesMap[key] = value +} + +// Delete deletes the value for the given key. +func (c *keyValueCache[V]) Delete(key string) { + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + + delete(c.valuesMap, key) +} + +// Clear empties the whole cache. +func (c *keyValueCache[V]) Clear() { + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + + c.valuesMap = make(map[string]V) +} diff --git a/pkg/client/query/cache/options.go b/pkg/client/query/cache/options.go new file mode 100644 index 000000000..83f21ad30 --- /dev/null +++ b/pkg/client/query/cache/options.go @@ -0,0 +1,38 @@ +package cache + +import ( + "context" + + "cosmossdk.io/depinject" + + "github.com/pokt-network/poktroll/pkg/client" + "github.com/pokt-network/poktroll/pkg/observable/channel" +) + +// Cache is an interface that defines the common methods for a cache object. +type Cache interface { + Clear() +} + +// CacheOption is a function type for the option functions that can customize +// the cache behavior. +type CacheOption[C Cache] func(context.Context, depinject.Config, C) error + +// WithNewBlockCacheClearing is a cache option that clears the cache every time +// a new block is observed. +func WithNewBlockCacheClearing[C Cache](ctx context.Context, deps depinject.Config, cache C) error { + var blockClient client.BlockClient + if err := depinject.Inject(deps, &blockClient); err != nil { + return err + } + + channel.ForEach( + ctx, + blockClient.CommittedBlocksSequence(ctx), + func(ctx context.Context, block client.Block) { + cache.Clear() + }, + ) + + return nil +} diff --git a/pkg/client/query/cache/paramscache.go b/pkg/client/query/cache/paramscache.go new file mode 100644 index 000000000..0d726c018 --- /dev/null +++ b/pkg/client/query/cache/paramscache.go @@ -0,0 +1,51 @@ +package cache + +import ( + "sync" + + "github.com/pokt-network/poktroll/pkg/client/query" +) + +var _ query.ParamsCache[any] = (*paramsCache[any])(nil) + +// paramsCache is a simple in-memory cache implementation for query parameters. +// It does not involve key-value pairs, but only stores a single value. +type paramsCache[T any] struct { + cacheMu sync.RWMutex + found bool + value T +} + +// NewParamsCache returns a new instance of a ParamsCache. +func NewParamsCache[T any]() query.ParamsCache[T] { + return ¶msCache[T]{} +} + +// Get returns the value stored in the cache. +// A boolean is returned as the second value to indicate if the value was found in the cache. +func (c *paramsCache[T]) Get() (value T, found bool) { + c.cacheMu.RLock() + defer c.cacheMu.RUnlock() + + return c.value, c.found +} + +// Set sets the value in the cache. +func (c *paramsCache[T]) Set(value T) { + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + + c.found = true + c.value = value +} + +// Clear empties the cache. +func (c *paramsCache[T]) Clear() { + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + + c.found = false + + var zero T + c.value = zero +} diff --git a/pkg/client/query/interface.go b/pkg/client/query/interface.go new file mode 100644 index 000000000..8d8fb541a --- /dev/null +++ b/pkg/client/query/interface.go @@ -0,0 +1,17 @@ +package query + +// ParamsCache is an interface for a simple in-memory cache implementation for query parameters. +// It does not involve key-value pairs, but only stores a single value. +type ParamsCache[T any] interface { + Get() (T, bool) + Set(T) + Clear() +} + +// KeyValueCache is an interface for a simple in-memory key-value cache implementation. +type KeyValueCache[V any] interface { + Get(string) (V, bool) + Set(string, V) + Delete(string) + Clear() +} diff --git a/pkg/client/query/proofquerier.go b/pkg/client/query/proofquerier.go index 6751dc995..c75b6b7ec 100644 --- a/pkg/client/query/proofquerier.go +++ b/pkg/client/query/proofquerier.go @@ -7,6 +7,7 @@ import ( "github.com/cosmos/gogoproto/grpc" "github.com/pokt-network/poktroll/pkg/client" + "github.com/pokt-network/poktroll/pkg/polylog" prooftypes "github.com/pokt-network/poktroll/x/proof/types" ) @@ -15,6 +16,10 @@ import ( type proofQuerier struct { clientConn grpc.ClientConn proofQuerier prooftypes.QueryClient + logger polylog.Logger + + // paramsCache caches proofQuerier.Params requests + paramsCache ParamsCache[prooftypes.Params] } // NewProofQuerier returns a new instance of a client.ProofQueryClient by @@ -28,6 +33,8 @@ func NewProofQuerier(deps depinject.Config) (client.ProofQueryClient, error) { if err := depinject.Inject( deps, &querier.clientConn, + &querier.logger, + &querier.paramsCache, ); err != nil { return nil, err } @@ -41,10 +48,23 @@ func NewProofQuerier(deps depinject.Config) (client.ProofQueryClient, error) { func (pq *proofQuerier) GetParams( ctx context.Context, ) (client.ProofParams, error) { + logger := pq.logger.With("query_client", "proof", "method", "GetParams") + + // Get the params from the cache if they exist. + if params, found := pq.paramsCache.Get(); found { + logger.Debug().Msg("cache hit") + return ¶ms, nil + } + + logger.Debug().Msg("cache miss") + req := &prooftypes.QueryParamsRequest{} res, err := pq.proofQuerier.Params(ctx, req) if err != nil { return nil, err } + + // Update the cache with the newly retrieved params. + pq.paramsCache.Set(res.Params) return &res.Params, nil } diff --git a/pkg/client/query/querycache_test.go b/pkg/client/query/querycache_test.go new file mode 100644 index 000000000..98b1d34fc --- /dev/null +++ b/pkg/client/query/querycache_test.go @@ -0,0 +1,350 @@ +package query_test + +import ( + "context" + "net" + "testing" + + "cosmossdk.io/depinject" + cosmostypes "github.com/cosmos/cosmos-sdk/types" + authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" + banktypes "github.com/cosmos/cosmos-sdk/x/bank/types" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" + + "github.com/pokt-network/poktroll/pkg/client" + "github.com/pokt-network/poktroll/pkg/client/query" + "github.com/pokt-network/poktroll/pkg/client/query/cache" + querytypes "github.com/pokt-network/poktroll/pkg/client/query/types" + "github.com/pokt-network/poktroll/pkg/polylog" + "github.com/pokt-network/poktroll/testutil/mockclient" + "github.com/pokt-network/poktroll/testutil/sample" + "github.com/pokt-network/poktroll/testutil/testclient/testqueryclients" + apptypes "github.com/pokt-network/poktroll/x/application/types" + prooftypes "github.com/pokt-network/poktroll/x/proof/types" + servicetypes "github.com/pokt-network/poktroll/x/service/types" + sessiontypes "github.com/pokt-network/poktroll/x/session/types" + sharedtypes "github.com/pokt-network/poktroll/x/shared/types" + suppliertypes "github.com/pokt-network/poktroll/x/supplier/types" +) + +const numCalls = 4 + +// QueryCacheTestSuite runs all the tests for the query clients that cache their responses. +type QueryCacheTestSuite struct { + suite.Suite + + queryClients *queryClients + queryServers *queryServers + + listener *bufconn.Listener + grpcServer *grpc.Server + grpcClientConn *grpc.ClientConn +} + +func (s *QueryCacheTestSuite) SetupTest() { + ctx := context.Background() + logger := polylog.Ctx(ctx) + + // Create the gRPC server for the query clients + s.grpcServer, s.listener, s.queryServers = createGRPCServer(s.T()) + + // Create a gRPC client connection to the gRPC server + s.grpcClientConn = createGRPCClienConn(s.T(), s.listener) + + // Create a depinject.Config with the cache dependencies + deps := supplyCacheDeps() + + // Create a new depinject config with a supplied gRPC client connection and logger + // needed by the query clients. + deps = depinject.Configs(deps, depinject.Supply(s.grpcClientConn, logger)) + + // Create the query clients under test. + s.queryClients = createQueryClients(s.T(), deps) +} + +func (s *QueryCacheTestSuite) TearDownTest() { + s.grpcServer.Stop() +} + +func TestQueryClientCache(t *testing.T) { + suite.Run(t, &QueryCacheTestSuite{}) +} + +func (s *QueryCacheTestSuite) TestKeyValueCache_ServiceQuerier() { + ctx := context.Background() + + // Call the GetService method numCalls times and assert that the service server + // is reached only once. + for i := 0; i < numCalls; i++ { + _, err := s.queryClients.service.GetService(ctx, "serviceId") + require.NoError(s.T(), err) + } + require.Equal(s.T(), 1, s.queryServers.service.ServiceCallCounter.CallCount()) + + // Call the GetServiceRelayDifficulty method numCalls times and assert that the service + // server is reached only once. + for i := 0; i < numCalls; i++ { + _, err := s.queryClients.service.GetServiceRelayDifficulty(ctx, "serviceId") + require.NoError(s.T(), err) + } + require.Equal(s.T(), 1, s.queryServers.service.RelayMiningDifficultyCallCounter.CallCount()) +} + +func (s *QueryCacheTestSuite) TestKeyValueCache_ApplicationQuerier() { + ctx := context.Background() + appAddress := sample.AccAddress() + + // Call the GetApplication method numCalls times and assert that the application server + // is reached only once. + for i := 0; i < numCalls; i++ { + _, err := s.queryClients.application.GetApplication(ctx, appAddress) + require.NoError(s.T(), err) + } + require.Equal(s.T(), 1, s.queryServers.application.AppCallCounter.CallCount()) + + // Call the GetParams method numCalls times and assert that the application server + // is reached only once. + for i := 0; i < numCalls; i++ { + _, err := s.queryClients.application.GetParams(ctx) + require.NoError(s.T(), err) + } + require.Equal(s.T(), 1, s.queryServers.application.ParamsCallCounter.CallCount()) +} + +func (s *QueryCacheTestSuite) TestKeyValueCache_SupplierQuerier() { + ctx := context.Background() + supplierAddress := sample.AccAddress() + + // Call the GetSupplier method numCalls times and assert that the supplier server + // is reached only once. + for i := 0; i < numCalls; i++ { + _, err := s.queryClients.supplier.GetSupplier(ctx, supplierAddress) + require.NoError(s.T(), err) + } + require.Equal(s.T(), 1, s.queryServers.supplier.SupplierCallCounter.CallCount()) +} + +func (s *QueryCacheTestSuite) TestKeyValueCache_SessionQuerier() { + ctx := context.Background() + appAddress := sample.AccAddress() + serviceId := "serviceId" + blockHeight := int64(1) + + // Call the GetSession method numCalls times and assert that the session server + // is reached only once. + for i := 0; i < numCalls; i++ { + _, err := s.queryClients.session.GetSession(ctx, appAddress, serviceId, blockHeight) + require.NoError(s.T(), err) + } + require.Equal(s.T(), 1, s.queryServers.session.SessionCallCounter.CallCount()) + + // Call the GetParams method numCalls times and assert that the session server + // is reached only once. + for i := 0; i < numCalls; i++ { + _, err := s.queryClients.session.GetParams(ctx) + require.NoError(s.T(), err) + } + require.Equal(s.T(), 1, s.queryServers.session.ParamsCallCounter.CallCount()) +} + +func (s *QueryCacheTestSuite) TestKeyValueCache_SharedQuerier() { + ctx := context.Background() + + // Call the GetParams method numCalls times and assert that the shared server + // is reached only once. + for i := 0; i < numCalls; i++ { + _, err := s.queryClients.shared.GetParams(ctx) + require.NoError(s.T(), err) + } + require.Equal(s.T(), 1, s.queryServers.shared.ParamsCallCounter.CallCount()) +} + +func (s *QueryCacheTestSuite) TestKeyValueCache_ProofQuerier() { + ctx := context.Background() + + // Call the GetParams method numCalls times and assert that the proof server + // is reached only once. + for i := 0; i < numCalls; i++ { + _, err := s.queryClients.proof.GetParams(ctx) + require.NoError(s.T(), err) + } + require.Equal(s.T(), 1, s.queryServers.proof.ParamsCallCounter.CallCount()) +} + +func (s *QueryCacheTestSuite) TestKeyValueCache_BankQuerier() { + ctx := context.Background() + accountAddress := sample.AccAddress() + + // Call the GetBalance method numCalls times and assert that the bank server + // is reached only once. + for i := 0; i < numCalls; i++ { + _, err := s.queryClients.bank.GetBalance(ctx, accountAddress) + require.NoError(s.T(), err) + } + require.Equal(s.T(), 1, s.queryServers.bank.BalanceCallCounter.CallCount()) +} + +func (s *QueryCacheTestSuite) TestKeyValueCache_AccountQuerier() { + ctx := context.Background() + accountAddress := sample.AccAddress() + + // Call the GetAccount method numCalls times and assert that the account server + // is reached only once. + for i := 0; i < numCalls; i++ { + _, err := s.queryClients.account.GetAccount(ctx, accountAddress) + require.NoError(s.T(), err) + } + require.Equal(s.T(), 1, s.queryServers.account.AccountCallCounter.CallCount()) +} + +// supplyCacheDeps supplies all the cache dependencies required by the query clients. +func supplyCacheDeps() depinject.Config { + return depinject.Supply( + cache.NewKeyValueCache[sharedtypes.Service](), + cache.NewKeyValueCache[servicetypes.RelayMiningDifficulty](), + cache.NewKeyValueCache[apptypes.Application](), + cache.NewKeyValueCache[sharedtypes.Supplier](), + cache.NewKeyValueCache[*sessiontypes.Session](), + cache.NewKeyValueCache[querytypes.Balance](), + cache.NewKeyValueCache[querytypes.BlockHash](), + + cache.NewParamsCache[sharedtypes.Params](), + cache.NewParamsCache[apptypes.Params](), + cache.NewParamsCache[sessiontypes.Params](), + cache.NewParamsCache[prooftypes.Params](), + + cache.NewKeyValueCache[cosmostypes.AccountI](), + ) +} + +// createQueryClients creates all the query clients that cache their responses +// and are being tested in this test suite. +func createQueryClients(t *testing.T, deps depinject.Config) *queryClients { + var err error + queryClients := &queryClients{} + + queryClients.service, err = query.NewServiceQuerier(deps) + require.NoError(t, err) + + queryClients.application, err = query.NewApplicationQuerier(deps) + require.NoError(t, err) + + queryClients.supplier, err = query.NewSupplierQuerier(deps) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + cometClientMock := mockclient.NewMockCometRPC(ctrl) + + deps = depinject.Configs(deps, depinject.Supply(cometClientMock)) + + queryClients.shared, err = query.NewSharedQuerier(deps) + require.NoError(t, err) + + // Supply the shared query client which the session query client depends on. + deps = depinject.Configs(deps, depinject.Supply(queryClients.shared)) + queryClients.session, err = query.NewSessionQuerier(deps) + require.NoError(t, err) + + queryClients.proof, err = query.NewProofQuerier(deps) + require.NoError(t, err) + + queryClients.bank, err = query.NewBankQuerier(deps) + require.NoError(t, err) + + queryClients.account, err = query.NewAccountQuerier(deps) + require.NoError(t, err) + + return queryClients +} + +// queryClients contains all the query clients that cache their responses and +// being tested in this test suite. +type queryClients struct { + service client.ServiceQueryClient + application client.ApplicationQueryClient + supplier client.SupplierQueryClient + session client.SessionQueryClient + shared client.SharedQueryClient + proof client.ProofQueryClient + + bank client.BankQueryClient + account client.AccountQueryClient +} + +// queryServers contains all the mock gRPC query servers that the query clients +// in the test suite are calling. +type queryServers struct { + service *testqueryclients.MockServiceQueryServer + application *testqueryclients.MockApplicationQueryServer + supplier *testqueryclients.MockSupplierQueryServer + session *testqueryclients.MockSessionQueryServer + shared *testqueryclients.MockSharedQueryServer + proof *testqueryclients.MockProofQueryServer + + bank *testqueryclients.MockBankQueryServer + account *testqueryclients.MockAccountQueryServer +} + +// createGRPCServer creates a gRPC server with all the mock query servers +// The gRPC server uses a bufconn.Listener to avoid port conflicts in concurrent tests. +func createGRPCServer(t *testing.T) (*grpc.Server, *bufconn.Listener, *queryServers) { + // Create the gRPC server + grpcServer := grpc.NewServer() + listener := bufconn.Listen(1024 * 1024) + queryServers := &queryServers{} + + // Register all the mock query servers used in the test with the gRPC server. + + queryServers.service = &testqueryclients.MockServiceQueryServer{} + servicetypes.RegisterQueryServer(grpcServer, queryServers.service) + + queryServers.application = &testqueryclients.MockApplicationQueryServer{} + apptypes.RegisterQueryServer(grpcServer, queryServers.application) + + queryServers.supplier = &testqueryclients.MockSupplierQueryServer{} + suppliertypes.RegisterQueryServer(grpcServer, queryServers.supplier) + + queryServers.session = &testqueryclients.MockSessionQueryServer{} + sessiontypes.RegisterQueryServer(grpcServer, queryServers.session) + + queryServers.shared = &testqueryclients.MockSharedQueryServer{} + sharedtypes.RegisterQueryServer(grpcServer, queryServers.shared) + + queryServers.proof = &testqueryclients.MockProofQueryServer{} + prooftypes.RegisterQueryServer(grpcServer, queryServers.proof) + + queryServers.bank = &testqueryclients.MockBankQueryServer{} + banktypes.RegisterQueryServer(grpcServer, queryServers.bank) + + queryServers.account = &testqueryclients.MockAccountQueryServer{} + authtypes.RegisterQueryServer(grpcServer, queryServers.account) + + // Start the gRPC server + go func() { + err := grpcServer.Serve(listener) + require.NoError(t, err) + }() + + return grpcServer, listener, queryServers +} + +// createGRPCClienConn creates a gRPC client connection to the bufconn.Listener. +func createGRPCClienConn(t *testing.T, listener *bufconn.Listener) *grpc.ClientConn { + dialer := func(context.Context, string) (net.Conn, error) { + return listener.Dial() + } + + grpcClientConn, err := grpc.NewClient( + "passthrough://bufnet", + grpc.WithContextDialer(dialer), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + + return grpcClientConn +} diff --git a/pkg/client/query/servicequerier.go b/pkg/client/query/servicequerier.go index 1f5ef2d2a..6cbab4af5 100644 --- a/pkg/client/query/servicequerier.go +++ b/pkg/client/query/servicequerier.go @@ -7,6 +7,7 @@ import ( "github.com/cosmos/gogoproto/grpc" "github.com/pokt-network/poktroll/pkg/client" + "github.com/pokt-network/poktroll/pkg/polylog" servicetypes "github.com/pokt-network/poktroll/x/service/types" sharedtypes "github.com/pokt-network/poktroll/x/shared/types" ) @@ -19,6 +20,12 @@ var _ client.ServiceQueryClient = (*serviceQuerier)(nil) type serviceQuerier struct { clientConn grpc.ClientConn serviceQuerier servicetypes.QueryClient + logger polylog.Logger + + // servicesCache caches serviceQueryClient.Service requests + servicesCache KeyValueCache[sharedtypes.Service] + // relayMiningDifficultyCache caches serviceQueryClient.RelayMiningDifficulty requests + relayMiningDifficultyCache KeyValueCache[servicetypes.RelayMiningDifficulty] } // NewServiceQuerier returns a new instance of a client.ServiceQueryClient by @@ -32,6 +39,9 @@ func NewServiceQuerier(deps depinject.Config) (client.ServiceQueryClient, error) if err := depinject.Inject( deps, &servq.clientConn, + &servq.logger, + &servq.servicesCache, + &servq.relayMiningDifficultyCache, ); err != nil { return nil, err } @@ -47,6 +57,16 @@ func (servq *serviceQuerier) GetService( ctx context.Context, serviceId string, ) (sharedtypes.Service, error) { + logger := servq.logger.With("query_client", "service", "method", "GetService") + + // Check if the service is present in the cache. + if service, found := servq.servicesCache.Get(serviceId); found { + logger.Debug().Msgf("cache hit for key: %s", serviceId) + return service, nil + } + + logger.Debug().Msgf("cache miss for key: %s", serviceId) + req := &servicetypes.QueryGetServiceRequest{ Id: serviceId, } @@ -58,6 +78,9 @@ func (servq *serviceQuerier) GetService( serviceId, err, ) } + + // Cache the service for future use. + servq.servicesCache.Set(serviceId, res.Service) return res.Service, nil } @@ -67,6 +90,16 @@ func (servq *serviceQuerier) GetServiceRelayDifficulty( ctx context.Context, serviceId string, ) (servicetypes.RelayMiningDifficulty, error) { + logger := servq.logger.With("query_client", "service", "method", "GetServiceRelayDifficulty") + + // Check if the relay mining difficulty is present in the cache. + if relayMiningDifficulty, found := servq.relayMiningDifficultyCache.Get(serviceId); found { + logger.Debug().Msgf("cache hit for key: %s", serviceId) + return relayMiningDifficulty, nil + } + + logger.Debug().Msgf("cache miss for key: %s", serviceId) + req := &servicetypes.QueryGetRelayMiningDifficultyRequest{ ServiceId: serviceId, } @@ -76,5 +109,7 @@ func (servq *serviceQuerier) GetServiceRelayDifficulty( return servicetypes.RelayMiningDifficulty{}, err } + // Cache the relay mining difficulty for future use. + servq.relayMiningDifficultyCache.Set(serviceId, res.RelayMiningDifficulty) return res.RelayMiningDifficulty, nil } diff --git a/pkg/client/query/sessionquerier.go b/pkg/client/query/sessionquerier.go index fdf6c42e9..2eaf64471 100644 --- a/pkg/client/query/sessionquerier.go +++ b/pkg/client/query/sessionquerier.go @@ -2,12 +2,15 @@ package query import ( "context" + "fmt" "cosmossdk.io/depinject" "github.com/cosmos/gogoproto/grpc" "github.com/pokt-network/poktroll/pkg/client" + "github.com/pokt-network/poktroll/pkg/polylog" sessiontypes "github.com/pokt-network/poktroll/x/session/types" + sharedtypes "github.com/pokt-network/poktroll/x/shared/types" ) var _ client.SessionQueryClient = (*sessionQuerier)(nil) @@ -16,8 +19,15 @@ var _ client.SessionQueryClient = (*sessionQuerier)(nil) // querying of onchain session information through a single exposed method // which returns an sessiontypes.Session struct type sessionQuerier struct { - clientConn grpc.ClientConn - sessionQuerier sessiontypes.QueryClient + clientConn grpc.ClientConn + sessionQuerier sessiontypes.QueryClient + sharedQueryClient client.SharedQueryClient + logger polylog.Logger + + // sessionsCache caches sessionQueryClient.GetSession requests + sessionsCache KeyValueCache[*sessiontypes.Session] + // paramsCache caches sessionQueryClient.Params requests + paramsCache ParamsCache[sessiontypes.Params] } // NewSessionQuerier returns a new instance of a client.SessionQueryClient by @@ -31,6 +41,10 @@ func NewSessionQuerier(deps depinject.Config) (client.SessionQueryClient, error) if err := depinject.Inject( deps, &sessq.clientConn, + &sessq.sharedQueryClient, + &sessq.logger, + &sessq.sessionsCache, + &sessq.paramsCache, ); err != nil { return nil, err } @@ -48,6 +62,24 @@ func (sessq *sessionQuerier) GetSession( serviceId string, blockHeight int64, ) (*sessiontypes.Session, error) { + logger := sessq.logger.With("query_client", "session", "method", "GetSession") + + // Get the shared parameters to calculate the session start height. + // Use the session start height as the canonical height to be used in the cache key. + sharedParams, err := sessq.sharedQueryClient.GetParams(ctx) + if err != nil { + return nil, err + } + sessionCacheKey := getSessionCacheKey(sharedParams, appAddress, serviceId, blockHeight) + + // Check if the session is present in the cache. + if session, found := sessq.sessionsCache.Get(sessionCacheKey); found { + logger.Debug().Msgf("cache hit for key: %s", sessionCacheKey) + return session, nil + } + + logger.Debug().Msgf("cache miss for key: %s", sessionCacheKey) + req := &sessiontypes.QueryGetSessionRequest{ ApplicationAddress: appAddress, ServiceId: serviceId, @@ -60,15 +92,45 @@ func (sessq *sessionQuerier) GetSession( appAddress, serviceId, blockHeight, err, ) } + + // Cache the session using the session key. + sessq.sessionsCache.Set(sessionCacheKey, res.Session) return res.Session, nil } // GetParams queries & returns the session module onchain parameters. func (sessq *sessionQuerier) GetParams(ctx context.Context) (*sessiontypes.Params, error) { + logger := sessq.logger.With("query_client", "session", "method", "GetParams") + + // Check if the params are present in the cache. + if params, found := sessq.paramsCache.Get(); found { + logger.Debug().Msg("cache hit") + return ¶ms, nil + } + + logger.Debug().Msg("cache miss") + req := &sessiontypes.QueryParamsRequest{} res, err := sessq.sessionQuerier.Params(ctx, req) if err != nil { return nil, ErrQuerySessionParams.Wrapf("[%v]", err) } + + // Cache the params for future queries. + sessq.paramsCache.Set(res.Params) return &res.Params, nil } + +// getSessionCacheKey constructs the cache key for a session. +func getSessionCacheKey( + sharedParams *sharedtypes.Params, + appAddress, + serviceId string, + blockHeight int64, +) string { + // Using the session start height as the canonical height ensures that the cache + // does not duplicate entries for the same session given different block heights + // of the same session. + sessionStartHeight := sharedtypes.GetSessionStartHeight(sharedParams, blockHeight) + return fmt.Sprintf("%s/%s/%d", appAddress, serviceId, sessionStartHeight) +} diff --git a/pkg/client/query/sharedquerier.go b/pkg/client/query/sharedquerier.go index bbe67b0de..81af662c1 100644 --- a/pkg/client/query/sharedquerier.go +++ b/pkg/client/query/sharedquerier.go @@ -2,11 +2,14 @@ package query import ( "context" + "strconv" "cosmossdk.io/depinject" "github.com/cosmos/gogoproto/grpc" "github.com/pokt-network/poktroll/pkg/client" + querytypes "github.com/pokt-network/poktroll/pkg/client/query/types" + "github.com/pokt-network/poktroll/pkg/polylog" sharedtypes "github.com/pokt-network/poktroll/x/shared/types" ) @@ -19,6 +22,12 @@ type sharedQuerier struct { clientConn grpc.ClientConn sharedQuerier sharedtypes.QueryClient blockQuerier client.BlockQueryClient + logger polylog.Logger + + // blockHashCache caches blockQuerier.Block requests + blockHashCache KeyValueCache[querytypes.BlockHash] + // paramsCache caches sharedQueryClient.Params requests + paramsCache ParamsCache[sharedtypes.Params] } // NewSharedQuerier returns a new instance of a client.SharedQueryClient by @@ -33,7 +42,10 @@ func NewSharedQuerier(deps depinject.Config) (client.SharedQueryClient, error) { if err := depinject.Inject( deps, &querier.clientConn, + &querier.logger, &querier.blockQuerier, + &querier.blockHashCache, + &querier.paramsCache, ); err != nil { return nil, err } @@ -49,11 +61,24 @@ func NewSharedQuerier(deps depinject.Config) (client.SharedQueryClient, error) { // Once `ModuleParamsClient` is implemented, use its replay observable's `#Last()` method // to get the most recently (asynchronously) observed (and cached) value. func (sq *sharedQuerier) GetParams(ctx context.Context) (*sharedtypes.Params, error) { + logger := sq.logger.With("query_client", "shared", "method", "GetParams") + + // Get the params from the cache if they exist. + if params, found := sq.paramsCache.Get(); found { + logger.Debug().Msg("cache hit") + return ¶ms, nil + } + + logger.Debug().Msg("cache miss") + req := &sharedtypes.QueryParamsRequest{} res, err := sq.sharedQuerier.Params(ctx, req) if err != nil { return nil, ErrQuerySessionParams.Wrapf("[%v]", err) } + + // Update the cache with the newly retrieved params. + sq.paramsCache.Set(res.Params) return &res.Params, nil } @@ -119,6 +144,8 @@ func (sq *sharedQuerier) GetSessionGracePeriodEndHeight( // TODO_MAINNET(@bryanchriswhite, #543): We also don't really want to use the current value of the params. // Instead, we should be using the value that the params had for the session which includes queryHeight. func (sq *sharedQuerier) GetEarliestSupplierClaimCommitHeight(ctx context.Context, queryHeight int64, supplierOperatorAddr string) (int64, error) { + logger := sq.logger.With("query_client", "shared", "method", "GetEarliestSupplierClaimCommitHeight") + sharedParams, err := sq.GetParams(ctx) if err != nil { return 0, err @@ -127,13 +154,25 @@ func (sq *sharedQuerier) GetEarliestSupplierClaimCommitHeight(ctx context.Contex // Fetch the block at the proof window open height. Its hash is used as part // of the seed to the pseudo-random number generator. claimWindowOpenHeight := sharedtypes.GetClaimWindowOpenHeight(sharedParams, queryHeight) - claimWindowOpenBlock, err := sq.blockQuerier.Block(ctx, &claimWindowOpenHeight) - if err != nil { - return 0, err - } - // NB: Byte slice representation of block hashes don't need to be normalized. - claimWindowOpenBlockHash := claimWindowOpenBlock.BlockID.Hash.Bytes() + // Check if the block hash is already in the cache. + blockHashCacheKey := getBlockHashKacheKey(claimWindowOpenHeight) + claimWindowOpenBlockHash, found := sq.blockHashCache.Get(blockHashCacheKey) + if !found { + logger.Debug().Msgf("cache miss for blockHeight: %s", blockHashCacheKey) + + claimWindowOpenBlock, err := sq.blockQuerier.Block(ctx, &claimWindowOpenHeight) + if err != nil { + return 0, err + } + + // Cache the block hash for future use. + // NB: Byte slice representation of block hashes don't need to be normalized. + claimWindowOpenBlockHash = claimWindowOpenBlock.BlockID.Hash.Bytes() + sq.blockHashCache.Set(blockHashCacheKey, claimWindowOpenBlockHash) + } else { + logger.Debug().Msgf("cache hit for blockHeight: %s", blockHashCacheKey) + } return sharedtypes.GetEarliestSupplierClaimCommitHeight( sharedParams, @@ -152,23 +191,38 @@ func (sq *sharedQuerier) GetEarliestSupplierClaimCommitHeight(ctx context.Contex // TODO_MAINNET(@bryanchriswhite, #543): We also don't really want to use the current value of the params. // Instead, we should be using the value that the params had for the session which includes queryHeight. func (sq *sharedQuerier) GetEarliestSupplierProofCommitHeight(ctx context.Context, queryHeight int64, supplierOperatorAddr string) (int64, error) { + logger := sq.logger.With("query_client", "shared", "method", "GetEarliestSupplierProofCommitHeight") + sharedParams, err := sq.GetParams(ctx) if err != nil { return 0, err } - // Fetch the block at the proof window open height. Its hash is used as part - // of the seed to the pseudo-random number generator. - proofWindowOpenHeight := sharedtypes.GetProofWindowOpenHeight(sharedParams, queryHeight) - proofWindowOpenBlock, err := sq.blockQuerier.Block(ctx, &proofWindowOpenHeight) - if err != nil { - return 0, err + blockHashCacheKey := getBlockHashKacheKey(queryHeight) + proofWindowOpenBlockHash, found := sq.blockHashCache.Get(blockHashCacheKey) + + if !found { + logger.Debug().Msgf("cache miss for blockHeight: %s", blockHashCacheKey) + + // Fetch the block at the proof window open height. Its hash is used as part + // of the seed to the pseudo-random number generator. + proofWindowOpenHeight := sharedtypes.GetProofWindowOpenHeight(sharedParams, queryHeight) + proofWindowOpenBlock, err := sq.blockQuerier.Block(ctx, &proofWindowOpenHeight) + if err != nil { + return 0, err + } + + // Cache the block hash for future use. + proofWindowOpenBlockHash = proofWindowOpenBlock.BlockID.Hash.Bytes() + sq.blockHashCache.Set(blockHashCacheKey, proofWindowOpenBlockHash) + } else { + logger.Debug().Msgf("cache hit for blockHeight: %s", blockHashCacheKey) } return sharedtypes.GetEarliestSupplierProofCommitHeight( sharedParams, queryHeight, - proofWindowOpenBlock.BlockID.Hash, + proofWindowOpenBlockHash, supplierOperatorAddr, ), nil } @@ -187,3 +241,8 @@ func (sq *sharedQuerier) GetComputeUnitsToTokensMultiplier(ctx context.Context) } return sharedParams.GetComputeUnitsToTokensMultiplier(), nil } + +// getBlockHashKacheKey constructs the cache key for a block hash. +func getBlockHashKacheKey(height int64) string { + return strconv.FormatInt(height, 10) +} diff --git a/pkg/client/query/supplierquerier.go b/pkg/client/query/supplierquerier.go index 927f2b335..7b4351e99 100644 --- a/pkg/client/query/supplierquerier.go +++ b/pkg/client/query/supplierquerier.go @@ -7,6 +7,7 @@ import ( "github.com/cosmos/gogoproto/grpc" "github.com/pokt-network/poktroll/pkg/client" + "github.com/pokt-network/poktroll/pkg/polylog" sharedtypes "github.com/pokt-network/poktroll/x/shared/types" suppliertypes "github.com/pokt-network/poktroll/x/supplier/types" ) @@ -17,6 +18,10 @@ import ( type supplierQuerier struct { clientConn grpc.ClientConn supplierQuerier suppliertypes.QueryClient + logger polylog.Logger + + // suppliersCache caches supplierQueryClient.Supplier requests + suppliersCache KeyValueCache[sharedtypes.Supplier] } // NewSupplierQuerier returns a new instance of a client.SupplierQueryClient by @@ -30,6 +35,8 @@ func NewSupplierQuerier(deps depinject.Config) (client.SupplierQueryClient, erro if err := depinject.Inject( deps, &supq.clientConn, + &supq.logger, + &supq.suppliersCache, ); err != nil { return nil, err } @@ -44,6 +51,16 @@ func (supq *supplierQuerier) GetSupplier( ctx context.Context, operatorAddress string, ) (sharedtypes.Supplier, error) { + logger := supq.logger.With("query_client", "supplier", "method", "GetSupplier") + + // Check if the supplier is present in the cache. + if supplier, found := supq.suppliersCache.Get(operatorAddress); found { + logger.Debug().Msgf("cache hit for key: %s", operatorAddress) + return supplier, nil + } + + logger.Debug().Msgf("cache miss for key: %s", operatorAddress) + req := &suppliertypes.QueryGetSupplierRequest{OperatorAddress: operatorAddress} res, err := supq.supplierQuerier.Supplier(ctx, req) if err != nil { @@ -52,5 +69,8 @@ func (supq *supplierQuerier) GetSupplier( operatorAddress, err, ) } + + // Cache the supplier for future use. + supq.suppliersCache.Set(operatorAddress, res.Supplier) return res.Supplier, nil } diff --git a/pkg/client/query/types/balance.go b/pkg/client/query/types/balance.go new file mode 100644 index 000000000..5976a083d --- /dev/null +++ b/pkg/client/query/types/balance.go @@ -0,0 +1,12 @@ +package types + +import ( + cosmostypes "github.com/cosmos/cosmos-sdk/types" +) + +// Balance represents a pointer to a Cosmos SDK Coin, specifically used for bank balance queries. +// It is deliberately defined as a distinct type (not a type alias) to ensure clear dependency +// injection and to differentiate it from other coin caches in the system. This type helps +// maintain separation of concerns between different types of coin-related data in the caching +// layer. +type Balance *cosmostypes.Coin diff --git a/pkg/client/query/types/blockhash.go b/pkg/client/query/types/blockhash.go new file mode 100644 index 000000000..2bf12f9f7 --- /dev/null +++ b/pkg/client/query/types/blockhash.go @@ -0,0 +1,8 @@ +package types + +// BlockHash represents a byte slice, specifically used for bank balance query caches. +// It is deliberately defined as a distinct type (not a type alias) to ensure clear +// dependency injection and to differentiate it from other byte slice caches in the system. +// This type helps maintain separation of concerns between different types of +// byte slice data in the caching layer. +type BlockHash []byte diff --git a/pkg/deps/config/suppliers.go b/pkg/deps/config/suppliers.go index 26f04043e..c734a51eb 100644 --- a/pkg/deps/config/suppliers.go +++ b/pkg/deps/config/suppliers.go @@ -17,6 +17,7 @@ import ( "github.com/pokt-network/poktroll/pkg/client/delegation" "github.com/pokt-network/poktroll/pkg/client/events" "github.com/pokt-network/poktroll/pkg/client/query" + "github.com/pokt-network/poktroll/pkg/client/query/cache" querytypes "github.com/pokt-network/poktroll/pkg/client/query/types" "github.com/pokt-network/poktroll/pkg/client/supplier" "github.com/pokt-network/poktroll/pkg/client/tx" @@ -507,3 +508,45 @@ func newSupplyTxClientsFn( return depinject.Configs(deps, depinject.Supply(txClient)), nil } + +// NewSupplyKeyValueCacheFn returns a function which constructs a KeyValueCache of type T. +// It take a list of cache options that can be used to configure the cache. +func NewSupplyKeyValueCacheFn[T any](opts ...cache.CacheOption[query.KeyValueCache[T]]) SupplierFn { + return func( + ctx context.Context, + deps depinject.Config, + _ *cobra.Command, + ) (depinject.Config, error) { + kvCache := cache.NewKeyValueCache[T]() + + // Apply the cache options + for _, opt := range opts { + if err := opt(ctx, deps, kvCache); err != nil { + return nil, err + } + } + + return depinject.Configs(deps, depinject.Supply(kvCache)), nil + } +} + +// NewSupplyParamsCacheFn returns a function which constructs a ParamsCache of type T. +// It take a list of cache options that can be used to configure the cache. +func NewSupplyParamsCacheFn[T any](opts ...cache.CacheOption[query.ParamsCache[T]]) SupplierFn { + return func( + ctx context.Context, + deps depinject.Config, + _ *cobra.Command, + ) (depinject.Config, error) { + paramsCache := cache.NewParamsCache[T]() + + // Apply the cache options + for _, opt := range opts { + if err := opt(ctx, deps, paramsCache); err != nil { + return nil, err + } + } + + return depinject.Configs(deps, depinject.Supply(paramsCache)), nil + } +} diff --git a/pkg/relayer/cmd/cmd.go b/pkg/relayer/cmd/cmd.go index 574f405b4..9d9947cc9 100644 --- a/pkg/relayer/cmd/cmd.go +++ b/pkg/relayer/cmd/cmd.go @@ -12,9 +12,12 @@ import ( cosmosclient "github.com/cosmos/cosmos-sdk/client" cosmosflags "github.com/cosmos/cosmos-sdk/client/flags" cosmostx "github.com/cosmos/cosmos-sdk/client/tx" + cosmostypes "github.com/cosmos/cosmos-sdk/types" "github.com/spf13/cobra" "github.com/pokt-network/poktroll/cmd/signals" + "github.com/pokt-network/poktroll/pkg/client/query/cache" + querytypes "github.com/pokt-network/poktroll/pkg/client/query/types" "github.com/pokt-network/poktroll/pkg/client/tx" txtypes "github.com/pokt-network/poktroll/pkg/client/tx/types" "github.com/pokt-network/poktroll/pkg/deps/config" @@ -25,6 +28,11 @@ import ( "github.com/pokt-network/poktroll/pkg/relayer/miner" "github.com/pokt-network/poktroll/pkg/relayer/proxy" "github.com/pokt-network/poktroll/pkg/relayer/session" + apptypes "github.com/pokt-network/poktroll/x/application/types" + prooftypes "github.com/pokt-network/poktroll/x/proof/types" + servicetypes "github.com/pokt-network/poktroll/x/service/types" + sessiontypes "github.com/pokt-network/poktroll/x/session/types" + sharedtypes "github.com/pokt-network/poktroll/x/shared/types" ) // We're `explicitly omitting default` so the relayer crashes if these aren't specified. @@ -198,7 +206,29 @@ func setupRelayerDependencies( config.NewSupplyQueryClientContextFn(queryNodeGRPCUrl), // leaf config.NewSupplyTxClientContextFn(queryNodeGRPCUrl, txNodeRPCUrl), // leaf config.NewSupplyDelegationClientFn(), // leaf - config.NewSupplySharedQueryClientFn(), // leaf + + // Setup the params caches and configure them to clear on new blocks. + // TODO_TECHDEBT: Consider a flag to change client queriers caching behavior. + // This would allow to easily switch between caching and non-caching queriers + // for benchmarking purposes. + config.NewSupplyParamsCacheFn[sharedtypes.Params](cache.WithNewBlockCacheClearing), + config.NewSupplyParamsCacheFn[apptypes.Params](cache.WithNewBlockCacheClearing), + config.NewSupplyParamsCacheFn[sessiontypes.Params](cache.WithNewBlockCacheClearing), + config.NewSupplyParamsCacheFn[prooftypes.Params](cache.WithNewBlockCacheClearing), + + // Setup the key-value caches for poktroll types and configure them to clear on new blocks. + config.NewSupplyKeyValueCacheFn[sharedtypes.Service](cache.WithNewBlockCacheClearing), + config.NewSupplyKeyValueCacheFn[servicetypes.RelayMiningDifficulty](cache.WithNewBlockCacheClearing), + config.NewSupplyKeyValueCacheFn[apptypes.Application](cache.WithNewBlockCacheClearing), + config.NewSupplyKeyValueCacheFn[sharedtypes.Supplier](cache.WithNewBlockCacheClearing), + config.NewSupplyKeyValueCacheFn[*sessiontypes.Session](cache.WithNewBlockCacheClearing), + config.NewSupplyKeyValueCacheFn[querytypes.BlockHash](cache.WithNewBlockCacheClearing), + config.NewSupplyKeyValueCacheFn[querytypes.Balance](cache.WithNewBlockCacheClearing), + + // Setup the key-value for cosmos types and configure them to clear on new blocks. + config.NewSupplyKeyValueCacheFn[cosmostypes.AccountI](cache.WithNewBlockCacheClearing), + + config.NewSupplySharedQueryClientFn(), // leaf config.NewSupplyServiceQueryClientFn(), config.NewSupplyApplicationQuerierFn(), config.NewSupplySessionQuerierFn(), diff --git a/pkg/relayer/session/session_test.go b/pkg/relayer/session/session_test.go index 03cbf00ed..9aa382036 100644 --- a/pkg/relayer/session/session_test.go +++ b/pkg/relayer/session/session_test.go @@ -73,9 +73,10 @@ func requireProofCountEqualsExpectedValueFromProofParams(t *testing.T, proofPara } supplierOperatorAddress := sample.AccAddress() // Set the supplier operator balance to be able to submit the expected number of proofs. - claimAndFeeGasCost := session.ClamAndProofGasCost.Amount.Int64() - numExpectedProofs := int64(2) - supplierOperatorBalance := claimAndFeeGasCost * numExpectedProofs + feePerProof := prooftypes.DefaultParams().ProofSubmissionFee.Amount.Int64() + gasCost := session.ClamAndProofGasCost.Amount.Int64() + proofCost := feePerProof + gasCost + supplierOperatorBalance := proofCost supplierClientMap := testsupplier.NewClaimProofSupplierClientMap(ctx, t, supplierOperatorAddress, proofCount) blockPublishCh, minedRelaysPublishCh := setupDependencies(t, ctx, supplierClientMap, emptyBlockHash, proofParams, supplierOperatorBalance) diff --git a/testutil/integration/suites/application.go b/testutil/integration/suites/application.go index 86b22fccf..337494129 100644 --- a/testutil/integration/suites/application.go +++ b/testutil/integration/suites/application.go @@ -10,6 +10,8 @@ import ( "github.com/pokt-network/poktroll/app/volatile" "github.com/pokt-network/poktroll/pkg/client" "github.com/pokt-network/poktroll/pkg/client/query" + "github.com/pokt-network/poktroll/pkg/client/query/cache" + "github.com/pokt-network/poktroll/pkg/polylog" apptypes "github.com/pokt-network/poktroll/x/application/types" sharedtypes "github.com/pokt-network/poktroll/x/shared/types" ) @@ -25,7 +27,11 @@ type ApplicationModuleSuite struct { // GetAppQueryClient constructs and returns a query client for the application // module of the integration app. func (s *ApplicationModuleSuite) GetAppQueryClient() client.ApplicationQueryClient { - deps := depinject.Supply(s.GetApp().QueryHelper()) + appCache := cache.NewKeyValueCache[apptypes.Application]() + appParamsCache := cache.NewParamsCache[apptypes.Params]() + logger := polylog.Ctx(s.GetApp().QueryHelper().Ctx) + + deps := depinject.Supply(s.GetApp().QueryHelper(), appCache, appParamsCache, logger) appQueryClient, err := query.NewApplicationQuerier(deps) require.NoError(s.T(), err) diff --git a/testutil/testclient/testqueryclients/grpcserver.go b/testutil/testclient/testqueryclients/grpcserver.go new file mode 100644 index 000000000..e12b285fd --- /dev/null +++ b/testutil/testclient/testqueryclients/grpcserver.go @@ -0,0 +1,168 @@ +package testqueryclients + +import ( + "context" + + codectypes "github.com/cosmos/cosmos-sdk/codec/types" + "github.com/cosmos/cosmos-sdk/crypto/keys/secp256k1" + cosmostypes "github.com/cosmos/cosmos-sdk/types" + authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" + banktypes "github.com/cosmos/cosmos-sdk/x/bank/types" + + apptypes "github.com/pokt-network/poktroll/x/application/types" + prooftypes "github.com/pokt-network/poktroll/x/proof/types" + servicetypes "github.com/pokt-network/poktroll/x/service/types" + sessiontypes "github.com/pokt-network/poktroll/x/session/types" + sharedtypes "github.com/pokt-network/poktroll/x/shared/types" + suppliertypes "github.com/pokt-network/poktroll/x/supplier/types" +) + +// callCounter is a simple struct that keeps track of the number of times a method is called +type callCounter struct { + callCount int +} + +func (c *callCounter) CallCount() int { + return c.callCount +} + +func (c *callCounter) Increment() { + c.callCount++ +} + +// MockServiceQueryServer is a mock implementation of the servicetypes.QueryServer interface +// that keeps track of the number of times each method is called. +type MockServiceQueryServer struct { + servicetypes.UnimplementedQueryServer + ServiceCallCounter callCounter + RelayMiningDifficultyCallCounter callCounter +} + +func (m *MockServiceQueryServer) Service(ctx context.Context, req *servicetypes.QueryGetServiceRequest) (*servicetypes.QueryGetServiceResponse, error) { + m.ServiceCallCounter.Increment() + return &servicetypes.QueryGetServiceResponse{}, nil +} + +func (m *MockServiceQueryServer) RelayMiningDifficulty(ctx context.Context, req *servicetypes.QueryGetRelayMiningDifficultyRequest) (*servicetypes.QueryGetRelayMiningDifficultyResponse, error) { + m.RelayMiningDifficultyCallCounter.Increment() + return &servicetypes.QueryGetRelayMiningDifficultyResponse{}, nil +} + +// MockApplicationQueryServer is a mock implementation of the apptypes.QueryServer interface +// that keeps track of the number of times each method is called. +type MockApplicationQueryServer struct { + apptypes.UnimplementedQueryServer + AppCallCounter callCounter + ParamsCallCounter callCounter +} + +func (m *MockApplicationQueryServer) Application(ctx context.Context, req *apptypes.QueryGetApplicationRequest) (*apptypes.QueryGetApplicationResponse, error) { + m.AppCallCounter.Increment() + return &apptypes.QueryGetApplicationResponse{}, nil +} + +func (m *MockApplicationQueryServer) Params(ctx context.Context, req *apptypes.QueryParamsRequest) (*apptypes.QueryParamsResponse, error) { + m.ParamsCallCounter.Increment() + return &apptypes.QueryParamsResponse{}, nil +} + +// MockSupplierQueryServer is a mock implementation of the suppliertypes.QueryServer interface +// that keeps track of the number of times each method is called. +type MockSupplierQueryServer struct { + suppliertypes.UnimplementedQueryServer + SupplierCallCounter callCounter +} + +func (m *MockSupplierQueryServer) Supplier(ctx context.Context, req *suppliertypes.QueryGetSupplierRequest) (*suppliertypes.QueryGetSupplierResponse, error) { + m.SupplierCallCounter.Increment() + return &suppliertypes.QueryGetSupplierResponse{}, nil +} + +// MockSessionQueryServer is a mock implementation of the sessiontypes.QueryServer interface +// that keeps track of the number of times each method is called. +type MockSessionQueryServer struct { + sessiontypes.UnimplementedQueryServer + SessionCallCounter callCounter + ParamsCallCounter callCounter +} + +func (m *MockSessionQueryServer) GetSession(ctx context.Context, req *sessiontypes.QueryGetSessionRequest) (*sessiontypes.QueryGetSessionResponse, error) { + m.SessionCallCounter.Increment() + return &sessiontypes.QueryGetSessionResponse{ + Session: &sessiontypes.Session{}, + }, nil +} + +func (m *MockSessionQueryServer) Params(ctx context.Context, req *sessiontypes.QueryParamsRequest) (*sessiontypes.QueryParamsResponse, error) { + m.ParamsCallCounter.Increment() + return &sessiontypes.QueryParamsResponse{}, nil +} + +// MockSharedQueryServer is a mock implementation of the sharedtypes.QueryServer interface +// that keeps track of the number of times each method is called. +type MockSharedQueryServer struct { + sharedtypes.UnimplementedQueryServer + ParamsCallCounter callCounter +} + +func (m *MockSharedQueryServer) Params(ctx context.Context, req *sharedtypes.QueryParamsRequest) (*sharedtypes.QueryParamsResponse, error) { + m.ParamsCallCounter.Increment() + return &sharedtypes.QueryParamsResponse{ + Params: sharedtypes.Params{ + NumBlocksPerSession: 10, + }, + }, nil +} + +// MockProofQueryServer is a mock implementation of the prooftypes.QueryServer interface +// that keeps track of the number of times each method is called. +type MockProofQueryServer struct { + prooftypes.UnimplementedQueryServer + ParamsCallCounter callCounter +} + +func (m *MockProofQueryServer) Params(ctx context.Context, req *prooftypes.QueryParamsRequest) (*prooftypes.QueryParamsResponse, error) { + m.ParamsCallCounter.Increment() + return &prooftypes.QueryParamsResponse{}, nil +} + +// MockBankQueryServer is a mock implementation of the banktypes.QueryServer interface +// that keeps track of the number of times each method is called. +type MockBankQueryServer struct { + banktypes.UnimplementedQueryServer + BalanceCallCounter callCounter +} + +func (m *MockBankQueryServer) Balance(ctx context.Context, req *banktypes.QueryBalanceRequest) (*banktypes.QueryBalanceResponse, error) { + m.BalanceCallCounter.Increment() + return &banktypes.QueryBalanceResponse{ + Balance: &cosmostypes.Coin{}, + }, nil +} + +// MockAccountQueryServer is a mock implementation of the authtypes.QueryServer interface +// that keeps track of the number of times each method is called. +type MockAccountQueryServer struct { + authtypes.UnimplementedQueryServer + AccountCallCounter callCounter +} + +func (m *MockAccountQueryServer) Account(ctx context.Context, req *authtypes.QueryAccountRequest) (*authtypes.QueryAccountResponse, error) { + m.AccountCallCounter.Increment() + pubKey := secp256k1.GenPrivKey().PubKey() + + account := &authtypes.BaseAccount{} + err := account.SetPubKey(pubKey) + if err != nil { + return nil, err + } + + accountAny, err := codectypes.NewAnyWithValue(account) + if err != nil { + return nil, err + } + + return &authtypes.QueryAccountResponse{ + Account: accountAny, + }, nil +}