Skip to content

Commit

Permalink
fixup macos mode
Browse files Browse the repository at this point in the history
  • Loading branch information
mhils committed Dec 17, 2023
1 parent e4dc327 commit 7dc0a0c
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 171 deletions.
21 changes: 14 additions & 7 deletions mitmproxy-rs/src/server/local_redirector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ pub struct LocalRedirector {
spec: String,
}

impl LocalRedirector {
pub fn new(server: Server, conf_tx: mpsc::UnboundedSender<InterceptConf>) -> Self {
Self {
server,
conf_tx,
spec: "inactive".to_string(),
}
}
}

#[pymethods]
impl LocalRedirector {
/// Return a textual description of the given spec,
Expand Down Expand Up @@ -84,11 +94,7 @@ pub fn start_local_redirector(
let (server, conf_tx) =
Server::init(conf, handle_tcp_stream, handle_udp_stream).await?;

Ok(LocalRedirector {
server,
conf_tx,
spec: "inactive".to_string(),
})
Ok(LocalRedirector::new(server, conf_tx))
})
}
#[cfg(target_os = "macos")]
Expand All @@ -115,8 +121,9 @@ pub fn start_local_redirector(
}
let conf = MacosConf;
pyo3_asyncio::tokio::future_into_py(py, async move {
let (server, conf_tx) = Server::init(conf, handle_connection, receive_datagram).await?;
Ok(LocalRedirector { server, conf_tx })
let (server, conf_tx) =
Server::init(conf, handle_tcp_stream, handle_udp_stream).await?;
Ok(LocalRedirector::new(server, conf_tx))
})
}
#[cfg(not(any(windows, target_os = "macos")))]
Expand Down
5 changes: 0 additions & 5 deletions mitmproxy-rs/src/server/wireguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,9 @@ use std::net::SocketAddr;

use crate::util::{socketaddr_to_py, string_to_key};

#[cfg(target_os = "macos")]
use mitmproxy::packet_sources::macos::MacosConf;

use mitmproxy::packet_sources::wireguard::WireGuardConf;

use pyo3::prelude::*;
#[cfg(target_os = "macos")]
use std::path::Path;

use boringtun::x25519::PublicKey;

