From d147d986134ad7ef4c8a345b62c9cdc8398d27f4 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 17 Dec 2023 17:20:27 +0100 Subject: [PATCH] add UDP client --- mitmproxy-rs/mitmproxy_rs/__init__.pyi | 31 ++++- mitmproxy-rs/src/lib.rs | 3 + mitmproxy-rs/src/server/udp.rs | 12 +- mitmproxy-rs/src/stream.rs | 56 ++++++--- mitmproxy-rs/src/task.rs | 2 +- mitmproxy-rs/src/udp_client.rs | 151 +++++++++++++++++++++++++ mitmproxy-rs/src/util.rs | 34 +----- src/messages.rs | 2 +- src/packet_sources/udp.rs | 33 ++---- 9 files changed, 235 insertions(+), 89 deletions(-) create mode 100644 mitmproxy-rs/src/udp_client.rs diff --git a/mitmproxy-rs/mitmproxy_rs/__init__.pyi b/mitmproxy-rs/mitmproxy_rs/__init__.pyi index a2875b11..d39a1e6b 100644 --- a/mitmproxy-rs/mitmproxy_rs/__init__.pyi +++ b/mitmproxy-rs/mitmproxy_rs/__init__.pyi @@ -2,9 +2,11 @@ from __future__ import annotations from pathlib import Path from typing import Awaitable, Callable, Any, Literal -from typing import final, overload +from typing import final, overload, TypeVar +T = TypeVar("T") + # WireGuard async def start_wireguard_server( @@ -43,7 +45,7 @@ class LocalRedirector: async def wait_closed(self) -> None: ... -# UDP Server +# UDP async def start_udp_server( host: str, @@ -58,6 +60,12 @@ class UdpServer: async def wait_closed(self) -> None: ... def __repr__(self) -> str: ... +async def open_udp_connection( + host: str, + port: int, + *, + local_addr: tuple[str, int] | None = None, +) -> Stream: ... # TCP / UDP @@ -73,9 +81,24 @@ class Stream: async def wait_closed(self) -> None: ... @overload - def get_extra_info(self, name: Literal["transport_protocol"], default: Any = None) -> Literal["tcp", "udp"]: ... + def get_extra_info(self, name: Literal["transport_protocol"], default: None = None) -> Literal["tcp", "udp"]: ... + @overload + def get_extra_info(self, name: Literal["transport_protocol"], default: T) -> Literal["tcp", "udp"] | T: ... + @overload + def get_extra_info(self, name: Literal["peername", "sockname", "original_src", "original_dst", "remote_endpoint"], default: None = None) -> tuple[str, int]: ... + @overload + def get_extra_info(self, name: Literal["peername", "sockname", "original_src", "original_dst", "remote_endpoint"], default: T) -> tuple[str, int] | T: ... + @overload + def get_extra_info(self, name: Literal["pid"], default: None = None) -> int: ... + @overload + def get_extra_info(self, name: Literal["pid"], default: T) -> int | T: ... + @overload + def get_extra_info(self, name: Literal["process_name"], default: None = None) -> str: ... + @overload + def get_extra_info(self, name: Literal["process_name"], default: T) -> str | T: ... + @overload + def get_extra_info(self, name: str, default: T) -> T: ... - def get_extra_info(self, name: str, default: Any = None) -> Any: ... def __repr__(self) -> str: ... diff --git a/mitmproxy-rs/src/lib.rs b/mitmproxy-rs/src/lib.rs index add36a15..b4f3e746 100644 --- a/mitmproxy-rs/src/lib.rs +++ b/mitmproxy-rs/src/lib.rs @@ -9,6 +9,7 @@ mod process_info; mod server; mod stream; mod task; +mod udp_client; mod util; static LOGGER_INITIALIZED: Lazy> = Lazy::new(|| RwLock::new(false)); @@ -52,6 +53,8 @@ pub fn mitmproxy_rs(py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(server::start_udp_server, m)?)?; m.add_class::()?; + m.add_function(wrap_pyfunction!(udp_client::open_udp_connection, m)?)?; + m.add_function(wrap_pyfunction!(process_info::active_executables, m)?)?; m.add_class::()?; m.add_function(wrap_pyfunction!(process_info::executable_icon, m)?)?; diff --git a/mitmproxy-rs/src/server/udp.rs b/mitmproxy-rs/src/server/udp.rs index 9b6d1ab3..8bebce36 100644 --- a/mitmproxy-rs/src/server/udp.rs +++ b/mitmproxy-rs/src/server/udp.rs @@ -1,10 +1,11 @@ -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::net::SocketAddr; use mitmproxy::packet_sources::udp::UdpConf; use pyo3::prelude::*; use crate::server::base::Server; + use crate::util::socketaddr_to_py; /// A running UDP server. @@ -62,17 +63,10 @@ pub fn start_udp_server( port: u16, handle_udp_stream: PyObject, ) -> PyResult<&PyAny> { - let is_unspecified = host.is_empty(); let conf = UdpConf { host, port }; let handle_tcp_stream = py.None(); pyo3_asyncio::tokio::future_into_py(py, async move { - let (server, mut local_addr) = - Server::init(conf, handle_tcp_stream, handle_udp_stream).await?; - // Work around Windows limitation, see packet_sources/udp.rs - if is_unspecified && local_addr == SocketAddr::from((Ipv4Addr::LOCALHOST, port)) { - local_addr.set_ip(IpAddr::V4(Ipv4Addr::UNSPECIFIED)); - } - + let (server, local_addr) = Server::init(conf, handle_tcp_stream, handle_udp_stream).await?; Ok(UdpServer { server, local_addr }) }) } diff --git a/mitmproxy-rs/src/stream.rs b/mitmproxy-rs/src/stream.rs index 0bffbc8d..a59b6076 100644 --- a/mitmproxy-rs/src/stream.rs +++ b/mitmproxy-rs/src/stream.rs @@ -2,6 +2,7 @@ use std::net::SocketAddr; use once_cell::sync::Lazy; +use pyo3::exceptions::PyKeyError; use pyo3::{exceptions::PyOSError, intern, prelude::*, types::PyBytes}; use tokio::sync::{ @@ -11,7 +12,7 @@ use tokio::sync::{ use mitmproxy::messages::{ConnectionId, TransportCommand, TunnelInfo}; -use crate::util::{event_queue_unavailable, get_tunnel_info, socketaddr_to_py}; +use crate::util::{event_queue_unavailable, socketaddr_to_py}; #[derive(Debug)] pub enum StreamState { @@ -28,7 +29,7 @@ pub enum StreamState { pub struct Stream { pub connection_id: ConnectionId, pub state: StreamState, - pub event_tx: mpsc::UnboundedSender, + pub command_tx: mpsc::UnboundedSender, pub peername: SocketAddr, pub sockname: SocketAddr, pub tunnel_info: TunnelInfo, @@ -49,7 +50,7 @@ impl Stream { StreamState::Open | StreamState::HalfClosed => { let (tx, rx) = oneshot::channel(); - self.event_tx + self.command_tx .send(TransportCommand::ReadData(self.connection_id, n, tx)) .ok(); // if this fails tx is dropped and rx.await will error. @@ -77,7 +78,7 @@ impl Stream { fn write(&self, data: Vec) -> PyResult<()> { match self.state { StreamState::Open => self - .event_tx + .command_tx .send(TransportCommand::WriteData(self.connection_id, data)) .map_err(event_queue_unavailable), StreamState::HalfClosed => Err(PyOSError::new_err("connection closed")), @@ -92,7 +93,7 @@ impl Stream { fn drain<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { let (tx, rx) = oneshot::channel(); - self.event_tx + self.command_tx .send(TransportCommand::DrainWriter(self.connection_id, tx)) .map_err(event_queue_unavailable)?; @@ -111,7 +112,7 @@ impl Stream { match self.state { StreamState::Open => { self.state = StreamState::HalfClosed; - self.event_tx + self.command_tx .send(TransportCommand::CloseConnection(self.connection_id, true)) .map_err(event_queue_unavailable) } @@ -128,7 +129,7 @@ impl Stream { match self.state { StreamState::Open | StreamState::HalfClosed => { self.state = StreamState::Closed; - self.event_tx + self.command_tx .send(TransportCommand::CloseConnection(self.connection_id, false)) .map_err(event_queue_unavailable) } @@ -155,7 +156,6 @@ impl Stream { /// - Always available: `transport_protocol`, `peername`, `sockname` /// - WireGuard mode: `original_dst`, `original_src` /// - Local redirector mode: `pid`, `process_name`, `remote_endpoint` - #[pyo3(text_signature = "(self, name, default=None)")] fn get_extra_info( &self, py: Python, @@ -163,14 +163,38 @@ impl Stream { default: Option, ) -> PyResult { match name.as_str() { - "transport_protocol" => Ok(PyObject::from(if self.connection_id.is_tcp() { - intern!(py, "tcp") - } else { - intern!(py, "udp") - })), - "peername" => Ok(socketaddr_to_py(py, self.peername)), - "sockname" => Ok(socketaddr_to_py(py, self.sockname)), - _ => get_tunnel_info(&self.tunnel_info, py, name, default), + "transport_protocol" => { + if self.connection_id.is_tcp() { + return Ok(PyObject::from(intern!(py, "tcp"))); + } else { + return Ok(PyObject::from(intern!(py, "udp"))); + } + } + "peername" => return Ok(socketaddr_to_py(py, self.peername)), + "sockname" => return Ok(socketaddr_to_py(py, self.sockname)), + _ => (), + } + match &self.tunnel_info { + TunnelInfo::WireGuard { src_addr, dst_addr } => match name.as_str() { + "original_src" => return Ok(socketaddr_to_py(py, *src_addr)), + "original_dst" => return Ok(socketaddr_to_py(py, *dst_addr)), + _ => (), + }, + TunnelInfo::LocalRedirector { + pid, + process_name, + remote_endpoint, + } => match name.as_str() { + "pid" => return Ok(pid.into_py(py)), + "process_name" => return Ok(process_name.clone().into_py(py)), + "remote_endpoint" => return Ok(remote_endpoint.clone().into_py(py)), + _ => (), + }, + TunnelInfo::Udp {} => (), + } + match default { + Some(x) => Ok(x), + None => Err(PyKeyError::new_err(name)), } } diff --git a/mitmproxy-rs/src/task.rs b/mitmproxy-rs/src/task.rs index 413086c5..9937c240 100644 --- a/mitmproxy-rs/src/task.rs +++ b/mitmproxy-rs/src/task.rs @@ -68,7 +68,7 @@ impl PyInteropTask { let stream = Stream { connection_id, state: StreamState::Open, - event_tx: self.transport_commands.clone(), + command_tx: self.transport_commands.clone(), peername: src_addr, sockname: dst_addr, tunnel_info, diff --git a/mitmproxy-rs/src/udp_client.rs b/mitmproxy-rs/src/udp_client.rs new file mode 100644 index 00000000..56e13466 --- /dev/null +++ b/mitmproxy-rs/src/udp_client.rs @@ -0,0 +1,151 @@ +use anyhow::Context; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + +use anyhow::Result; +use pyo3::prelude::*; +use tokio::net::{lookup_host, UdpSocket}; +use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver}; +use tokio::sync::oneshot; + +use crate::stream::{Stream, StreamState}; +use mitmproxy::messages::{ConnectionId, TransportCommand, TunnelInfo}; +use mitmproxy::MAX_PACKET_SIZE; + +/// Start a UDP client that is configured with the given parameters: +/// +/// - `host`: The host address. +/// - `port`: The listen port. +/// - `local_addr`: The local address to bind to. +#[pyfunction] +#[pyo3(signature = (host, port, *, local_addr = None))] +pub fn open_udp_connection( + py: Python<'_>, + host: String, + port: u16, + local_addr: Option<(String, u16)>, +) -> PyResult<&PyAny> { + pyo3_asyncio::tokio::future_into_py(py, async move { + let socket = udp_connect(host, port, local_addr).await?; + + let peername = socket.peer_addr()?; + let sockname = socket.local_addr()?; + + let (command_tx, command_rx) = unbounded_channel(); + + tokio::spawn( + UdpClientTask { + socket, + transport_commands_rx: command_rx, + } + .run(), + ); + + let stream = Stream { + connection_id: ConnectionId::unassigned(), + state: StreamState::Open, + command_tx, + peername, + sockname, + tunnel_info: TunnelInfo::Udp, + }; + + Ok(stream) + }) +} + +/// Open an UDP socket from bind_to to host:port. +/// This is a bit trickier than expected because we want to support IPv4 and IPv6. +async fn udp_connect( + host: String, + port: u16, + local_addr: Option<(String, u16)>, +) -> Result { + let addrs: Vec = lookup_host((host.as_str(), port)) + .await + .with_context(|| format!("unable to resolve hostname: {}", host))? + .collect(); + + if let Some((host, port)) = local_addr { + let socket = UdpSocket::bind((host.as_str(), port)) + .await + .with_context(|| format!("unable to bind to ({}, {})", host, port))?; + socket + .connect(addrs.as_slice()) + .await + .context("unable to connect to remote address")?; + Ok(socket) + } else { + if let Ok(socket) = + UdpSocket::bind(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)).await + { + if socket.connect(addrs.as_slice()).await.is_ok() { + return Ok(socket); + } + } + let socket = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)) + .await + .context("unable to bind to 127.0.0.1:0")?; + socket + .connect(addrs.as_slice()) + .await + .context("unable to connect to remote address")?; + Ok(socket) + } +} + +#[derive(Debug)] +pub struct UdpClientTask { + socket: UdpSocket, + transport_commands_rx: UnboundedReceiver, +} + +impl UdpClientTask { + pub async fn run(mut self) { + let mut udp_buf = [0; MAX_PACKET_SIZE]; + + // this here isn't perfect because we block the entire transport_commands_rx channel if we + // cannot send (so we also block receiving new packets), but that's hopefully good enough. + let mut packet_needs_sending = false; + let mut packet_payload = Vec::new(); + + let mut packet_tx: Option>> = None; + + loop { + tokio::select! { + // wait for transport_events_tx channel capacity... + Ok(len) = self.socket.recv(&mut udp_buf), if packet_tx.is_some() => { + packet_tx + .take() + .unwrap() + .send(udp_buf[..len].to_vec()) + .ok(); + }, + // send_to is cancel safe, so we can use that for backpressure. + _ = self.socket.send(&packet_payload), if packet_needs_sending => { + packet_needs_sending = false; + }, + Some(command) = self.transport_commands_rx.recv(), if !packet_needs_sending => { + match command { + TransportCommand::ReadData(_,_,tx) => { + packet_tx = Some(tx); + }, + TransportCommand::WriteData(_, data) => { + packet_payload = data; + packet_needs_sending = true; + }, + TransportCommand::DrainWriter(_,tx) => { + tx.send(()).ok(); + }, + TransportCommand::CloseConnection(_, half_close) => { + if !half_close { + break; + } + }, + } + } + else => break, + } + } + log::debug!("UDP client task shutting down."); + } +} diff --git a/mitmproxy-rs/src/util.rs b/mitmproxy-rs/src/util.rs index 195a1fee..bce1b7bd 100644 --- a/mitmproxy-rs/src/util.rs +++ b/mitmproxy-rs/src/util.rs @@ -3,8 +3,8 @@ use anyhow::{anyhow, Result}; use data_encoding::BASE64; #[cfg(target_os = "macos")] use mitmproxy::macos; -use mitmproxy::messages::TunnelInfo; -use pyo3::exceptions::{PyKeyError, PyOSError}; + +use pyo3::exceptions::PyOSError; use pyo3::types::{PyString, PyTuple}; use pyo3::{exceptions::PyValueError, prelude::*}; use rand_core::OsRng; @@ -114,33 +114,3 @@ pub fn remove_cert() -> PyResult<()> { "OS proxy mode is only available on macos", )) } - -pub(crate) fn get_tunnel_info( - tunnel: &TunnelInfo, - py: Python, - name: String, - default: Option, -) -> PyResult { - match tunnel { - TunnelInfo::WireGuard { src_addr, dst_addr } => match name.as_str() { - "original_src" => return Ok(socketaddr_to_py(py, *src_addr)), - "original_dst" => return Ok(socketaddr_to_py(py, *dst_addr)), - _ => (), - }, - TunnelInfo::LocalRedirector { - pid, - process_name, - remote_endpoint, - } => match name.as_str() { - "pid" => return Ok(pid.into_py(py)), - "process_name" => return Ok(process_name.clone().into_py(py)), - "remote_endpoint" => return Ok(remote_endpoint.clone().into_py(py)), - _ => (), - }, - TunnelInfo::Udp {} => (), - } - match default { - Some(x) => Ok(x), - None => Err(PyKeyError::new_err(name)), - } -} diff --git a/src/messages.rs b/src/messages.rs index ae1d9c15..efbbe547 100755 --- a/src/messages.rs +++ b/src/messages.rs @@ -19,7 +19,7 @@ pub enum TunnelInfo { /// an unresolved remote_endpoint instead. remote_endpoint: Option<(String, u16)>, }, - Udp {}, + Udp, } /// Events that are sent by WireGuard to the TCP stack. diff --git a/src/packet_sources/udp.rs b/src/packet_sources/udp.rs index 869153eb..da228b9d 100644 --- a/src/packet_sources/udp.rs +++ b/src/packet_sources/udp.rs @@ -1,6 +1,6 @@ -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::net::{Ipv4Addr, SocketAddr}; -use anyhow::Result; +use anyhow::{Context, Result}; use async_trait::async_trait; use tokio::sync::mpsc::{Permit, UnboundedReceiver}; @@ -34,32 +34,13 @@ impl PacketSourceConf for UdpConf { transport_commands_rx: UnboundedReceiver, shutdown: broadcast::Receiver<()>, ) -> Result<(Self::Task, Self::Data)> { - // bind to UDP socket(s) - - let socket_addrs = if self.host.is_empty() { - vec![ - // Windows quirks: We need to bind to 127.0.0.1 explicitly for IPv4. - #[cfg(windows)] - SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), self.port), - #[cfg(not(windows))] - SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), self.port), - SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), self.port), - ] - } else { - vec![SocketAddr::new(self.host.parse()?, self.port)] - }; - - let socket = UdpSocket::bind(socket_addrs.as_slice()).await?; + // bind to UDP socket. Note that UdpSocket::bind accepts ToSocketAddrs, but will only ever bind to one address! + let socket = UdpSocket::bind((self.host.as_str(), self.port)) + .await + .with_context(|| format!("Failed to bind UDP socket to {}:{}", self.host, self.port))?; let local_addr = socket.local_addr()?; - log::debug!( - "UDP server listening on {} ...", - socket_addrs - .iter() - .map(|addr| addr.to_string()) - .collect::>() - .join(" and ") - ); + log::debug!("UDP server listening on {} ...", local_addr); Ok(( UdpTask {