From 6211e8a8664d7939a45b2f9e555e2159b2fb2246 Mon Sep 17 00:00:00 2001 From: cpprian Date: Thu, 1 Aug 2024 12:11:46 +0200 Subject: [PATCH] add health check grpc --- src/grpc.rs | 11 ++++++++++- src/http.rs | 12 +++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/grpc.rs b/src/grpc.rs index 0e140f8..44568e3 100644 --- a/src/grpc.rs +++ b/src/grpc.rs @@ -2,7 +2,7 @@ use std::{ collections::HashMap, net::SocketAddr, sync::{ - atomic::{AtomicU64, Ordering}, + atomic::{AtomicBool, AtomicU64, Ordering}, Arc, Mutex, }, }; @@ -24,6 +24,7 @@ pub(crate) struct ProxyServer { current_id: Arc, clients: Arc>, results: Arc>>>, + pub(crate) connected: Arc, } impl ProxyServer { @@ -34,6 +35,7 @@ impl ProxyServer { current_id: Arc::new(AtomicU64::new(1)), clients: Arc::new(Mutex::new(HashMap::new())), results: Arc::new(Mutex::new(HashMap::new())), + connected: Arc::new(AtomicBool::new(false)), } } @@ -59,9 +61,11 @@ impl ProxyServer { let (tx, rx) = oneshot::channel(); let mut results = self.results.lock().unwrap(); results.insert(id, tx); + self.connected.store(true, Ordering::Relaxed); Ok(rx) } else { error!("Defguard core is disconnected"); + self.connected.store(false, Ordering::Relaxed); Err(ApiError::Unexpected("Defguard core is disconnected".into())) } } @@ -73,6 +77,7 @@ impl Clone for ProxyServer { current_id: Arc::clone(&self.current_id), clients: Arc::clone(&self.clients), results: Arc::clone(&self.results), + connected: Arc::clone(&self.connected), } } } @@ -95,15 +100,18 @@ impl proxy_server::Proxy for ProxyServer { let (tx, rx) = mpsc::unbounded_channel(); self.clients.lock().unwrap().insert(address, tx); + self.connected.store(true, Ordering::Relaxed); let clients = Arc::clone(&self.clients); let results = Arc::clone(&self.results); + let connected = Arc::clone(&self.connected); let mut in_stream = request.into_inner(); tokio::spawn(async move { while let Some(result) = in_stream.next().await { match result { Ok(response) => { debug!("Received message from Defguard core: {response:?}"); + connected.store(true, Ordering::Relaxed); // Discard empty payloads. if let Some(payload) = response.payload { if let Some(rx) = results.lock().unwrap().remove(&response.id) { @@ -119,6 +127,7 @@ impl proxy_server::Proxy for ProxyServer { } } info!("Defguard core client disconnected: {address}"); + connected.store(false, Ordering::Relaxed); clients.lock().unwrap().remove(&address); }); diff --git a/src/http.rs b/src/http.rs index f4e0cb1..0967b97 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,13 +1,14 @@ use std::{ fs::read_to_string, net::{IpAddr, Ipv4Addr, SocketAddr}, + sync::atomic::Ordering, time::Duration, }; use anyhow::Context; use axum::{ body::Body, - extract::{ConnectInfo, FromRef}, + extract::{ConnectInfo, FromRef, State}, http::{Request, StatusCode}, routing::get, serve, Json, Router, @@ -66,6 +67,14 @@ async fn healthcheck() -> &'static str { "I'm alive!" } +async fn healthcheckgrpc(State(state): State) -> (StatusCode, &'static str) { + if state.grpc_server.connected.load(Ordering::Relaxed) { + (StatusCode::OK, "Alive") + } else { + (StatusCode::SERVICE_UNAVAILABLE, "Not connected to store") + } +} + // Retrieves client address from the request. Uses either the left most x-forwarded-for // header value, or socket address if the header is not present. fn get_client_addr(request: &Request) -> String { @@ -178,6 +187,7 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> { .nest("/password-reset", password_reset::router()) .nest("/client-mfa", desktop_client_mfa::router()) .route("/health", get(healthcheck)) + .route("/health-grpc", get(healthcheckgrpc)) .route("/info", get(app_info)), ) .fallback_service(get(handle_404))