diff --git a/server/src/http/broadcaster.rs b/server/src/http/broadcaster.rs index 9632cd60..02c690e1 100644 --- a/server/src/http/broadcaster.rs +++ b/server/src/http/broadcaster.rs @@ -1,8 +1,4 @@ -use std::{ - hash::{Hash, Hasher}, - sync::Arc, - time::Duration, -}; +use std::{hash::Hash, sync::Arc, time::Duration}; use actix_web::{rt::time::interval, web::Json}; use actix_web_lab::{ @@ -12,7 +8,6 @@ use actix_web_lab::{ use dashmap::DashMap; use futures::future; use prometheus::{register_int_gauge, IntGauge}; -use serde::Serialize; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tracing::{debug, warn}; @@ -21,31 +16,56 @@ use unleash_types::client_features::{ClientFeatures, Query}; use crate::{ error::EdgeError, feature_cache::FeatureCache, - filters::{filter_client_features, name_prefix_filter, project_filter, FeatureFilterSet}, - tokens::cache_key, + filters::{filter_client_features, name_prefix_filter, FeatureFilter, FeatureFilterSet}, types::{EdgeJsonResult, EdgeResult, EdgeToken}, }; -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -struct QueryWrapper { - query: Query, +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct StreamingQuery { + pub projects: Vec, + pub name_prefix: Option, + pub environment: String, +} + +impl Into for StreamingQuery { + fn into(self) -> Query { + Query { + tags: None, + name_prefix: self.name_prefix, + environment: Some(self.environment), + inline_segment_constraints: None, + projects: Some(self.projects), + } + } } -impl Hash for QueryWrapper { - fn hash(&self, state: &mut H) { - serde_json::to_string(&self.query).unwrap().hash(state); +impl From<(&Query, &EdgeToken)> for StreamingQuery { + fn from((query, token): (&Query, &EdgeToken)) -> Self { + Self { + projects: token.projects.clone(), + name_prefix: query.name_prefix.clone(), + environment: match token.environment { + Some(ref env) => env.clone(), + None => token.token.clone(), + }, + } } } +#[derive(Clone, Debug)] +struct ClientData { + token: String, + sender: mpsc::Sender, +} + #[derive(Clone, Debug)] struct ClientGroup { - clients: Vec>, - token: EdgeToken, + clients: Vec, // last_hash: u64 } pub struct Broadcaster { - active_connections: DashMap, + active_connections: DashMap, features_cache: Arc, } @@ -104,13 +124,16 @@ impl Broadcaster { for mut group in self.active_connections.iter_mut() { let mut ok_clients = Vec::new(); - for client in &group.clients { - if client + for ClientData { token, sender } in &group.clients { + if sender .send(sse::Event::Comment("keep-alive".into())) .await .is_ok() { - ok_clients.push(client.clone()); + ok_clients.push(ClientData { + token: token.clone(), + sender: sender.clone(), + }); } } @@ -136,7 +159,9 @@ impl Broadcaster { ) -> EdgeResult> { let (tx, rx) = mpsc::channel(10); - let features = self.resolve_features(&token, query.clone()).await?; + let streaming_query = StreamingQuery::from((&query, &token)); + + let features = self.resolve_features(streaming_query.clone()).await?; tx.send( sse::Data::new_json(&features)? .event("unleash-connected") @@ -145,43 +170,44 @@ impl Broadcaster { .await?; self.active_connections - .entry(QueryWrapper { query }) + .entry(streaming_query) .and_modify(|group| { - group.clients.push(tx.clone()); + group.clients.push(ClientData { + token: token.token.clone(), + sender: tx.clone(), + }); }) .or_insert(ClientGroup { - clients: vec![tx.clone()], - token, + clients: vec![ClientData { + token: token.token.clone(), + sender: tx.clone(), + }], }); Ok(rx) } - fn get_query_filters(query: &Query, token: &EdgeToken) -> FeatureFilterSet { + fn get_query_filters(query: &StreamingQuery) -> FeatureFilterSet { let filter_set = if let Some(name_prefix) = &query.name_prefix { FeatureFilterSet::from(Box::new(name_prefix_filter(name_prefix.clone()))) } else { FeatureFilterSet::default() } - .with_filter(project_filter(token)); + .with_filter(project_filter(query.projects.clone())); filter_set } - async fn resolve_features( - &self, - validated_token: &EdgeToken, - query: Query, - ) -> EdgeJsonResult { - let filter_set = Broadcaster::get_query_filters(&query, validated_token); + async fn resolve_features(&self, query: StreamingQuery) -> EdgeJsonResult { + let filter_set = Broadcaster::get_query_filters(&query); let features = self .features_cache - .get(&cache_key(validated_token)) + .get(&query.environment) .map(|client_features| filter_client_features(&client_features, &filter_set)); match features { Some(features) => Ok(Json(ClientFeatures { - query: Some(query), + query: Some(query.into()), ..features })), // Note: this is a simplification for now, using the following assumptions: @@ -202,7 +228,7 @@ impl Broadcaster { let (query, group) = entry.pair(); let event_data = self - .resolve_features(&group.token, query.query.clone()) + .resolve_features(query.clone()) .await .and_then(|features| sse::Data::new_json(&features).map_err(|e| e.into())); @@ -223,64 +249,23 @@ impl Broadcaster { // disconnected clients will get swept up by `remove_stale_clients` let send_events = client_events .iter() - .map(|(client, event)| client.send(event.clone())); + .map(|(ClientData { sender, .. }, event)| sender.send(event.clone())); let _ = future::join_all(send_events).await; } } -// probably not worth taking this out of the broadcaster. It relies on resolving features etc, which is part of the broadcaster -// async fn broadcast(active_connections: &DashMap) { -// let mut client_events = Vec::new(); -// for entry in active_connections.iter() { -// let (query, group) = entry.pair(); - -// let event_data = self -// .resolve_features(&group.token, group.filter_set.clone(), query.query.clone()) -// .await -// .and_then(|features| sse::Data::new_json(&features).map_err(|e| e.into())); - -// match event_data { -// Ok(sse_data) => { -// let event: Event = sse_data.event("unleash-updated").into(); - -// for client in &group.clients { -// client_events.push((client.clone(), event.clone())); -// } -// } -// Err(e) => { -// warn!("Failed to broadcast features: {:?}", e); -// } -// } -// } -// // try to send to all clients, ignoring failures -// // disconnected clients will get swept up by `remove_stale_clients` -// let send_events = client_events -// .iter() -// .map(|(client, event)| client.send(event.clone())); - -// let _ = future::join_all(send_events).await; -// } - -// -// fn filter_client_groups( -// update_type: UpdateType, -// all_connections: &DashMap, -// ) -> std::iter::Filter< -// dashmap::iter::Iter<'_, QueryWrapper, ClientGroup>, -// impl FnMut(&dashmap::mapref::multiple::RefMulti<'_, QueryWrapper, ClientGroup>) -> bool, -// > { -// all_connections -// .iter() -// .filter(|entry| *entry.key) -// // match update_type { -// // UpdateType::Full(environment) | -// // UpdateType::Update(environment) => all_connections -// // .iter() -// // .filter(|entry| entry.value().token.project == key) - -// // } -// } +fn project_filter(projects: Vec) -> FeatureFilter { + Box::new(move |feature| { + if let Some(feature_project) = &feature.project { + projects.is_empty() + || projects.contains(&"*".to_string()) + || projects.contains(feature_project) + } else { + false + } + }) +} #[cfg(test)] mod test {