diff --git a/mitmproxy-rs/src/server/local_redirector.rs b/mitmproxy-rs/src/server/local_redirector.rs index 08a932a8..b9df16e1 100644 --- a/mitmproxy-rs/src/server/local_redirector.rs +++ b/mitmproxy-rs/src/server/local_redirector.rs @@ -22,6 +22,16 @@ pub struct LocalRedirector { spec: String, } +impl LocalRedirector { + pub fn new(server: Server, conf_tx: mpsc::UnboundedSender) -> Self { + Self { + server, + conf_tx, + spec: "inactive".to_string(), + } + } +} + #[pymethods] impl LocalRedirector { /// Return a textual description of the given spec, @@ -84,11 +94,7 @@ pub fn start_local_redirector( let (server, conf_tx) = Server::init(conf, handle_tcp_stream, handle_udp_stream).await?; - Ok(LocalRedirector { - server, - conf_tx, - spec: "inactive".to_string(), - }) + Ok(LocalRedirector::new(server, conf_tx)) }) } #[cfg(target_os = "macos")] @@ -115,8 +121,9 @@ pub fn start_local_redirector( } let conf = MacosConf; pyo3_asyncio::tokio::future_into_py(py, async move { - let (server, conf_tx) = Server::init(conf, handle_connection, receive_datagram).await?; - Ok(LocalRedirector { server, conf_tx }) + let (server, conf_tx) = + Server::init(conf, handle_tcp_stream, handle_udp_stream).await?; + Ok(LocalRedirector::new(server, conf_tx)) }) } #[cfg(not(any(windows, target_os = "macos")))] diff --git a/mitmproxy-rs/src/server/wireguard.rs b/mitmproxy-rs/src/server/wireguard.rs index 4f9e7361..aebd94ac 100644 --- a/mitmproxy-rs/src/server/wireguard.rs +++ b/mitmproxy-rs/src/server/wireguard.rs @@ -2,14 +2,9 @@ use std::net::SocketAddr; use crate::util::{socketaddr_to_py, string_to_key}; -#[cfg(target_os = "macos")] -use mitmproxy::packet_sources::macos::MacosConf; - use mitmproxy::packet_sources::wireguard::WireGuardConf; use pyo3::prelude::*; -#[cfg(target_os = "macos")] -use std::path::Path; use boringtun::x25519::PublicKey; diff --git a/mitmproxy-rs/src/task.rs b/mitmproxy-rs/src/task.rs index 9937c240..394603d6 100644 --- a/mitmproxy-rs/src/task.rs +++ b/mitmproxy-rs/src/task.rs @@ -63,12 +63,14 @@ impl PyInteropTask { src_addr, dst_addr, tunnel_info, + command_tx, } => { + let command_tx = command_tx.unwrap_or_else(|| self.transport_commands.clone()); // initialize new stream let stream = Stream { connection_id, state: StreamState::Open, - command_tx: self.transport_commands.clone(), + command_tx, peername: src_addr, sockname: dst_addr, tunnel_info, diff --git a/src/messages.rs b/src/messages.rs index efbbe547..2e496c77 100755 --- a/src/messages.rs +++ b/src/messages.rs @@ -4,7 +4,7 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use anyhow::{anyhow, Result}; use internet_packet::InternetPacket; use smoltcp::wire::{IpProtocol, Ipv4Packet, Ipv6Packet}; -use tokio::sync::oneshot; +use tokio::sync::{mpsc, oneshot}; #[derive(Debug, Clone)] pub enum TunnelInfo { @@ -39,10 +39,10 @@ pub enum NetworkCommand { pub struct ConnectionIdGenerator(usize); impl ConnectionIdGenerator { - pub fn tcp() -> Self { + pub const fn tcp() -> Self { Self(2) } - pub fn udp() -> Self { + pub const fn udp() -> Self { Self(1) } pub fn next_id(&mut self) -> ConnectionId { @@ -85,6 +85,7 @@ pub enum TransportEvent { src_addr: SocketAddr, dst_addr: SocketAddr, tunnel_info: TunnelInfo, + command_tx: Option>, }, } @@ -98,14 +99,13 @@ pub enum TransportCommand { } impl TransportCommand { - pub fn is_tcp(&self) -> bool { + pub fn connection_id(&self) -> &ConnectionId { match self { TransportCommand::ReadData(id, _, _) => id, TransportCommand::WriteData(id, _) => id, TransportCommand::DrainWriter(id, _) => id, TransportCommand::CloseConnection(id, _) => id, } - .is_tcp() } } diff --git a/src/network/core.rs b/src/network/core.rs index 8575cec2..c9f62a36 100644 --- a/src/network/core.rs +++ b/src/network/core.rs @@ -89,7 +89,7 @@ impl<'a> NetworkStack<'a> { } pub fn handle_transport_command(&mut self, command: TransportCommand) { - if command.is_tcp() { + if command.connection_id().is_tcp() { self.tcp.handle_transport_command(command); } else if let Some(packet) = self.udp.handle_transport_command(command) { if self diff --git a/src/network/task.rs b/src/network/task.rs index 11a43ef2..6c06e57c 100755 --- a/src/network/task.rs +++ b/src/network/task.rs @@ -93,7 +93,7 @@ impl NetworkTask<'_> { // wait for graceful shutdown _ = self.shutdown.recv() => break 'task, // wait for timeouts when the device is idle - _ = async { tokio::time::sleep(delay.unwrap()).await }, if delay.is_some() => {}, + _ = tokio::time::sleep(delay.unwrap()), if delay.is_some() => {}, // wait for py_tx channel capacity... Ok(permit) = self.py_tx.reserve(), if !py_tx_available => { py_tx_permit = Some(permit); diff --git a/src/network/tcp.rs b/src/network/tcp.rs index ddd26b59..43e5d02e 100644 --- a/src/network/tcp.rs +++ b/src/network/tcp.rs @@ -147,6 +147,7 @@ impl<'a> TcpHandler<'a> { src_addr, dst_addr, tunnel_info, + command_tx: None, }; permit.send(event); } diff --git a/src/network/tests.rs b/src/network/tests.rs index bec43ca0..b655b220 100755 --- a/src/network/tests.rs +++ b/src/network/tests.rs @@ -356,7 +356,7 @@ async fn udp_read_write( connection_id, src_addr: recv_src_addr, dst_addr: recv_dst_addr, - tunnel_info: _, + .. } = event; assert_eq!(src_addr, recv_src_addr); @@ -494,7 +494,7 @@ async fn tcp_ipv4_connection() -> Result<()> { connection_id: tcp_conn_id, src_addr: tcp_src_sock, dst_addr: tcp_dst_sock, - tunnel_info: _, + .. } = event; assert_eq!(IpAddress::Ipv4(src_addr), tcp_src_sock.ip().into()); assert_eq!(IpAddress::Ipv4(dst_addr), tcp_dst_sock.ip().into()); @@ -671,7 +671,7 @@ async fn tcp_ipv6_connection() -> Result<()> { connection_id: tcp_conn_id, src_addr: tcp_src_sock, dst_addr: tcp_dst_sock, - tunnel_info: _, + .. } = event; assert_eq!(IpAddress::Ipv6(src_addr), tcp_src_sock.ip().into()); assert_eq!(IpAddress::Ipv6(dst_addr), tcp_dst_sock.ip().into()); diff --git a/src/network/udp.rs b/src/network/udp.rs index 8e45d6f1..a3a54b2c 100644 --- a/src/network/udp.rs +++ b/src/network/udp.rs @@ -65,6 +65,8 @@ impl ConnectionState { } } +pub const UDP_TIMEOUT: Duration = Duration::from_secs(60); + pub struct UdpHandler { connection_id_generator: ConnectionIdGenerator, id_lookup: LruCache<(SocketAddr, SocketAddr), ConnectionId>, @@ -73,12 +75,10 @@ pub struct UdpHandler { impl UdpHandler { pub fn new() -> Self { - let connections = LruCache::::with_expiry_duration( - Duration::from_secs(60), - ); - let id_lookup = LruCache::<(SocketAddr, SocketAddr), ConnectionId>::with_expiry_duration( - Duration::from_secs(60), - ); + let connections = + LruCache::::with_expiry_duration(UDP_TIMEOUT); + let id_lookup = + LruCache::<(SocketAddr, SocketAddr), ConnectionId>::with_expiry_duration(UDP_TIMEOUT); Self { connections, id_lookup, @@ -172,6 +172,7 @@ impl UdpHandler { src_addr: packet.src_addr, dst_addr: packet.dst_addr, tunnel_info, + command_tx: None, }); } }; diff --git a/src/packet_sources/macos.rs b/src/packet_sources/macos.rs index 7d758c0f..9a5536e9 100644 --- a/src/packet_sources/macos.rs +++ b/src/packet_sources/macos.rs @@ -1,7 +1,6 @@ -use std::collections::HashMap; -use std::net::{Ipv4Addr, SocketAddr}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use crate::messages::{ConnectionId, TransportCommand, TransportEvent, TunnelInfo}; +use crate::messages::{ConnectionIdGenerator, TransportCommand, TransportEvent, TunnelInfo}; use crate::intercept_conf::InterceptConf; use crate::ipc; @@ -12,7 +11,7 @@ use async_trait::async_trait; use futures_util::SinkExt; use futures_util::StreamExt; -use prost::bytes::{Buf, BytesMut}; +use prost::bytes::BytesMut; use prost::Message; use std::process::Stdio; @@ -78,7 +77,7 @@ impl PacketSourceConf for MacosConf { async fn build( self, transport_events_tx: Sender, - transport_commands_rx: UnboundedReceiver, + _transport_commands_rx: UnboundedReceiver, shutdown: broadcast::Receiver<()>, ) -> Result<(Self::Task, Self::Data)> { let listener_addr = format!("/tmp/mitmproxy-{}", std::process::id()); @@ -100,11 +99,7 @@ impl PacketSourceConf for MacosConf { control_channel, listener, connections: JoinSet::new(), - connection_by_id: HashMap::new(), - connection_by_addr: HashMap::new(), - next_connection_id: 0, transport_events_tx, - transport_commands_rx, conf_rx, shutdown, }, @@ -116,12 +111,8 @@ impl PacketSourceConf for MacosConf { pub struct MacOsTask { control_channel: UnixStream, listener: UnixListener, - connections: JoinSet)>>, - connection_by_id: HashMap>, - connection_by_addr: HashMap>, - next_connection_id: ConnectionId, + connections: JoinSet>, transport_events_tx: Sender, - transport_commands_rx: UnboundedReceiver, conf_rx: UnboundedReceiver, shutdown: broadcast::Receiver<()>, } @@ -131,9 +122,6 @@ impl PacketSourceTask for MacOsTask { async fn run(mut self) -> Result<()> { let mut control_channel = Framed::new(self.control_channel, LengthDelimitedCodec::new()); - let (register_addr_tx, mut register_addr_rx) = - unbounded_channel::(); - loop { tokio::select! { // wait for graceful shutdown @@ -144,66 +132,24 @@ impl PacketSourceTask for MacOsTask { }, Some(task) = self.connections.join_next() => { match task { - Ok(Ok((cid, src_addr))) => { - self.connection_by_id.remove(&cid); - if let Some(src_addr) = src_addr { - self.connection_by_addr.remove(&src_addr); - } - }, + Ok(Ok(())) => (), Ok(Err(e)) => log::error!("Connection task failure: {e:?}"), Err(e) => log::error!("Connection task panic: {e:?}"), } }, - Some(RegisterConnectionSocketAddr(cid, addr, done)) = register_addr_rx.recv() => { - let tx = self.connection_by_id.get(&cid).unwrap().clone(); - self.connection_by_addr.insert(addr, tx); - done.send(()).expect("ok channel dead"); - }, l = self.listener.accept() => { match l { Ok((stream, _)) => { - let (conn_tx, conn_rx) = unbounded_channel(); - let connection_id = { - self.next_connection_id += 1; - self.next_connection_id - }; - self.connections.spawn( - ConnectionTask::new(connection_id, stream, conn_rx, self.transport_events_tx.clone(), register_addr_tx.clone()) - .run() - ); - self.connection_by_id.insert( - connection_id, - conn_tx + let task = ConnectionTask::new( + stream, + self.transport_events_tx.clone(), + self.shutdown.resubscribe(), ); + self.connections.spawn(task.run()); }, Err(e) => log::error!("Error accepting connection from macos-redirector: {}", e) } }, - Some(cmd) = self.transport_commands_rx.recv() => { - match &cmd { - TransportCommand::ReadData(connection_id, _, _) - | TransportCommand::WriteData(connection_id, _) - | TransportCommand::DrainWriter(connection_id, _) - | TransportCommand::CloseConnection(connection_id, _) => { - let Some(conn_tx) = self.connection_by_id.get(connection_id) else { - log::error!("Received command for unknown connection: {:?}", &cmd); - continue; - }; - conn_tx.send(cmd).ok(); - }, - TransportCommand::SendDatagram { - data: _, - src_addr, - dst_addr, - } => { - let Some(conn_tx) = self.connection_by_addr.get(dst_addr) else { - log::error!("Received command for unknown address: src={:?} dst={:?}", src_addr, dst_addr); - continue; - }; - conn_tx.send(cmd).ok(); - }, - } - } // pipe through changes to the intercept list Some(conf) = self.conf_rx.recv() => { let msg = ipc::InterceptConf::from(conf); @@ -220,37 +166,25 @@ impl PacketSourceTask for MacOsTask { } } -struct RegisterConnectionSocketAddr(ConnectionId, SocketAddr, oneshot::Sender<()>); - struct ConnectionTask { - id: ConnectionId, stream: UnixStream, - commands: UnboundedReceiver, events: Sender, - read_tx: Option<(usize, oneshot::Sender>)>, - drain_tx: Option>, - register_addr: UnboundedSender, + shutdown: broadcast::Receiver<()>, } impl ConnectionTask { pub fn new( - id: ConnectionId, stream: UnixStream, - commands: UnboundedReceiver, events: Sender, - register_addr: UnboundedSender, + shutdown: broadcast::Receiver<()>, ) -> Self { Self { - id, stream, - commands, events, - read_tx: None, - drain_tx: None, - register_addr, + shutdown, } } - async fn run(mut self) -> Result<(ConnectionId, Option)> { + async fn run(mut self) -> Result<()> { let new_flow = { let len = self .stream @@ -276,7 +210,7 @@ impl ConnectionTask { } } - async fn handle_udp(mut self, flow: UdpFlow) -> Result<(ConnectionId, Option)> { + async fn handle_udp(mut self, flow: UdpFlow) -> Result<()> { // For UDP connections, we pass length-delimited protobuf messages over the unix socket // in both directions. let mut write_buf = BytesMut::new(); @@ -292,134 +226,149 @@ impl ConnectionTask { remote_endpoint: None, } }; - let local_addr = { + let src_addr = { let Some(addr) = &flow.local_address else { bail!("no local address") }; SocketAddr::try_from(addr)? }; + let mut first_packet_dst = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0); + let (command_tx, mut command_rx) = unbounded_channel(); - // Send our socket address to the main macos task and wait until it has been processed. - let (done_tx, done_rx) = oneshot::channel(); - self.register_addr - .send(RegisterConnectionSocketAddr(self.id, local_addr, done_tx))?; - done_rx.await?; + let mut first_packet = Some((tunnel_info, src_addr, command_tx)); + + let mut read_data: Option> = None; + let mut read_tx: Option>> = None; loop { tokio::select! { - packet = stream.next() => { - let Some(packet) = packet else { - break; - }; + _ = self.shutdown.recv() => break, + Some(packet) = stream.next(), if read_data.is_none() => { let packet = ipc::UdpPacket::decode( packet.context("IPC read error")? ).context("invalid IPC message")?; - let dst_addr = { let Some(dst_addr) = &packet.remote_address else { bail!("no remote addr") }; SocketAddr::try_from(dst_addr).context("invalid socket address")? }; - todo!(); - if let Err(e) = self.events.try_send(TransportEvent::DatagramReceived { - data: packet.data, - src_addr: local_addr, - dst_addr, - tunnel_info: tunnel_info.clone(), - }) { - log::debug!("Failed to send UDP packet: {}", e); + // We can only send ConnectionEstablished once we know the destination address. + if let Some((tunnel_info, src_addr, command_tx)) = first_packet.take() { + first_packet_dst = dst_addr; + self.events.send(TransportEvent::ConnectionEstablished { + connection_id: ConnectionIdGenerator::udp().next_id(), + src_addr, + dst_addr, + tunnel_info, + command_tx: Some(command_tx), + }).await?; + } else if first_packet_dst != dst_addr { + bail!("UDP packet destinations do not match: {first_packet_dst} -> {dst_addr}") + } + if let Some(tx) = read_tx.take() { + tx.send(packet.data).ok(); + } else { + read_data = Some(packet.data); } }, - command = self.commands.recv() => { - let Some(command) = command else { - break; - }; + Some(command) = command_rx.recv() => { match command { - todo!(); - TransportCommand::SendDatagram { data, src_addr, dst_addr } => { - assert_eq!(dst_addr, local_addr); + TransportCommand::ReadData(_, _, tx) => { + if let Some(data) = read_data.take() { + tx.send(data).ok(); + } else { + if read_tx.is_some() { + bail!("Concurrent readers are not supported."); + } + read_tx = Some(tx); + } + }, + TransportCommand::WriteData(_, data) => { let packet = ipc::UdpPacket { data, remote_address: Some(src_addr.into()), }; write_buf.reserve(packet.encoded_len()); packet.encode(&mut write_buf)?; - stream.send(write_buf.split().freeze()).await?; + // Awaiting here isn't ideal because it blocks reading, but what to do. + stream.send(write_buf.split().freeze()).await.ok(); }, - TransportCommand::ReadData(_, _, _) | - TransportCommand::WriteData(_, _) | - TransportCommand::DrainWriter(_, _) | - TransportCommand::CloseConnection(_, _) => { - bail!("UDP connection received TCP event: {command:?}"); + TransportCommand::DrainWriter(_, tx) => { + tx.send(()).ok(); + }, + TransportCommand::CloseConnection(_, half_close) => { + if !half_close { + break; + } } } } + else => break, } } - Ok((self.id, Some(local_addr))) + Ok(()) } - async fn handle_tcp(mut self, flow: TcpFlow) -> Result<(ConnectionId, Option)> { + async fn handle_tcp(mut self, flow: TcpFlow) -> Result<()> { let mut write_buf = BytesMut::new(); + let mut drain_tx: Option> = None; + let mut read_tx: Option<(usize, oneshot::Sender>)> = None; + + let (command_tx, mut command_rx) = unbounded_channel(); let remote = flow.remote_address.expect("no remote address"); let src_addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 0)); - let dst_addr = match SocketAddr::try_from(&remote) { - Ok(addr) => addr, - Err(_) => SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)), + let dst_addr = SocketAddr::try_from(&remote) + .unwrap_or_else(|_| SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0))); + let tunnel_info = TunnelInfo::LocalRedirector { + pid: flow.tunnel_info.as_ref().map(|t| t.pid).unwrap_or(0), + process_name: flow.tunnel_info.and_then(|t| t.process_name), + remote_endpoint: Some((remote.host, remote.port as u16)), }; - let remote_endpoint = Some((remote.host, remote.port as u16)); self.events .send(TransportEvent::ConnectionEstablished { - connection_id: self.id, + connection_id: ConnectionIdGenerator::tcp().next_id(), src_addr, dst_addr, - tunnel_info: TunnelInfo::LocalRedirector { - pid: flow.tunnel_info.as_ref().map(|t| t.pid).unwrap_or(0), - process_name: flow.tunnel_info.and_then(|t| t.process_name), - remote_endpoint, - }, + tunnel_info, + command_tx: Some(command_tx), }) .await?; loop { tokio::select! { + _ = self.shutdown.recv() => break, Ok(()) = self.stream.writable(), if !write_buf.is_empty() => { self.stream.write_buf(&mut write_buf).await.context("failed to write to socket from buf")?; if write_buf.is_empty() { - if let Some(tx) = self.drain_tx.take() { + if let Some(tx) = drain_tx.take() { tx.send(()).ok(); } } }, - Ok(()) = self.stream.readable(), if self.read_tx.is_some() => { - let (n, tx) = self.read_tx.take().unwrap(); + Ok(()) = self.stream.readable(), if read_tx.is_some() => { + let (n, tx) = read_tx.take().unwrap(); let mut data = Vec::with_capacity(n); self.stream.read_buf(&mut data).await.context("failed to read from socket")?; tx.send(data).ok(); }, - command = self.commands.recv() => { - let Some(command) = command else { - break; - }; + Some(command) = command_rx.recv() => { match command { TransportCommand::ReadData(_, n, tx) => { - assert!(self.read_tx.is_none()); - self.read_tx = Some((n as usize, tx)); + assert!(read_tx.is_none()); + read_tx = Some((n as usize, tx)); }, TransportCommand::WriteData(_, data) => { - let mut c = std::io::Cursor::new(data); - self.stream.write_buf(&mut c).await.context("failed to write to socket")?; - write_buf.extend_from_slice(c.chunk()); + write_buf.extend_from_slice(data.as_slice()); }, TransportCommand::DrainWriter(_, tx) => { - assert!(self.drain_tx.is_none()); + assert!(drain_tx.is_none()); if write_buf.is_empty() { tx.send(()).ok(); } else { - self.drain_tx = Some(tx); + drain_tx = Some(tx); } }, TransportCommand::CloseConnection(_, half_close) => { @@ -430,9 +379,10 @@ impl ConnectionTask { } } } - } + }, + else => break, } } - Ok((self.id, None)) + Ok(()) } }