Skip to content

Commit

Permalink
fixup macos mode
Browse files Browse the repository at this point in the history
  • Loading branch information
mhils committed Dec 17, 2023
1 parent e4dc327 commit b31620d
Show file tree
Hide file tree
Showing 11 changed files with 130 additions and 173 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class TransparentProxyProvider: NETransparentProxyProvider {
guard let remoteEndpoint = tcp_flow.remoteEndpoint as? NWHostEndpoint else {
throw TransparentProxyError.noRemoteEndpoint
}
log.debug("remoteEndpoint: \(String(describing: remoteEndpoint), privacy: .public)")
// log.debug("remoteEndpoint: \(String(describing: remoteEndpoint), privacy: .public)")
// It would be nice if we could also include info on the local endpoint here, but that's not exposed.
message = MitmproxyIpc_NewFlow.with {
$0.tcp = MitmproxyIpc_TcpFlow.with {
Expand Down
21 changes: 14 additions & 7 deletions mitmproxy-rs/src/server/local_redirector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ pub struct LocalRedirector {
spec: String,
}

impl LocalRedirector {
pub fn new(server: Server, conf_tx: mpsc::UnboundedSender<InterceptConf>) -> Self {
Self {
server,
conf_tx,
spec: "inactive".to_string(),
}
}
}

#[pymethods]
impl LocalRedirector {
/// Return a textual description of the given spec,
Expand Down Expand Up @@ -84,11 +94,7 @@ pub fn start_local_redirector(
let (server, conf_tx) =
Server::init(conf, handle_tcp_stream, handle_udp_stream).await?;

Ok(LocalRedirector {
server,
conf_tx,
spec: "inactive".to_string(),
})
Ok(LocalRedirector::new(server, conf_tx))
})
}
#[cfg(target_os = "macos")]
Expand All @@ -115,8 +121,9 @@ pub fn start_local_redirector(
}
let conf = MacosConf;
pyo3_asyncio::tokio::future_into_py(py, async move {
let (server, conf_tx) = Server::init(conf, handle_connection, receive_datagram).await?;
Ok(LocalRedirector { server, conf_tx })
let (server, conf_tx) =
Server::init(conf, handle_tcp_stream, handle_udp_stream).await?;
Ok(LocalRedirector::new(server, conf_tx))
})
}
#[cfg(not(any(windows, target_os = "macos")))]
Expand Down
5 changes: 0 additions & 5 deletions mitmproxy-rs/src/server/wireguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,9 @@ use std::net::SocketAddr;

use crate::util::{socketaddr_to_py, string_to_key};

#[cfg(target_os = "macos")]
use mitmproxy::packet_sources::macos::MacosConf;

use mitmproxy::packet_sources::wireguard::WireGuardConf;

use pyo3::prelude::*;
#[cfg(target_os = "macos")]
use std::path::Path;

use boringtun::x25519::PublicKey;

Expand Down
4 changes: 3 additions & 1 deletion mitmproxy-rs/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,14 @@ impl PyInteropTask {
src_addr,
dst_addr,
tunnel_info,
command_tx,
} => {
let command_tx = command_tx.unwrap_or_else(|| self.transport_commands.clone());
// initialize new stream
let stream = Stream {
connection_id,
state: StreamState::Open,
command_tx: self.transport_commands.clone(),
command_tx,
peername: src_addr,
sockname: dst_addr,
tunnel_info,
Expand Down
10 changes: 5 additions & 5 deletions src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use anyhow::{anyhow, Result};
use internet_packet::InternetPacket;
use smoltcp::wire::{IpProtocol, Ipv4Packet, Ipv6Packet};
use tokio::sync::oneshot;
use tokio::sync::{mpsc, oneshot};

#[derive(Debug, Clone)]
pub enum TunnelInfo {
Expand Down Expand Up @@ -39,10 +39,10 @@ pub enum NetworkCommand {

pub struct ConnectionIdGenerator(usize);
impl ConnectionIdGenerator {
pub fn tcp() -> Self {
pub const fn tcp() -> Self {
Self(2)
}
pub fn udp() -> Self {
pub const fn udp() -> Self {
Self(1)
}
pub fn next_id(&mut self) -> ConnectionId {
Expand Down Expand Up @@ -85,6 +85,7 @@ pub enum TransportEvent {
src_addr: SocketAddr,
dst_addr: SocketAddr,
tunnel_info: TunnelInfo,
command_tx: Option<mpsc::UnboundedSender<TransportCommand>>,
},
}

Expand All @@ -98,14 +99,13 @@ pub enum TransportCommand {
}

impl TransportCommand {
pub fn is_tcp(&self) -> bool {
pub fn connection_id(&self) -> &ConnectionId {
match self {
TransportCommand::ReadData(id, _, _) => id,
TransportCommand::WriteData(id, _) => id,
TransportCommand::DrainWriter(id, _) => id,
TransportCommand::CloseConnection(id, _) => id,
}
.is_tcp()
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/network/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ impl<'a> NetworkStack<'a> {
}

pub fn handle_transport_command(&mut self, command: TransportCommand) {
if command.is_tcp() {
if command.connection_id().is_tcp() {
self.tcp.handle_transport_command(command);
} else if let Some(packet) = self.udp.handle_transport_command(command) {
if self
Expand Down
2 changes: 1 addition & 1 deletion src/network/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl NetworkTask<'_> {
// wait for graceful shutdown
_ = self.shutdown.recv() => break 'task,
// wait for timeouts when the device is idle
_ = async { tokio::time::sleep(delay.unwrap()).await }, if delay.is_some() => {},
_ = tokio::time::sleep(delay.unwrap()), if delay.is_some() => {},
// wait for py_tx channel capacity...
Ok(permit) = self.py_tx.reserve(), if !py_tx_available => {
py_tx_permit = Some(permit);
Expand Down
1 change: 1 addition & 0 deletions src/network/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ impl<'a> TcpHandler<'a> {
src_addr,
dst_addr,
tunnel_info,
command_tx: None,
};
permit.send(event);
}
Expand Down
6 changes: 3 additions & 3 deletions src/network/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ async fn udp_read_write(
connection_id,
src_addr: recv_src_addr,
dst_addr: recv_dst_addr,
tunnel_info: _,
..
} = event;

assert_eq!(src_addr, recv_src_addr);
Expand Down Expand Up @@ -494,7 +494,7 @@ async fn tcp_ipv4_connection() -> Result<()> {
connection_id: tcp_conn_id,
src_addr: tcp_src_sock,
dst_addr: tcp_dst_sock,
tunnel_info: _,
..
} = event;
assert_eq!(IpAddress::Ipv4(src_addr), tcp_src_sock.ip().into());
assert_eq!(IpAddress::Ipv4(dst_addr), tcp_dst_sock.ip().into());
Expand Down Expand Up @@ -671,7 +671,7 @@ async fn tcp_ipv6_connection() -> Result<()> {
connection_id: tcp_conn_id,
src_addr: tcp_src_sock,
dst_addr: tcp_dst_sock,
tunnel_info: _,
..
} = event;
assert_eq!(IpAddress::Ipv6(src_addr), tcp_src_sock.ip().into());
assert_eq!(IpAddress::Ipv6(dst_addr), tcp_dst_sock.ip().into());
Expand Down
13 changes: 7 additions & 6 deletions src/network/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ impl ConnectionState {
}
}

pub const UDP_TIMEOUT: Duration = Duration::from_secs(60);

pub struct UdpHandler {
connection_id_generator: ConnectionIdGenerator,
id_lookup: LruCache<(SocketAddr, SocketAddr), ConnectionId>,
Expand All @@ -73,12 +75,10 @@ pub struct UdpHandler {

impl UdpHandler {
pub fn new() -> Self {
let connections = LruCache::<ConnectionId, ConnectionState>::with_expiry_duration(
Duration::from_secs(60),
);
let id_lookup = LruCache::<(SocketAddr, SocketAddr), ConnectionId>::with_expiry_duration(
Duration::from_secs(60),
);
let connections =
LruCache::<ConnectionId, ConnectionState>::with_expiry_duration(UDP_TIMEOUT);
let id_lookup =
LruCache::<(SocketAddr, SocketAddr), ConnectionId>::with_expiry_duration(UDP_TIMEOUT);
Self {
connections,
id_lookup,
Expand Down Expand Up @@ -172,6 +172,7 @@ impl UdpHandler {
src_addr: packet.src_addr,
dst_addr: packet.dst_addr,
tunnel_info,
command_tx: None,
});
}
};
Expand Down
Loading

0 comments on commit b31620d

Please sign in to comment.