diff --git a/src/sdam/description/server.rs b/src/sdam/description/server.rs index b4ad79b2d..87f200589 100644 --- a/src/sdam/description/server.rs +++ b/src/sdam/description/server.rs @@ -192,15 +192,15 @@ impl PartialEq for ServerDescription { } impl ServerDescription { - pub(crate) fn new(address: ServerAddress) -> Self { + pub(crate) fn new(address: &ServerAddress) -> Self { Self { address: match address { ServerAddress::Tcp { host, port } => ServerAddress::Tcp { host: host.to_lowercase(), - port, + port: *port, }, #[cfg(unix)] - ServerAddress::Unix { path } => ServerAddress::Unix { path }, + ServerAddress::Unix { path } => ServerAddress::Unix { path: path.clone() }, }, server_type: Default::default(), last_update_time: None, @@ -214,7 +214,7 @@ impl ServerDescription { mut reply: HelloReply, average_rtt: Duration, ) -> Self { - let mut description = Self::new(address); + let mut description = Self::new(&address); description.average_round_trip_time = Some(average_rtt); description.last_update_time = Some(DateTime::now()); @@ -259,7 +259,7 @@ impl ServerDescription { } pub(crate) fn new_from_error(address: ServerAddress, error: Error) -> Self { - let mut description = Self::new(address); + let mut description = Self::new(&address); description.last_update_time = Some(DateTime::now()); description.average_round_trip_time = None; description.reply = Err(error); @@ -310,7 +310,7 @@ impl ServerDescription { Ok(set_name) } - pub(crate) fn known_hosts(&self) -> Result> { + pub(crate) fn known_hosts(&self) -> Result> { let known_hosts = self .reply .as_ref() @@ -328,7 +328,11 @@ impl ServerDescription { .chain(arbiters.into_iter().flatten()) }); - Ok(known_hosts.into_iter().flatten()) + known_hosts + .into_iter() + .flatten() + .map(ServerAddress::parse) + .collect() } pub(crate) fn invalid_me(&self) -> Result { diff --git a/src/sdam/description/topology.rs b/src/sdam/description/topology.rs index 7859054e2..23ea188ef 100644 --- a/src/sdam/description/topology.rs +++ b/src/sdam/description/topology.rs @@ -170,7 +170,7 @@ impl TopologyDescription { }; for address in options.hosts.iter() { - let description = ServerDescription::new(address.clone()); + let description = ServerDescription::new(address); self.servers.insert(address.to_owned(), description); } @@ -387,7 +387,7 @@ impl TopologyDescription { let mut new = vec![]; for host in hosts { if !self.servers.contains_key(&host) { - new.push((host.clone(), ServerDescription::new(host))); + new.push((host.clone(), ServerDescription::new(&host))); } } if let Some(max) = self.srv_max_hosts { @@ -599,7 +599,7 @@ impl TopologyDescription { return Ok(()); } - self.add_new_servers(server_description.known_hosts()?)?; + self.add_new_servers(server_description.known_hosts()?); if server_description.invalid_me()? { self.servers.remove(&server_description.address); @@ -655,7 +655,7 @@ impl TopologyDescription { { self.servers.insert( server_description.address.clone(), - ServerDescription::new(server_description.address), + ServerDescription::new(&server_description.address), ); self.record_primary_state(); return Ok(()); @@ -688,16 +688,16 @@ impl TopologyDescription { } if let ServerType::RsPrimary = self.servers.get(&address).unwrap().server_type { - self.servers - .insert(address.clone(), ServerDescription::new(address)); + let description = ServerDescription::new(&address); + self.servers.insert(address, description); } } - self.add_new_servers(server_description.known_hosts()?)?; - let known_hosts: HashSet<_> = server_description.known_hosts()?.collect(); + let known_hosts = server_description.known_hosts()?; + self.add_new_servers(known_hosts.clone()); for address in addresses { - if !known_hosts.contains(&address.to_string()) { + if !known_hosts.contains(&address) { self.servers.remove(&address); } } @@ -724,23 +724,11 @@ impl TopologyDescription { } /// Create a new ServerDescription for each address and add it to the topology. - fn add_new_servers<'a>(&mut self, servers: impl Iterator) -> Result<()> { - let servers: Result> = servers.map(ServerAddress::parse).collect(); - - self.add_new_servers_from_addresses(servers?.iter()); - Ok(()) - } - - /// Create a new ServerDescription for each address and add it to the topology. - fn add_new_servers_from_addresses<'a>( - &mut self, - servers: impl Iterator, - ) { - for server in servers { - if !self.servers.contains_key(server) { - self.servers - .insert(server.clone(), ServerDescription::new(server.clone())); - } + fn add_new_servers(&mut self, addresses: impl IntoIterator) { + for address in addresses { + self.servers + .entry(address.clone()) + .or_insert_with(|| ServerDescription::new(&address)); } } } diff --git a/src/sdam/description/topology/server_selection/test.rs b/src/sdam/description/topology/server_selection/test.rs index a12e306ca..91e6bb671 100644 --- a/src/sdam/description/topology/server_selection/test.rs +++ b/src/sdam/description/topology/server_selection/test.rs @@ -103,7 +103,7 @@ impl TestServerDescription { reply, avg_rtt_ms.map(f64_ms_as_duration).unwrap(), ), - None => ServerDescription::new(server_address), + None => ServerDescription::new(&server_address), }; server_desc.last_update_time = self .last_update_time diff --git a/src/sdam/topology.rs b/src/sdam/topology.rs index 4703ed6b7..94ee0c0bc 100644 --- a/src/sdam/topology.rs +++ b/src/sdam/topology.rs @@ -335,7 +335,7 @@ impl TopologyWorker { self.update_topology(new_description).await; if self.options.load_balanced == Some(true) { - let base = ServerDescription::new(self.options.hosts[0].clone()); + let base = ServerDescription::new(&self.options.hosts[0]); self.update_server(ServerDescription { server_type: ServerType::LoadBalancer, average_round_trip_time: None, @@ -374,7 +374,9 @@ impl TopologyWorker { UpdateMessage::SyncHosts(hosts) => { self.sync_hosts(hosts).await } - UpdateMessage::ServerUpdate(sd) => self.update_server(*sd).await, + UpdateMessage::ServerUpdate(sd) => { + self.update_server(*sd).await + } UpdateMessage::MonitorError { address, error } => { self.handle_monitor_error(address, error).await } diff --git a/src/test/client.rs b/src/test/client.rs index 67f6269fc..eb3a28c34 100644 --- a/src/test/client.rs +++ b/src/test/client.rs @@ -1,4 +1,4 @@ -use std::{borrow::Cow, collections::HashMap, future::IntoFuture, time::Duration}; +use std::{borrow::Cow, collections::HashMap, future::IntoFuture, net::Ipv6Addr, time::Duration}; use bson::Document; use serde::{Deserialize, Serialize}; @@ -982,3 +982,44 @@ async fn end_sessions_on_shutdown() { client2.into_client().shutdown().await; assert_eq!(get_end_session_event_count(&mut event_stream).await, 0); } + +#[tokio::test] +async fn ipv6_connect() { + let ipv6_localhost = Ipv6Addr::LOCALHOST.to_string(); + + let client = Client::for_test().await; + // The hello command returns the hostname as "localhost". However, whatsmyuri returns an + // IP-literal, which allows us to detect whether we can re-construct the client with an IPv6 + // address. + let is_ipv6_localhost = client + .database("admin") + .run_command(doc! { "whatsmyuri": 1 }) + .await + .ok() + .and_then(|response| { + response + .get_str("you") + .ok() + .map(|you| you.contains(&ipv6_localhost)) + }) + .unwrap_or(false); + if !is_ipv6_localhost { + log_uncaptured("skipping ipv6_connect due to non-ipv6-localhost configuration"); + return; + } + + let mut options = get_client_options().await.clone(); + for address in options.hosts.iter_mut() { + if let ServerAddress::Tcp { host, .. } = address { + *host = ipv6_localhost.clone(); + } + } + let client = Client::with_options(options).unwrap(); + + let result = client + .database("admin") + .run_command(doc! { "ping": 1 }) + .await + .unwrap(); + assert_eq!(result.get_f64("ok"), Ok(1.0)); +} diff --git a/src/test/spec/trace.rs b/src/test/spec/trace.rs index 51ad63bd9..debce66ff 100644 --- a/src/test/spec/trace.rs +++ b/src/test/spec/trace.rs @@ -459,7 +459,7 @@ fn topology_description_tracing_representation() { let mut servers = HashMap::new(); servers.insert( ServerAddress::default(), - ServerDescription::new(ServerAddress::default()), + ServerDescription::new(&ServerAddress::default()), ); let oid = bson::oid::ObjectId::new();