diff --git a/common/common.go b/common/common.go index 00b6f92d28..9f122b8688 100644 --- a/common/common.go +++ b/common/common.go @@ -6,6 +6,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/consensus" ) @@ -84,3 +85,14 @@ func VerifyProofAgainstHeader(proof data.HeaderProofHandler, header data.HeaderH return nil } + +// GetShardIDs returns a map of shard IDs based on the provided shard coordinator +func GetShardIDs(numShards uint32) map[uint32]struct{} { + shardIdentifiers := make(map[uint32]struct{}) + for i := uint32(0); i < numShards; i++ { + shardIdentifiers[i] = struct{}{} + } + shardIdentifiers[core.MetachainShardId] = struct{}{} + + return shardIdentifiers +} diff --git a/consensus/broadcast/delayedBroadcast.go b/consensus/broadcast/delayedBroadcast.go index 512d9767d8..9f67dcbc24 100644 --- a/consensus/broadcast/delayedBroadcast.go +++ b/consensus/broadcast/delayedBroadcast.go @@ -585,7 +585,7 @@ func (dbb *delayedBlockBroadcaster) registerInterceptorsCallbackForShard( rootTopic string, cb func(topic string, hash []byte, data interface{}), ) error { - shardIDs := dbb.shardIdentifiers() + shardIDs := common.GetShardIDs(dbb.shardCoordinator.NumberOfShards()) for idx := range shardIDs { // interested only in cross shard data if idx == dbb.shardCoordinator.SelfId() { @@ -603,16 +603,6 @@ func (dbb *delayedBlockBroadcaster) registerInterceptorsCallbackForShard( return nil } -func (dbb *delayedBlockBroadcaster) shardIdentifiers() map[uint32]struct{} { - shardIdentifiers := make(map[uint32]struct{}) - for i := uint32(0); i < dbb.shardCoordinator.NumberOfShards(); i++ { - shardIdentifiers[i] = struct{}{} - } - shardIdentifiers[core.MetachainShardId] = struct{}{} - - return shardIdentifiers -} - func (dbb *delayedBlockBroadcaster) interceptedHeader(_ string, headerHash []byte, header interface{}) { headerHandler, ok := header.(data.HeaderHandler) if !ok { diff --git a/process/block/baseProcess.go b/process/block/baseProcess.go index 0d9d13b4ce..8d67ec7a10 100644 --- a/process/block/baseProcess.go +++ b/process/block/baseProcess.go @@ -666,10 +666,51 @@ func (bp *baseProcessor) verifyFees(header data.HeaderHandler) error { return nil } +func (bp *baseProcessor) filterHeadersWithoutProofs() (map[string]*hdrInfo, error) { + removedNonces := make(map[uint32]map[uint64]struct{}) + noncesWithProofs := make(map[uint32]map[uint64]struct{}) + shardIDs := common.GetShardIDs(bp.shardCoordinator.NumberOfShards()) + for shard := range shardIDs { + removedNonces[shard] = make(map[uint64]struct{}) + noncesWithProofs[shard] = make(map[uint64]struct{}) + } + filteredHeadersInfo := make(map[string]*hdrInfo) + + for hdrHash, headerInfo := range bp.hdrsForCurrBlock.hdrHashAndInfo { + if bp.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, headerInfo.hdr.GetEpoch()) { + if bp.hasMissingProof(headerInfo, hdrHash) { + removedNonces[headerInfo.hdr.GetShardID()][headerInfo.hdr.GetNonce()] = struct{}{} + continue + } + + noncesWithProofs[headerInfo.hdr.GetShardID()][headerInfo.hdr.GetNonce()] = struct{}{} + filteredHeadersInfo[hdrHash] = bp.hdrsForCurrBlock.hdrHashAndInfo[hdrHash] + continue + } + + filteredHeadersInfo[hdrHash] = bp.hdrsForCurrBlock.hdrHashAndInfo[hdrHash] + } + + for shard, nonces := range removedNonces { + for nonce := range nonces { + if _, ok := noncesWithProofs[shard][nonce]; !ok { + return nil, fmt.Errorf("%w for shard %d and nonce %d", process.ErrMissingHeaderProof, shard, nonce) + } + } + } + + return filteredHeadersInfo, nil +} + func (bp *baseProcessor) computeHeadersForCurrentBlock(usedInBlock bool) (map[uint32][]data.HeaderHandler, error) { hdrsForCurrentBlock := make(map[uint32][]data.HeaderHandler) - for hdrHash, headerInfo := range bp.hdrsForCurrBlock.hdrHashAndInfo { + hdrHashAndInfo, err := bp.filterHeadersWithoutProofs() + if err != nil { + return nil, err + } + + for hdrHash, headerInfo := range hdrHashAndInfo { if headerInfo.usedInBlock != usedInBlock { continue } @@ -747,7 +788,7 @@ func (bp *baseProcessor) sortHeaderHashesForCurrentBlockByNonce(usedInBlock bool } func (bp *baseProcessor) hasMissingProof(headerInfo *hdrInfo, hdrHash string) bool { - isFlagEnabledForHeader := bp.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, headerInfo.hdr.GetEpoch()) && headerInfo.hdr.GetNonce() > 1 + isFlagEnabledForHeader := bp.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, headerInfo.hdr.GetEpoch()) && headerInfo.hdr.GetNonce() >= 1 if !isFlagEnabledForHeader { return false }