Skip to content

Commit

Permalink
RUST-2109 Fix comparison of IPv6 addresses when updating the topology (
Browse files Browse the repository at this point in the history
  • Loading branch information
isabelatkinson authored Dec 12, 2024
1 parent bcff155 commit b1490b5
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 38 deletions.
18 changes: 11 additions & 7 deletions src/sdam/description/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,15 +192,15 @@ impl PartialEq for ServerDescription {
}

impl ServerDescription {
pub(crate) fn new(address: ServerAddress) -> Self {
pub(crate) fn new(address: &ServerAddress) -> Self {
Self {
address: match address {
ServerAddress::Tcp { host, port } => ServerAddress::Tcp {
host: host.to_lowercase(),
port,
port: *port,
},
#[cfg(unix)]
ServerAddress::Unix { path } => ServerAddress::Unix { path },
ServerAddress::Unix { path } => ServerAddress::Unix { path: path.clone() },
},
server_type: Default::default(),
last_update_time: None,
Expand All @@ -214,7 +214,7 @@ impl ServerDescription {
mut reply: HelloReply,
average_rtt: Duration,
) -> Self {
let mut description = Self::new(address);
let mut description = Self::new(&address);
description.average_round_trip_time = Some(average_rtt);
description.last_update_time = Some(DateTime::now());

Expand Down Expand Up @@ -259,7 +259,7 @@ impl ServerDescription {
}

pub(crate) fn new_from_error(address: ServerAddress, error: Error) -> Self {
let mut description = Self::new(address);
let mut description = Self::new(&address);
description.last_update_time = Some(DateTime::now());
description.average_round_trip_time = None;
description.reply = Err(error);
Expand Down Expand Up @@ -310,7 +310,7 @@ impl ServerDescription {
Ok(set_name)
}

pub(crate) fn known_hosts(&self) -> Result<impl Iterator<Item = &String>> {
pub(crate) fn known_hosts(&self) -> Result<Vec<ServerAddress>> {
let known_hosts = self
.reply
.as_ref()
Expand All @@ -328,7 +328,11 @@ impl ServerDescription {
.chain(arbiters.into_iter().flatten())
});

Ok(known_hosts.into_iter().flatten())
known_hosts
.into_iter()
.flatten()
.map(ServerAddress::parse)
.collect()
}

pub(crate) fn invalid_me(&self) -> Result<bool> {
Expand Down
40 changes: 14 additions & 26 deletions src/sdam/description/topology.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ impl TopologyDescription {
};

for address in options.hosts.iter() {
let description = ServerDescription::new(address.clone());
let description = ServerDescription::new(address);
self.servers.insert(address.to_owned(), description);
}

Expand Down Expand Up @@ -387,7 +387,7 @@ impl TopologyDescription {
let mut new = vec![];
for host in hosts {
if !self.servers.contains_key(&host) {
new.push((host.clone(), ServerDescription::new(host)));
new.push((host.clone(), ServerDescription::new(&host)));
}
}
if let Some(max) = self.srv_max_hosts {
Expand Down Expand Up @@ -599,7 +599,7 @@ impl TopologyDescription {
return Ok(());
}

self.add_new_servers(server_description.known_hosts()?)?;
self.add_new_servers(server_description.known_hosts()?);

if server_description.invalid_me()? {
self.servers.remove(&server_description.address);
Expand Down Expand Up @@ -655,7 +655,7 @@ impl TopologyDescription {
{
self.servers.insert(
server_description.address.clone(),
ServerDescription::new(server_description.address),
ServerDescription::new(&server_description.address),
);
self.record_primary_state();
return Ok(());
Expand Down Expand Up @@ -688,16 +688,16 @@ impl TopologyDescription {
}

if let ServerType::RsPrimary = self.servers.get(&address).unwrap().server_type {
self.servers
.insert(address.clone(), ServerDescription::new(address));
let description = ServerDescription::new(&address);
self.servers.insert(address, description);
}
}

self.add_new_servers(server_description.known_hosts()?)?;
let known_hosts: HashSet<_> = server_description.known_hosts()?.collect();
let known_hosts = server_description.known_hosts()?;
self.add_new_servers(known_hosts.clone());

for address in addresses {
if !known_hosts.contains(&address.to_string()) {
if !known_hosts.contains(&address) {
self.servers.remove(&address);
}
}
Expand All @@ -724,23 +724,11 @@ impl TopologyDescription {
}

/// Create a new ServerDescription for each address and add it to the topology.
fn add_new_servers<'a>(&mut self, servers: impl Iterator<Item = &'a String>) -> Result<()> {
let servers: Result<Vec<_>> = servers.map(ServerAddress::parse).collect();

self.add_new_servers_from_addresses(servers?.iter());
Ok(())
}

/// Create a new ServerDescription for each address and add it to the topology.
fn add_new_servers_from_addresses<'a>(
&mut self,
servers: impl Iterator<Item = &'a ServerAddress>,
) {
for server in servers {
if !self.servers.contains_key(server) {
self.servers
.insert(server.clone(), ServerDescription::new(server.clone()));
}
fn add_new_servers(&mut self, addresses: impl IntoIterator<Item = ServerAddress>) {
for address in addresses {
self.servers
.entry(address.clone())
.or_insert_with(|| ServerDescription::new(&address));
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/sdam/description/topology/server_selection/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ impl TestServerDescription {
reply,
avg_rtt_ms.map(f64_ms_as_duration).unwrap(),
),
None => ServerDescription::new(server_address),
None => ServerDescription::new(&server_address),
};
server_desc.last_update_time = self
.last_update_time
Expand Down
6 changes: 4 additions & 2 deletions src/sdam/topology.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ impl TopologyWorker {
self.update_topology(new_description).await;

if self.options.load_balanced == Some(true) {
let base = ServerDescription::new(self.options.hosts[0].clone());
let base = ServerDescription::new(&self.options.hosts[0]);
self.update_server(ServerDescription {
server_type: ServerType::LoadBalancer,
average_round_trip_time: None,
Expand Down Expand Up @@ -374,7 +374,9 @@ impl TopologyWorker {
UpdateMessage::SyncHosts(hosts) => {
self.sync_hosts(hosts).await
}
UpdateMessage::ServerUpdate(sd) => self.update_server(*sd).await,
UpdateMessage::ServerUpdate(sd) => {
self.update_server(*sd).await
}
UpdateMessage::MonitorError { address, error } => {
self.handle_monitor_error(address, error).await
}
Expand Down
43 changes: 42 additions & 1 deletion src/test/client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{borrow::Cow, collections::HashMap, future::IntoFuture, time::Duration};
use std::{borrow::Cow, collections::HashMap, future::IntoFuture, net::Ipv6Addr, time::Duration};

use bson::Document;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -982,3 +982,44 @@ async fn end_sessions_on_shutdown() {
client2.into_client().shutdown().await;
assert_eq!(get_end_session_event_count(&mut event_stream).await, 0);
}

#[tokio::test]
async fn ipv6_connect() {
let ipv6_localhost = Ipv6Addr::LOCALHOST.to_string();

let client = Client::for_test().await;
// The hello command returns the hostname as "localhost". However, whatsmyuri returns an
// IP-literal, which allows us to detect whether we can re-construct the client with an IPv6
// address.
let is_ipv6_localhost = client
.database("admin")
.run_command(doc! { "whatsmyuri": 1 })
.await
.ok()
.and_then(|response| {
response
.get_str("you")
.ok()
.map(|you| you.contains(&ipv6_localhost))
})
.unwrap_or(false);
if !is_ipv6_localhost {
log_uncaptured("skipping ipv6_connect due to non-ipv6-localhost configuration");
return;
}

let mut options = get_client_options().await.clone();
for address in options.hosts.iter_mut() {
if let ServerAddress::Tcp { host, .. } = address {
*host = ipv6_localhost.clone();
}
}
let client = Client::with_options(options).unwrap();

let result = client
.database("admin")
.run_command(doc! { "ping": 1 })
.await
.unwrap();
assert_eq!(result.get_f64("ok"), Ok(1.0));
}
2 changes: 1 addition & 1 deletion src/test/spec/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ fn topology_description_tracing_representation() {
let mut servers = HashMap::new();
servers.insert(
ServerAddress::default(),
ServerDescription::new(ServerAddress::default()),
ServerDescription::new(&ServerAddress::default()),
);

let oid = bson::oid::ObjectId::new();
Expand Down

0 comments on commit b1490b5

Please sign in to comment.