Skip to content

Commit

Permalink
feat: add support for processing handshake packets async via compute-…
Browse files Browse the repository at this point in the history
…heavy-future-executor
  • Loading branch information
jlizen committed Dec 30, 2024
1 parent cd399ab commit 0df5320
Show file tree
Hide file tree
Showing 11 changed files with 507 additions and 21 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ jobs:
run: |
cargo test --locked --all
cargo test --locked -p tokio-rustls --features early-data --test early-data
# we run all test suites against this feature since it shifts the default behavior globally
cargo test --locked -p tokio-rustls --features compute-heavy-future-executor
lints:
name: Lints
Expand Down
25 changes: 23 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,18 @@ rust-version = "1.70"
exclude = ["/.github", "/examples", "/scripts"]

[dependencies]
# implicitly enables the tokio feature for compute-heavy-future-executor
# (defaulting to strategy of spawn_blocking w/ concurrency conctorl)
compute-heavy-future-executor = { version = "0.1", optional = true}
pin-project-lite = { version = "0.2.15", optional = true }
rustls = { version = "0.23.15", default-features = false, features = ["std"] }
tokio = "1.0"

[features]
default = ["logging", "tls12", "aws_lc_rs"]
aws_lc_rs = ["rustls/aws_lc_rs"]
aws-lc-rs = ["aws_lc_rs"] # Alias because Cargo features commonly use `-`
compute-heavy-future-executor = ["dep:compute-heavy-future-executor", "pin-project-lite"]
early-data = []
fips = ["rustls/fips"]
logging = ["rustls/logging"]
Expand All @@ -33,3 +38,6 @@ lazy_static = "1.1"
rcgen = { version = "0.13", features = ["pem"] }
tokio = { version = "1.0", features = ["full"] }
webpki-roots = "0.26"

[patch.crates-io]
compute-heavy-future-executor = { path = "../compute-heavy-future-executor" }
36 changes: 35 additions & 1 deletion src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,15 @@ where
self.get_ref().0.as_raw_socket()
}
}
#[cfg(feature = "early-data")]
type TlsStreamExtras = Option<Waker>;
#[cfg(not(feature = "early-data"))]
type TlsStreamExtras = ();

impl<IO> IoSession for TlsStream<IO> {
type Io = IO;
type Session = ClientConnection;
type Extras = TlsStreamExtras;

#[inline]
fn skip_handshake(&self) -> bool {
Expand All @@ -80,6 +85,35 @@ impl<IO> IoSession for TlsStream<IO> {
fn into_io(self) -> Self::Io {
self.io
}

#[inline]
fn into_inner(self) -> (TlsState, Self::Io, Self::Session, Self::Extras) {
#[cfg(feature = "early-data")]
return (self.state, self.io, self.session, self.early_waker);

#[cfg(not(feature = "early-data"))]
(self.state, self.io, self.session, ())
}

#[inline]
#[allow(unused_variables)]
fn from_inner(
state: TlsState,
io: Self::Io,
session: Self::Session,
extras: Self::Extras,
) -> Self {
#[cfg(feature = "early-data")]
return Self {
io,
session,
state,
early_waker: extras,
};

#[cfg(not(feature = "early-data"))]
Self { io, session, state }
}
}

impl<IO> AsyncRead for TlsStream<IO>
Expand Down Expand Up @@ -287,7 +321,7 @@ where

// complete handshake
while stream.session.is_handshaking() {
ready!(stream.handshake(cx))?;
ready!(stream.handshake(cx, false))?;
}

// write early data (fallback)
Expand Down
128 changes: 128 additions & 0 deletions src/common/async_session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
use std::{
future::Future,
io,
ops::{Deref, DerefMut},
pin::Pin,
task::{Context, Poll},
};

use pin_project_lite::pin_project;
use rustls::{ConnectionCommon, SideData};
use tokio::io::{AsyncRead, AsyncWrite};

use crate::common::IoSession;

use super::{Stream, TlsState};

/// Full result of sync closure
type SessionResult<S> = Result<S, (Option<S>, io::Error)>;
/// Executor result wrapping sync closure result
type SyncExecutorResult<S> = Result<SessionResult<S>, compute_heavy_future_executor::Error>;
/// Future wrapping waiting on executor
type SessionFuture<S> = Box<dyn Future<Output = SyncExecutorResult<S>> + Unpin + Send>;

pin_project! {
/// Session is off doing compute-heavy sync work, such as initializing the session or processing handshake packets.
/// Might be on another thread / external threadpool.
///
/// This future sleeps on it in current worker thread until it completes.
pub(crate) struct AsyncSession<IS: IoSession> {
#[pin]
future: SessionFuture<IS::Session>,
io: IS::Io,
state: TlsState,
extras: IS::Extras,
}
}

impl<IS, SD> AsyncSession<IS>
where
IS: IoSession + Unpin,
IS::Io: AsyncRead + AsyncWrite + Unpin,
IS::Session: DerefMut + Deref<Target = ConnectionCommon<SD>> + Unpin + Send + 'static,
SD: SideData,
{
pub(crate) fn process_packets(stream: IS) -> Self {
let (state, io, mut session, extras) = stream.into_inner();

let closure = move || match session.process_new_packets() {
Ok(_) => Ok(session),
Err(err) => Err((
Some(session),
io::Error::new(io::ErrorKind::InvalidData, err),
)),
};

let future = compute_heavy_future_executor::execute_sync(closure);

Self {
future: Box::new(Box::pin(future)),
io,
state,
extras,
}
}

pub(crate) fn into_stream(
mut self,
session_result: Result<IS::Session, (Option<IS::Session>, io::Error)>,
cx: &mut Context<'_>,
) -> Result<IS, (io::Error, IS::Io)> {
match session_result {
Ok(session) => Ok(IS::from_inner(self.state, self.io, session, self.extras)),
Err((Some(mut session), err)) => {
// In case we have an alert to send describing this error,
// try a last-gasp write -- but don't predate the primary
// error.
let mut tls_stream: Stream<'_, <IS as IoSession>::Io, <IS as IoSession>::Session> =
Stream::new(&mut self.io, &mut session).set_eof(!self.state.readable());
let _ = tls_stream.write_io(cx);

// still drop the tls session and return the io error only
Err((err, self.io))
}
Err((None, err)) => Err((err, self.io)),
}
}

#[inline]
pub fn get_ref(&self) -> &IS::Io {
&self.io
}

#[inline]
pub fn get_mut(&mut self) -> &mut IS::Io {
&mut self.io
}
}

impl<IS, SD> Future for AsyncSession<IS>
where
IS: IoSession + Unpin,
IS::Session: DerefMut + Deref<Target = ConnectionCommon<SD>> + Unpin + Send + 'static,
SD: SideData,
{
type Output = Result<IS::Session, (Option<IS::Session>, io::Error)>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();

match ready!(this.future.as_mut().poll(cx)) {
Ok(session_res) => match session_res {
Ok(res) => Poll::Ready(Ok(res)),
// return any session along with the error,
// so the caller can flush any remaining alerts in buffer to i/o
Err((session, err)) => Poll::Ready(Err((
session,
io::Error::new(io::ErrorKind::InvalidData, err),
))),
},
// We don't have a session to flush here because the executor ate it
// TODO: not all errors should be modeled as io
Err(executor_error) => Poll::Ready(Err((
None,
io::Error::new(io::ErrorKind::Other, executor_error),
))),
}
}
}
59 changes: 56 additions & 3 deletions src/common/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,32 @@ use tokio::io::{AsyncRead, AsyncWrite};

use crate::common::{Stream, SyncWriteAdapter, TlsState};

#[cfg(feature = "compute-heavy-future-executor")]
use super::async_session::AsyncSession;

pub(crate) trait IoSession {
type Io;
type Session;
type Extras;

fn skip_handshake(&self) -> bool;
fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session);
fn into_io(self) -> Self::Io;
#[allow(dead_code)]
fn into_inner(self) -> (TlsState, Self::Io, Self::Session, Self::Extras);
#[allow(dead_code)]
fn from_inner(
state: TlsState,
io: Self::Io,
session: Self::Session,
extras: Self::Extras,
) -> Self;
}

