Skip to content

Commit

Permalink
[Code Health] refactor: random number generation (#618)
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanchriswhite authored Jun 26, 2024
1 parent 1587a30 commit 6a41dd4
Show file tree
Hide file tree
Showing 11 changed files with 151 additions and 135 deletions.
34 changes: 34 additions & 0 deletions pkg/crypto/rand/float.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package rand

import (
"bytes"
"encoding/binary"
"math/rand"

"github.com/cometbft/cometbft/crypto"
)

// SeededFloat32 generates a deterministic float32 between 0 and 1 given a seed.
//
// TODO_MAINNET: To support other language implementations of the protocol, the
// pseudo-random number generator used here should be language-agnostic (i.e. not
// golang specific).
func SeededFloat32(seedParts ...[]byte) (float32, error) {
seedHashInputBz := bytes.Join(append([][]byte{}, seedParts...), nil)
seedHash := crypto.Sha256(seedHashInputBz)
seed, _ := binary.Varint(seedHash)

// Construct a pseudo-random number generator with the seed.
pseudoRand := rand.New(rand.NewSource(seed))

// Generate a random uint32.
randUint32 := pseudoRand.Uint32()

// Clamp the random float32 between [0,1]. This is achieved by dividing the random uint32
// by the most significant digit of a float32, which is 2^32, guaranteeing an output between
// 0 and 1, inclusive.
oneMostSignificantDigitFloat32 := float32(1 << 32)
randClampedFloat32 := float32(randUint32) / oneMostSignificantDigitFloat32

return randClampedFloat32, nil
}
54 changes: 54 additions & 0 deletions pkg/crypto/rand/float_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package rand_test

import (
"encoding/binary"
"sync"
"sync/atomic"
"testing"

"github.com/stretchr/testify/require"

poktrand "github.com/pokt-network/poktroll/pkg/crypto/rand"
prooftypes "github.com/pokt-network/poktroll/x/proof/types"
)

func TestSeededFloat32(t *testing.T) {
probability := prooftypes.DefaultProofRequestProbability
tolerance := 0.01
confidence := 0.99

sampleSize := poktrand.RequiredSampleSize(float64(probability), tolerance, confidence)

var numTrueSamples atomic.Int64

// Sample concurrently to save time.
wg := sync.WaitGroup{}
for idx := int64(0); idx < sampleSize; idx++ {
wg.Add(1)
go func() {
idxBz := make([]byte, binary.MaxVarintLen64)
binary.PutVarint(idxBz, idx)
randFloat, err := poktrand.SeededFloat32(idxBz)
require.NoError(t, err)

if randFloat < 0 || randFloat > 1 {
t.Fatalf("secureRandFloat64() returned out of bounds value: %f", randFloat)
}

if randFloat <= probability {
numTrueSamples.Add(1)
}
wg.Done()
}()
}
wg.Wait()

expectedNumTrueSamples := float32(sampleSize) * probability
expectedNumFalseSamples := float32(sampleSize) * (1 - probability)
toleranceSamples := tolerance * float64(sampleSize)

// Check that the number of samples for each outcome is within the expected range.
numFalseSamples := sampleSize - numTrueSamples.Load()
require.InDeltaf(t, expectedNumTrueSamples, numTrueSamples.Load(), toleranceSamples, "true samples")
require.InDeltaf(t, expectedNumFalseSamples, numFalseSamples, toleranceSamples, "false samples")
}
23 changes: 23 additions & 0 deletions pkg/crypto/rand/integer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package rand

import (
"bytes"
"encoding/binary"
"math/rand"

"github.com/cometbft/cometbft/crypto"
)

// SeededInt63 generates a deterministic non-negative int64 by seeding a random
// source with the hash of seedParts.
//
// TODO_MAINNET: To support other language implementations of the protocol, the
// pseudo-random number generator used here should be language-agnostic (i.e. not
// golang specific).
func SeededInt63(seedParts ...[]byte) int64 {
seedHashInputBz := bytes.Join(append([][]byte{}, seedParts...), nil)
seedHash := crypto.Sha256(seedHashInputBz)
seed, _ := binary.Varint(seedHash)

return rand.NewSource(seed).Int63()
}
58 changes: 5 additions & 53 deletions x/tokenomics/keeper/random_test.go → pkg/crypto/rand/samples.go
Original file line number Diff line number Diff line change
@@ -1,56 +1,8 @@
package keeper
package rand

import (
"math"
"sync"
"sync/atomic"
"testing"
import "math"

"github.com/stretchr/testify/require"

prooftypes "github.com/pokt-network/poktroll/x/proof/types"
)

func TestRandProbability(t *testing.T) {
probability := prooftypes.DefaultProofRequestProbability
tolerance := 0.01
confidence := 0.99

sampleSize := requiredSampleSize(float64(probability), tolerance, confidence)

var numTrueSamples atomic.Int64

// Sample concurrently to save time.
wg := sync.WaitGroup{}
for i := 0; i < sampleSize; i++ {
wg.Add(1)
go func() {
rand, err := randProbability(int64(i))
require.NoError(t, err)

if rand < 0 || rand > 1 {
t.Fatalf("secureRandFloat64() returned out of bounds value: %f", rand)
}

if rand <= probability {
numTrueSamples.Add(1)
}
wg.Done()
}()
}
wg.Wait()

expectedNumTrueSamples := float32(sampleSize) * probability
expectedNumFalseSamples := float32(sampleSize) * (1 - probability)
toleranceSamples := tolerance * float64(sampleSize)

// Check that the number of samples for each outcome is within the expected range.
numFalseSamples := int64(sampleSize) - numTrueSamples.Load()
require.InDeltaf(t, expectedNumTrueSamples, numTrueSamples.Load(), toleranceSamples, "true samples")
require.InDeltaf(t, expectedNumFalseSamples, numFalseSamples, toleranceSamples, "false samples")
}

// requiredSampleSize calculates the number of samples needed to achieve a desired confidence level
// RequiredSampleSize calculates the number of samples needed to achieve a desired confidence level
// for a given probability and error threshold.
// Arguments:
// - probability: the estimated proportion of the population (e.g., 0.5 for 50%).
Expand All @@ -64,7 +16,7 @@ func TestRandProbability(t *testing.T) {
//
// The function uses the standard formula for sample size determination for estimating a proportion.
// For more details, see: https://en.wikipedia.org/wiki/Sample_size_determination#Estimation_of_a_proportion
func requiredSampleSize(probability, errThreshold, confidence float64) int {
func RequiredSampleSize(probability, errThreshold, confidence float64) int64 {
// Calculate the z-score corresponding to the desired confidence level.
// The z-score represents the number of standard deviations a data point
// is from the mean in a standard normal distribution. For a given confidence
Expand All @@ -78,7 +30,7 @@ func requiredSampleSize(probability, errThreshold, confidence float64) int {
// Calculate the number of trials needed
n := (z * z * probability * (1 - probability)) / (errThreshold * errThreshold)

return int(math.Ceil(n))
return int64(math.Ceil(n))
}

// normInv returns the inverse of the standard normal cumulative distribution function (CDF),
Expand Down
2 changes: 1 addition & 1 deletion pkg/relayer/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,5 +158,5 @@ type SessionTree interface {
// It returns an error if it has already been marked as such.
StartClaiming() error

SupplierAddress() *cosmostypes.AccAddress
GetSupplierAddress() *cosmostypes.AccAddress
}
4 changes: 2 additions & 2 deletions pkg/relayer/session/claim.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,12 @@ func (rs *relayerSessionsManager) newMapClaimSessionsFn(
// Map key is the supplier address.
sessionClaims := map[string][]*relayer.SessionClaim{}
for _, session := range sessionTrees {
supplierAddr := session.SupplierAddress().String()
supplierAddr := session.GetSupplierAddress().String()

sessionClaims[supplierAddr] = append(sessionClaims[supplierAddr], &relayer.SessionClaim{
RootHash: session.GetClaimRoot(),
SessionHeader: session.GetSessionHeader(),
SupplierAddress: *session.SupplierAddress(),
SupplierAddress: *session.GetSupplierAddress(),
})
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/relayer/session/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,11 @@ func (rs *relayerSessionsManager) newMapProveSessionsFn(
// Map key is the supplier address.
sessionProofs := map[string][]*relayer.SessionProof{}
for _, session := range sessionTrees {
supplierAddr := session.SupplierAddress().String()
supplierAddr := session.GetSupplierAddress().String()
sessionProofs[supplierAddr] = append(sessionProofs[supplierAddr], &relayer.SessionProof{
ProofBz: session.GetProofBz(),
SessionHeader: session.GetSessionHeader(),
SupplierAddress: *session.SupplierAddress(),
SupplierAddress: *session.GetSupplierAddress(),
})
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/relayer/session/sessiontree.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,6 @@ func (st *sessionTree) StartClaiming() error {
}

// SupplierAddress returns a CosmosSDK address of the supplier this sessionTree belongs to.
func (st *sessionTree) SupplierAddress() *cosmostypes.AccAddress {
func (st *sessionTree) GetSupplierAddress() *cosmostypes.AccAddress {
return st.supplierAddress
}
12 changes: 12 additions & 0 deletions x/proof/types/claim.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package types
import (
"fmt"

"github.com/cometbft/cometbft/crypto"

"github.com/pokt-network/smt"
)

Expand Down Expand Up @@ -40,3 +42,13 @@ func (claim *Claim) GetNumRelays() (numRelays uint64, err error) {

return smt.MerkleRoot(claim.GetRootHash()).Count(), nil
}

// GetHash returns the SHA-256 hash of the serialized claim.
func (claim *Claim) GetHash() ([]byte, error) {
claimBz, err := claim.Marshal()
if err != nil {
return nil, err
}

return crypto.Sha256(claimBz), nil
}
33 changes: 13 additions & 20 deletions x/tokenomics/keeper/proof_requirement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ package keeper_test

import (
"math/rand"
"sync"
"sync/atomic"
"testing"

"cosmossdk.io/log"
cosmostypes "github.com/cosmos/cosmos-sdk/types"
"github.com/stretchr/testify/require"

poktrand "github.com/pokt-network/poktroll/pkg/crypto/rand"
"github.com/pokt-network/poktroll/testutil/keeper"
tetsproof "github.com/pokt-network/poktroll/testutil/proof"
"github.com/pokt-network/poktroll/testutil/sample"
Expand All @@ -23,47 +23,40 @@ func init() {
}

func TestKeeper_IsProofRequired(t *testing.T) {
// TODO_UPNEXT(#618): reuse requiredSampleSize()
t.SkipNow()

// Set expectedCompute units to be below the proof requirement threshold to only
// exercise the probabilistic branch of the #isProofRequired() logic.
expectedComputeUnits := prooftypes.DefaultProofRequirementThreshold - 1
keepers, ctx := keeper.NewTokenomicsModuleKeepers(t, log.NewNopLogger())
sdkCtx := cosmostypes.UnwrapSDKContext(ctx)

var (
sampleSize = 15000
probability = prooftypes.DefaultProofRequestProbability
tolerance = 0.01
confidence = 0.99

numTrueSamples atomic.Int64
)

// Sample concurrently to save time.
wg := sync.WaitGroup{}
for i := 0; i < sampleSize; i++ {
wg.Add(1)
go func() {
claim := tetsproof.ClaimWithRandomHash(t, sample.AccAddress(), sample.AccAddress(), expectedComputeUnits)
sampleSize := poktrand.RequiredSampleSize(float64(probability), tolerance, confidence)

// NB: Not possible to sample concurrently, this causes a race condition due to the keeper's gas meter.
for i := int64(0); i < sampleSize; i++ {
claim := tetsproof.ClaimWithRandomHash(t, sample.AccAddress(), sample.AccAddress(), expectedComputeUnits)

isRequired, err := keepers.Keeper.IsProofRequiredForClaim(sdkCtx, &claim)
require.NoError(t, err)
isRequired, err := keepers.Keeper.IsProofRequiredForClaim(sdkCtx, &claim)
require.NoError(t, err)

if isRequired {
numTrueSamples.Add(1)
}
wg.Done()
}()
if isRequired {
numTrueSamples.Add(1)
}
}
wg.Wait()

expectedNumTrueSamples := float32(sampleSize) * probability
expectedNumFalseSamples := float32(sampleSize) * (1 - probability)
toleranceSamples := tolerance * float64(sampleSize)

// Check that the number of samples for each outcome is within the expected range.
numFalseSamples := int64(sampleSize) - numTrueSamples.Load()
numFalseSamples := sampleSize - numTrueSamples.Load()
require.InDeltaf(t, expectedNumTrueSamples, numTrueSamples.Load(), toleranceSamples, "true samples")
require.InDeltaf(t, expectedNumFalseSamples, numFalseSamples, toleranceSamples, "false samples")
}
Loading

0 comments on commit 6a41dd4

Please sign in to comment.