diff --git a/server/src/client_api.rs b/server/src/client_api.rs index e17a2380..c655a5ec 100644 --- a/server/src/client_api.rs +++ b/server/src/client_api.rs @@ -50,9 +50,7 @@ pub async fn stream_features( let (validated_token, _filter_set, query) = get_feature_filter(&edge_token, &token_cache, filter_query.clone())?; - broadcaster - .connect(validated_token, filter_query, query) - .await + broadcaster.connect(validated_token, query).await } #[utoipa::path( diff --git a/server/src/http/broadcaster.rs b/server/src/http/broadcaster.rs index d97ce09e..aeec915a 100644 --- a/server/src/http/broadcaster.rs +++ b/server/src/http/broadcaster.rs @@ -1,13 +1,6 @@ -use std::{ - hash::{Hash, Hasher}, - sync::Arc, - time::Duration, -}; +use std::{hash::Hash, sync::Arc, time::Duration}; -use actix_web::{ - rt::time::interval, - web::{Json, Query}, -}; +use actix_web::{rt::time::interval, web::Json}; use actix_web_lab::{ sse::{self, Event, Sse}, util::InfallibleStream, @@ -15,40 +8,63 @@ use actix_web_lab::{ use dashmap::DashMap; use futures::future; use prometheus::{register_int_gauge, IntGauge}; -use serde::Serialize; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tracing::{debug, warn}; -use unleash_types::client_features::{ClientFeatures, Query as FlagQuery}; +use unleash_types::client_features::{ClientFeatures, Query}; use crate::{ error::EdgeError, feature_cache::FeatureCache, - filters::{filter_client_features, name_prefix_filter, project_filter, FeatureFilterSet}, - tokens::cache_key, - types::{EdgeJsonResult, EdgeResult, EdgeToken, FeatureFilters}, + filters::{filter_client_features, name_prefix_filter, FeatureFilter, FeatureFilterSet}, + types::{EdgeJsonResult, EdgeResult, EdgeToken}, }; -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -struct QueryWrapper { - query: FlagQuery, +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct StreamingQuery { + pub projects: Vec, + pub name_prefix: Option, + pub environment: String, } -impl Hash for QueryWrapper { - fn hash(&self, state: &mut H) { - serde_json::to_string(&self.query).unwrap().hash(state); +impl From for Query { + fn from(value: StreamingQuery) -> Self { + Self { + tags: None, + name_prefix: value.name_prefix, + environment: Some(value.environment), + inline_segment_constraints: Some(false), + projects: Some(value.projects), + } + } +} + +impl From<(&Query, &EdgeToken)> for StreamingQuery { + fn from((query, token): (&Query, &EdgeToken)) -> Self { + Self { + projects: token.projects.clone(), + name_prefix: query.name_prefix.clone(), + environment: match token.environment { + Some(ref env) => env.clone(), + None => token.token.clone(), + }, + } } } +#[derive(Clone, Debug)] +struct ClientData { + token: String, + sender: mpsc::Sender, +} + #[derive(Clone, Debug)] struct ClientGroup { - clients: Vec>, - filter_set: Query, - token: EdgeToken, + clients: Vec, } pub struct Broadcaster { - active_connections: DashMap, + active_connections: DashMap, features_cache: Arc, } @@ -101,88 +117,88 @@ impl Broadcaster { async fn heartbeat(&self) { let mut active_connections = 0i64; for mut group in self.active_connections.iter_mut() { - let mut ok_clients = Vec::new(); + let clients = std::mem::take(&mut group.clients); + let ok_clients = &mut group.clients; - for client in &group.clients { - if client + for ClientData { token, sender } in clients { + if sender .send(sse::Event::Comment("keep-alive".into())) .await .is_ok() { - ok_clients.push(client.clone()); + ok_clients.push(ClientData { token, sender }); } } active_connections += ok_clients.len() as i64; - group.clients = ok_clients; } CONNECTED_STREAMING_CLIENTS.set(active_connections) } - /// Registers client with broadcaster, returning an SSE response body. pub async fn connect( &self, token: EdgeToken, - filter_set: Query, - query: unleash_types::client_features::Query, + query: Query, ) -> EdgeResult>>> { - let (tx, rx) = mpsc::channel(10); + self.create_connection(StreamingQuery::from((&query, &token)), &token.token) + .await + .map(Sse::from_infallible_receiver) + } - let features = &self - .resolve_features(&token, filter_set.clone(), query.clone()) - .await?; + async fn create_connection( + &self, + query: StreamingQuery, + token: &str, + ) -> EdgeResult> { + let (tx, rx) = mpsc::channel(10); + let features = self.resolve_features(query.clone()).await?; tx.send( - sse::Data::new_json(features)? + sse::Data::new_json(&features)? .event("unleash-connected") .into(), ) .await?; self.active_connections - .entry(QueryWrapper { query }) + .entry(query) .and_modify(|group| { - group.clients.push(tx.clone()); + group.clients.push(ClientData { + token: token.into(), + sender: tx.clone(), + }); }) .or_insert(ClientGroup { - clients: vec![tx.clone()], - filter_set, - token, + clients: vec![ClientData { + token: token.into(), + sender: tx.clone(), + }], }); - Ok(Sse::from_infallible_receiver(rx)) - } - fn get_query_filters( - filter_query: Query, - token: &EdgeToken, - ) -> FeatureFilterSet { - let query_filters = filter_query.into_inner(); + Ok(rx) + } - let filter_set = if let Some(name_prefix) = query_filters.name_prefix { - FeatureFilterSet::from(Box::new(name_prefix_filter(name_prefix))) + fn get_query_filters(query: &StreamingQuery) -> FeatureFilterSet { + let filter_set = if let Some(name_prefix) = &query.name_prefix { + FeatureFilterSet::from(Box::new(name_prefix_filter(name_prefix.clone()))) } else { FeatureFilterSet::default() } - .with_filter(project_filter(token)); + .with_filter(project_filter(query.projects.clone())); filter_set } - async fn resolve_features( - &self, - validated_token: &EdgeToken, - filter_set: Query, - query: FlagQuery, - ) -> EdgeJsonResult { - let filter_set = Broadcaster::get_query_filters(filter_set.clone(), validated_token); + async fn resolve_features(&self, query: StreamingQuery) -> EdgeJsonResult { + let filter_set = Broadcaster::get_query_filters(&query); let features = self .features_cache - .get(&cache_key(validated_token)) + .get(&query.environment) .map(|client_features| filter_client_features(&client_features, &filter_set)); match features { Some(features) => Ok(Json(ClientFeatures { - query: Some(query), + query: Some(query.into()), ..features })), // Note: this is a simplification for now, using the following assumptions: @@ -196,11 +212,12 @@ impl Broadcaster { /// Broadcast new features to all clients. pub async fn broadcast(&self) { let mut client_events = Vec::new(); + for entry in self.active_connections.iter() { let (query, group) = entry.pair(); let event_data = self - .resolve_features(&group.token, group.filter_set.clone(), query.query.clone()) + .resolve_features(query.clone()) .await .and_then(|features| sse::Data::new_json(&features).map_err(|e| e.into())); @@ -221,8 +238,20 @@ impl Broadcaster { // disconnected clients will get swept up by `remove_stale_clients` let send_events = client_events .iter() - .map(|(client, event)| client.send(event.clone())); + .map(|(ClientData { sender, .. }, event)| sender.send(event.clone())); let _ = future::join_all(send_events).await; } } + +fn project_filter(projects: Vec) -> FeatureFilter { + Box::new(move |feature| { + if let Some(feature_project) = &feature.project { + projects.is_empty() + || projects.contains(&"*".to_string()) + || projects.contains(feature_project) + } else { + false + } + }) +}