From 42fd6b827812597ecf47fcea90abd32620018252 Mon Sep 17 00:00:00 2001 From: Simon-Pierre Vivier Date: Thu, 30 Jan 2025 08:46:34 -0500 Subject: [PATCH] feat: waku sync shard matching check (#3259) --- tests/waku_store_sync/sync_utils.nim | 4 ++ tests/waku_store_sync/test_protocol.nim | 34 +++++++++++++++++ waku/node/waku_node.nim | 12 +++++- waku/waku_store_sync/codec.nim | 49 ++++++++++++++++++++++++- waku/waku_store_sync/common.nim | 3 ++ waku/waku_store_sync/reconciliation.nim | 31 ++++++++++++---- 6 files changed, 124 insertions(+), 9 deletions(-) diff --git a/tests/waku_store_sync/sync_utils.nim b/tests/waku_store_sync/sync_utils.nim index aa56ff2e51..20a6bdfb19 100644 --- a/tests/waku_store_sync/sync_utils.nim +++ b/tests/waku_store_sync/sync_utils.nim @@ -25,10 +25,14 @@ proc newTestWakuRecon*( idsRx: AsyncQueue[SyncID], wantsTx: AsyncQueue[(PeerId, Fingerprint)], needsTx: AsyncQueue[(PeerId, Fingerprint)], + cluster: uint16 = 1, + shards: seq[uint16] = @[0, 1, 2, 3, 4, 5, 6, 7], ): Future[SyncReconciliation] {.async.} = let peerManager = PeerManager.new(switch) let res = await SyncReconciliation.new( + cluster = cluster, + shards = shards, peerManager = peerManager, wakuArchive = nil, relayJitter = 0.seconds, diff --git a/tests/waku_store_sync/test_protocol.nim b/tests/waku_store_sync/test_protocol.nim index d3ffa187f3..f507ad95bc 100644 --- a/tests/waku_store_sync/test_protocol.nim +++ b/tests/waku_store_sync/test_protocol.nim @@ -157,6 +157,40 @@ suite "Waku Sync: reconciliation": localWants.contains((clientPeerInfo.peerId, hash3)) == true localWants.contains((serverPeerInfo.peerId, hash2)) == true + asyncTest "sync 2 nodes different shards": + let + msg1 = fakeWakuMessage(ts = now(), contentTopic = DefaultContentTopic) + msg2 = fakeWakuMessage(ts = now() + 1, contentTopic = DefaultContentTopic) + msg3 = fakeWakuMessage(ts = now() + 2, contentTopic = DefaultContentTopic) + hash1 = computeMessageHash(DefaultPubsubTopic, msg1) + hash2 = computeMessageHash(DefaultPubsubTopic, msg2) + hash3 = computeMessageHash(DefaultPubsubTopic, msg3) + + server.messageIngress(hash1, msg1) + server.messageIngress(hash2, msg2) + client.messageIngress(hash1, msg1) + client.messageIngress(hash3, msg3) + + check: + remoteNeeds.contains((serverPeerInfo.peerId, hash3)) == false + remoteNeeds.contains((clientPeerInfo.peerId, hash2)) == false + localWants.contains((clientPeerInfo.peerId, hash3)) == false + localWants.contains((serverPeerInfo.peerId, hash2)) == false + + server = await newTestWakuRecon( + serverSwitch, idsChannel, localWants, remoteNeeds, shards = @[0.uint16, 1, 2, 3] + ) + client = await newTestWakuRecon( + clientSwitch, idsChannel, localWants, remoteNeeds, shards = @[4.uint16, 5, 6, 7] + ) + + var syncRes = await client.storeSynchronization(some(serverPeerInfo)) + assert syncRes.isOk(), $syncRes.error + + check: + remoteNeeds.len == 0 + localWants.len == 0 + asyncTest "sync 2 nodes same hashes": let msg1 = fakeWakuMessage(ts = now(), contentTopic = DefaultContentTopic) diff --git a/waku/node/waku_node.nim b/waku/node/waku_node.nim index 1e3b2f1272..5b4f9900ff 100644 --- a/waku/node/waku_node.nim +++ b/waku/node/waku_node.nim @@ -216,9 +216,19 @@ proc mountStoreSync*( let wantsChannel = newAsyncQueue[(PeerId, WakuMessageHash)](100) let needsChannel = newAsyncQueue[(PeerId, WakuMessageHash)](100) + var cluster: uint16 + var shards: seq[uint16] + let enrRes = node.enr.toTyped() + if enrRes.isOk(): + let shardingRes = enrRes.get().relaySharding() + if shardingRes.isSome(): + let relayShard = shardingRes.get() + cluster = relayShard.clusterID + shards = relayShard.shardIds + let recon = ?await SyncReconciliation.new( - node.peerManager, node.wakuArchive, storeSyncRange.seconds, + cluster, shards, node.peerManager, node.wakuArchive, storeSyncRange.seconds, storeSyncInterval.seconds, storeSyncRelayJitter.seconds, idsChannel, wantsChannel, needsChannel, ) diff --git a/waku/waku_store_sync/codec.nim b/waku/waku_store_sync/codec.nim index ee0b926a3a..815ed9d618 100644 --- a/waku/waku_store_sync/codec.nim +++ b/waku/waku_store_sync/codec.nim @@ -52,6 +52,18 @@ proc deltaEncode*(value: RangesData): seq[byte] = i = 0 j = 0 + # encode cluster + buf = uint64(value.cluster).toBytes(Leb128) + output &= @buf + + # encode shards + buf = uint64(value.shards.len).toBytes(Leb128) + output &= @buf + + for shard in value.shards: + buf = uint64(shard).toBytes(Leb128) + output &= @buf + # the first range is implicit but must be explicit when encoded let (bound, _) = value.ranges[0] @@ -209,6 +221,38 @@ proc getReconciled(idx: var int, buffer: seq[byte]): Result[bool, string] = return ok(recon) +proc getCluster(idx: var int, buffer: seq[byte]): Result[uint16, string] = + if idx + VarIntLen > buffer.len: + return err("Cannot decode cluster") + + let slice = buffer[idx ..< idx + VarIntLen] + let (val, len) = uint64.fromBytes(slice, Leb128) + idx += len + + return ok(uint16(val)) + +proc getShards(idx: var int, buffer: seq[byte]): Result[seq[uint16], string] = + if idx + VarIntLen > buffer.len: + return err("Cannot decode shards count") + + let slice = buffer[idx ..< idx + VarIntLen] + let (val, len) = uint64.fromBytes(slice, Leb128) + idx += len + let shardsLen = val + + var shards: seq[uint16] + for i in 0 ..< shardsLen: + if idx + VarIntLen > buffer.len: + return err("Cannot decode shard value. idx: " & $i) + + let slice = buffer[idx ..< idx + VarIntLen] + let (val, len) = uint64.fromBytes(slice, Leb128) + idx += len + + shards.add(uint16(val)) + + return ok(shards) + proc deltaDecode*( itemSet: var ItemSet, buffer: seq[byte], setLength: int ): Result[int, string] = @@ -242,7 +286,7 @@ proc getItemSet( return ok(itemSet) proc deltaDecode*(T: type RangesData, buffer: seq[byte]): Result[T, string] = - if buffer.len == 1: + if buffer.len <= 1: return ok(RangesData()) var @@ -250,6 +294,9 @@ proc deltaDecode*(T: type RangesData, buffer: seq[byte]): Result[T, string] = lastTime = Timestamp(0) idx = 0 + payload.cluster = ?getCluster(idx, buffer) + payload.shards = ?getShards(idx, buffer) + lastTime = ?getTimestamp(idx, buffer) # implicit first hash is always 0 diff --git a/waku/waku_store_sync/common.nim b/waku/waku_store_sync/common.nim index 2795450786..e2eac0f853 100644 --- a/waku/waku_store_sync/common.nim +++ b/waku/waku_store_sync/common.nim @@ -26,6 +26,9 @@ type ItemSet = 2 RangesData* = object + cluster*: uint16 + shards*: seq[uint16] + ranges*: seq[(Slice[SyncID], RangeType)] fingerprints*: seq[Fingerprint] # Range type fingerprint stored here in order itemSets*: seq[ItemSet] # Range type itemset stored here in order diff --git a/waku/waku_store_sync/reconciliation.nim b/waku/waku_store_sync/reconciliation.nim index 9ac81c6677..10e8aed52c 100644 --- a/waku/waku_store_sync/reconciliation.nim +++ b/waku/waku_store_sync/reconciliation.nim @@ -1,7 +1,7 @@ {.push raises: [].} import - std/[sequtils, options], + std/[sequtils, options, packedsets], stew/byteutils, results, chronicles, @@ -37,6 +37,9 @@ logScope: const DefaultStorageCap = 50_000 type SyncReconciliation* = ref object of LPProtocol + cluster: uint16 + shards: PackedSet[uint16] + peerManager: PeerManager wakuArchive: WakuArchive @@ -114,16 +117,24 @@ proc processRequest( var hashToRecv: seq[WakuMessageHash] hashToSend: seq[WakuMessageHash] + sendPayload: RangesData + rawPayload: seq[byte] + + # Only process the ranges IF the shards and cluster matches + if self.cluster == recvPayload.cluster and + recvPayload.shards.toPackedSet() == self.shards: + sendPayload = self.storage.processPayload(recvPayload, hashToSend, hashToRecv) - let sendPayload = self.storage.processPayload(recvPayload, hashToSend, hashToRecv) + sendPayload.cluster = self.cluster + sendPayload.shards = self.shards.toSeq() - for hash in hashToSend: - await self.remoteNeedsTx.addLast((conn.peerId, hash)) + for hash in hashToSend: + await self.remoteNeedsTx.addLast((conn.peerId, hash)) - for hash in hashToRecv: - await self.localWantstx.addLast((conn.peerId, hash)) + for hash in hashToRecv: + await self.localWantstx.addLast((conn.peerId, hash)) - let rawPayload = sendPayload.deltaEncode() + rawPayload = sendPayload.deltaEncode() total_bytes_exchanged.observe( rawPayload.len, labelValues = [Reconciliation, Sending] @@ -162,6 +173,8 @@ proc initiate( fingerprint = self.storage.computeFingerprint(bounds) initPayload = RangesData( + cluster: self.cluster, + shards: self.shards.toSeq(), ranges: @[(bounds, RangeType.Fingerprint)], fingerprints: @[fingerprint], itemSets: @[], @@ -261,6 +274,8 @@ proc initFillStorage( proc new*( T: type SyncReconciliation, + cluster: uint16, + shards: seq[uint16], peerManager: PeerManager, wakuArchive: WakuArchive, syncRange: timer.Duration = DefaultSyncRange, @@ -279,6 +294,8 @@ proc new*( SeqStorage.new(res.get()) var sync = SyncReconciliation( + cluster: cluster, + shards: shards.toPackedSet(), peerManager: peerManager, storage: storage, syncRange: syncRange,