diff --git a/sable_ircd/src/command/handlers/chathistory.rs b/sable_ircd/src/command/handlers/chathistory.rs index 46b51df2..bf46a4ee 100644 --- a/sable_ircd/src/command/handlers/chathistory.rs +++ b/sable_ircd/src/command/handlers/chathistory.rs @@ -136,8 +136,11 @@ async fn handle_chathistory( } }; - let log = server.node().history(); - match log.get_entries(source.id(), target_id, request).await { + let history_service = LocalHistoryService::new(server.node()); + match history_service + .get_entries(source.id(), target_id, request) + .await + { Ok(entries) => send_history_entries(server, response, target, entries)?, Err(HistoryError::InvalidTarget(_)) => Err(invalid_target_error())?, }; @@ -149,17 +152,19 @@ async fn handle_chathistory( // For listing targets, we iterate backwards through time; this allows us to just collect the // first timestamp we see for each target and know that it's the most recent one -async fn list_targets( - server: &ClientServer, - into: impl MessageSink, - source: &wrapper::User<'_>, +async fn list_targets<'a>( + server: &'a ClientServer, + into: impl MessageSink + 'a, + source: &'a wrapper::User<'_>, from_ts: Option, to_ts: Option, limit: Option, ) { - let log = server.node().history(); + let history_service = LocalHistoryService::new(server.node()); - let found_targets = log.list_targets(source.id(), to_ts, from_ts, limit).await; + let found_targets = history_service + .list_targets(source.id(), to_ts, from_ts, limit) + .await; // The appropriate cap here is Batch - chathistory is enabled because we got here, // but can be used without batch support. @@ -195,7 +200,7 @@ fn send_history_entries<'a>( server: &ClientServer, into: impl MessageSink, target: &str, - entries: impl Iterator, + entries: impl IntoIterator, ) -> CommandResult { let batch = into .batch("chathistory", ClientCapability::Batch) @@ -205,7 +210,7 @@ fn send_history_entries<'a>( for entry in entries { // Ignore errors here; it's possible that a message has been expired out of network state // but a reference to it still exists in the history log - let _ = server.send_item(entry, &batch, entry); + let _ = server.send_item(&entry, &batch, &entry); } Ok(()) diff --git a/sable_network/src/history/local_service.rs b/sable_network/src/history/local_service.rs index 4c8a9bd4..7d91f3ef 100644 --- a/sable_network/src/history/local_service.rs +++ b/sable_network/src/history/local_service.rs @@ -20,7 +20,103 @@ fn target_id_for_entry(for_user: UserId, entry: &HistoryLogEntry) -> Option { + node: &'a NetworkNode, +} + +impl<'a> LocalHistoryService<'a> { + pub fn new(node: &'a NetworkNode) -> Self { + LocalHistoryService { node } + } + + fn get_history_for_target( + &self, + source: UserId, + target: TargetId, + from_ts: Option, + to_ts: Option, + backward_limit: usize, + forward_limit: usize, + ) -> Result, HistoryError> { + let mut backward_entries = Vec::new(); + let mut forward_entries = Vec::new(); + let mut target_exists = false; + + // Keep the lock on the NetworkHistoryLog between the backward and the forward + // search to make sure both have a consistent state + let log = self.node.history(); + + if backward_limit != 0 { + let from_ts = if forward_limit == 0 { + from_ts + } else { + // HACK: This is AROUND so we want to capture messages whose timestamp matches exactly + // (it's a message in the middle of the range) + from_ts.map(|from_ts| from_ts + 1) + }; + + for entry in log.entries_for_user_reverse(source) { + target_exists = true; + if matches!(from_ts, Some(ts) if entry.timestamp >= ts) { + // Skip over until we hit the timestamp window we're interested in + continue; + } + if matches!(to_ts, Some(ts) if entry.timestamp <= ts) { + // If we hit this then we've passed the requested window and should stop + break; + } + + if let Some(event_target) = target_id_for_entry(source, entry) { + if event_target == target { + backward_entries.push(entry.clone()); + } + } + + if backward_limit <= backward_entries.len() { + break; + } + } + } + + if forward_limit != 0 { + for entry in log.entries_for_user(source) { + target_exists = true; + if matches!(from_ts, Some(ts) if entry.timestamp <= ts) { + // Skip over until we hit the timestamp window we're interested in + continue; + } + if matches!(to_ts, Some(ts) if entry.timestamp >= ts) { + // If we hit this then we've passed the requested window and should stop + break; + } + + if let Some(event_target) = target_id_for_entry(source, entry) { + if event_target == target { + forward_entries.push(entry.clone()); + } + } + + if forward_limit <= forward_entries.len() { + break; + } + } + } + + if target_exists { + // "The order of returned messages within the batch is implementation-defined, but SHOULD be + // ascending time order or some approximation thereof, regardless of the subcommand used." + // -- https://ircv3.net/specs/extensions/chathistory#returned-message-notes + Ok(backward_entries + .into_iter() + .rev() + .chain(forward_entries.into_iter())) + } else { + Err(HistoryError::InvalidTarget(target)) + } + } +} + +impl<'a> HistoryService for LocalHistoryService<'a> { async fn list_targets( &self, user: UserId, @@ -30,7 +126,7 @@ impl HistoryService for NetworkHistoryLog { ) -> HashMap { let mut found_targets = HashMap::new(); - for entry in self.entries_for_user_reverse(user) { + for entry in self.node.history().entries_for_user_reverse(user) { if matches!(after_ts, Some(ts) if entry.timestamp >= ts) { // Skip over until we hit the timestamp window we're interested in continue; @@ -59,11 +155,10 @@ impl HistoryService for NetworkHistoryLog { user: UserId, target: TargetId, request: HistoryRequest, - ) -> Result, HistoryError> { + ) -> Result, HistoryError> { match request { #[rustfmt::skip] - HistoryRequest::Latest { to_ts, limit } => get_history_for_target( - self, + HistoryRequest::Latest { to_ts, limit } => self.get_history_for_target( user, target, None, @@ -73,8 +168,7 @@ impl HistoryService for NetworkHistoryLog { ), HistoryRequest::Before { from_ts, limit } => { - get_history_for_target( - self, + self.get_history_for_target( user, target, Some(from_ts), @@ -83,8 +177,7 @@ impl HistoryService for NetworkHistoryLog { 0, // Forward limit ) } - HistoryRequest::After { start_ts, limit } => get_history_for_target( - self, + HistoryRequest::After { start_ts, limit } => self.get_history_for_target( user, target, Some(start_ts), @@ -93,8 +186,7 @@ impl HistoryService for NetworkHistoryLog { limit, ), HistoryRequest::Around { around_ts, limit } => { - get_history_for_target( - self, + self.get_history_for_target( user, target, Some(around_ts), @@ -109,8 +201,7 @@ impl HistoryService for NetworkHistoryLog { limit, } => { if start_ts <= end_ts { - get_history_for_target( - self, + self.get_history_for_target( user, target, Some(start_ts), @@ -121,8 +212,7 @@ impl HistoryService for NetworkHistoryLog { } else { // Search backward from start_ts instead of swapping start_ts and end_ts, // because we want to match the last messages first in case we reach the limit - get_history_for_target( - self, + self.get_history_for_target( user, target, Some(start_ts), @@ -135,85 +225,3 @@ impl HistoryService for NetworkHistoryLog { } } } - -fn get_history_for_target( - log: &NetworkHistoryLog, - source: UserId, - target: TargetId, - from_ts: Option, - to_ts: Option, - backward_limit: usize, - forward_limit: usize, -) -> Result, HistoryError> { - let mut backward_entries = Vec::new(); - let mut forward_entries = Vec::new(); - let mut target_exists = false; - - if backward_limit != 0 { - let from_ts = if forward_limit == 0 { - from_ts - } else { - // HACK: This is AROUND so we want to capture messages whose timestamp matches exactly - // (it's a message in the middle of the range) - from_ts.map(|from_ts| from_ts + 1) - }; - - for entry in log.entries_for_user_reverse(source) { - target_exists = true; - if matches!(from_ts, Some(ts) if entry.timestamp >= ts) { - // Skip over until we hit the timestamp window we're interested in - continue; - } - if matches!(to_ts, Some(ts) if entry.timestamp <= ts) { - // If we hit this then we've passed the requested window and should stop - break; - } - - if let Some(event_target) = target_id_for_entry(source, entry) { - if event_target == target { - backward_entries.push(entry); - } - } - - if backward_limit <= backward_entries.len() { - break; - } - } - } - - if forward_limit != 0 { - for entry in log.entries_for_user(source) { - target_exists = true; - if matches!(from_ts, Some(ts) if entry.timestamp <= ts) { - // Skip over until we hit the timestamp window we're interested in - continue; - } - if matches!(to_ts, Some(ts) if entry.timestamp >= ts) { - // If we hit this then we've passed the requested window and should stop - break; - } - - if let Some(event_target) = target_id_for_entry(source, entry) { - if event_target == target { - forward_entries.push(entry); - } - } - - if forward_limit <= forward_entries.len() { - break; - } - } - } - - if target_exists { - // "The order of returned messages within the batch is implementation-defined, but SHOULD be - // ascending time order or some approximation thereof, regardless of the subcommand used." - // -- https://ircv3.net/specs/extensions/chathistory#returned-message-notes - Ok(backward_entries - .into_iter() - .rev() - .chain(forward_entries.into_iter())) - } else { - Err(HistoryError::InvalidTarget(target)) - } -} diff --git a/sable_network/src/history/log.rs b/sable_network/src/history/log.rs index a9151fd1..9d714bf1 100644 --- a/sable_network/src/history/log.rs +++ b/sable_network/src/history/log.rs @@ -10,7 +10,7 @@ use concurrent_log::ConcurrentLog; pub type LogEntryId = usize; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct HistoryLogEntry { pub id: LogEntryId, pub timestamp: i64, diff --git a/sable_network/src/history/mod.rs b/sable_network/src/history/mod.rs index 8188c62f..74aa1716 100644 --- a/sable_network/src/history/mod.rs +++ b/sable_network/src/history/mod.rs @@ -5,6 +5,7 @@ pub use service::*; mod local_service; use crate::network::NetworkStateChange; +pub use local_service::LocalHistoryService; /// Implemented by types that provide metadata for a historic state change pub trait HistoryItem { diff --git a/sable_network/src/history/service.rs b/sable_network/src/history/service.rs index 8b6d3b9b..a4d97a66 100644 --- a/sable_network/src/history/service.rs +++ b/sable_network/src/history/service.rs @@ -91,12 +91,14 @@ pub trait HistoryService { after_ts: Option, before_ts: Option, limit: Option, - ) -> impl Future> + Send; + ) -> impl Future> + Send + Sync; fn get_entries( &self, user: UserId, target: TargetId, request: HistoryRequest, - ) -> impl Future, HistoryError>> + Send; + ) -> impl Future, HistoryError>> + + Send + + Sync; }