Skip to content

Commit

Permalink
update to reflect new changes to vacation api, add enum for whether t…
Browse files Browse the repository at this point in the history
…o process packets async
  • Loading branch information
jlizen committed Dec 30, 2024
1 parent 0df5320 commit 09b9c77
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 66 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ jobs:
run: |
cargo test --locked --all
cargo test --locked -p tokio-rustls --features early-data --test early-data
# we run all test suites against this feature since it shifts the default behavior globally
cargo test --locked -p tokio-rustls --features compute-heavy-future-executor
# we run all test suites against this feature
# to capture any regressions that come from changes to the handshake future state machine
cargo test --locked -p tokio-rustls --features vacation
lints:
name: Lints
Expand Down
20 changes: 10 additions & 10 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 3 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ rust-version = "1.70"
exclude = ["/.github", "/examples", "/scripts"]

[dependencies]
# implicitly enables the tokio feature for compute-heavy-future-executor
# (defaulting to strategy of spawn_blocking w/ concurrency conctorl)
compute-heavy-future-executor = { version = "0.1", optional = true}
vacation = { version = "0.1", optional = true, default-features = false }
pin-project-lite = { version = "0.2.15", optional = true }
rustls = { version = "0.23.15", default-features = false, features = ["std"] }
tokio = "1.0"
Expand All @@ -24,7 +22,7 @@ tokio = "1.0"
default = ["logging", "tls12", "aws_lc_rs"]
aws_lc_rs = ["rustls/aws_lc_rs"]
aws-lc-rs = ["aws_lc_rs"] # Alias because Cargo features commonly use `-`
compute-heavy-future-executor = ["dep:compute-heavy-future-executor", "pin-project-lite"]
vacation = ["dep:vacation", "pin-project-lite"]
early-data = []
fips = ["rustls/fips"]
logging = ["rustls/logging"]
Expand All @@ -37,7 +35,5 @@ futures-util = "0.3.1"
lazy_static = "1.1"
rcgen = { version = "0.13", features = ["pem"] }
tokio = { version = "1.0", features = ["full"] }
vacation = { version = "0.1", features = ["tokio"] }
webpki-roots = "0.26"

[patch.crates-io]
compute-heavy-future-executor = { path = "../compute-heavy-future-executor" }
5 changes: 4 additions & 1 deletion src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,8 @@ fn poll_handle_early_data<IO>(
where
IO: AsyncRead + AsyncWrite + Unpin,
{
use crate::common::PacketProcessingMode;

if let TlsState::EarlyData(pos, data) = state {
use std::io::Write;

Expand Down Expand Up @@ -321,7 +323,8 @@ where

// complete handshake
while stream.session.is_handshaking() {
ready!(stream.handshake(cx, false))?;
// TODO: also model as using `vacation` executor
ready!(stream.handshake(cx, PacketProcessingMode::Sync))?;
}

// write early data (fallback)
Expand Down
6 changes: 4 additions & 2 deletions src/common/async_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use super::{Stream, TlsState};
/// Full result of sync closure
type SessionResult<S> = Result<S, (Option<S>, io::Error)>;
/// Executor result wrapping sync closure result
type SyncExecutorResult<S> = Result<SessionResult<S>, compute_heavy_future_executor::Error>;
type SyncExecutorResult<S> = Result<SessionResult<S>, vacation::Error>;
/// Future wrapping waiting on executor
type SessionFuture<S> = Box<dyn Future<Output = SyncExecutorResult<S>> + Unpin + Send>;

Expand Down Expand Up @@ -53,7 +53,9 @@ where
)),
};

let future = compute_heavy_future_executor::execute_sync(closure);
// TODO: if we ever start also delegating non-handshake byte processing, make this chance of blocking
// variable and set by caller
let future = vacation::execute_sync(closure, vacation::ChanceOfBlocking::High);

