Skip to content

Commit

Permalink
preferred address
Browse files Browse the repository at this point in the history
  • Loading branch information
devsnek authored and Ralith committed Apr 13, 2024
1 parent 88f48b0 commit af2f5eb
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 26 deletions.
30 changes: 29 additions & 1 deletion quinn-proto/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
use std::{fmt, num::TryFromIntError, sync::Arc, time::Duration};
use std::{
fmt,
net::{SocketAddrV4, SocketAddrV6},
num::TryFromIntError,
sync::Arc,
time::Duration,
};

use thiserror::Error;

Expand Down Expand Up @@ -765,6 +771,9 @@ pub struct ServerConfig {
/// Improves behavior for clients that move between different internet connections or suffer NAT
/// rebinding. Enabled by default.
pub(crate) migration: bool,

pub(crate) preferred_address_v4: Option<SocketAddrV4>,
pub(crate) preferred_address_v6: Option<SocketAddrV6>,
}

impl ServerConfig {
Expand All @@ -781,6 +790,9 @@ impl ServerConfig {
retry_token_lifetime: Duration::from_secs(15),

migration: true,

preferred_address_v4: None,
preferred_address_v6: None,
}
}

Expand Down Expand Up @@ -810,6 +822,20 @@ impl ServerConfig {
self.migration = value;
self
}

/// The preferred IPv4 address that will be communicated to clients during handshaking.
/// If the client is able to reach this address, it will switch to it.
pub fn preferred_address_v4(&mut self, address: Option<SocketAddrV4>) -> &mut Self {
self.preferred_address_v4 = address;
self
}

/// The preferred IPv6 address that will be communicated to clients during handshaking.
/// If the client is able to reach this address, it will switch to it.
pub fn preferred_address_v6(&mut self, address: Option<SocketAddrV6>) -> &mut Self {
self.preferred_address_v6 = address;
self
}
}

#[cfg(feature = "rustls")]
Expand Down Expand Up @@ -849,6 +875,8 @@ impl fmt::Debug for ServerConfig {
.field("token_key", &"[ elided ]")
.field("retry_token_lifetime", &self.retry_token_lifetime)
.field("migration", &self.migration)
.field("preferred_address_v4", &self.preferred_address_v4)
.field("preferred_address_v6", &self.preferred_address_v6)
.finish()
}
}
Expand Down
21 changes: 15 additions & 6 deletions quinn-proto/src/connection/cid_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,30 @@ pub(super) struct CidState {
}