pub(crate) enum MidHandshake<IS: IoSession> {
Handshaking(IS),
#[cfg(feature = "compute-heavy-future-executor")]
AsyncSession(AsyncSession<IS>),
End,
SendAlert {
io: IS::Io,
Expand All @@ -32,12 +47,11 @@ pub(crate) enum MidHandshake<IS: IoSession> {
error: io::Error,
},
}

impl<IS, SD> Future for MidHandshake<IS>
where
IS: IoSession + Unpin,
IS::Io: AsyncRead + AsyncWrite + Unpin,
IS::Session: DerefMut + Deref<Target = ConnectionCommon<SD>> + Unpin,
IS::Session: DerefMut + Deref<Target = ConnectionCommon<SD>> + Unpin + Send + 'static,
SD: SideData,
{
type Output = Result<IS, (io::Error, IS::Io)>;
Expand All @@ -47,6 +61,12 @@ where

let mut stream = match mem::replace(this, MidHandshake::End) {
MidHandshake::Handshaking(stream) => stream,
#[cfg(feature = "compute-heavy-future-executor")]
MidHandshake::AsyncSession(mut async_session) => {
let pinned = Pin::new(&mut async_session);
let session_result = ready!(pinned.poll(cx));
async_session.into_stream(session_result, cx)?
}
MidHandshake::SendAlert {
mut io,
mut alert,
Expand Down Expand Up @@ -74,6 +94,35 @@ where
( $e:expr ) => {
match $e {
Poll::Ready(Ok(_)) => (),
#[cfg(feature = "compute-heavy-future-executor")]
Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::WouldBlock => {
// TODO: downcast to decide on closure, for now we only do this for
// process_packets

// decompose the stream and send the session to background executor
let mut async_session = AsyncSession::process_packets(stream);

let pinned = Pin::new(&mut async_session);
// poll once to kick off work
match pinned.poll(cx) {
// didn't need to sleep for async session
Poll::Ready(res) => {
let stream = async_session.into_stream(res, cx)?;
// rather than continuing processing here,
// we keep memory management simple and recompose
// our future for a fresh poll
*this = MidHandshake::Handshaking(stream);
// tell executor to immediately poll us again
cx.waker().wake_by_ref();
return Poll::Pending;
}
// task is sleeping until async session is complete
Poll::Pending => {
*this = MidHandshake::AsyncSession(async_session);
return Poll::Pending;
}
}
}
Poll::Ready(Err(err)) => return Poll::Ready(Err((err, stream.into_io()))),
Poll::Pending => {
*this = MidHandshake::Handshaking(stream);
Expand All @@ -83,8 +132,12 @@ where
};
}


while tls_stream.session.is_handshaking() {
try_poll!(tls_stream.handshake(cx));
#[cfg(feature = "compute-heavy-future-executor")]
try_poll!(tls_stream.handshake(cx, true));
#[cfg(not(feature = "compute-heavy-future-executor"))]
try_poll!(tls_stream.handshake(cx, false));
}

try_poll!(Pin::new(&mut tls_stream).poll_flush(cx));
Expand Down
Loading

0 comments on commit 0df5320

Please sign in to comment.