From 826fc2defbadfd3c5dd4430b6c13a6881309ab1d Mon Sep 17 00:00:00 2001 From: Vitaly Lavrov Date: Sun, 15 Sep 2024 14:41:56 +0200 Subject: [PATCH] Ping candidates and remove bad nodes --- cmd/src/main/scala/Main.scala | 37 ++---- .../torrentdam/bittorrent/dht/Client.scala | 85 +++++++++----- .../torrentdam/bittorrent/dht/Node.scala | 109 ++++++++++++------ .../bittorrent/dht/PeerDiscovery.scala | 56 ++++----- .../bittorrent/dht/PingRoutine.scala | 7 +- .../bittorrent/dht/QueryHandler.scala | 40 +++---- .../bittorrent/dht/RequestResponse.scala | 8 +- .../bittorrent/dht/RoutingTable.scala | 18 ++- .../dht/RoutingTableBootstrap.scala | 44 +++---- 9 files changed, 229 insertions(+), 175 deletions(-) diff --git a/cmd/src/main/scala/Main.scala b/cmd/src/main/scala/Main.scala index d8f2b9e..2ab0ff5 100644 --- a/cmd/src/main/scala/Main.scala +++ b/cmd/src/main/scala/Main.scala @@ -67,16 +67,12 @@ object Main async[ResourceIO] { given Random[IO] = Resource.eval(Random.scalaUtilRandom[IO]).await - val selfId = Resource.eval(NodeId.generate[IO]).await val selfPeerId = Resource.eval(PeerId.generate[IO]).await val infoHash = Resource.eval(infoHashFromString(infoHashOption)).await - val table = Resource.eval(RoutingTable[IO](selfId)).await - val node = Node(selfId, none, QueryHandler(selfId, table)).await - Resource.eval(RoutingTableBootstrap[IO](table, node.client)).await - val discovery = PeerDiscovery.make(table, node.client).await + val node = Node().await val swarm = Swarm( - discovery.discover(infoHash), + node.discovery.discover(infoHash), Connection.connect(selfPeerId, _, infoHash) ).await val metadata = DownloadMetadata(swarm).toResource.await @@ -132,8 +128,6 @@ object Main throw new Exception("Missing info-hash") given Random[IO] = Resource.eval(Random.scalaUtilRandom[IO]).await - - val selfId = Resource.eval(NodeId.generate[IO]).await val selfPeerId = Resource.eval(PeerId.generate[IO]).await val peerAddress = peerAddressOption.flatMap(SocketAddress.fromStringIp) val peers: Stream[IO, PeerInfo] = @@ -141,14 +135,9 @@ object Main case Some(peerAddress) => Stream.emit(PeerInfo(peerAddress)).covary[IO] case None => - val bootstrapNodeAddress = dhtNodeAddressOption - .map(SocketAddress.fromString(_).toList) - .getOrElse(RoutingTableBootstrap.PublicBootstrapNodes) - val table = Resource.eval(RoutingTable[IO](selfId)).await - val node = Node(selfId, none, QueryHandler(selfId, table)).await - Resource.eval(RoutingTableBootstrap(table, node.client, bootstrapNodeAddress)).await - val discovery = PeerDiscovery.make(table, node.client).await - discovery.discover(infoHash) + val bootstrapNodeAddress = dhtNodeAddressOption.flatMap(SocketAddress.fromString) + val node = Node(none, bootstrapNodeAddress).await + node.discovery.discover(infoHash) val swarm = Swarm(peers, peerInfo => Connection.connect(selfPeerId, peerInfo, infoHash)).await val metadata = torrentFile match @@ -278,11 +267,7 @@ object Main async[ResourceIO] { val port = Port.fromInt(portParam).liftTo[ResourceIO](new Exception("Invalid port")).await given Random[IO] = Resource.eval(Random.scalaUtilRandom[IO]).await - val selfId = Resource.eval(NodeId.generate[IO]).await - val table = Resource.eval(RoutingTable[IO](selfId)).await - val node = Node(selfId, Some(port), QueryHandler(selfId, table)).await - Resource.eval(RoutingTableBootstrap(table, node.client)).await - PingRoutine(table, node.client).runForever.background.await + Node(Some(port)).await }.useForever } } @@ -299,13 +284,13 @@ object Main val nodeAddress = SocketAddress.fromString(nodeAddressParam).liftTo[ResourceIO](new Exception("Invalid address")).await val nodeIpAddress = nodeAddress.resolve[IO].toResource.await given Random[IO] = Resource.eval(Random.scalaUtilRandom[IO]).await - val infoHash = infoHashFromString(infoHashParam).toResource.await val selfId = Resource.eval(NodeId.generate[IO]).await - val table = Resource.eval(RoutingTable[IO](selfId)).await - val node = Node(selfId, none, QueryHandler(selfId, table)).await + val infoHash = infoHashFromString(infoHashParam).toResource.await + val messageSocket = MessageSocket(none).await + val client = Client(selfId, messageSocket, QueryHandler.noop).await async[IO]: - val pong = node.client.ping(nodeIpAddress).await - val response = node.client.getPeers(NodeInfo(pong.id, nodeIpAddress), infoHash).await + val pong = client.ping(nodeIpAddress).await + val response = client.getPeers(NodeInfo(pong.id, nodeIpAddress), infoHash).await IO.println(response).await ExitCode.Success }.useEval diff --git a/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/Client.scala b/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/Client.scala index a9d17de..367c5c3 100644 --- a/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/Client.scala +++ b/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/Client.scala @@ -1,77 +1,104 @@ package com.github.torrentdam.bittorrent.dht import cats.effect.kernel.Temporal -import cats.effect.Concurrent -import cats.effect.Resource -import cats.effect.Sync +import cats.effect.std.{Queue, Random} +import cats.effect.{Concurrent, IO, Resource, Sync} import cats.syntax.all.* import com.comcast.ip4s.* import com.github.torrentdam.bittorrent.InfoHash + import java.net.InetSocketAddress import org.legogroup.woof.given import org.legogroup.woof.Logger import scodec.bits.ByteVector -trait Client[F[_]] { +trait Client { + + def id: NodeId - def getPeers(nodeInfo: NodeInfo, infoHash: InfoHash): F[Either[Response.Nodes, Response.Peers]] + def getPeers(nodeInfo: NodeInfo, infoHash: InfoHash): IO[Either[Response.Nodes, Response.Peers]] - def findNodes(nodeInfo: NodeInfo, target: NodeId): F[Response.Nodes] + def findNodes(nodeInfo: NodeInfo, target: NodeId): IO[Response.Nodes] - def ping(address: SocketAddress[IpAddress]): F[Response.Ping] + def ping(address: SocketAddress[IpAddress]): IO[Response.Ping] - def sampleInfoHashes(nodeInfo: NodeInfo, target: NodeId): F[Either[Response.Nodes, Response.SampleInfoHashes]] + def sampleInfoHashes(nodeInfo: NodeInfo, target: NodeId): IO[Either[Response.Nodes, Response.SampleInfoHashes]] } object Client { - def apply[F[_]]( + def generateTransactionId(using random: Random[IO]): IO[ByteVector] = + val nextChar = random.nextAlphaNumeric + (nextChar, nextChar).mapN((a, b) => ByteVector.encodeAscii(List(a, b).mkString).toOption.get) + + def apply( selfId: NodeId, - sendQueryMessage: (SocketAddress[IpAddress], Message.QueryMessage) => F[Unit], - receiveResponse: F[(SocketAddress[IpAddress], Either[Message.ErrorMessage, Message.ResponseMessage])], - generateTransactionId: F[ByteVector] - )(using - F: Temporal[F], - logger: Logger[F] - ): Resource[F, Client[F]] = { - for { + messageSocket: MessageSocket, + queryHandler: QueryHandler[IO] + )(using Logger[IO], Random[IO]): Resource[IO, Client] = { + for + responses <- Resource.eval { + Queue.unbounded[IO, (SocketAddress[IpAddress], Message.ErrorMessage | Message.ResponseMessage)] + } requestResponse <- RequestResponse.make( generateTransactionId, - sendQueryMessage, - receiveResponse + messageSocket.writeMessage, + responses.take ) - } yield new Client[F] { + _ <- + messageSocket.readMessage + .flatMap { + case (a, m: Message.QueryMessage) => + Logger[IO].debug(s"Received $m") >> + queryHandler(a, m.query).flatMap { + case Some(response) => + val responseMessage = Message.ResponseMessage(m.transactionId, response) + Logger[IO].debug(s"Responding with $responseMessage") >> + messageSocket.writeMessage(a, responseMessage) + case None => + Logger[IO].debug(s"No response for $m") + } + case (a, m: Message.ResponseMessage) => responses.offer((a, m)) + case (a, m: Message.ErrorMessage) => responses.offer((a, m)) + } + .recoverWith { case e: Throwable => + Logger[IO].debug(s"Failed to read message: $e") + } + .foreverM + .background + yield new Client { + + def id: NodeId = selfId def getPeers( nodeInfo: NodeInfo, infoHash: InfoHash - ): F[Either[Response.Nodes, Response.Peers]] = + ): IO[Either[Response.Nodes, Response.Peers]] = requestResponse.sendQuery(nodeInfo.address, Query.GetPeers(selfId, infoHash)).flatMap { case nodes: Response.Nodes => nodes.asLeft.pure case peers: Response.Peers => peers.asRight.pure - case _ => F.raiseError(InvalidResponse()) + case _ => IO.raiseError(InvalidResponse()) } - def findNodes(nodeInfo: NodeInfo, target: NodeId): F[Response.Nodes] = + def findNodes(nodeInfo: NodeInfo, target: NodeId): IO[Response.Nodes] = requestResponse.sendQuery(nodeInfo.address, Query.FindNode(selfId, target)).flatMap { case nodes: Response.Nodes => nodes.pure - case _ => Concurrent[F].raiseError(InvalidResponse()) + case _ => IO.raiseError(InvalidResponse()) } - def ping(address: SocketAddress[IpAddress]): F[Response.Ping] = + def ping(address: SocketAddress[IpAddress]): IO[Response.Ping] = requestResponse.sendQuery(address, Query.Ping(selfId)).flatMap { case ping: Response.Ping => ping.pure - case _ => Concurrent[F].raiseError(InvalidResponse()) + case _ => IO.raiseError(InvalidResponse()) } - def sampleInfoHashes(nodeInfo: NodeInfo, target: NodeId): F[Either[Response.Nodes, Response.SampleInfoHashes]] = + def sampleInfoHashes(nodeInfo: NodeInfo, target: NodeId): IO[Either[Response.Nodes, Response.SampleInfoHashes]] = requestResponse.sendQuery(nodeInfo.address, Query.SampleInfoHashes(selfId, target)).flatMap { case response: Response.SampleInfoHashes => response.asRight[Response.Nodes].pure case response: Response.Nodes => response.asLeft[Response.SampleInfoHashes].pure - case _ => Concurrent[F].raiseError(InvalidResponse()) + case _ => IO.raiseError(InvalidResponse()) } } } - case class BootstrapError(message: String) extends Throwable(message) case class InvalidResponse() extends Throwable } diff --git a/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/Node.scala b/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/Node.scala index 4a69f71..0eaeee9 100644 --- a/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/Node.scala +++ b/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/Node.scala @@ -9,57 +9,94 @@ import cats.effect.Resource import cats.effect.Sync import cats.implicits.* import com.comcast.ip4s.* -import fs2.io.net.DatagramSocketGroup +import fs2.Stream +import com.github.torrentdam.bittorrent.InfoHash + import java.net.InetSocketAddress import org.legogroup.woof.given import org.legogroup.woof.Logger import scodec.bits.ByteVector -trait Node { - def client: Client[IO] -} +import scala.concurrent.duration.DurationInt + +class Node(val id: NodeId, val client: Client, val routingTable: RoutingTable[IO], val discovery: PeerDiscovery) object Node { def apply( - selfId: NodeId, - port: Option[Port], - queryHandler: QueryHandler[IO] + port: Option[Port] = None, + bootstrapNodeAddress: Option[SocketAddress[Host]] = None )(using random: Random[IO], logger: Logger[IO] ): Resource[IO, Node] = - - def generateTransactionId: IO[ByteVector] = - val nextChar = random.nextAlphaNumeric - (nextChar, nextChar).mapN((a, b) => ByteVector.encodeAscii(List(a, b).mkString).toOption.get) - for + selfId <- Resource.eval(NodeId.generate[IO]) messageSocket <- MessageSocket(port) - responses <- Resource.eval { - Queue.unbounded[IO, (SocketAddress[IpAddress], Either[Message.ErrorMessage, Message.ResponseMessage])] + routingTable <- RoutingTable[IO](selfId).toResource + queryingNodes <- Queue.unbounded[IO, NodeInfo].toResource + queryHandler = reportingQueryHandler(queryingNodes, QueryHandler.simple(selfId, routingTable)) + client <- Client(selfId, messageSocket, queryHandler) + insertingClient = new InsertingClient(client, routingTable) + bootstrapNodes = bootstrapNodeAddress.map(List(_)).getOrElse(RoutingTableBootstrap.PublicBootstrapNodes) + discovery = PeerDiscovery(routingTable, insertingClient) + _ <- RoutingTableBootstrap(routingTable, insertingClient, discovery, bootstrapNodes).toResource + _ <- PingRoutine(routingTable, client).runForever.background + _ <- pingCandidates(queryingNodes, client, routingTable).background + yield new Node(selfId, insertingClient, routingTable, discovery) + + private class InsertingClient(client: Client, routingTable: RoutingTable[IO]) extends Client { + + def id: NodeId = client.id + + def getPeers(nodeInfo: NodeInfo, infoHash: InfoHash): IO[Either[Response.Nodes, Response.Peers]] = + client.getPeers(nodeInfo, infoHash) <* routingTable.insert(nodeInfo) + + def findNodes(nodeInfo: NodeInfo, target: NodeId): IO[Response.Nodes] = + client.findNodes(nodeInfo, target).flatTap { response => + routingTable.insert(NodeInfo(response.id, nodeInfo.address)) } - client0 <- Client(selfId, messageSocket.writeMessage, responses.take, generateTransactionId) - _ <- - messageSocket.readMessage - .flatMap { - case (a, m: Message.QueryMessage) => - logger.debug(s"Received $m") >> - queryHandler(a, m.query).flatMap { response => - val responseMessage = Message.ResponseMessage(m.transactionId, response) - logger.debug(s"Responding with $responseMessage") >> - messageSocket.writeMessage(a, responseMessage) - } - case (a, m: Message.ResponseMessage) => responses.offer((a, m.asRight)) - case (a, m: Message.ErrorMessage) => responses.offer((a, m.asLeft)) - } - .recoverWith { case e: Throwable => - logger.trace(s"Failed to read message: $e") - } - .foreverM - .background - - yield new Node { - def client: Client[IO] = client0 + + def ping(address: SocketAddress[IpAddress]): IO[Response.Ping] = + client.ping(address).flatTap { response => + routingTable.insert(NodeInfo(response.id, address)) + } + + def sampleInfoHashes(nodeInfo: NodeInfo, target: NodeId): IO[Either[Response.Nodes, Response.SampleInfoHashes]] = + client.sampleInfoHashes(nodeInfo, target).flatTap { response => + routingTable.insert( + response match + case Left(response) => NodeInfo(response.id, nodeInfo.address) + case Right(response) => NodeInfo(response.id, nodeInfo.address) + ) + } + + override def toString: String = s"InsertingClient($client)" + } + + private def pingCandidate(node: NodeInfo, client: Client, routingTable: RoutingTable[IO])(using Logger[IO]) = + routingTable.lookup(node.id).flatMap { + case Some(_) => IO.unit + case None => + Logger[IO].info(s"Pinging $node") *> + client.ping(node.address).timeout(5.seconds).attempt.flatMap { + case Right(_) => + Logger[IO].info(s"Got pong from $node -- insert as good") *> + routingTable.insert(node) + case Left(_) => IO.unit + } } + + private def pingCandidates(nodes: Queue[IO, NodeInfo], client: Client, routingTable: RoutingTable[IO])(using Logger[IO]) = + nodes.take.flatMap(pingCandidate(_, client, routingTable).attempt.void).foreverM + + + private def reportingQueryHandler(queue: Queue[IO, NodeInfo], next: QueryHandler[IO]): QueryHandler[IO] = (address, query) => + val nodeInfo = query match + case Query.Ping(id) => NodeInfo(id, address) + case Query.FindNode(id, _) => NodeInfo(id, address) + case Query.GetPeers(id, _) => NodeInfo(id, address) + case Query.AnnouncePeer(id, _, _) => NodeInfo(id, address) + case Query.SampleInfoHashes(id, _) => NodeInfo(id, address) + queue.offer(nodeInfo) *> next(address, query) } diff --git a/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/PeerDiscovery.scala b/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/PeerDiscovery.scala index f98b745..2d0c795 100644 --- a/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/PeerDiscovery.scala +++ b/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/PeerDiscovery.scala @@ -23,42 +23,36 @@ trait PeerDiscovery { object PeerDiscovery { - def make( + def apply( routingTable: RoutingTable[IO], - dhtClient: Client[IO] + dhtClient: Client )(using logger: Logger[IO] - ): Resource[IO, PeerDiscovery] = - Resource.pure[IO, PeerDiscovery] { - - new PeerDiscovery { - - def discover(infoHash: InfoHash): Stream[IO, PeerInfo] = { - - Stream - .eval { - for { - _ <- logger.info("Start discovery") - initialNodes <- routingTable.findNodes(NodeId(infoHash.bytes)) - initialNodes <- initialNodes.take(100).toList.pure[IO] - _ <- logger.info(s"Got ${initialNodes.size} from routing table") - state <- DiscoveryState(initialNodes, infoHash) - } yield { - start( - infoHash, - dhtClient.getPeers, - state - ) - } - } - .flatten - .onFinalizeCase { - case Resource.ExitCase.Errored(e) => logger.error(s"Discovery failed with ${e.getMessage}") - case _ => IO.unit - } + ): PeerDiscovery = new { + def discover(infoHash: InfoHash): Stream[IO, PeerInfo] = { + Stream + .eval { + for { + _ <- logger.info("Start discovery") + initialNodes <- routingTable.findNodes(NodeId(infoHash.bytes)) + initialNodes <- initialNodes.take(100).toList.pure[IO] + _ <- logger.info(s"Got ${initialNodes.size} from routing table") + state <- DiscoveryState(initialNodes, infoHash) + } yield { + start( + infoHash, + dhtClient.getPeers, + state + ) + } + } + .flatten + .onFinalizeCase { + case Resource.ExitCase.Errored(e) => logger.error(s"Discovery failed with ${e.getMessage}") + case _ => IO.unit } - } } + } private[dht] def start( infoHash: InfoHash, diff --git a/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/PingRoutine.scala b/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/PingRoutine.scala index 8389c81..8dce949 100644 --- a/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/PingRoutine.scala +++ b/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/PingRoutine.scala @@ -7,10 +7,13 @@ import org.legogroup.woof.{Logger, given} import scala.concurrent.duration.DurationInt -class PingRoutine(table: RoutingTable[IO], client: Client[IO])(using logger: Logger[IO]): +class PingRoutine(table: RoutingTable[IO], client: Client)(using logger: Logger[IO]): def run: IO[Unit] = async[IO]: - val nodes = table.allNodes.await + val (nodes, desperateNodes) = table.allNodes.await.partition(_.badCount < 3) + if desperateNodes.nonEmpty then + logger.info(s"Removing ${desperateNodes.size} desperate nodes").await + desperateNodes.traverse_(node => table.remove(node.id)).await logger.info(s"Pinging ${nodes.size} nodes").await val queries = nodes.map { node => client.ping(node.address).timeout(5.seconds).attempt.map(_.bimap(_ => node.id, _ => node.id)) diff --git a/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/QueryHandler.scala b/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/QueryHandler.scala index 5cf041e..ec8fa72 100644 --- a/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/QueryHandler.scala +++ b/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/QueryHandler.scala @@ -6,44 +6,40 @@ import com.comcast.ip4s.* import com.github.torrentdam.bittorrent.PeerInfo trait QueryHandler[F[_]] { - def apply(address: SocketAddress[IpAddress], query: Query): F[Response] + def apply(address: SocketAddress[IpAddress], query: Query): F[Option[Response]] } object QueryHandler { + + def noop[F[_]: Monad]: QueryHandler[F] = (_, _) => none.pure[F] - def apply[F[_]: Monad](selfId: NodeId, routingTable: RoutingTable[F]): QueryHandler[F] = { (address, query) => + def simple[F[_]: Monad](selfId: NodeId, routingTable: RoutingTable[F]): QueryHandler[F] = { (address, query) => query match { - case Query.Ping(nodeId) => - routingTable.insert(NodeInfo(nodeId, address)).as { - Response.Ping(selfId): Response - } - case Query.FindNode(nodeId, target) => - routingTable.insert(NodeInfo(nodeId, address)) >> + case Query.Ping(_) => + Response.Ping(selfId).some.pure[F] + case Query.FindNode(_, target) => routingTable.findBucket(target).map { nodes => - Response.Nodes(selfId, nodes): Response + Response.Nodes(selfId, nodes).some } - case Query.GetPeers(nodeId, infoHash) => - routingTable.insert(NodeInfo(nodeId, address)) >> + case Query.GetPeers(_, infoHash) => routingTable.findPeers(infoHash).flatMap { case Some(peers) => - Response.Peers(selfId, peers.toList).pure[F].widen[Response] + Response.Peers(selfId, peers.toList).some.pure[F] case None => routingTable .findBucket(NodeId(infoHash.bytes)) .map { nodes => - Response.Nodes(selfId, nodes) + Response.Nodes(selfId, nodes).some } - .widen[Response] - } - case Query.AnnouncePeer(nodeId, infoHash, port) => - routingTable.insert(NodeInfo(nodeId, address)) >> - routingTable.addPeer(infoHash, PeerInfo(SocketAddress(address.host, Port.fromInt(port.toInt).get))).as { - Response.Ping(selfId): Response } + case Query.AnnouncePeer(_, infoHash, port) => + routingTable + .addPeer(infoHash, PeerInfo(SocketAddress(address.host, Port.fromInt(port.toInt).get))) + .as( + Response.Ping(selfId).some + ) case Query.SampleInfoHashes(_, _) => - (Response.SampleInfoHashes(selfId, None, List.empty): Response).pure[F] + Response.SampleInfoHashes(selfId, None, List.empty).some.pure[F] } } - - def fromFunction[F[_]](f: (SocketAddress[IpAddress], Query) => F[Response]): QueryHandler[F] = f(_, _) } diff --git a/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/RequestResponse.scala b/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/RequestResponse.scala index 3b4a05b..e99f297 100644 --- a/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/RequestResponse.scala +++ b/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/RequestResponse.scala @@ -24,7 +24,7 @@ object RequestResponse { generateTransactionId: F[ByteVector], sendQuery: (SocketAddress[IpAddress], Message.QueryMessage) => F[Unit], receiveMessage: F[ - (SocketAddress[IpAddress], Either[Message.ErrorMessage, Message.ResponseMessage]) + (SocketAddress[IpAddress], Message.ErrorMessage | Message.ResponseMessage) ] )(using F: Temporal[F] @@ -57,16 +57,16 @@ object RequestResponse { private def receiveLoop[F[_]]( receive: F[ - (SocketAddress[IpAddress], Either[Message.ErrorMessage, Message.ResponseMessage]) + (SocketAddress[IpAddress], Message.ErrorMessage | Message.ResponseMessage) ], continue: (ByteVector, Either[Throwable, Response]) => F[Boolean] )(using F: Monad[F] ): F[Unit] = { val step = receive.map(_._2).flatMap { - case Right(Message.ResponseMessage(transactionId, response)) => + case Message.ResponseMessage(transactionId, response) => continue(transactionId, response.asRight) - case Left(Message.ErrorMessage(transactionId, details)) => + case Message.ErrorMessage(transactionId, details) => continue(transactionId, ErrorResponse(details).asLeft) } step.foreverM[Unit] diff --git a/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/RoutingTable.scala b/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/RoutingTable.scala index 63ad8f8..3bb4899 100644 --- a/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/RoutingTable.scala +++ b/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/RoutingTable.scala @@ -16,6 +16,8 @@ import scala.annotation.tailrec trait RoutingTable[F[_]] { def insert(node: NodeInfo): F[Unit] + + def remove(nodeId: NodeId): F[Unit] def findNodes(nodeId: NodeId): F[LazyList[NodeInfo]] @@ -28,6 +30,8 @@ trait RoutingTable[F[_]] { def allNodes: F[LazyList[RoutingTable.Node]] def updateGoodness(good: Set[NodeId], bad: Set[NodeId]): F[Unit] + + def lookup(nodeId: NodeId): F[Option[RoutingTable.Node]] } object RoutingTable { @@ -36,7 +40,7 @@ object RoutingTable { case Split(center: BigInt, lower: TreeNode, higher: TreeNode) case Bucket(from: BigInt, until: BigInt, nodes: ListMap[NodeId, Node]) - case class Node(id: NodeId, address: SocketAddress[IpAddress], isGood: Boolean): + case class Node(id: NodeId, address: SocketAddress[IpAddress], isGood: Boolean, badCount: Int = 0): def toNodeInfo: NodeInfo = NodeInfo(id, address) object TreeNode { @@ -62,7 +66,7 @@ object RoutingTable { else b.copy(higher = higher.insert(node, selfId)) case b @ Bucket(from, until, nodes) => - if nodes.size == MaxNodes + if nodes.size >= MaxNodes && !nodes.contains(selfId) then if selfId.int >= from && selfId.int < until then @@ -140,6 +144,9 @@ object RoutingTable { def insert(node: NodeInfo): F[Unit] = treeNodeRef.update(_.insert(node, selfId)) + + def remove(nodeId: NodeId): F[Unit] = + treeNodeRef.update(_.remove(nodeId)) def findNodes(nodeId: NodeId): F[LazyList[NodeInfo]] = treeNodeRef.get.map(_.findNodes(nodeId).filter(_.isGood).map(_.toNodeInfo)) @@ -164,10 +171,13 @@ object RoutingTable { def updateGoodness(good: Set[NodeId], bad: Set[NodeId]): F[Unit] = treeNodeRef.update( _.update(node => - if good.contains(node.id) then node.copy(isGood = true) - else if bad.contains(node.id) then node.copy(isGood = false) + if good.contains(node.id) then node.copy(isGood = true, badCount = 0) + else if bad.contains(node.id) then node.copy(isGood = false, badCount = node.badCount + 1) else node ) ) + + def lookup(nodeId: NodeId): F[Option[Node]] = + treeNodeRef.get.map(_.findBucket(nodeId).nodes.get(nodeId)) } } diff --git a/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/RoutingTableBootstrap.scala b/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/RoutingTableBootstrap.scala index 21ebbd3..fb967d2 100644 --- a/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/RoutingTableBootstrap.scala +++ b/dht/src/main/scala/com/github/torrentdam/bittorrent/dht/RoutingTableBootstrap.scala @@ -3,8 +3,10 @@ package com.github.torrentdam.bittorrent.dht import cats.effect.kernel.Temporal import cats.implicits.* import cats.MonadError +import cats.effect.IO import cats.effect.implicits.* import com.comcast.ip4s.* +import com.github.torrentdam.bittorrent.InfoHash import org.legogroup.woof.given import org.legogroup.woof.Logger import fs2.Stream @@ -13,37 +15,37 @@ import scala.concurrent.duration.* object RoutingTableBootstrap { - def apply[F[_]]( - table: RoutingTable[F], - client: Client[F], + def apply( + table: RoutingTable[IO], + client: Client, + discovery: PeerDiscovery, bootstrapNodeAddress: List[SocketAddress[Host]] = PublicBootstrapNodes )(using - F: Temporal[F], - dns: Dns[F], - logger: Logger[F] - ): F[Unit] = + dns: Dns[IO], + logger: Logger[IO] + ): IO[Unit] = for { _ <- logger.info("Bootstrapping") - count <- resolveBootstrapNode(client, bootstrapNodeAddress) - .evalMap(table.insert) - .compile - .count - _ <- logger.info(s"Bootstrap completed with $count nodes") + count <- resolveNodes(client, bootstrapNodeAddress).compile.count + _ <- logger.info(s"Pinged $count bootstrap nodes") + _ <- logger.info("Discover self to fill up routing table") + _ <- discovery.discover(InfoHash(client.id.bytes)).take(10).compile.drain + nodeCount <- table.allNodes.map(_.size) + _ <- logger.info(s"Bootstrapping finished with $nodeCount nodes") } yield {} - private def resolveBootstrapNode[F[_]]( - client: Client[F], + private def resolveNodes( + client: Client, bootstrapNodeAddress: List[SocketAddress[Host]] )(using - F: Temporal[F], - dns: Dns[F], - logger: Logger[F] - ): Stream[F, NodeInfo] = - def tryThis(hostname: SocketAddress[Host]): Stream[F, NodeInfo] = + dns: Dns[IO], + logger: Logger[IO] + ): Stream[IO, NodeInfo] = + def tryThis(hostname: SocketAddress[Host]): Stream[IO, NodeInfo] = Stream.eval(logger.info(s"Trying to reach $hostname")) >> Stream .evals( - hostname.host.resolveAll[F] + hostname.host.resolveAll[IO] .recoverWith: e => logger.info(s"Failed to resolve $hostname $e").as(List.empty) ) @@ -61,7 +63,7 @@ object RoutingTableBootstrap { } Stream .emits(bootstrapNodeAddress) - .covary[F] + .covary[IO] .flatMap(tryThis) val PublicBootstrapNodes: List[SocketAddress[Host]] = List(