diff --git a/opentelemetry/benches/context_attach.rs b/opentelemetry/benches/context_attach.rs index f550152ea7..8bdcb06fc2 100644 --- a/opentelemetry/benches/context_attach.rs +++ b/opentelemetry/benches/context_attach.rs @@ -8,7 +8,7 @@ use opentelemetry::{ }; // Run this benchmark with: -// cargo bench --bench current_context +// cargo bench --bench context_attach fn criterion_benchmark(c: &mut Criterion) { let span_context = Context::new().with_remote_span_context(SpanContext::empty_context()); diff --git a/opentelemetry/src/context.rs b/opentelemetry/src/context.rs index 60c07b844e..b58011fbc9 100644 --- a/opentelemetry/src/context.rs +++ b/opentelemetry/src/context.rs @@ -9,7 +9,7 @@ use std::marker::PhantomData; use std::sync::Arc; thread_local! { - static CURRENT_CONTEXT: RefCell = RefCell::new(Context::default()); + static CURRENT_CONTEXT: RefCell = RefCell::new(ContextStack::default()); } /// An execution-scoped collection of values. @@ -122,7 +122,7 @@ impl Context { /// Note: This function will panic if you attempt to attach another context /// while the current one is still borrowed. pub fn map_current(f: impl FnOnce(&Context) -> T) -> T { - CURRENT_CONTEXT.with(|cx| f(&cx.borrow())) + CURRENT_CONTEXT.with(|cx| cx.borrow().map_current_cx(f)) } /// Returns a clone of the current thread's context with the given value. @@ -298,12 +298,10 @@ impl Context { /// assert_eq!(Context::current().get::(), None); /// ``` pub fn attach(self) -> ContextGuard { - let previous_cx = CURRENT_CONTEXT - .try_with(|current| current.replace(self)) - .ok(); + let cx_id = CURRENT_CONTEXT.with(|cx| cx.borrow_mut().push(self)); ContextGuard { - previous_cx, + cx_id, _marker: PhantomData, } } @@ -336,15 +334,16 @@ impl fmt::Debug for Context { /// A guard that resets the current context to the prior context when dropped. #[allow(missing_debug_implementations)] pub struct ContextGuard { - previous_cx: Option, + cx_id: usize, // ensure this type is !Send as it relies on thread locals _marker: PhantomData<*const ()>, } impl Drop for ContextGuard { fn drop(&mut self) { - if let Some(previous_cx) = self.previous_cx.take() { - let _ = CURRENT_CONTEXT.try_with(|current| current.replace(previous_cx)); + let id = self.cx_id; + if id > 0 { + CURRENT_CONTEXT.with(|context_stack| context_stack.borrow_mut().pop_id(id)); } } } @@ -371,6 +370,75 @@ impl Hasher for IdHasher { } } +struct ContextStack { + current_cx: Context, + current_id: usize, + // TODO:ban wrap the whole id thing in its own type + id_count: usize, + // TODO:ban wrap the the tuple in its own type + stack: Vec>, +} + +impl ContextStack { + #[inline(always)] + fn push(&mut self, cx: Context) -> usize { + self.id_count += 512; // TODO:ban clean up this + let next_id = self.stack.len() + 1 + self.id_count; + let current_cx = std::mem::replace(&mut self.current_cx, cx); + self.stack.push(Some((self.current_id, current_cx))); + self.current_id = next_id; + next_id + } + + #[inline(always)] + fn pop_id(&mut self, id: usize) { + if id == 0 { + return; + } + // Are we at the top of the stack? + if id == self.current_id { + // Shrink the stack if possible + while let Some(None) = self.stack.last() { + self.stack.pop(); + } + // There is always the initial context at the bottom of the stack + if let Some(Some((next_id, next_cx))) = self.stack.pop() { + self.current_cx = next_cx; + self.current_id = next_id; + } + } else { + let pos = id & 511; // TODO:ban clean up this + if pos >= self.stack.len() { + // This is an invalid id, ignore it + return; + } + if let Some((pos_id, _)) = self.stack[pos] { + // Is the correct id at this position? + if pos_id == id { + // Clear out this entry + self.stack[pos] = None; + } + } + } + } + + #[inline(always)] + fn map_current_cx(&self, f: impl FnOnce(&Context) -> T) -> T { + f(&self.current_cx) + } +} + +impl Default for ContextStack { + fn default() -> Self { + ContextStack { + current_id: 0, + current_cx: Context::default(), + id_count: 0, + stack: Vec::with_capacity(64), + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -415,7 +483,6 @@ mod tests { } #[test] - #[ignore = "overlapping contexts are not supported yet"] fn overlapping_contexts() { #[derive(Debug, PartialEq)] struct ValueA(&'static str);