Expand Down
4 changes: 3 additions & 1 deletion mitmproxy-rs/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,14 @@ impl PyInteropTask {
src_addr,
dst_addr,
tunnel_info,
command_tx,
} => {
let command_tx = command_tx.unwrap_or_else(|| self.transport_commands.clone());
// initialize new stream
let stream = Stream {
connection_id,
state: StreamState::Open,
command_tx: self.transport_commands.clone(),
command_tx,
peername: src_addr,
sockname: dst_addr,
tunnel_info,
Expand Down
10 changes: 5 additions & 5 deletions src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use anyhow::{anyhow, Result};
use internet_packet::InternetPacket;
use smoltcp::wire::{IpProtocol, Ipv4Packet, Ipv6Packet};
use tokio::sync::oneshot;
use tokio::sync::{mpsc, oneshot};

#[derive(Debug, Clone)]
pub enum TunnelInfo {
Expand Down Expand Up @@ -39,10 +39,10 @@ pub enum NetworkCommand {

pub struct ConnectionIdGenerator(usize);
impl ConnectionIdGenerator {
pub fn tcp() -> Self {
pub const fn tcp() -> Self {
Self(2)
}
pub fn udp() -> Self {
pub const fn udp() -> Self {
Self(1)
}
pub fn next_id(&mut self) -> ConnectionId {
Expand Down Expand Up @@ -85,6 +85,7 @@ pub enum TransportEvent {
src_addr: SocketAddr,
dst_addr: SocketAddr,
tunnel_info: TunnelInfo,
command_tx: Option<mpsc::UnboundedSender<TransportCommand>>,
},
}

Expand All @@ -98,14 +99,13 @@ pub enum TransportCommand {
}

impl TransportCommand {
pub fn is_tcp(&self) -> bool {
pub fn connection_id(&self) -> &ConnectionId {
match self {
TransportCommand::ReadData(id, _, _) => id,
TransportCommand::WriteData(id, _) => id,
TransportCommand::DrainWriter(id, _) => id,
TransportCommand::CloseConnection(id, _) => id,
}
.is_tcp()
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/network/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ impl<'a> NetworkStack<'a> {
}

pub fn handle_transport_command(&mut self, command: TransportCommand) {
if command.is_tcp() {
if command.connection_id().is_tcp() {
self.tcp.handle_transport_command(command);
} else if let Some(packet) = self.udp.handle_transport_command(command) {
if self
Expand Down
2 changes: 1 addition & 1 deletion src/network/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl NetworkTask<'_> {
// wait for graceful shutdown
_ = self.shutdown.recv() => break 'task,
// wait for timeouts when the device is idle
_ = async { tokio::time::sleep(delay.unwrap()).await }, if delay.is_some() => {},
_ = tokio::time::sleep(delay.unwrap()), if delay.is_some() => {},
// wait for py_tx channel capacity...
Ok(permit) = self.py_tx.reserve(), if !py_tx_available => {
py_tx_permit = Some(permit);
Expand Down
1 change: 1 addition & 0 deletions src/network/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ impl<'a> TcpHandler<'a> {
src_addr,
dst_addr,
tunnel_info,
command_tx: None,
};
permit.send(event);
}
Expand Down
6 changes: 3 additions & 3 deletions src/network/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ async fn udp_read_write(
connection_id,
src_addr: recv_src_addr,
dst_addr: recv_dst_addr,
tunnel_info: _,
..
} = event;

assert_eq!(src_addr, recv_src_addr);
Expand Down Expand Up @@ -494,7 +494,7 @@ async fn tcp_ipv4_connection() -> Result<()> {
connection_id: tcp_conn_id,
src_addr: tcp_src_sock,
dst_addr: tcp_dst_sock,
tunnel_info: _,
..
} = event;
assert_eq!(IpAddress::Ipv4(src_addr), tcp_src_sock.ip().into());
assert_eq!(IpAddress::Ipv4(dst_addr), tcp_dst_sock.ip().into());
Expand Down Expand Up @@ -671,7 +671,7 @@ async fn tcp_ipv6_connection() -> Result<()> {
connection_id: tcp_conn_id,
src_addr: tcp_src_sock,
dst_addr: tcp_dst_sock,
tunnel_info: _,
..
} = event;
assert_eq!(IpAddress::Ipv6(src_addr), tcp_src_sock.ip().into());
assert_eq!(IpAddress::Ipv6(dst_addr), tcp_dst_sock.ip().into());
Expand Down
13 changes: 7 additions & 6 deletions src/network/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ impl ConnectionState {
}
}

pub const UDP_TIMEOUT: Duration = Duration::from_secs(60);

pub struct UdpHandler {
connection_id_generator: ConnectionIdGenerator,
id_lookup: LruCache<(SocketAddr, SocketAddr), ConnectionId>,
Expand All @@ -73,12 +75,10 @@ pub struct UdpHandler {

impl UdpHandler {
pub fn new() -> Self {
let connections = LruCache::<ConnectionId, ConnectionState>::with_expiry_duration(
Duration::from_secs(60),
);
let id_lookup = LruCache::<(SocketAddr, SocketAddr), ConnectionId>::with_expiry_duration(
Duration::from_secs(60),
);
let connections =
LruCache::<ConnectionId, ConnectionState>::with_expiry_duration(UDP_TIMEOUT);
let id_lookup =
LruCache::<(SocketAddr, SocketAddr), ConnectionId>::with_expiry_duration(UDP_TIMEOUT);
Self {
connections,
id_lookup,
Expand Down Expand Up @@ -172,6 +172,7 @@ impl UdpHandler {
src_addr: packet.src_addr,
dst_addr: packet.dst_addr,
tunnel_info,
command_tx: None,
});
}
};
Expand Down
Loading

0 comments on commit 7dc0a0c

Please sign in to comment.