Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(WIP) feat: add support for processing handshake packets async via vacation #99

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ jobs:
run: |
cargo test --locked --all
cargo test --locked -p tokio-rustls --features early-data --test early-data
# we run all test suites against this feature since it shifts the default behavior globally
cargo test --locked -p tokio-rustls --features compute-heavy-future-executor

lints:
name: Lints
Expand Down
25 changes: 23 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,18 @@ rust-version = "1.70"
exclude = ["/.github", "/examples", "/scripts"]

[dependencies]
# implicitly enables the tokio feature for compute-heavy-future-executor
# (defaulting to strategy of spawn_blocking w/ concurrency conctorl)
compute-heavy-future-executor = { version = "0.1", optional = true}
pin-project-lite = { version = "0.2.15", optional = true }
rustls = { version = "0.23.15", default-features = false, features = ["std"] }
tokio = "1.0"

[features]
default = ["logging", "tls12", "aws_lc_rs"]
aws_lc_rs = ["rustls/aws_lc_rs"]
aws-lc-rs = ["aws_lc_rs"] # Alias because Cargo features commonly use `-`
compute-heavy-future-executor = ["dep:compute-heavy-future-executor", "pin-project-lite"]
early-data = []
fips = ["rustls/fips"]
logging = ["rustls/logging"]
Expand All @@ -33,3 +38,6 @@ lazy_static = "1.1"
rcgen = { version = "0.13", features = ["pem"] }
tokio = { version = "1.0", features = ["full"] }
webpki-roots = "0.26"

[patch.crates-io]
compute-heavy-future-executor = { path = "../compute-heavy-future-executor" }
jlizen marked this conversation as resolved.
Show resolved Hide resolved
36 changes: 35 additions & 1 deletion src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,15 @@ where
self.get_ref().0.as_raw_socket()
}
}
#[cfg(feature = "early-data")]
type TlsStreamExtras = Option<Waker>;
#[cfg(not(feature = "early-data"))]
type TlsStreamExtras = ();

impl<IO> IoSession for TlsStream<IO> {
type Io = IO;
type Session = ClientConnection;
type Extras = TlsStreamExtras;

#[inline]
fn skip_handshake(&self) -> bool {
Expand All @@ -80,6 +85,35 @@ impl<IO> IoSession for TlsStream<IO> {
fn into_io(self) -> Self::Io {
self.io
}

#[inline]
fn into_inner(self) -> (TlsState, Self::Io, Self::Session, Self::Extras) {
#[cfg(feature = "early-data")]
return (self.state, self.io, self.session, self.early_waker);

#[cfg(not(feature = "early-data"))]
(self.state, self.io, self.session, ())
}

#[inline]
#[allow(unused_variables)]
fn from_inner(
state: TlsState,
io: Self::Io,
session: Self::Session,
extras: Self::Extras,
) -> Self {
#[cfg(feature = "early-data")]
return Self {
io,
session,
state,
early_waker: extras,
};

#[cfg(not(feature = "early-data"))]
Self { io, session, state }
}
}

impl<IO> AsyncRead for TlsStream<IO>
Expand Down Expand Up @@ -287,7 +321,7 @@ where

// complete handshake
while stream.session.is_handshaking() {
ready!(stream.handshake(cx))?;
ready!(stream.handshake(cx, false))?;
}

// write early data (fallback)
Expand Down
128 changes: 128 additions & 0 deletions src/common/async_session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
use std::{
future::Future,
io,
ops::{Deref, DerefMut},
pin::Pin,
task::{Context, Poll},
};

use pin_project_lite::pin_project;
use rustls::{ConnectionCommon, SideData};
use tokio::io::{AsyncRead, AsyncWrite};

use crate::common::IoSession;

use super::{Stream, TlsState};

/// Full result of sync closure
type SessionResult<S> = Result<S, (Option<S>, io::Error)>;
/// Executor result wrapping sync closure result
type SyncExecutorResult<S> = Result<SessionResult<S>, compute_heavy_future_executor::Error>;
/// Future wrapping waiting on executor
type SessionFuture<S> = Box<dyn Future<Output = SyncExecutorResult<S>> + Unpin + Send>;

pin_project! {
/// Session is off doing compute-heavy sync work, such as initializing the session or processing handshake packets.
/// Might be on another thread / external threadpool.
///
/// This future sleeps on it in current worker thread until it completes.
pub(crate) struct AsyncSession<IS: IoSession> {
#[pin]
future: SessionFuture<IS::Session>,
io: IS::Io,
state: TlsState,
extras: IS::Extras,
}
}

impl<IS, SD> AsyncSession<IS>
where
IS: IoSession + Unpin,
IS::Io: AsyncRead + AsyncWrite + Unpin,
IS::Session: DerefMut + Deref<Target = ConnectionCommon<SD>> + Unpin + Send + 'static,
SD: SideData,
{
pub(crate) fn process_packets(stream: IS) -> Self {
let (state, io, mut session, extras) = stream.into_inner();

let closure = move || match session.process_new_packets() {
Ok(_) => Ok(session),
Err(err) => Err((
Some(session),
io::Error::new(io::ErrorKind::InvalidData, err),
)),
};

let future = compute_heavy_future_executor::execute_sync(closure);

Self {
future: Box::new(Box::pin(future)),
io,
state,
extras,
}
}

pub(crate) fn into_stream(
mut self,
session_result: Result<IS::Session, (Option<IS::Session>, io::Error)>,
cx: &mut Context<'_>,
) -> Result<IS, (io::Error, IS::Io)> {
match session_result {
Ok(session) => Ok(IS::from_inner(self.state, self.io, session, self.extras)),
Err((Some(mut session), err)) => {
// In case we have an alert to send describing this error,
// try a last-gasp write -- but don't predate the primary
// error.
let mut tls_stream: Stream<'_, <IS as IoSession>::Io, <IS as IoSession>::Session> =
Stream::new(&mut self.io, &mut session).set_eof(!self.state.readable());
let _ = tls_stream.write_io(cx);

// still drop the tls session and return the io error only
Err((err, self.io))
}
Err((None, err)) => Err((err, self.io)),
}
}

#[inline]
pub fn get_ref(&self) -> &IS::Io {
&self.io
}

#[inline]
pub fn get_mut(&mut self) -> &mut IS::Io {
&mut self.io
}
}

impl<IS, SD> Future for AsyncSession<IS>
where
IS: IoSession + Unpin,
IS::Session: DerefMut + Deref<Target = ConnectionCommon<SD>> + Unpin + Send + 'static,
SD: SideData,
{
type Output = Result<IS::Session, (Option<IS::Session>, io::Error)>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();

match ready!(this.future.as_mut().poll(cx)) {
Ok(session_res) => match session_res {
Ok(res) => Poll::Ready(Ok(res)),
// return any session along with the error,
// so the caller can flush any remaining alerts in buffer to i/o
Err((session, err)) => Poll::Ready(Err((
session,
io::Error::new(io::ErrorKind::InvalidData, err),
))),
},
// We don't have a session to flush here because the executor ate it
// TODO: not all errors should be modeled as io
Err(executor_error) => Poll::Ready(Err((
None,
io::Error::new(io::ErrorKind::Other, executor_error),
))),
}
}
}
59 changes: 56 additions & 3 deletions src/common/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,32 @@ use tokio::io::{AsyncRead, AsyncWrite};

use crate::common::{Stream, SyncWriteAdapter, TlsState};

#[cfg(feature = "compute-heavy-future-executor")]
use super::async_session::AsyncSession;

pub(crate) trait IoSession {
type Io;
type Session;
type Extras;

fn skip_handshake(&self) -> bool;
fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session);
fn into_io(self) -> Self::Io;
#[allow(dead_code)]
fn into_inner(self) -> (TlsState, Self::Io, Self::Session, Self::Extras);
#[allow(dead_code)]
fn from_inner(
state: TlsState,
io: Self::Io,
session: Self::Session,
extras: Self::Extras,
) -> Self;
}

