From 09b9c77f0bdfb9f455d127f34e44ef57750df400 Mon Sep 17 00:00:00 2001 From: jlizen Date: Mon, 30 Dec 2024 19:29:29 +0000 Subject: [PATCH] update to reflect new changes to vacation api, add enum for whether to process packets async --- .github/workflows/CI.yml | 5 +++-- Cargo.lock | 20 ++++++++--------- Cargo.toml | 10 +++------ src/client.rs | 5 ++++- src/common/async_session.rs | 6 ++++-- src/common/handshake.rs | 19 ++++++++-------- src/common/mod.rs | 30 ++++++++++++++++++++------ src/common/test_stream.rs | 43 ++++++++++++++++++++----------------- src/lib.rs | 8 +++---- tests/async_session.rs | 6 +++--- 10 files changed, 86 insertions(+), 66 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 8ed5b77..137fbfc 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -62,8 +62,9 @@ jobs: run: | cargo test --locked --all cargo test --locked -p tokio-rustls --features early-data --test early-data - # we run all test suites against this feature since it shifts the default behavior globally - cargo test --locked -p tokio-rustls --features compute-heavy-future-executor + # we run all test suites against this feature + # to capture any regressions that come from changes to the handshake future state machine + cargo test --locked -p tokio-rustls --features vacation lints: name: Lints diff --git a/Cargo.lock b/Cargo.lock index 21625b2..daa8757 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -207,15 +207,6 @@ dependencies = [ "cc", ] -[[package]] -name = "compute-heavy-future-executor" -version = "0.1.0" -dependencies = [ - "log", - "num_cpus", - "tokio", -] - [[package]] name = "deranged" version = "0.3.11" @@ -828,13 +819,13 @@ name = "tokio-rustls" version = "0.26.1" dependencies = [ "argh", - "compute-heavy-future-executor", "futures-util", "lazy_static", "pin-project-lite", "rcgen", "rustls", "tokio", + "vacation", "webpki-roots", ] @@ -850,6 +841,15 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "vacation" +version = "0.1.0" +dependencies = [ + "log", + "num_cpus", + "tokio", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/Cargo.toml b/Cargo.toml index b8302a4..fcc259d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,9 +13,7 @@ rust-version = "1.70" exclude = ["/.github", "/examples", "/scripts"] [dependencies] -# implicitly enables the tokio feature for compute-heavy-future-executor -# (defaulting to strategy of spawn_blocking w/ concurrency conctorl) -compute-heavy-future-executor = { version = "0.1", optional = true} +vacation = { version = "0.1", optional = true, default-features = false } pin-project-lite = { version = "0.2.15", optional = true } rustls = { version = "0.23.15", default-features = false, features = ["std"] } tokio = "1.0" @@ -24,7 +22,7 @@ tokio = "1.0" default = ["logging", "tls12", "aws_lc_rs"] aws_lc_rs = ["rustls/aws_lc_rs"] aws-lc-rs = ["aws_lc_rs"] # Alias because Cargo features commonly use `-` -compute-heavy-future-executor = ["dep:compute-heavy-future-executor", "pin-project-lite"] +vacation = ["dep:vacation", "pin-project-lite"] early-data = [] fips = ["rustls/fips"] logging = ["rustls/logging"] @@ -37,7 +35,5 @@ futures-util = "0.3.1" lazy_static = "1.1" rcgen = { version = "0.13", features = ["pem"] } tokio = { version = "1.0", features = ["full"] } +vacation = { version = "0.1", features = ["tokio"] } webpki-roots = "0.26" - -[patch.crates-io] -compute-heavy-future-executor = { path = "../compute-heavy-future-executor" } \ No newline at end of file diff --git a/src/client.rs b/src/client.rs index 13dbc41..6fa21c5 100644 --- a/src/client.rs +++ b/src/client.rs @@ -288,6 +288,8 @@ fn poll_handle_early_data( where IO: AsyncRead + AsyncWrite + Unpin, { + use crate::common::PacketProcessingMode; + if let TlsState::EarlyData(pos, data) = state { use std::io::Write; @@ -321,7 +323,8 @@ where // complete handshake while stream.session.is_handshaking() { - ready!(stream.handshake(cx, false))?; + // TODO: also model as using `vacation` executor + ready!(stream.handshake(cx, PacketProcessingMode::Sync))?; } // write early data (fallback) diff --git a/src/common/async_session.rs b/src/common/async_session.rs index 4ca8fab..fa5f613 100644 --- a/src/common/async_session.rs +++ b/src/common/async_session.rs @@ -17,7 +17,7 @@ use super::{Stream, TlsState}; /// Full result of sync closure type SessionResult = Result, io::Error)>; /// Executor result wrapping sync closure result -type SyncExecutorResult = Result, compute_heavy_future_executor::Error>; +type SyncExecutorResult = Result, vacation::Error>; /// Future wrapping waiting on executor type SessionFuture = Box> + Unpin + Send>; @@ -53,7 +53,9 @@ where )), }; - let future = compute_heavy_future_executor::execute_sync(closure); + // TODO: if we ever start also delegating non-handshake byte processing, make this chance of blocking + // variable and set by caller + let future = vacation::execute_sync(closure, vacation::ChanceOfBlocking::High); Self { future: Box::new(Box::pin(future)), diff --git a/src/common/handshake.rs b/src/common/handshake.rs index e571c72..d9e8715 100644 --- a/src/common/handshake.rs +++ b/src/common/handshake.rs @@ -8,9 +8,9 @@ use rustls::server::AcceptedAlert; use rustls::{ConnectionCommon, SideData}; use tokio::io::{AsyncRead, AsyncWrite}; -use crate::common::{Stream, SyncWriteAdapter, TlsState}; +use crate::common::{PacketProcessingMode, Stream, SyncWriteAdapter, TlsState}; -#[cfg(feature = "compute-heavy-future-executor")] +#[cfg(feature = "vacation")] use super::async_session::AsyncSession; pub(crate) trait IoSession { @@ -34,7 +34,7 @@ pub(crate) trait IoSession { pub(crate) enum MidHandshake { Handshaking(IS), - #[cfg(feature = "compute-heavy-future-executor")] + #[cfg(feature = "vacation")] AsyncSession(AsyncSession), End, SendAlert { @@ -61,7 +61,7 @@ where let mut stream = match mem::replace(this, MidHandshake::End) { MidHandshake::Handshaking(stream) => stream, - #[cfg(feature = "compute-heavy-future-executor")] + #[cfg(feature = "vacation")] MidHandshake::AsyncSession(mut async_session) => { let pinned = Pin::new(&mut async_session); let session_result = ready!(pinned.poll(cx)); @@ -94,7 +94,7 @@ where ( $e:expr ) => { match $e { Poll::Ready(Ok(_)) => (), - #[cfg(feature = "compute-heavy-future-executor")] + #[cfg(feature = "vacation")] Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::WouldBlock => { // TODO: downcast to decide on closure, for now we only do this for // process_packets @@ -132,12 +132,11 @@ where }; } - while tls_stream.session.is_handshaking() { - #[cfg(feature = "compute-heavy-future-executor")] - try_poll!(tls_stream.handshake(cx, true)); - #[cfg(not(feature = "compute-heavy-future-executor"))] - try_poll!(tls_stream.handshake(cx, false)); + #[cfg(feature = "vacation")] + try_poll!(tls_stream.handshake(cx, PacketProcessingMode::Async)); + #[cfg(not(feature = "vacation"))] + try_poll!(tls_stream.handshake(cx, PacketProcessingMode::Sync)); } try_poll!(Pin::new(&mut tls_stream).poll_flush(cx)); diff --git a/src/common/mod.rs b/src/common/mod.rs index 4793c07..7df16d9 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -6,7 +6,7 @@ use std::task::{Context, Poll}; use rustls::{ConnectionCommon, SideData}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -#[cfg(feature = "compute-heavy-future-executor")] +#[cfg(feature = "vacation")] mod async_session; mod handshake; pub(crate) use handshake::{IoSession, MidHandshake}; @@ -21,6 +21,14 @@ pub enum TlsState { FullyShutdown, } +/// Whether to delegate the call to the `vacation` executor, +/// only kicks in if `vacation` feature is enabled +#[derive(Debug, PartialEq, Clone, Copy)] +pub enum PacketProcessingMode { + Async, + Sync, +} + impl TlsState { #[inline] pub fn shutdown_read(&mut self) { @@ -92,7 +100,11 @@ where } #[allow(unused_variables)] - pub fn read_io(&mut self, cx: &mut Context, process_packets_async: bool) -> Poll> { + pub fn read_io( + &mut self, + cx: &mut Context, + packet_processing_mode: PacketProcessingMode, + ) -> Poll> { let mut reader = SyncReadAdapter { io: self.io, cx }; let n: usize = match self.session.read_tls(&mut reader) { @@ -101,8 +113,8 @@ where Err(err) => return Poll::Ready(Err(err)), }; - #[cfg(feature = "compute-heavy-future-executor")] - if process_packets_async { + #[cfg(feature = "vacation")] + if packet_processing_mode == PacketProcessingMode::Async { // TODO: stop modeling errors as IO, use enum on types of async session processing return Poll::Ready(Err(io::Error::new( io::ErrorKind::WouldBlock, @@ -131,7 +143,11 @@ where } } - pub fn handshake(&mut self, cx: &mut Context, process_packets_async: bool) -> Poll> { + pub fn handshake( + &mut self, + cx: &mut Context, + packet_processing_mode: PacketProcessingMode, + ) -> Poll> { let mut wrlen = 0; let mut rdlen = 0; @@ -164,7 +180,7 @@ where } while !self.eof && self.session.wants_read() { - match self.read_io(cx, process_packets_async) { + match self.read_io(cx, packet_processing_mode) { Poll::Ready(Ok(0)) => self.eof = true, Poll::Ready(Ok(n)) => rdlen += n, Poll::Pending => { @@ -208,7 +224,7 @@ where // read a packet while !self.eof && self.session.wants_read() { - match self.read_io(cx, false) { + match self.read_io(cx, PacketProcessingMode::Sync) { Poll::Ready(Ok(0)) => { break; } diff --git a/src/common/test_stream.rs b/src/common/test_stream.rs index 9e5391d..24e3dde 100644 --- a/src/common/test_stream.rs +++ b/src/common/test_stream.rs @@ -9,7 +9,7 @@ use rustls::pki_types::ServerName; use rustls::{ClientConnection, Connection, ServerConnection}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; -use super::Stream; +use super::{PacketProcessingMode, Stream}; struct Good<'a>(&'a mut Connection); @@ -229,12 +229,13 @@ async fn stream_handshake() -> io::Result<()> { { let mut good = Good(&mut server); let mut stream = Stream::new(&mut good, &mut client); - let (r, w) = poll_fn(|cx| stream.handshake(cx, false)).await?; + let (r, w) = poll_fn(|cx| stream.handshake(cx, PacketProcessingMode::Sync)).await?; assert!(r > 0); assert!(w > 0); - poll_fn(|cx: &mut Context<'_>| stream.handshake(cx, false)).await?; // finish server handshake + poll_fn(|cx: &mut Context<'_>| stream.handshake(cx, PacketProcessingMode::Sync)).await?; + // finish server handshake } assert!(!server.is_handshaking()); @@ -253,12 +254,12 @@ async fn stream_buffered_handshake() -> io::Result<()> { { let mut good = BufWriter::new(Good(&mut server)); let mut stream = Stream::new(&mut good, &mut client); - let (r, w) = poll_fn(|cx| stream.handshake(cx, false)).await?; + let (r, w) = poll_fn(|cx| stream.handshake(cx, PacketProcessingMode::Sync)).await?; assert!(r > 0); assert!(w > 0); - poll_fn(|cx| stream.handshake(cx, false)).await?; // finish server handshake + poll_fn(|cx| stream.handshake(cx, PacketProcessingMode::Sync)).await?; // finish server handshake } assert!(!server.is_handshaking()); @@ -275,7 +276,7 @@ async fn stream_handshake_eof() -> io::Result<()> { let mut stream = Stream::new(&mut bad, &mut client); let mut cx = Context::from_waker(noop_waker_ref()); - let r = stream.handshake(&mut cx, false); + let r = stream.handshake(&mut cx, PacketProcessingMode::Sync); assert_eq!( r.map_err(|err| err.kind()), Poll::Ready(Err(io::ErrorKind::UnexpectedEof)) @@ -292,7 +293,7 @@ async fn stream_handshake_write_eof() -> io::Result<()> { let mut stream = Stream::new(&mut io, &mut client); let mut cx = Context::from_waker(noop_waker_ref()); - let r = stream.handshake(&mut cx, false); + let r = stream.handshake(&mut cx, PacketProcessingMode::Sync); assert_eq!( r.map_err(|err| err.kind()), Poll::Ready(Err(io::ErrorKind::WriteZero)) @@ -310,7 +311,7 @@ async fn stream_handshake_regression_issues_77() -> io::Result<()> { let mut stream = Stream::new(&mut bad, &mut client); let mut cx = Context::from_waker(noop_waker_ref()); - let r = stream.handshake(&mut cx, false); + let r = stream.handshake(&mut cx, PacketProcessingMode::Sync); assert_eq!( r.map_err(|err| err.kind()), Poll::Ready(Err(io::ErrorKind::InvalidData)) @@ -366,8 +367,9 @@ async fn async_process_packets() -> io::Result<()> { let mut stream = Stream::new(&mut good, &mut client); // if feature is enabled, we expect a blocking response on process packets throughout the handshake, - #[cfg(feature = "compute-heavy-future-executor")] - { let result = poll_fn(|cx| stream.handshake(cx, true)).await; + #[cfg(feature = "vacation")] + { + let result = poll_fn(|cx| stream.handshake(cx, PacketProcessingMode::Async)).await; assert_eq!( result.err().map(|e| e.kind()), @@ -375,22 +377,23 @@ async fn async_process_packets() -> io::Result<()> { ); // finish the handshake without delegating to async session - poll_fn(|cx| stream.handshake(cx, false)).await?; // client handshake - poll_fn(|cx: &mut Context<'_>| stream.handshake(cx, true)).await?; // server handshake + poll_fn(|cx| stream.handshake(cx, PacketProcessingMode::Sync)).await?; // client handshake + poll_fn(|cx: &mut Context<'_>| stream.handshake(cx, PacketProcessingMode::Sync)).await?; + // server handshake } - // if feature is disabled, we expect normal handling - #[cfg(not(feature = "compute-heavy-future-executor"))] + // if feature is disabled, we expect normal handling even if async is passed in + #[cfg(not(feature = "vacation"))] { { - let (r, w) = poll_fn(|cx| stream.handshake(cx, true)).await?; // client handshake - + let (r, w) = poll_fn(|cx| stream.handshake(cx, PacketProcessingMode::Async)).await?; // client handshake + assert!(r > 0); assert!(w > 0); - - poll_fn(|cx| stream.handshake(cx, true)).await?; // server handshake + + poll_fn(|cx| stream.handshake(cx, PacketProcessingMode::Async)).await?; + // server handshake } - } // once handshake is done, there is no longer blocking sending data over the stream @@ -426,7 +429,7 @@ fn do_handshake( let mut stream = Stream::new(&mut good, client); while stream.session.is_handshaking() { - ready!(stream.handshake(cx, false))?; + ready!(stream.handshake(cx, PacketProcessingMode::Sync))?; } while stream.session.wants_write() { diff --git a/src/lib.rs b/src/lib.rs index 30d6fec..760912e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -381,7 +381,7 @@ where pub fn get_ref(&self) -> Option<&IO> { match &self.0 { MidHandshake::Handshaking(sess) => Some(sess.get_ref().0), - #[cfg(feature = "compute-heavy-future-executor")] + #[cfg(feature = "vacation")] MidHandshake::AsyncSession(sess) => Some(sess.get_ref()), MidHandshake::SendAlert { io, .. } => Some(io), MidHandshake::Error { io, .. } => Some(io), @@ -392,7 +392,7 @@ where pub fn get_mut(&mut self) -> Option<&mut IO> { match &mut self.0 { MidHandshake::Handshaking(sess) => Some(sess.get_mut().0), - #[cfg(feature = "compute-heavy-future-executor")] + #[cfg(feature = "vacation")] MidHandshake::AsyncSession(sess) => Some(sess.get_mut()), MidHandshake::SendAlert { io, .. } => Some(io), MidHandshake::Error { io, .. } => Some(io), @@ -413,7 +413,7 @@ where pub fn get_ref(&self) -> Option<&IO> { match &self.0 { MidHandshake::Handshaking(sess) => Some(sess.get_ref().0), - #[cfg(feature = "compute-heavy-future-executor")] + #[cfg(feature = "vacation")] MidHandshake::AsyncSession(sess) => Some(sess.get_ref()), MidHandshake::SendAlert { io, .. } => Some(io), MidHandshake::Error { io, .. } => Some(io), @@ -424,7 +424,7 @@ where pub fn get_mut(&mut self) -> Option<&mut IO> { match &mut self.0 { MidHandshake::Handshaking(sess) => Some(sess.get_mut().0), - #[cfg(feature = "compute-heavy-future-executor")] + #[cfg(feature = "vacation")] MidHandshake::AsyncSession(sess) => Some(sess.get_mut()), MidHandshake::SendAlert { io, .. } => Some(io), MidHandshake::Error { io, .. } => Some(io), diff --git a/tests/async_session.rs b/tests/async_session.rs index 62247e1..a34e796 100644 --- a/tests/async_session.rs +++ b/tests/async_session.rs @@ -1,5 +1,5 @@ -#![cfg(feature = "compute-heavy-future-executor")] -//! Using the `compute-heavy-future-executor` feature shifts the global behavior +#![cfg(feature = "vacation")] +//! Using the `vacation` feature shifts the global behavior //! of processing bytes + establishing handshakes. So all other test suites running are validating //! parity of processing. //! @@ -14,13 +14,13 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::thread; -use compute_heavy_future_executor::{global_sync_strategy_builder, CustomExecutorSyncClosure}; use futures_util::{future::Future, ready}; use rustls::pki_types::ServerName; use rustls::{self, ClientConfig, ServerConnection, Stream}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadBuf}; use tokio::net::TcpStream; use tokio_rustls::{client::TlsStream, TlsConnector}; +use vacation::{global_sync_strategy_builder, CustomExecutorSyncClosure}; struct Read1(T);