impl CidState {
pub(crate) fn new(cid_len: usize, cid_lifetime: Option<Duration>, now: Instant) -> Self {
pub(crate) fn new(
cid_len: usize,
cid_lifetime: Option<Duration>,
now: Instant,
issued: u64,
) -> Self {
let mut active_seq = FxHashSet::default();
// Add sequence number of CID used in handshaking into tracking set
active_seq.insert(0);
// Add sequence number of CIDs used in handshaking into tracking set
for seq in 0..issued {
active_seq.insert(seq);
}
let mut this = Self {
retire_timestamp: VecDeque::new(),
issued: 1, // One CID is already supplied during handshaking
issued,
active_seq,
prev_retire_seq: 0,
retire_seq: 0,
cid_len,
cid_lifetime,
};
// Track lifetime of cid used in handshaking
this.track_lifetime(0, now);
// Track lifetime of CIDs used in handshaking
for seq in 0..issued {
this.track_lifetime(seq, now);
}
this
}

Expand Down
8 changes: 7 additions & 1 deletion quinn-proto/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ impl Connection {
init_cid: ConnectionId,
loc_cid: ConnectionId,
rem_cid: ConnectionId,
pref_addr_cid: Option<ConnectionId>,
remote: SocketAddr,
local_ip: Option<IpAddr>,
crypto: Box<dyn crypto::Session>,
Expand Down Expand Up @@ -275,7 +276,12 @@ impl Connection {
crypto,
handshake_cid: loc_cid,
rem_handshake_cid: rem_cid,
local_cid_state: CidState::new(cid_gen.cid_len(), cid_gen.cid_lifetime(), now),
local_cid_state: CidState::new(
cid_gen.cid_len(),
cid_gen.cid_lifetime(),
now,
if pref_addr_cid.is_some() { 2 } else { 1 },
),
path: PathData::new(
remote,
config.initial_rtt,
Expand Down
43 changes: 37 additions & 6 deletions quinn-proto/src/endpoint.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{
collections::{hash_map, HashMap},
convert::TryFrom,
fmt, iter,
fmt,
net::{IpAddr, SocketAddr},
ops::{Index, IndexMut},
sync::Arc,
Expand Down Expand Up @@ -30,7 +30,7 @@ use crate::{
ConnectionEvent, ConnectionEventInner, ConnectionId, EcnCodepoint, EndpointEvent,
EndpointEventInner, IssuedCid,
},
transport_parameters::TransportParameters,
transport_parameters::{PreferredAddress, TransportParameters},
ResetToken, RetryToken, Side, Transmit, TransportConfig, TransportError, INITIAL_MTU,
MAX_CID_SIZE, MIN_INITIAL_SIZE, RESET_TOKEN_SIZE,
};
Expand Down Expand Up @@ -390,6 +390,7 @@ impl Endpoint {
remote_id,
loc_cid,
remote_id,
None,
FourTuple {
remote,
local_ip: None,
Expand All @@ -413,8 +414,8 @@ impl Endpoint {
for _ in 0..num {
let id = self.new_cid(ch);
let meta = &mut self.connections[ch];
meta.cids_issued += 1;
let sequence = meta.cids_issued;
meta.cids_issued += 1;
meta.loc_cids.insert(sequence, id);
ids.push(IssuedCid {
sequence,
Expand Down Expand Up @@ -506,6 +507,7 @@ impl Endpoint {
mut incoming: Incoming,
now: Instant,
buf: &mut BytesMut,
server_config: Option<Arc<ServerConfig>>,
) -> Result<(ConnectionHandle, Connection), AcceptError> {
let packet_number = incoming.packet.header.number.expand(0);
let InitialHeader {
Expand All @@ -530,7 +532,8 @@ impl Endpoint {
});
}

let server_config = self.server_config.as_ref().unwrap().clone();
let server_config =
server_config.unwrap_or_else(|| self.server_config.as_ref().unwrap().clone());

if incoming
.crypto
Expand Down Expand Up @@ -562,6 +565,19 @@ impl Endpoint {
params.stateless_reset_token = Some(ResetToken::new(&*self.config.reset_key, &loc_cid));
params.original_dst_cid = Some(incoming.orig_dst_cid);
params.retry_src_cid = incoming.retry_src_cid;
let mut pref_addr_cid = None;
if server_config.preferred_address_v4.is_some()
|| server_config.preferred_address_v6.is_some()
{
let cid = self.new_cid(ch);
pref_addr_cid = Some(cid);
params.preferred_address = Some(PreferredAddress {
address_v4: server_config.preferred_address_v4,
address_v6: server_config.preferred_address_v6,
connection_id: cid,
stateless_reset_token: ResetToken::new(&*self.config.reset_key, &cid),
});
}

let tls = server_config.crypto.clone().start_session(version, &params);
let transport_config = server_config.transport.clone();
Expand All @@ -571,6 +587,7 @@ impl Endpoint {
dst_cid,
loc_cid,
src_cid,
pref_addr_cid,
incoming.addresses,
now,
tls,
Expand Down Expand Up @@ -718,6 +735,7 @@ impl Endpoint {
init_cid: ConnectionId,
loc_cid: ConnectionId,
rem_cid: ConnectionId,
pref_addr_cid: Option<ConnectionId>,
addresses: FourTuple,
now: Instant,
tls: Box<dyn crypto::Session>,
Expand All @@ -734,6 +752,7 @@ impl Endpoint {
init_cid,
loc_cid,
rem_cid,
pref_addr_cid,
addresses.remote,
addresses.local_ip,
tls,
Expand All @@ -745,10 +764,22 @@ impl Endpoint {
path_validated,
);

let mut cids_issued = 0;
let mut loc_cids = FxHashMap::default();

loc_cids.insert(cids_issued, loc_cid);
cids_issued += 1;

if let Some(cid) = pref_addr_cid {
debug_assert_eq!(cids_issued, 1, "preferred address cid seq must be 1");
loc_cids.insert(cids_issued, cid);
cids_issued += 1;
}

let id = self.connections.insert(ConnectionMeta {
init_cid,
cids_issued: 0,
loc_cids: iter::once((0, loc_cid)).collect(),
cids_issued,
loc_cids,
addresses,
reset_token: None,
});
Expand Down
5 changes: 4 additions & 1 deletion quinn-proto/src/tests/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,10 @@ impl TestEndpoint {
now: Instant,
) -> Result<ConnectionHandle, ConnectionError> {
let mut buf = BytesMut::new();
match self.endpoint.accept(incoming, now, &mut buf) {
match self
.endpoint
.accept(incoming, now, &mut buf, Default::default())
{
Ok((ch, conn)) => {
self.connections.insert(ch, conn);
self.accepted = Some(Ok(ch));
Expand Down
2 changes: 1 addition & 1 deletion quinn-proto/src/transport_parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ mod test {
max_udp_payload_size: 1200u32.into(),
preferred_address: Some(PreferredAddress {
address_v4: Some(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 42)),
address_v6: None,
address_v6: Some(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 24, 0, 0)),
connection_id: ConnectionId::new(&[0x42]),
stateless_reset_token: [0xab; RESET_TOKEN_SIZE].into(),
}),
Expand Down
16 changes: 11 additions & 5 deletions quinn/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,13 +359,19 @@ pub(crate) struct EndpointInner {
}

impl EndpointInner {
pub(crate) fn accept(&self, incoming: proto::Incoming) -> Result<Connecting, ConnectionError> {
pub(crate) fn accept(
&self,
incoming: proto::Incoming,
server_config: Option<Arc<ServerConfig>>,
) -> Result<Connecting, ConnectionError> {
let mut state = self.state.lock().unwrap();
let mut response_buffer = BytesMut::new();
match state
.inner
.accept(incoming, Instant::now(), &mut response_buffer)
{
match state.inner.accept(
incoming,
Instant::now(),
&mut response_buffer,
server_config,
) {
Ok((handle, conn)) => {
let socket = state.socket.clone();
let runtime = state.runtime.clone();
Expand Down
12 changes: 8 additions & 4 deletions quinn/src/incoming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ use std::{
future::{Future, IntoFuture},
net::{IpAddr, SocketAddr},
pin::Pin,
sync::Arc,
task::{Context, Poll},
};

use proto::ConnectionError;
use proto::{ConnectionError, ServerConfig};
use thiserror::Error;

use crate::{
Expand All @@ -23,9 +24,12 @@ impl Incoming {
}

/// Attempt to accept this incoming connection (an error may still occur)
pub fn accept(mut self) -> Result<Connecting, ConnectionError> {
pub fn accept(
mut self,
server_config: Option<Arc<ServerConfig>>,
) -> Result<Connecting, ConnectionError> {
let state = self.0.take().unwrap();
state.endpoint.accept(state.inner)
state.endpoint.accept(state.inner, server_config)
}

/// Reject this incoming connection attempt
Expand Down Expand Up @@ -120,6 +124,6 @@ impl IntoFuture for Incoming {
type IntoFuture = IncomingFuture;

fn into_future(self) -> Self::IntoFuture {
IncomingFuture(self.accept())
IncomingFuture(self.accept(None))
}
}
2 changes: 1 addition & 1 deletion quinn/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ async fn zero_rtt() {
let endpoint2 = endpoint.clone();
tokio::spawn(async move {
for _ in 0..2 {
let incoming = endpoint2.accept().await.unwrap().accept().unwrap();
let incoming = endpoint2.accept().await.unwrap().accept(None).unwrap();
let (connection, established) = incoming.into_0rtt().unwrap_or_else(|_| unreachable!());
let c = connection.clone();
tokio::spawn(async move {
Expand Down

0 comments on commit af2f5eb

Please sign in to comment.