Skip to content

Commit

Permalink
add health check grpc (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpprian authored Aug 6, 2024
1 parent 15338e7 commit 1f519f2
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
11 changes: 10 additions & 1 deletion src/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{
collections::HashMap,
net::SocketAddr,
sync::{
atomic::{AtomicU64, Ordering},
atomic::{AtomicBool, AtomicU64, Ordering},
Arc, Mutex,
},
};
Expand All @@ -24,6 +24,7 @@ pub(crate) struct ProxyServer {
current_id: Arc<AtomicU64>,
clients: Arc<Mutex<ClientMap>>,
results: Arc<Mutex<HashMap<u64, oneshot::Sender<core_response::Payload>>>>,
pub(crate) connected: Arc<AtomicBool>,
}

impl ProxyServer {
Expand All @@ -34,6 +35,7 @@ impl ProxyServer {
current_id: Arc::new(AtomicU64::new(1)),
clients: Arc::new(Mutex::new(HashMap::new())),
results: Arc::new(Mutex::new(HashMap::new())),
connected: Arc::new(AtomicBool::new(false)),
}
}

Expand All @@ -59,9 +61,11 @@ impl ProxyServer {
let (tx, rx) = oneshot::channel();
let mut results = self.results.lock().unwrap();
results.insert(id, tx);
self.connected.store(true, Ordering::Relaxed);
Ok(rx)
} else {
error!("Defguard core is disconnected");
self.connected.store(false, Ordering::Relaxed);
Err(ApiError::Unexpected("Defguard core is disconnected".into()))
}
}
Expand All @@ -73,6 +77,7 @@ impl Clone for ProxyServer {
current_id: Arc::clone(&self.current_id),
clients: Arc::clone(&self.clients),
results: Arc::clone(&self.results),
connected: Arc::clone(&self.connected),
}
}
}
Expand All @@ -95,15 +100,18 @@ impl proxy_server::Proxy for ProxyServer {

let (tx, rx) = mpsc::unbounded_channel();
self.clients.lock().unwrap().insert(address, tx);
self.connected.store(true, Ordering::Relaxed);

let clients = Arc::clone(&self.clients);
let results = Arc::clone(&self.results);
let connected = Arc::clone(&self.connected);
let mut in_stream = request.into_inner();
tokio::spawn(async move {
while let Some(result) = in_stream.next().await {
match result {
Ok(response) => {
debug!("Received message from Defguard core: {response:?}");
connected.store(true, Ordering::Relaxed);
// Discard empty payloads.
if let Some(payload) = response.payload {
if let Some(rx) = results.lock().unwrap().remove(&response.id) {
Expand All @@ -119,6 +127,7 @@ impl proxy_server::Proxy for ProxyServer {
}
}
info!("Defguard core client disconnected: {address}");
connected.store(false, Ordering::Relaxed);
clients.lock().unwrap().remove(&address);
});

Expand Down
12 changes: 11 additions & 1 deletion src/http.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use std::{
fs::read_to_string,
net::{IpAddr, Ipv4Addr, SocketAddr},
sync::atomic::Ordering,
time::Duration,
};

use anyhow::Context;
use axum::{
body::Body,
extract::{ConnectInfo, FromRef},
extract::{ConnectInfo, FromRef, State},
http::{Request, StatusCode},
routing::get,
serve, Json, Router,
Expand Down Expand Up @@ -66,6 +67,14 @@ async fn healthcheck() -> &'static str {
"I'm alive!"
}

async fn healthcheckgrpc(State(state): State<AppState>) -> (StatusCode, &'static str) {
if state.grpc_server.connected.load(Ordering::Relaxed) {
(StatusCode::OK, "Alive")
} else {
(StatusCode::SERVICE_UNAVAILABLE, "Not connected to store")
}
}

// Retrieves client address from the request. Uses either the left most x-forwarded-for
// header value, or socket address if the header is not present.
fn get_client_addr(request: &Request<Body>) -> String {
Expand Down Expand Up @@ -178,6 +187,7 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> {
.nest("/password-reset", password_reset::router())
.nest("/client-mfa", desktop_client_mfa::router())
.route("/health", get(healthcheck))
.route("/health-grpc", get(healthcheckgrpc))
.route("/info", get(app_info)),
)
.fallback_service(get(handle_404))
Expand Down

0 comments on commit 1f519f2

Please sign in to comment.