Self {
future: Box::new(Box::pin(future)),
Expand Down
19 changes: 9 additions & 10 deletions src/common/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ use rustls::server::AcceptedAlert;
use rustls::{ConnectionCommon, SideData};
use tokio::io::{AsyncRead, AsyncWrite};

use crate::common::{Stream, SyncWriteAdapter, TlsState};
use crate::common::{PacketProcessingMode, Stream, SyncWriteAdapter, TlsState};

#[cfg(feature = "compute-heavy-future-executor")]
#[cfg(feature = "vacation")]
use super::async_session::AsyncSession;

pub(crate) trait IoSession {
Expand All @@ -34,7 +34,7 @@ pub(crate) trait IoSession {

pub(crate) enum MidHandshake<IS: IoSession> {
Handshaking(IS),
#[cfg(feature = "compute-heavy-future-executor")]
#[cfg(feature = "vacation")]
AsyncSession(AsyncSession<IS>),
End,
SendAlert {
Expand All @@ -61,7 +61,7 @@ where

let mut stream = match mem::replace(this, MidHandshake::End) {
MidHandshake::Handshaking(stream) => stream,
#[cfg(feature = "compute-heavy-future-executor")]
#[cfg(feature = "vacation")]
MidHandshake::AsyncSession(mut async_session) => {
let pinned = Pin::new(&mut async_session);
let session_result = ready!(pinned.poll(cx));
Expand Down Expand Up @@ -94,7 +94,7 @@ where
( $e:expr ) => {
match $e {
Poll::Ready(Ok(_)) => (),
#[cfg(feature = "compute-heavy-future-executor")]
#[cfg(feature = "vacation")]
Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::WouldBlock => {
// TODO: downcast to decide on closure, for now we only do this for
// process_packets
Expand Down Expand Up @@ -132,12 +132,11 @@ where
};
}


while tls_stream.session.is_handshaking() {
#[cfg(feature = "compute-heavy-future-executor")]
try_poll!(tls_stream.handshake(cx, true));
#[cfg(not(feature = "compute-heavy-future-executor"))]
try_poll!(tls_stream.handshake(cx, false));
#[cfg(feature = "vacation")]
try_poll!(tls_stream.handshake(cx, PacketProcessingMode::Async));
#[cfg(not(feature = "vacation"))]
try_poll!(tls_stream.handshake(cx, PacketProcessingMode::Sync));
}

try_poll!(Pin::new(&mut tls_stream).poll_flush(cx));
Expand Down
30 changes: 23 additions & 7 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::task::{Context, Poll};
use rustls::{ConnectionCommon, SideData};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

#[cfg(feature = "compute-heavy-future-executor")]
#[cfg(feature = "vacation")]
mod async_session;
mod handshake;
pub(crate) use handshake::{IoSession, MidHandshake};
Expand All @@ -21,6 +21,14 @@ pub enum TlsState {
FullyShutdown,
}

/// Whether to delegate the call to the `vacation` executor,
/// only kicks in if `vacation` feature is enabled
#[derive(Debug, PartialEq, Clone, Copy)]
pub enum PacketProcessingMode {
Async,
Sync,
}

impl TlsState {
#[inline]
pub fn shutdown_read(&mut self) {
Expand Down Expand Up @@ -92,7 +100,11 @@ where
}

#[allow(unused_variables)]
pub fn read_io(&mut self, cx: &mut Context, process_packets_async: bool) -> Poll<io::Result<usize>> {
pub fn read_io(
&mut self,
cx: &mut Context,
packet_processing_mode: PacketProcessingMode,
) -> Poll<io::Result<usize>> {
let mut reader = SyncReadAdapter { io: self.io, cx };

let n: usize = match self.session.read_tls(&mut reader) {
Expand All @@ -101,8 +113,8 @@ where
Err(err) => return Poll::Ready(Err(err)),
};

#[cfg(feature = "compute-heavy-future-executor")]
if process_packets_async {
#[cfg(feature = "vacation")]
if packet_processing_mode == PacketProcessingMode::Async {
// TODO: stop modeling errors as IO, use enum on types of async session processing
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WouldBlock,
Expand Down Expand Up @@ -131,7 +143,11 @@ where
}
}

pub fn handshake(&mut self, cx: &mut Context, process_packets_async: bool) -> Poll<io::Result<(usize, usize)>> {
pub fn handshake(
&mut self,
cx: &mut Context,
packet_processing_mode: PacketProcessingMode,
) -> Poll<io::Result<(usize, usize)>> {
let mut wrlen = 0;
let mut rdlen = 0;

Expand Down Expand Up @@ -164,7 +180,7 @@ where
}

while !self.eof && self.session.wants_read() {
match self.read_io(cx, process_packets_async) {
match self.read_io(cx, packet_processing_mode) {
Poll::Ready(Ok(0)) => self.eof = true,
Poll::Ready(Ok(n)) => rdlen += n,
Poll::Pending => {
Expand Down Expand Up @@ -208,7 +224,7 @@ where

// read a packet
while !self.eof && self.session.wants_read() {
match self.read_io(cx, false) {
match self.read_io(cx, PacketProcessingMode::Sync) {
Poll::Ready(Ok(0)) => {
break;
}
Expand Down
43 changes: 23 additions & 20 deletions src/common/test_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use rustls::pki_types::ServerName;
use rustls::{ClientConnection, Connection, ServerConnection};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};

use super::Stream;
use super::{PacketProcessingMode, Stream};

struct Good<'a>(&'a mut Connection);

Expand Down Expand Up @@ -229,12 +229,13 @@ async fn stream_handshake() -> io::Result<()> {
{
let mut good = Good(&mut server);
let mut stream = Stream::new(&mut good, &mut client);
let (r, w) = poll_fn(|cx| stream.handshake(cx, false)).await?;
let (r, w) = poll_fn(|cx| stream.handshake(cx, PacketProcessingMode::Sync)).await?;

assert!(r > 0);
assert!(w > 0);

poll_fn(|cx: &mut Context<'_>| stream.handshake(cx, false)).await?; // finish server handshake
poll_fn(|cx: &mut Context<'_>| stream.handshake(cx, PacketProcessingMode::Sync)).await?;
// finish server handshake
}

assert!(!server.is_handshaking());
Expand All @@ -253,12 +254,12 @@ async fn stream_buffered_handshake() -> io::Result<()> {
{
let mut good = BufWriter::new(Good(&mut server));
let mut stream = Stream::new(&mut good, &mut client);
let (r, w) = poll_fn(|cx| stream.handshake(cx, false)).await?;
let (r, w) = poll_fn(|cx| stream.handshake(cx, PacketProcessingMode::Sync)).await?;

assert!(r > 0);
assert!(w > 0);

poll_fn(|cx| stream.handshake(cx, false)).await?; // finish server handshake
poll_fn(|cx| stream.handshake(cx, PacketProcessingMode::Sync)).await?; // finish server handshake
}

assert!(!server.is_handshaking());
Expand All @@ -275,7 +276,7 @@ async fn stream_handshake_eof() -> io::Result<()> {
let mut stream = Stream::new(&mut bad, &mut client);

let mut cx = Context::from_waker(noop_waker_ref());
let r = stream.handshake(&mut cx, false);
let r = stream.handshake(&mut cx, PacketProcessingMode::Sync);
assert_eq!(
r.map_err(|err| err.kind()),
Poll::Ready(Err(io::ErrorKind::UnexpectedEof))
Expand All @@ -292,7 +293,7 @@ async fn stream_handshake_write_eof() -> io::Result<()> {
let mut stream = Stream::new(&mut io, &mut client);

let mut cx = Context::from_waker(noop_waker_ref());
let r = stream.handshake(&mut cx, false);
let r = stream.handshake(&mut cx, PacketProcessingMode::Sync);
assert_eq!(
r.map_err(|err| err.kind()),
Poll::Ready(Err(io::ErrorKind::WriteZero))
Expand All @@ -310,7 +311,7 @@ async fn stream_handshake_regression_issues_77() -> io::Result<()> {
let mut stream = Stream::new(&mut bad, &mut client);

let mut cx = Context::from_waker(noop_waker_ref());
let r = stream.handshake(&mut cx, false);
let r = stream.handshake(&mut cx, PacketProcessingMode::Sync);
assert_eq!(
r.map_err(|err| err.kind()),
Poll::Ready(Err(io::ErrorKind::InvalidData))
Expand Down Expand Up @@ -366,31 +367,33 @@ async fn async_process_packets() -> io::Result<()> {
let mut stream = Stream::new(&mut good, &mut client);

// if feature is enabled, we expect a blocking response on process packets throughout the handshake,
#[cfg(feature = "compute-heavy-future-executor")]
{ let result = poll_fn(|cx| stream.handshake(cx, true)).await;
#[cfg(feature = "vacation")]
{
let result = poll_fn(|cx| stream.handshake(cx, PacketProcessingMode::Async)).await;

assert_eq!(
result.err().map(|e| e.kind()),
Some(io::ErrorKind::WouldBlock)
);

// finish the handshake without delegating to async session
poll_fn(|cx| stream.handshake(cx, false)).await?; // client handshake
poll_fn(|cx: &mut Context<'_>| stream.handshake(cx, true)).await?; // server handshake
poll_fn(|cx| stream.handshake(cx, PacketProcessingMode::Sync)).await?; // client handshake
poll_fn(|cx: &mut Context<'_>| stream.handshake(cx, PacketProcessingMode::Sync)).await?;
// server handshake
}

// if feature is disabled, we expect normal handling
#[cfg(not(feature = "compute-heavy-future-executor"))]
// if feature is disabled, we expect normal handling even if async is passed in
#[cfg(not(feature = "vacation"))]
{
{
let (r, w) = poll_fn(|cx| stream.handshake(cx, true)).await?; // client handshake
let (r, w) = poll_fn(|cx| stream.handshake(cx, PacketProcessingMode::Async)).await?; // client handshake

assert!(r > 0);
assert!(w > 0);

poll_fn(|cx| stream.handshake(cx, true)).await?; // server handshake

poll_fn(|cx| stream.handshake(cx, PacketProcessingMode::Async)).await?;
// server handshake
}

}

// once handshake is done, there is no longer blocking sending data over the stream
Expand Down Expand Up @@ -426,7 +429,7 @@ fn do_handshake(
let mut stream = Stream::new(&mut good, client);

while stream.session.is_handshaking() {
ready!(stream.handshake(cx, false))?;
ready!(stream.handshake(cx, PacketProcessingMode::Sync))?;
}

while stream.session.wants_write() {
Expand Down
Loading

0 comments on commit 09b9c77

Please sign in to comment.