pub(crate) enum MidHandshake<IS: IoSession> {
Handshaking(IS),
#[cfg(feature = "compute-heavy-future-executor")]
AsyncSession(AsyncSession<IS>),
End,
SendAlert {
io: IS::Io,
Expand All @@ -32,12 +47,11 @@ pub(crate) enum MidHandshake<IS: IoSession> {
error: io::Error,
},
}

impl<IS, SD> Future for MidHandshake<IS>
where
IS: IoSession + Unpin,
IS::Io: AsyncRead + AsyncWrite + Unpin,
IS::Session: DerefMut + Deref<Target = ConnectionCommon<SD>> + Unpin,
IS::Session: DerefMut + Deref<Target = ConnectionCommon<SD>> + Unpin + Send + 'static,
SD: SideData,
{
type Output = Result<IS, (io::Error, IS::Io)>;
Expand All @@ -47,6 +61,12 @@ where

let mut stream = match mem::replace(this, MidHandshake::End) {
MidHandshake::Handshaking(stream) => stream,
#[cfg(feature = "compute-heavy-future-executor")]
MidHandshake::AsyncSession(mut async_session) => {
let pinned = Pin::new(&mut async_session);
let session_result = ready!(pinned.poll(cx));
async_session.into_stream(session_result, cx)?
}
MidHandshake::SendAlert {
mut io,
mut alert,
Expand Down Expand Up @@ -74,6 +94,35 @@ where
( $e:expr ) => {
match $e {
Poll::Ready(Ok(_)) => (),
#[cfg(feature = "compute-heavy-future-executor")]
Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::WouldBlock => {
// TODO: downcast to decide on closure, for now we only do this for
// process_packets

// decompose the stream and send the session to background executor
let mut async_session = AsyncSession::process_packets(stream);

let pinned = Pin::new(&mut async_session);
// poll once to kick off work
match pinned.poll(cx) {
// didn't need to sleep for async session
Poll::Ready(res) => {
let stream = async_session.into_stream(res, cx)?;
// rather than continuing processing here,
// we keep memory management simple and recompose
// our future for a fresh poll
*this = MidHandshake::Handshaking(stream);
// tell executor to immediately poll us again
cx.waker().wake_by_ref();
return Poll::Pending;
}
// task is sleeping until async session is complete
Poll::Pending => {
*this = MidHandshake::AsyncSession(async_session);
return Poll::Pending;
}
}
}
Poll::Ready(Err(err)) => return Poll::Ready(Err((err, stream.into_io()))),
Poll::Pending => {
*this = MidHandshake::Handshaking(stream);
Expand All @@ -83,8 +132,12 @@ where
};
}


while tls_stream.session.is_handshaking() {
try_poll!(tls_stream.handshake(cx));
#[cfg(feature = "compute-heavy-future-executor")]
try_poll!(tls_stream.handshake(cx, true));
#[cfg(not(feature = "compute-heavy-future-executor"))]
try_poll!(tls_stream.handshake(cx, false));
}

try_poll!(Pin::new(&mut tls_stream).poll_flush(cx));
Expand Down
Loading
Loading