diff --git a/ci/Cargo.lock.min b/ci/Cargo.lock.min index 134d85a..c7266d6 100644 --- a/ci/Cargo.lock.min +++ b/ci/Cargo.lock.min @@ -58,7 +58,7 @@ checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.90", ] [[package]] @@ -154,7 +154,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn", + "syn 2.0.90", "which", ] @@ -249,9 +249,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.4" +version = "1.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9157bbaa6b165880c27a4293a474c91cdcf265cc68cc829bf10be0964a391caf" +checksum = "c31a0499c1dc64f458ad13872de75c0eb7e3fdb0e67964610c914b034fc5956e" dependencies = [ "jobserver", "libc", @@ -343,6 +343,12 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + [[package]] name = "darling" version = "0.20.10" @@ -364,7 +370,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn", + "syn 2.0.90", ] [[package]] @@ -375,7 +381,21 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core", "quote", - "syn", + "syn 2.0.90", +] + +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", ] [[package]] @@ -403,7 +423,7 @@ checksum = "bc2323e10c92e1cf4d86e11538512e6dc03ceb586842970b6332af3d4046a046" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.90", ] [[package]] @@ -424,7 +444,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.90", ] [[package]] @@ -579,7 +599,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.90", ] [[package]] @@ -641,6 +661,12 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.2" @@ -950,7 +976,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.90", ] [[package]] @@ -1105,9 +1131,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.168" +version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aaeb2981e0606ca11d79718f8bb01164f1d6ed75080182d3abf017e6d244b6d" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" [[package]] name = "libloading" @@ -1224,7 +1250,7 @@ checksum = "254a5372af8fc138e36684761d3c0cdb758a4410e938babcff1c860ce14ddbfc" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.90", ] [[package]] @@ -1245,6 +1271,7 @@ dependencies = [ "bytes", "chrono", "chrono-tz", + "dashmap", "deadpool", "delegate", "futures", @@ -1280,7 +1307,7 @@ name = "neo4rs-macros" version = "0.3.0" dependencies = [ "quote", - "syn", + "syn 2.0.90", ] [[package]] @@ -1412,7 +1439,7 @@ dependencies = [ "regex", "regex-syntax", "structmeta", - "syn", + "syn 2.0.90", ] [[package]] @@ -1518,7 +1545,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "64d1ec885c64d0457d564db4ec299b2dae3f9c02808b8ad9c3a089c591b18033" dependencies = [ "proc-macro2", - "syn", + "syn 2.0.90", ] [[package]] @@ -1797,7 +1824,7 @@ checksum = "46f859dbbf73865c6627ed570e78961cd3ac92407a2d117204c49232485da55e" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.90", ] [[package]] @@ -1814,13 +1841,13 @@ dependencies = [ [[package]] name = "serde_repr" -version = "0.1.19" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c64451ba24fc7a6a2d60fc75dd9c83c90903b19028d4eff35e88fc1e86564e9" +checksum = "cd02c7587ec314570041b2754829f84d873ced14a96d1fd1823531e11db40573" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -1871,7 +1898,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn", + "syn 2.0.90", ] [[package]] @@ -1960,7 +1987,7 @@ dependencies = [ "proc-macro2", "quote", "structmeta-derive", - "syn", + "syn 2.0.90", ] [[package]] @@ -1971,7 +1998,7 @@ checksum = "152a0b65a590ff6c3da95cabe2353ee04e6167c896b28e3b14478c2636c922fc" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.90", ] [[package]] @@ -1980,6 +2007,17 @@ version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.90" @@ -1999,7 +2037,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.90", ] [[package]] @@ -2035,7 +2073,7 @@ dependencies = [ "cfg-if", "proc-macro2", "quote", - "syn", + "syn 2.0.90", ] [[package]] @@ -2046,7 +2084,7 @@ checksum = "5c89e72a01ed4c579669add59014b9a524d609c0c88c6a585ce37485879f6ffb" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.90", "test-case-core", ] @@ -2105,7 +2143,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.90", ] [[package]] @@ -2175,7 +2213,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.90", ] [[package]] @@ -2351,7 +2389,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn", + "syn 2.0.90", "wasm-bindgen-shared", ] @@ -2373,7 +2411,7 @@ checksum = "30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.90", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2659,7 +2697,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.90", "synstructure", ] @@ -2681,7 +2719,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.90", ] [[package]] @@ -2701,7 +2739,7 @@ checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.90", "synstructure", ] @@ -2730,5 +2768,5 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.90", ] diff --git a/ci/Cargo.lock.msrv b/ci/Cargo.lock.msrv index 134d85a..c6242c2 100644 --- a/ci/Cargo.lock.msrv +++ b/ci/Cargo.lock.msrv @@ -249,9 +249,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.4" +version = "1.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9157bbaa6b165880c27a4293a474c91cdcf265cc68cc829bf10be0964a391caf" +checksum = "c31a0499c1dc64f458ad13872de75c0eb7e3fdb0e67964610c914b034fc5956e" dependencies = [ "jobserver", "libc", @@ -343,6 +343,12 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + [[package]] name = "darling" version = "0.20.10" @@ -378,6 +384,20 @@ dependencies = [ "syn", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "deadpool" version = "0.12.1" @@ -641,6 +661,12 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.2" @@ -1105,9 +1131,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.168" +version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aaeb2981e0606ca11d79718f8bb01164f1d6ed75080182d3abf017e6d244b6d" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" [[package]] name = "libloading" @@ -1245,6 +1271,7 @@ dependencies = [ "bytes", "chrono", "chrono-tz", + "dashmap", "deadpool", "delegate", "futures", diff --git a/lib/Cargo.toml b/lib/Cargo.toml index 1d1a1e9..c6449da 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -31,6 +31,7 @@ unstable-bolt-protocol-impl-v2 = [ backoff = { version = "0.4.0", features = ["tokio"] } bytes = { version = "1.5.0", features = ["serde"] } chrono-tz = "0.10.0" +dashmap = "6.1.0" delegate = "0.13.0" futures = { version = "0.3.0" } log = "0.4.0" diff --git a/lib/src/bolt/request/mod.rs b/lib/src/bolt/request/mod.rs index b97baea..258a609 100644 --- a/lib/src/bolt/request/mod.rs +++ b/lib/src/bolt/request/mod.rs @@ -6,6 +6,7 @@ mod hello; mod pull; mod reset; mod rollback; +mod route; pub use commit::Commit; pub use discard::Discard; diff --git a/lib/src/bolt/request/route.rs b/lib/src/bolt/request/route.rs new file mode 100644 index 0000000..74d97f3 --- /dev/null +++ b/lib/src/bolt/request/route.rs @@ -0,0 +1,71 @@ +use crate::bolt::{ExpectedResponse, Summary}; +use crate::connection::NeoUrl; +use crate::routing::{Route, RoutingTable}; +use serde::ser::SerializeStructVariant; +use serde::{Deserialize, Serialize}; +use std::fmt::{format, Display, Formatter}; + +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +pub struct Response { + pub(crate) rt: RoutingTable, +} + +impl<'a> ExpectedResponse for Route<'a> { + type Response = Summary; +} + +#[cfg(test)] +mod tests { + use crate::bolt::request::route::Response; + use crate::bolt::{Message, MessageResponse}; + use crate::packstream::bolt; + use crate::routing::{Route, RouteBuilder, Routing}; + use crate::{Database, Version}; + + #[test] + fn serialize() { + let route = RouteBuilder::new(Routing::Yes(vec![("address".into(), "localhost:7687".into())]), vec!["bookmark"]) + .with_db(Database::from("neo4j")) + .build(Version::V4_3); + let bytes = route.to_bytes().unwrap(); + + let expected = bolt() + .structure(3, 0x66) + .tiny_map(1) + .tiny_string("address") + .tiny_string("localhost:7687") + .tiny_list(1) + .tiny_string("bookmark") + .tiny_string("neo4j") + .build(); + + assert_eq!(bytes, expected); + } + + #[test] + fn parse() { + let data = bolt() + .tiny_map(1) + .tiny_string("rt") + .tiny_map(3) + .tiny_string("ttl") + .int64(1000) + .tiny_string("db") + .tiny_string("neo4j") + .tiny_string("servers") + .tiny_list(1) + .tiny_map(2) + .tiny_string("addresses") + .tiny_list(1) + .tiny_string("localhost:7687") + .tiny_string("role") + .tiny_string("ROUTE") + .build(); + + let response = Response::parse(data).unwrap(); + + assert_eq!(response.rt.ttl, 1000); + assert_eq!(response.rt.db.unwrap().as_ref(), "neo4j"); + assert_eq!(response.rt.servers.len(), 1); + } +} diff --git a/lib/src/config.rs b/lib/src/config.rs index 93f442f..0af80cf 100644 --- a/lib/src/config.rs +++ b/lib/src/config.rs @@ -1,5 +1,7 @@ use crate::auth::{ClientCertificate, ConnectionTLSConfig}; use crate::errors::{Error, Result}; +#[cfg(feature = "unstable-bolt-protocol-impl-v2")] +use serde::{Deserialize, Deserializer}; use std::path::Path; use std::{ops::Deref, sync::Arc}; @@ -11,6 +13,17 @@ const DEFAULT_MAX_CONNECTIONS: usize = 16; #[derive(Clone, Debug, PartialEq, Eq)] pub struct Database(Arc); +#[cfg(feature = "unstable-bolt-protocol-impl-v2")] +impl<'de> Deserialize<'de> for Database { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + Ok(Database::from(s)) + } +} + impl From<&str> for Database { fn from(s: &str) -> Self { Database(s.into()) @@ -37,6 +50,12 @@ impl Deref for Database { } } +impl Default for Database { + fn default() -> Self { + Database("neo4j".into()) + } +} + /// The configuration that is used once a connection is alive. #[derive(Debug, Clone)] pub struct LiveConfig { diff --git a/lib/src/connection.rs b/lib/src/connection.rs index a768672..8cf14a3 100644 --- a/lib/src/connection.rs +++ b/lib/src/connection.rs @@ -4,21 +4,25 @@ use crate::bolt::{ ExpectedResponse, Hello, HelloBuilder, Message, MessageResponse, Reset, Summary, }; #[cfg(not(feature = "unstable-bolt-protocol-impl-v2"))] -use crate::messages::HelloBuilder; +use {crate::messages::HelloBuilder, crate::types::BoltMap}; + +#[cfg(feature = "unstable-bolt-protocol-impl-v2")] +use crate::routing::Route; +use crate::routing::{Routing, RoutingTable}; use crate::{ connection::stream::ConnectionStream, errors::{Error, Result}, messages::{BoltRequest, BoltResponse}, version::Version, - BoltMap, BoltString, BoltType, + BoltString, }; use bytes::{BufMut, Bytes, BytesMut}; -use log::warn; +use log::{info, warn}; use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; use rustls::crypto::CryptoProvider; use rustls::pki_types::{CertificateDer, UnixTime}; use rustls::{DigitallySignedStruct, SignatureScheme}; -use std::fmt::{Debug, Formatter}; +use std::fmt::{Debug, Display, Formatter}; use std::{fs::File, io::BufReader, mem, sync::Arc}; use tokio::{ io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufStream}, @@ -49,6 +53,10 @@ impl Connection { Ok(connection) } + pub fn version(&self) -> Version { + self.version + } + pub(crate) async fn prepare(info: &ConnectionInfo) -> Result { let mut stream = match &info.host { Host::Domain(domain) => TcpStream::connect((&**domain, info.port)).await?, @@ -76,6 +84,7 @@ impl Connection { let mut response = [0, 0, 0, 0]; stream.read_exact(&mut response).await?; let version = Version::parse(response)?; + info!("Connected to Neo4j with version {}", version); Ok(version) } @@ -110,11 +119,36 @@ impl Connection { match hello { Summary::Success(_msg) => Ok(()), - Summary::Ignored => todo!(), + Summary::Ignored => Err(Error::RequestIgnoredError), Summary::Failure(msg) => Err(Error::AuthenticationError(msg.message)), } } + #[cfg(feature = "unstable-bolt-protocol-impl-v2")] + pub async fn route(&mut self, route: Route<'_>) -> Result { + let route = self.send_recv_as(route).await?; + + match route { + Summary::Success(msg) => Ok(msg.metadata.rt), + Summary::Ignored => Err(Error::RequestIgnoredError), + Summary::Failure(msg) => Err(Error::RoutingTableError((msg.code, msg.message))), + } + } + + #[cfg(not(feature = "unstable-bolt-protocol-impl-v2"))] + pub async fn route(&mut self, route: BoltRequest) -> Result { + match self.send_recv(route).await? { + BoltResponse::Success(msg) => { + let rt: BoltMap = msg.get("rt").unwrap(); + Ok(RoutingTable::from(rt)) + } + BoltResponse::Failure(msg) => { + Err(Error::RoutingTableError(msg.get("message").unwrap())) + } + msg => Err(msg.into_error("HELLO")), + } + } + pub async fn reset(&mut self) -> Result<()> { #[cfg(not(feature = "unstable-bolt-protocol-impl-v2"))] { @@ -235,16 +269,16 @@ impl Connection { } pub(crate) struct ConnectionInfo { - user: Arc, - password: Arc, - host: Host>, - port: u16, - routing: Routing, - encryption: Option<(TlsConnector, ServerName<'static>)>, + pub user: Arc, + pub password: Arc, + pub host: Host>, + pub port: u16, + pub routing: Routing, + pub encryption: Option<(TlsConnector, ServerName<'static>)>, } -impl std::fmt::Debug for ConnectionInfo { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl Debug for ConnectionInfo { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("ConnectionInfo") .field("user", &self.user) .field("password", &"***") @@ -256,26 +290,6 @@ impl std::fmt::Debug for ConnectionInfo { } } -#[derive(Debug, Clone)] -pub(crate) enum Routing { - No, - Yes(Vec<(BoltString, BoltString)>), -} - -impl From for Option { - fn from(routing: Routing) -> Self { - match routing { - Routing::No => None, - Routing::Yes(routing) => Some( - routing - .into_iter() - .map(|(k, v)| (k, BoltType::String(v))) - .collect(), - ), - } - } -} - impl ConnectionInfo { pub(crate) fn new( uri: &str, @@ -300,8 +314,8 @@ impl ConnectionInfo { .transpose()?; let routing = if routing { - log::warn!(concat!( - "This driver does not yet implement client-side routing. ", + warn!(concat!( + "Client-side routing is in experimental mode.", "It is possible that operations against a cluster (such as Aura) will fail." )); Routing::Yes(url.routing_context()) @@ -399,10 +413,11 @@ impl ConnectionInfo { } } -struct NeoUrl(Url); +#[derive(Clone, Debug)] +pub struct NeoUrl(Url); impl NeoUrl { - fn parse(uri: &str) -> Result { + pub(crate) fn parse(uri: &str) -> Result { let url = match Url::parse(uri) { Ok(url) if url.has_host() => url, // missing scheme @@ -419,11 +434,11 @@ impl NeoUrl { self.0.scheme() } - fn host(&self) -> Host<&str> { + pub(crate) fn host(&self) -> Host<&str> { self.0.host().unwrap() } - fn port(&self) -> u16 { + pub(crate) fn port(&self) -> u16 { self.0.port().unwrap_or(7687) } @@ -433,28 +448,34 @@ impl NeoUrl { fn warn_on_unexpected_components(&self) { if !self.0.username().is_empty() || self.0.password().is_some() { - log::warn!(concat!( + warn!(concat!( "URI contained auth credentials, which are ignored.", "Credentials are passed outside of the URI" )); } if !matches!(self.0.path(), "" | "/") { - log::warn!("URI contained a path, which is ignored."); + warn!("URI contained a path, which is ignored."); } if self.0.query().is_some() { - log::warn!(concat!( + warn!(concat!( "This client does not yet support client-side routing.", "The routing context passed as a query to the URI is ignored." )); } if self.0.fragment().is_some() { - log::warn!("URI contained a fragment, which is ignored."); + warn!("URI contained a fragment, which is ignored."); } } } +impl Display for NeoUrl { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + mod stream { use pin_project_lite::pin_project; use tokio::{ diff --git a/lib/src/errors.rs b/lib/src/errors.rs index 6efd65f..10c9a65 100644 --- a/lib/src/errors.rs +++ b/lib/src/errors.rs @@ -93,6 +93,18 @@ pub enum Error { #[error("{0}")] DeserializationError(DeError), + + #[error("Failed to fetch the routing table [{}]: {}", _0.0, _0.1)] + RoutingTableError((String, String)), + + #[error("The request has been ignored by the server. This can happen if the server is under pressure or there was an issue with the memory.")] + RequestIgnoredError, + + #[error("{0}")] + RoutingTableRefreshFailed(String), + + #[error("{0}")] + ServerUnavailableError(String), } #[derive(Copy, Clone, Debug, PartialEq, Eq)] diff --git a/lib/src/graph.rs b/lib/src/graph.rs index 1f59060..14a0d59 100644 --- a/lib/src/graph.rs +++ b/lib/src/graph.rs @@ -1,5 +1,15 @@ -use std::time::Duration; +use crate::routing::RouteBuilder; + +use { + crate::connection::{Connection, ConnectionInfo}, + crate::graph::ConnectionPoolManager::Routed, + crate::routing::{RoundRobinStrategy, RoutedConnectionManager, Routing}, + log::info, + std::sync::Arc, +}; +use crate::graph::ConnectionPoolManager::Normal; +use crate::pool::ManagedConnection; use crate::{ config::{Config, ConfigBuilder, Database, LiveConfig}, errors::Result, @@ -7,7 +17,32 @@ use crate::{ query::Query, stream::DetachedRowStream, txn::Txn, + Operation, }; +use backoff::ExponentialBackoff; +use std::time::Duration; + +#[derive(Clone)] +enum ConnectionPoolManager { + Routed(RoutedConnectionManager), + Normal(ConnectionPool), +} + +impl ConnectionPoolManager { + async fn get(&self, operation: Option) -> Result { + match self { + Routed(manager) => manager.get(operation).await, + Normal(pool) => pool.get().await.map_err(crate::Error::from), + } + } + + fn backoff(&self) -> ExponentialBackoff { + match self { + Routed(manager) => manager.backoff(), + Normal(pool) => pool.manager().backoff(), + } + } +} /// A neo4j database abstraction. /// This type can be cloned and shared across threads, internal resources @@ -15,7 +50,7 @@ use crate::{ #[derive(Clone)] pub struct Graph { config: LiveConfig, - pool: ConnectionPool, + pool: ConnectionPoolManager, } /// Returns a [`Query`] which provides methods like [`Query::param`] to add parameters to the query @@ -28,9 +63,42 @@ impl Graph { /// /// You can build a config using [`ConfigBuilder::default()`]. pub async fn connect(config: Config) -> Result { - let pool = create_pool(&config).await?; - let config = config.into_live_config(); - Ok(Graph { config, pool }) + let info = ConnectionInfo::new( + &config.uri, + &config.user, + &config.password, + &config.tls_config, + )?; + if matches!(info.routing, Routing::Yes(_)) { + let mut connection = Connection::new(&info).await?; + let mut builder = RouteBuilder::new(info.routing, vec![]); + if let Some(db) = config.db.clone() { + builder = builder.with_db(db); + } + let rt = connection + .route(builder.build(connection.version())) + .await?; + connection.reset().await?; + info!("Connected to routing server, routing table: {:?}", rt); + let pool = Routed( + RoutedConnectionManager::new( + &config, + Arc::new(rt.clone()), + Arc::new(RoundRobinStrategy::new(rt)), + ) + .await?, + ); + Ok(Graph { + config: config.into_live_config(), + pool, + }) + } else { + let pool = Normal(create_pool(&config).await?); + Ok(Graph { + config: config.into_live_config(), + pool, + }) + } } /// Connects to the database with default configurations @@ -66,7 +134,7 @@ impl Graph { } async fn impl_start_txn_on(&self, db: Option) -> Result { - let connection = self.pool.get().await?; + let connection = self.pool.get(Some(Operation::Write)).await?; Txn::new(db, self.config.fetch_size, connection).await } @@ -82,7 +150,8 @@ impl Graph { /// /// use [`Graph::execute`] when you are interested in the result stream pub async fn run(&self, q: Query) -> Result<()> { - self.impl_run_on(self.config.db.clone(), q).await + self.impl_run_on(self.config.db.clone(), q, Operation::Read) + .await } /// Runs a query on the provided database using a connection from the connection pool. @@ -97,18 +166,24 @@ impl Graph { /// /// use [`Graph::execute`] when you are interested in the result stream pub async fn run_on(&self, db: impl Into, q: Query) -> Result<()> { - self.impl_run_on(Some(db.into()), q).await + self.impl_run_on(Some(db.into()), q, Operation::Read).await } - async fn impl_run_on(&self, db: Option, q: Query) -> Result<()> { + async fn impl_run_on( + &self, + db: Option, + q: Query, + operation: Operation, + ) -> Result<()> { backoff::future::retry_notify( - self.pool.manager().backoff(), + self.pool.backoff(), || { let pool = &self.pool; let query = &q; let db = db.as_deref(); + let operation = operation.clone(); async move { - let mut connection = pool.get().await.map_err(crate::Error::from)?; + let mut connection = pool.get(Some(operation)).await?; query.run_retryable(db, &mut connection).await } }, @@ -124,7 +199,8 @@ impl Graph { /// This includes errors during a leader election or when the transaction resources on the server (memory, handles, ...) are exhausted. /// Retries happen with an exponential backoff until a retry delay exceeds 60s, at which point the query fails with the last error as it would without any retry. pub async fn execute(&self, q: Query) -> Result { - self.impl_execute_on(self.config.db.clone(), q).await + self.impl_execute_on(self.config.db.clone(), q, Operation::Write) + .await } /// Executes a query on the provided database and returns a [`DetachedRowStream`] @@ -134,19 +210,26 @@ impl Graph { /// This includes errors during a leader election or when the transaction resources on the server (memory, handles, ...) are exhausted. /// Retries happen with an exponential backoff until a retry delay exceeds 60s, at which point the query fails with the last error as it would without any retry. pub async fn execute_on(&self, db: impl Into, q: Query) -> Result { - self.impl_execute_on(Some(db.into()), q).await + self.impl_execute_on(Some(db.into()), q, Operation::Write) + .await } - async fn impl_execute_on(&self, db: Option, q: Query) -> Result { + async fn impl_execute_on( + &self, + db: Option, + q: Query, + operation: Operation, + ) -> Result { backoff::future::retry_notify( - self.pool.manager().backoff(), + self.pool.backoff(), || { let pool = &self.pool; let fetch_size = self.config.fetch_size; let query = &q; let db = db.as_deref(); + let operation = operation.clone(); async move { - let connection = pool.get().await.map_err(crate::Error::from)?; + let connection = pool.get(Some(operation)).await?; query.execute_retryable(db, fetch_size, connection).await } }, diff --git a/lib/src/lib.rs b/lib/src/lib.rs index b62b21d..7a2064e 100644 --- a/lib/src/lib.rs +++ b/lib/src/lib.rs @@ -465,6 +465,7 @@ mod messages; mod packstream; mod pool; mod query; +mod routing; mod row; mod stream; #[cfg(feature = "unstable-result-summary")] @@ -493,5 +494,21 @@ pub use crate::types::{ BoltPoint2D, BoltPoint3D, BoltRelation, BoltString, BoltTime, BoltType, BoltUnboundedRelation, }; pub use crate::version::Version; +use std::fmt::Display; pub(crate) use messages::Success; + +#[derive(Debug, PartialEq, Clone)] +pub enum Operation { + Read, + Write, +} + +impl Display for Operation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Operation::Read => write!(f, "READ"), + Operation::Write => write!(f, "WRITE"), + } + } +} diff --git a/lib/src/messages.rs b/lib/src/messages.rs index af7e8d6..703cd0e 100644 --- a/lib/src/messages.rs +++ b/lib/src/messages.rs @@ -9,10 +9,13 @@ mod pull; mod record; mod reset; mod rollback; +#[cfg(not(feature = "unstable-bolt-protocol-impl-v2"))] +mod route; mod run; mod success; +#[cfg(not(feature = "unstable-bolt-protocol-impl-v2"))] +use crate::routing; -use crate::messages::ignore::Ignore; use crate::{ errors::{Error, Result}, types::{BoltMap, BoltWireFormat}, @@ -25,6 +28,7 @@ use failure::Failure; use record::Record; use run::Run; pub(crate) use success::Success; +use ignore::Ignore; #[derive(Debug, PartialEq, Clone)] pub enum BoltResponse { @@ -69,6 +73,8 @@ pub enum BoltRequest { deprecated(since = "0.9.0", note = "Use `crate::bolt::Reset` instead.") )] Reset(reset::Reset), + #[cfg(not(feature = "unstable-bolt-protocol-impl-v2"))] + Route(routing::Route), } #[cfg(not(feature = "unstable-bolt-protocol-impl-v2"))] @@ -206,6 +212,8 @@ impl BoltRequest { BoltRequest::Commit(commit) => commit.into_bytes(version)?, BoltRequest::Rollback(rollback) => rollback.into_bytes(version)?, BoltRequest::Reset(reset) => reset.into_bytes(version)?, + #[cfg(not(feature = "unstable-bolt-protocol-impl-v2"))] + BoltRequest::Route(route) => route.into_bytes(version)?, }; Ok(bytes) } diff --git a/lib/src/messages/route.rs b/lib/src/messages/route.rs new file mode 100644 index 0000000..4587a79 --- /dev/null +++ b/lib/src/messages/route.rs @@ -0,0 +1,79 @@ +use crate::routing::{RoutingTable, Server}; +use crate::BoltMap; + +/// Convert a BoltMap into a RoutingTable +impl From for RoutingTable { + fn from(rt: BoltMap) -> Self { + let ttl = rt.get::("ttl").unwrap_or(0); + let db = rt.get::("db").ok().map(|db| db.into()); + let rt_servers = rt.get::>("servers").unwrap_or_default(); + let server = rt_servers + .iter() + .map(|server| { + let role = server.get::("role").unwrap_or_default(); + let addresses = server.get::>("addresses").unwrap_or_default(); + Server { addresses, role } + }) + .collect::>(); + RoutingTable { + ttl, + db, + servers: server, + } + } +} + +#[cfg(test)] +mod tests { + use crate::messages::BoltRequest; + use crate::routing::{Route, RouteBuilder, Routing}; + use crate::types::{list, map, string, BoltWireFormat}; + use crate::version::Version; + use bytes::*; + + #[test] + fn should_serialize_route() { + let route = RouteBuilder::new(Routing::Yes(vec![("address".into(), "localhost".into())]), vec![]) + .with_db("neo4j".into()) + .build(Version::V4_3); + let r = match route { + BoltRequest::Route(r) => r, + _ => panic!("Expected Route"), + }; + let bytes: Bytes = Route::from(r).into_bytes(Version::V4_1).unwrap(); + + assert_eq!( + bytes, + Bytes::from_static(&[ + 0xB3, + 0x66, + map::TINY | 1, + string::TINY | 7, + b'a', + b'd', + b'd', + b'r', + b'e', + b's', + b's', + string::TINY | 9, + b'l', + b'o', + b'c', + b'a', + b'l', + b'h', + b'o', + b's', + b't', + list::TINY | 0, + string::TINY | 5, + b'n', + b'e', + b'o', + b'4', + b'j', + ]) + ); + } +} diff --git a/lib/src/routing/connection_registry.rs b/lib/src/routing/connection_registry.rs new file mode 100644 index 0000000..67b4382 --- /dev/null +++ b/lib/src/routing/connection_registry.rs @@ -0,0 +1,167 @@ +use crate::pool::{create_pool, ConnectionPool}; +use crate::routing::{RoutingTable, Server}; +use crate::{Config, Error}; +use dashmap::DashMap; +use futures::lock::Mutex; +use log::info; +use std::sync::Arc; + +pub type Registry = DashMap; + +#[derive(Clone)] +pub(crate) struct ConnectionRegistry { + config: Config, + creation_time: Arc>, + ttl: u64, + pub(crate) connections: Registry, // Arc is needed for Clone +} + +impl ConnectionRegistry { + pub(crate) async fn new( + config: &Config, + routing_table: Arc, + ) -> Result { + let ttl = routing_table.ttl; + let connections = Self::build_registry(config, routing_table).await?; + Ok(ConnectionRegistry { + config: config.clone(), + creation_time: Arc::new(Mutex::new( + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(), + )), + ttl, + connections, + }) + } + + async fn build_registry( + config: &Config, + routing_table: Arc, + ) -> Result { + let registry = DashMap::new(); + let servers = routing_table.servers.clone(); + for server in servers.iter() { + registry.insert(server.clone(), create_pool(config).await?); + } + Ok(registry) + } + + pub(crate) async fn update_if_expired(&self, f: F) -> Result<(), Error> + where + F: FnOnce() -> R, + R: std::future::Future>, + { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); + info!("Checking if routing table is expired..."); + if let Some(mut guard) = self.creation_time.try_lock() { + if now - *guard > self.ttl { + info!("Routing table expired, refreshing..."); + let routing_table = f().await?; + info!("Routing table refreshed: {:?}", routing_table); + let registry = &self.connections; + let servers = routing_table.servers.clone(); + for server in servers.iter() { + if registry.contains_key(server) { + continue; + } + registry.insert(server.clone(), create_pool(&self.config).await?); + } + registry.retain(|k, _| servers.contains(k)); + info!("Registry updated. New size is {}", registry.len()); + *guard = now; + } + } + Ok(()) + } + /// Retrieve the pool for a specific server. + pub fn get_pool(&self, server: &Server) -> Option { + self.connections.get(server).map(|entry| entry.clone()) + } + + pub fn mark_unavailable(&self, server: &Server) { + self.connections.remove(server); + } + + pub fn servers(&self) -> Vec { + self.connections + .iter() + .map(|entry| entry.key().clone()) + .collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::auth::ConnectionTLSConfig; + use crate::routing::load_balancing::LoadBalancingStrategy; + use crate::routing::RoundRobinStrategy; + use crate::routing::Server; + + #[tokio::test] + async fn test_available_servers() { + let readers = vec![ + Server { + addresses: vec!["host1:7687".to_string()], + role: "READ".to_string(), + }, + Server { + addresses: vec!["host2:7688".to_string()], + role: "READ".to_string(), + }, + ]; + let writers = vec![ + Server { + addresses: vec!["host3:7687".to_string()], + role: "WRITE".to_string(), + }, + Server { + addresses: vec!["host4:7688".to_string()], + role: "WRITE".to_string(), + }, + ]; + let routers = vec![Server { + addresses: vec!["host0:7687".to_string()], + role: "ROUTE".to_string(), + }]; + let cluster_routing_table = RoutingTable { + ttl: 0, + db: None, + servers: readers + .clone() + .into_iter() + .chain(writers.clone()) + .chain(routers.clone()) + .collect(), + }; + let config = Config { + uri: "neo4j://localhost:7687".to_string(), + user: "user".to_string(), + password: "password".to_string(), + max_connections: 10, + db: Some("neo4j".into()), + fetch_size: 0, + tls_config: ConnectionTLSConfig::None, + }; + let registry = ConnectionRegistry::new(&config, Arc::new(cluster_routing_table.clone())) + .await + .unwrap(); + assert_eq!(registry.connections.len(), 5); + let strategy = RoundRobinStrategy::new(cluster_routing_table.clone()); + let router = strategy + .select_router(registry.servers().as_slice()) + .unwrap(); + assert_eq!(router, routers[0]); + registry.mark_unavailable(&writers[0]); + assert_eq!(registry.connections.len(), 4); + let writer = strategy + .select_writer(registry.servers().as_slice()) + .unwrap(); + assert_eq!(writer, writers[1]); + } +} diff --git a/lib/src/routing/load_balancing/mod.rs b/lib/src/routing/load_balancing/mod.rs new file mode 100644 index 0000000..638acf1 --- /dev/null +++ b/lib/src/routing/load_balancing/mod.rs @@ -0,0 +1,9 @@ +pub(crate) mod round_robin_strategy; + +use crate::routing::Server; + +pub trait LoadBalancingStrategy: Sync + Send { + fn select_reader(&self, servers: &[Server]) -> Option; + fn select_writer(&self, servers: &[Server]) -> Option; + fn select_router(&self, servers: &[Server]) -> Option; +} diff --git a/lib/src/routing/load_balancing/round_robin_strategy.rs b/lib/src/routing/load_balancing/round_robin_strategy.rs new file mode 100644 index 0000000..0a589e0 --- /dev/null +++ b/lib/src/routing/load_balancing/round_robin_strategy.rs @@ -0,0 +1,135 @@ +use crate::routing::load_balancing::LoadBalancingStrategy; +use crate::routing::{RoutingTable, Server}; +use std::sync::atomic::AtomicUsize; + +pub struct RoundRobinStrategy { + reader_index: AtomicUsize, + writer_index: AtomicUsize, + router_index: AtomicUsize, +} + +impl RoundRobinStrategy { + pub(crate) fn new(cluster_routing_table: RoutingTable) -> Self { + let readers: Vec = cluster_routing_table + .servers + .iter() + .filter(|s| s.role == "READ") + .cloned() + .collect(); + let writers: Vec = cluster_routing_table + .servers + .iter() + .filter(|s| s.role == "WRITE") + .cloned() + .collect(); + let routers: Vec = cluster_routing_table + .servers + .iter() + .filter(|s| s.role == "ROUTE") + .cloned() + .collect(); + let reader_index = AtomicUsize::new(readers.len()); + let writer_index = AtomicUsize::new(writers.len()); + let router_index = AtomicUsize::new(routers.len()); + RoundRobinStrategy { + reader_index, + writer_index, + router_index, + } + } + + fn select(servers: &[Server], index: &AtomicUsize) -> Option { + if servers.is_empty() { + return None; + } + + index + .compare_exchange( + 0, + servers.len(), + std::sync::atomic::Ordering::Relaxed, + std::sync::atomic::Ordering::Relaxed, + ) + .ok(); + let i = index.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + if let Some(server) = servers.get(i - 1) { + Some(server.clone()) + } else { + //reset index + index.store(servers.len(), std::sync::atomic::Ordering::Relaxed); + servers.last().cloned() + } + } +} + +impl LoadBalancingStrategy for RoundRobinStrategy { + fn select_reader(&self, servers: &[Server]) -> Option { + let readers = servers + .iter() + .filter(|s| s.role == "READ") + .cloned() + .collect::>(); + + Self::select(readers.as_slice(), &self.reader_index) + } + + fn select_writer(&self, servers: &[Server]) -> Option { + let writers = servers + .iter() + .filter(|s| s.role == "WRITE") + .cloned() + .collect::>(); + + Self::select(writers.as_slice(), &self.writer_index) + } + + fn select_router(&self, servers: &[Server]) -> Option { + let routers = servers + .iter() + .filter(|s| s.role == "ROUTE") + .cloned() + .collect::>(); + + Self::select(routers.as_slice(), &self.router_index) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn should_get_next_server() { + let readers = vec![ + Server { + addresses: vec!["localhost:7687".to_string()], + role: "READ".to_string(), + }, + Server { + addresses: vec!["localhost:7688".to_string()], + role: "READ".to_string(), + }, + ]; + let writers = vec![]; + let cluster_routing_table = RoutingTable { + ttl: 0, + db: None, + servers: readers.clone().into_iter().chain(writers.clone()).collect(), + }; + let strategy = RoundRobinStrategy::new(cluster_routing_table.clone()); + let reader = strategy + .select_reader(cluster_routing_table.servers.as_slice()) + .unwrap(); + assert_eq!(reader, readers[1]); + let reader = strategy + .select_reader(cluster_routing_table.servers.as_slice()) + .unwrap(); + assert_eq!(reader, readers[0]); + let reader = strategy + .select_reader(cluster_routing_table.servers.as_slice()) + .unwrap(); + assert_eq!(reader, readers[1]); + let writer = strategy.select_writer(cluster_routing_table.servers.as_slice()); + assert_eq!(writer, None); + } +} diff --git a/lib/src/routing/mod.rs b/lib/src/routing/mod.rs new file mode 100644 index 0000000..074e0cc --- /dev/null +++ b/lib/src/routing/mod.rs @@ -0,0 +1,224 @@ +mod connection_registry; +mod load_balancing; +mod routed_connection_manager; +use crate::types::{BoltMap, BoltString, BoltType}; +use std::fmt::{Display, Formatter}; +#[cfg(not(feature = "unstable-bolt-protocol-impl-v2"))] +use {crate::messages::BoltRequest, crate::types::BoltList, neo4rs_macros::BoltStruct}; +#[cfg(feature = "unstable-bolt-protocol-impl-v2")] +use { + serde::ser::SerializeMap, + serde::{ser::SerializeStructVariant, Deserialize, Serialize}, +}; + +#[cfg(feature = "unstable-bolt-protocol-impl-v2")] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Route<'a> { + pub(crate) routing: Routing, + pub(crate) bookmarks: Vec<&'a str>, + pub(crate) db: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum Routing { + No, + Yes(Vec<(BoltString, BoltString)>), +} + +impl From for Option { + fn from(routing: Routing) -> Self { + match routing { + Routing::No => None, + Routing::Yes(routing) => Some( + routing + .into_iter() + .map(|(k, v)| (k, BoltType::String(v))) + .collect(), + ), + } + } +} + +#[derive(Debug, Clone, BoltStruct, PartialEq)] +#[signature(0xB3, 0x66)] +#[cfg(not(feature = "unstable-bolt-protocol-impl-v2"))] +pub struct Route { + routing: BoltMap, + bookmarks: BoltList, + db: BoltString, // TODO: this can also be null. How do we represent a null string? +} + +#[cfg(not(feature = "unstable-bolt-protocol-impl-v2"))] +impl Route { + pub fn new(routing: BoltMap, bookmarks: Vec<&str>, db: Option) -> Self { + Route { + routing, + bookmarks: BoltList::from( + bookmarks + .into_iter() + .map(|b| BoltType::String(BoltString::new(b))) + .collect::>(), + ), + db: BoltString::from(db.map(|d| d.to_string()).unwrap_or("".to_string())), + } + } +} + +// NOTE: this structure will be needed in the future when we implement the Bolt protocol v4.4 +#[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "unstable-bolt-protocol-impl-v2", derive(Serialize))] +#[allow(dead_code)] +pub struct Extra<'a> { + pub(crate) db: &'a str, + pub(crate) imp_user: &'a str, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "unstable-bolt-protocol-impl-v2", derive(Deserialize))] +pub struct RoutingTable { + pub(crate) ttl: u64, + pub(crate) db: Option, + pub(crate) servers: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "unstable-bolt-protocol-impl-v2", derive(Deserialize))] +pub struct Server { + pub(crate) addresses: Vec, + pub(crate) role: String, // TODO: use an enum here +} + +#[cfg(not(feature = "unstable-bolt-protocol-impl-v2"))] +pub struct RouteBuilder { + routing: BoltMap, + bookmarks: BoltList, + db: BoltString, +} + +#[cfg(not(feature = "unstable-bolt-protocol-impl-v2"))] +impl RouteBuilder { + pub fn new(routing: Routing, bookmarks: Vec<&str>) -> Self { + let map = match routing { + Routing::No => BoltMap::default(), + Routing::Yes(routing) => routing + .into_iter() + .map(|(k, v)| (k, BoltType::String(v))) + .collect(), + }; + RouteBuilder { + routing: map, + bookmarks: BoltList::from( + bookmarks + .into_iter() + .map(|b| BoltType::String(BoltString::new(b))) + .collect::>(), + ), + db: BoltString::from("".to_string()), + } + } + + pub fn with_db(self, db: Database) -> Self { + Self { + db: BoltString::from(db.to_string()), + ..self + } + } + + pub fn build(self, _version: Version) -> BoltRequest { + BoltRequest::Route(Route { + routing: self.routing, + bookmarks: self.bookmarks, + db: self.db, + }) + } +} + +#[cfg(feature = "unstable-bolt-protocol-impl-v2")] +pub struct RouteBuilder<'a> { + routing: Routing, + bookmarks: Vec<&'a str>, + db: Option, +} + +#[cfg(feature = "unstable-bolt-protocol-impl-v2")] +impl<'a> RouteBuilder<'a> { + pub fn new(routing: Routing, bookmarks: Vec<&'a str>) -> Self { + Self { + routing, + bookmarks, + db: None, + } + } + + pub fn with_db(self, db: Database) -> Self { + Self { + db: Some(db), + ..self + } + } + + pub fn build(self, _version: Version) -> Route<'a> { + Route { + routing: self.routing, + bookmarks: self.bookmarks, + db: self.db, + } + } +} + +impl Display for RoutingTable { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "RoutingTable {{ ttl: {}, db: {:?}, servers: {} }}", + self.ttl, + self.db.clone().unwrap_or_default(), + self.servers + .iter() + .map(|s| s.addresses.join(", ")) + .collect::>() + .join(", ") + ) + } +} + +#[cfg(feature = "unstable-bolt-protocol-impl-v2")] +impl Serialize for Routing { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + Routing::No => serializer.serialize_none(), + Routing::Yes(routing) => { + let mut map = serializer.serialize_map(Some(routing.len()))?; + for (k, v) in routing { + map.serialize_entry(k.to_string().as_str(), v.to_string().as_str())?; + } + map.end() + } + } + } +} + +#[cfg(feature = "unstable-bolt-protocol-impl-v2")] +impl<'a> Serialize for Route<'a> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut structure = serializer.serialize_struct_variant("Request", 0x66, "ROUTE", 3)?; + structure.serialize_field("routing", &self.routing)?; + structure.serialize_field("bookmarks", &self.bookmarks)?; + if let Some(db) = &self.db { + structure.serialize_field("db", db.as_ref())?; + } else { + structure.serialize_field("db", &"")?; + } + structure.end() + } +} + +use crate::{Database, Version}; +pub use load_balancing::round_robin_strategy::RoundRobinStrategy; +pub use routed_connection_manager::RoutedConnectionManager; diff --git a/lib/src/routing/routed_connection_manager.rs b/lib/src/routing/routed_connection_manager.rs new file mode 100644 index 0000000..8bf2b39 --- /dev/null +++ b/lib/src/routing/routed_connection_manager.rs @@ -0,0 +1,145 @@ +use crate::pool::ManagedConnection; +use crate::routing::connection_registry::ConnectionRegistry; +use crate::routing::load_balancing::LoadBalancingStrategy; +use crate::routing::{RouteBuilder, Routing, RoutingTable}; +use crate::{Config, Error, Operation}; +use backoff::{ExponentialBackoff, ExponentialBackoffBuilder}; +use futures::lock::Mutex; +use log::{debug, error, info}; +use std::sync::Arc; +use std::time::Duration; + +#[derive(Clone)] +pub struct RoutedConnectionManager { + load_balancing_strategy: Arc, + registry: Arc, + bookmarks: Arc>>, + backoff: Arc, + config: Config, +} + +impl RoutedConnectionManager { + pub async fn new( + config: &Config, + routing_table: Arc, + load_balancing_strategy: Arc, + ) -> Result { + let registry = Arc::new(ConnectionRegistry::new(config, routing_table.clone()).await?); + let backoff = Arc::new( + ExponentialBackoffBuilder::new() + .with_initial_interval(Duration::from_millis(1)) + .with_randomization_factor(0.42) + .with_multiplier(2.0) + .with_max_elapsed_time(Some(Duration::from_secs(60))) + .build(), + ); + + Ok(RoutedConnectionManager { + load_balancing_strategy, + registry, + bookmarks: Arc::new(Mutex::new(vec![])), + backoff, + config: config.clone(), + }) + } + + pub async fn refresh_routing_table(&self) -> Result { + while let Some(router) = self + .load_balancing_strategy + .select_router(self.registry.servers().as_slice()) + { + if let Some(pool) = self.registry.get_pool(&router) { + if let Ok(mut connection) = pool.get().await { + info!( + "Refreshing routing table from router {}", + router.addresses.first().unwrap() + ); + let bookmarks = self.bookmarks.lock().await; + let bookmarks = bookmarks.iter().map(|b| b.as_str()).collect(); + let route = RouteBuilder::new(Routing::Yes(vec![]), bookmarks) + .with_db(self.config.db.clone().unwrap_or_default()) + .build(connection.version()); + match connection.route(route).await { + Ok(rt) => { + debug!("Routing table refreshed: {:?}", rt); + return Ok(rt); + } + Err(e) => { + self.registry.mark_unavailable(&router); + error!( + "Failed to refresh routing table from router {}: {}", + router.addresses.first().unwrap(), + e + ); + } + } + } else { + self.registry.mark_unavailable(&router); + error!( + "Failed to create connection to router `{}`", + router.addresses.first().unwrap() + ); + } + } else { + error!( + "No connection manager available for router `{}` in the registry. Maybe it was marked as unavailable", + router.addresses.first().unwrap() + ); + } + } + // After trying all routers, we still couldn't refresh the routing table: return an error + Err(Error::ServerUnavailableError( + "No router available".to_string(), + )) + } + + pub(crate) async fn get( + &self, + operation: Option, + ) -> Result { + // We probably need to do this in a more efficient way, since this will block the request of a connection + // while we refresh the routing table. We should probably have a separate thread that refreshes the routing + self.registry + .update_if_expired(|| self.refresh_routing_table()) + .await?; + + let op = operation.unwrap_or(Operation::Write); + let available_servers = self.registry.servers(); + while let Some(server) = match op { + Operation::Write => self + .load_balancing_strategy + .select_writer(available_servers.as_slice()), + _ => self + .load_balancing_strategy + .select_reader(available_servers.as_slice()), + } { + if let Some(pool) = self.registry.get_pool(&server) { + match pool.get().await { + Ok(connection) => return Ok(connection), + Err(e) => { + error!( + "Failed to get connection from pool for server `{}`: {}", + server.addresses.first().unwrap(), + e + ); + self.registry.mark_unavailable(&server); + continue; + } + } + } else { + // We couldn't find a connection manager for the server, it was probably marked unavailable + error!( + "No connection manager available for router `{}` in the registry", + server.addresses.first().unwrap() + ); + } + } + Err(Error::RoutingTableRefreshFailed(format!( + "No server available for {op} operation" + ))) + } + + pub(crate) fn backoff(&self) -> ExponentialBackoff { + self.backoff.as_ref().clone() + } +} diff --git a/lib/src/version.rs b/lib/src/version.rs index b6cca9b..bdfcbde 100644 --- a/lib/src/version.rs +++ b/lib/src/version.rs @@ -1,18 +1,20 @@ use crate::errors::{Error, Result}; use bytes::{BufMut, BytesMut}; use std::cmp::PartialEq; -use std::fmt::Debug; +use std::fmt::{Debug, Display, Formatter}; #[derive(Debug, PartialEq, Eq, Clone, Copy, PartialOrd, Ord)] #[non_exhaustive] pub enum Version { V4, V4_1, + V4_3, } impl Version { pub fn add_supported_versions(bytes: &mut BytesMut) { bytes.reserve(16); + bytes.put_u32(0x0304); // V4_3 bytes.put_u32(0x0104); // V4_1 bytes.put_u32(0x0004); // V4 bytes.put_u32(0); @@ -21,6 +23,7 @@ impl Version { pub fn parse(version_bytes: [u8; 4]) -> Result { match version_bytes { + [0, 0, 3, 4] => Ok(Version::V4_3), [0, 0, 1, 4] => Ok(Version::V4_1), [0, 0, 0, 4] => Ok(Version::V4), [0, 0, minor, major] => Err(Error::UnsupportedVersion(major, minor)), @@ -29,12 +32,23 @@ impl Version { } } +impl Display for Version { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Version::V4 => write!(f, "4.0"), + Version::V4_1 => write!(f, "4.1"), + Version::V4_3 => write!(f, "4.3"), + } + } +} + #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn should_parse_version() { + assert_eq!(Version::parse([0, 0, 3, 4]).unwrap(), Version::V4_3); assert_eq!(Version::parse([0, 0, 1, 4]).unwrap(), Version::V4_1); assert_eq!(Version::parse([0, 0, 0, 4]).unwrap(), Version::V4); }