Skip to content

Commit

Permalink
add UDP client
Browse files Browse the repository at this point in the history
  • Loading branch information
mhils committed Dec 17, 2023
1 parent 44abe5d commit d147d98
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 89 deletions.
31 changes: 27 additions & 4 deletions mitmproxy-rs/mitmproxy_rs/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ from __future__ import annotations

from pathlib import Path
from typing import Awaitable, Callable, Any, Literal
from typing import final, overload
from typing import final, overload, TypeVar


T = TypeVar("T")

# WireGuard

async def start_wireguard_server(
Expand Down Expand Up @@ -43,7 +45,7 @@ class LocalRedirector:
async def wait_closed(self) -> None: ...


# UDP Server
# UDP

async def start_udp_server(
host: str,
Expand All @@ -58,6 +60,12 @@ class UdpServer:
async def wait_closed(self) -> None: ...
def __repr__(self) -> str: ...

async def open_udp_connection(
host: str,
port: int,
*,
local_addr: tuple[str, int] | None = None,
) -> Stream: ...

# TCP / UDP

Expand All @@ -73,9 +81,24 @@ class Stream:
async def wait_closed(self) -> None: ...

@overload
def get_extra_info(self, name: Literal["transport_protocol"], default: Any = None) -> Literal["tcp", "udp"]: ...
def get_extra_info(self, name: Literal["transport_protocol"], default: None = None) -> Literal["tcp", "udp"]: ...
@overload
def get_extra_info(self, name: Literal["transport_protocol"], default: T) -> Literal["tcp", "udp"] | T: ...
@overload
def get_extra_info(self, name: Literal["peername", "sockname", "original_src", "original_dst", "remote_endpoint"], default: None = None) -> tuple[str, int]: ...
@overload
def get_extra_info(self, name: Literal["peername", "sockname", "original_src", "original_dst", "remote_endpoint"], default: T) -> tuple[str, int] | T: ...
@overload
def get_extra_info(self, name: Literal["pid"], default: None = None) -> int: ...
@overload
def get_extra_info(self, name: Literal["pid"], default: T) -> int | T: ...
@overload
def get_extra_info(self, name: Literal["process_name"], default: None = None) -> str: ...
@overload
def get_extra_info(self, name: Literal["process_name"], default: T) -> str | T: ...
@overload
def get_extra_info(self, name: str, default: T) -> T: ...

def get_extra_info(self, name: str, default: Any = None) -> Any: ...
def __repr__(self) -> str: ...


Expand Down
3 changes: 3 additions & 0 deletions mitmproxy-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mod process_info;
mod server;
mod stream;
mod task;
mod udp_client;
mod util;

static LOGGER_INITIALIZED: Lazy<RwLock<bool>> = Lazy::new(|| RwLock::new(false));
Expand Down Expand Up @@ -52,6 +53,8 @@ pub fn mitmproxy_rs(py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(server::start_udp_server, m)?)?;
m.add_class::<server::UdpServer>()?;

m.add_function(wrap_pyfunction!(udp_client::open_udp_connection, m)?)?;

m.add_function(wrap_pyfunction!(process_info::active_executables, m)?)?;
m.add_class::<process_info::Process>()?;
m.add_function(wrap_pyfunction!(process_info::executable_icon, m)?)?;
Expand Down
12 changes: 3 additions & 9 deletions mitmproxy-rs/src/server/udp.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::net::SocketAddr;

use mitmproxy::packet_sources::udp::UdpConf;

use pyo3::prelude::*;

use crate::server::base::Server;

use crate::util::socketaddr_to_py;

/// A running UDP server.
Expand Down Expand Up @@ -62,17 +63,10 @@ pub fn start_udp_server(
port: u16,
handle_udp_stream: PyObject,
) -> PyResult<&PyAny> {
let is_unspecified = host.is_empty();
let conf = UdpConf { host, port };
let handle_tcp_stream = py.None();
pyo3_asyncio::tokio::future_into_py(py, async move {
let (server, mut local_addr) =
Server::init(conf, handle_tcp_stream, handle_udp_stream).await?;
// Work around Windows limitation, see packet_sources/udp.rs
if is_unspecified && local_addr == SocketAddr::from((Ipv4Addr::LOCALHOST, port)) {
local_addr.set_ip(IpAddr::V4(Ipv4Addr::UNSPECIFIED));
}

let (server, local_addr) = Server::init(conf, handle_tcp_stream, handle_udp_stream).await?;
Ok(UdpServer { server, local_addr })
})
}
56 changes: 40 additions & 16 deletions mitmproxy-rs/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::net::SocketAddr;

use once_cell::sync::Lazy;

use pyo3::exceptions::PyKeyError;
use pyo3::{exceptions::PyOSError, intern, prelude::*, types::PyBytes};

use tokio::sync::{
Expand All @@ -11,7 +12,7 @@ use tokio::sync::{

use mitmproxy::messages::{ConnectionId, TransportCommand, TunnelInfo};

use crate::util::{event_queue_unavailable, get_tunnel_info, socketaddr_to_py};
use crate::util::{event_queue_unavailable, socketaddr_to_py};

#[derive(Debug)]
pub enum StreamState {
Expand All @@ -28,7 +29,7 @@ pub enum StreamState {
pub struct Stream {
pub connection_id: ConnectionId,
pub state: StreamState,
pub event_tx: mpsc::UnboundedSender<TransportCommand>,
pub command_tx: mpsc::UnboundedSender<TransportCommand>,
pub peername: SocketAddr,
pub sockname: SocketAddr,
pub tunnel_info: TunnelInfo,
Expand All @@ -49,7 +50,7 @@ impl Stream {
StreamState::Open | StreamState::HalfClosed => {
let (tx, rx) = oneshot::channel();

self.event_tx
self.command_tx
.send(TransportCommand::ReadData(self.connection_id, n, tx))
.ok(); // if this fails tx is dropped and rx.await will error.

Expand Down Expand Up @@ -77,7 +78,7 @@ impl Stream {
fn write(&self, data: Vec<u8>) -> PyResult<()> {
match self.state {
StreamState::Open => self
.event_tx
.command_tx
.send(TransportCommand::WriteData(self.connection_id, data))
.map_err(event_queue_unavailable),
StreamState::HalfClosed => Err(PyOSError::new_err("connection closed")),
Expand All @@ -92,7 +93,7 @@ impl Stream {
fn drain<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> {
let (tx, rx) = oneshot::channel();

self.event_tx
self.command_tx
.send(TransportCommand::DrainWriter(self.connection_id, tx))
.map_err(event_queue_unavailable)?;

Expand All @@ -111,7 +112,7 @@ impl Stream {
match self.state {
StreamState::Open => {
self.state = StreamState::HalfClosed;
self.event_tx
self.command_tx
.send(TransportCommand::CloseConnection(self.connection_id, true))
.map_err(event_queue_unavailable)
}
Expand All @@ -128,7 +129,7 @@ impl Stream {
match self.state {
StreamState::Open | StreamState::HalfClosed => {
self.state = StreamState::Closed;
self.event_tx
self.command_tx
.send(TransportCommand::CloseConnection(self.connection_id, false))
.map_err(event_queue_unavailable)
}
Expand All @@ -155,22 +156,45 @@ impl Stream {
/// - Always available: `transport_protocol`, `peername`, `sockname`
/// - WireGuard mode: `original_dst`, `original_src`
/// - Local redirector mode: `pid`, `process_name`, `remote_endpoint`
#[pyo3(text_signature = "(self, name, default=None)")]
fn get_extra_info(
&self,
py: Python,
name: String,
default: Option<PyObject>,
) -> PyResult<PyObject> {
match name.as_str() {
"transport_protocol" => Ok(PyObject::from(if self.connection_id.is_tcp() {
intern!(py, "tcp")
} else {
intern!(py, "udp")
})),
"peername" => Ok(socketaddr_to_py(py, self.peername)),
"sockname" => Ok(socketaddr_to_py(py, self.sockname)),
_ => get_tunnel_info(&self.tunnel_info, py, name, default),
"transport_protocol" => {
if self.connection_id.is_tcp() {
return Ok(PyObject::from(intern!(py, "tcp")));
} else {
return Ok(PyObject::from(intern!(py, "udp")));
}
}
"peername" => return Ok(socketaddr_to_py(py, self.peername)),
"sockname" => return Ok(socketaddr_to_py(py, self.sockname)),
_ => (),
}
match &self.tunnel_info {
TunnelInfo::WireGuard { src_addr, dst_addr } => match name.as_str() {
"original_src" => return Ok(socketaddr_to_py(py, *src_addr)),
"original_dst" => return Ok(socketaddr_to_py(py, *dst_addr)),
_ => (),
},
TunnelInfo::LocalRedirector {
pid,
process_name,
remote_endpoint,
} => match name.as_str() {
"pid" => return Ok(pid.into_py(py)),
"process_name" => return Ok(process_name.clone().into_py(py)),
"remote_endpoint" => return Ok(remote_endpoint.clone().into_py(py)),
_ => (),
},
TunnelInfo::Udp {} => (),
}
match default {
Some(x) => Ok(x),
None => Err(PyKeyError::new_err(name)),
}
}

Expand Down
2 changes: 1 addition & 1 deletion mitmproxy-rs/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl PyInteropTask {
let stream = Stream {
connection_id,
state: StreamState::Open,
event_tx: self.transport_commands.clone(),
command_tx: self.transport_commands.clone(),
peername: src_addr,
sockname: dst_addr,
tunnel_info,
Expand Down
151 changes: 151 additions & 0 deletions mitmproxy-rs/src/udp_client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
use anyhow::Context;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};

use anyhow::Result;
use pyo3::prelude::*;
use tokio::net::{lookup_host, UdpSocket};
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver};
use tokio::sync::oneshot;

use crate::stream::{Stream, StreamState};
use mitmproxy::messages::{ConnectionId, TransportCommand, TunnelInfo};
use mitmproxy::MAX_PACKET_SIZE;

/// Start a UDP client that is configured with the given parameters:
///
/// - `host`: The host address.
/// - `port`: The listen port.
/// - `local_addr`: The local address to bind to.
#[pyfunction]
#[pyo3(signature = (host, port, *, local_addr = None))]
pub fn open_udp_connection(
py: Python<'_>,
host: String,
port: u16,
local_addr: Option<(String, u16)>,
) -> PyResult<&PyAny> {
pyo3_asyncio::tokio::future_into_py(py, async move {
let socket = udp_connect(host, port, local_addr).await?;

let peername = socket.peer_addr()?;
let sockname = socket.local_addr()?;

let (command_tx, command_rx) = unbounded_channel();

tokio::spawn(
UdpClientTask {
socket,
transport_commands_rx: command_rx,
}
.run(),
);

let stream = Stream {
connection_id: ConnectionId::unassigned(),
state: StreamState::Open,
command_tx,
peername,
sockname,
tunnel_info: TunnelInfo::Udp,
};

Ok(stream)
})
}

/// Open an UDP socket from bind_to to host:port.
/// This is a bit trickier than expected because we want to support IPv4 and IPv6.
async fn udp_connect(
host: String,
port: u16,
local_addr: Option<(String, u16)>,
) -> Result<UdpSocket> {
let addrs: Vec<SocketAddr> = lookup_host((host.as_str(), port))
.await
.with_context(|| format!("unable to resolve hostname: {}", host))?
.collect();

if let Some((host, port)) = local_addr {
let socket = UdpSocket::bind((host.as_str(), port))
.await
.with_context(|| format!("unable to bind to ({}, {})", host, port))?;
socket
.connect(addrs.as_slice())
.await
.context("unable to connect to remote address")?;
Ok(socket)
} else {
if let Ok(socket) =
UdpSocket::bind(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)).await
{
if socket.connect(addrs.as_slice()).await.is_ok() {
return Ok(socket);
}
}
let socket = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0))
.await
.context("unable to bind to 127.0.0.1:0")?;
socket
.connect(addrs.as_slice())
.await
.context("unable to connect to remote address")?;
Ok(socket)
}
}

#[derive(Debug)]
pub struct UdpClientTask {
socket: UdpSocket,
transport_commands_rx: UnboundedReceiver<TransportCommand>,
}

impl UdpClientTask {
pub async fn run(mut self) {
let mut udp_buf = [0; MAX_PACKET_SIZE];

// this here isn't perfect because we block the entire transport_commands_rx channel if we
// cannot send (so we also block receiving new packets), but that's hopefully good enough.
let mut packet_needs_sending = false;
let mut packet_payload = Vec::new();

let mut packet_tx: Option<oneshot::Sender<Vec<u8>>> = None;

loop {
tokio::select! {
// wait for transport_events_tx channel capacity...
Ok(len) = self.socket.recv(&mut udp_buf), if packet_tx.is_some() => {
packet_tx
.take()
.unwrap()
.send(udp_buf[..len].to_vec())
.ok();
},
// send_to is cancel safe, so we can use that for backpressure.
_ = self.socket.send(&packet_payload), if packet_needs_sending => {
packet_needs_sending = false;
},
Some(command) = self.transport_commands_rx.recv(), if !packet_needs_sending => {
match command {
TransportCommand::ReadData(_,_,tx) => {
packet_tx = Some(tx);
},
TransportCommand::WriteData(_, data) => {
packet_payload = data;
packet_needs_sending = true;
},
TransportCommand::DrainWriter(_,tx) => {
tx.send(()).ok();
},
TransportCommand::CloseConnection(_, half_close) => {
if !half_close {
break;
}
},
}
}
else => break,
}
}
log::debug!("UDP client task shutting down.");
}
}
Loading

0 comments on commit d147d98

Please sign in to comment.