diff --git a/newrelic/core/context.py b/newrelic/core/context.py index 95de15b4ea..7560855aef 100644 --- a/newrelic/core/context.py +++ b/newrelic/core/context.py @@ -46,7 +46,7 @@ def log_propagation_failure(s): elif trace is not None: self.trace = trace elif trace_cache_id is not None: - self.trace = self.trace_cache._cache.get(trace_cache_id, None) + self.trace = self.trace_cache.get(trace_cache_id, None) if self.trace is None: log_propagation_failure("No trace with id %d." % trace_cache_id) elif hasattr(request, "_nr_trace") and request._nr_trace is not None: @@ -60,11 +60,11 @@ def __enter__(self): self.thread_id = self.trace_cache.current_thread_id() # Save previous cache contents - self.restore = self.trace_cache._cache.get(self.thread_id, None) + self.restore = self.trace_cache.get(self.thread_id, None) self.should_restore = True # Set context in trace cache - self.trace_cache._cache[self.thread_id] = self.trace + self.trace_cache[self.thread_id] = self.trace return self @@ -72,10 +72,10 @@ def __exit__(self, exc, value, tb): if self.should_restore: if self.restore is not None: # Restore previous contents - self.trace_cache._cache[self.thread_id] = self.restore + self.trace_cache[self.thread_id] = self.restore else: # Remove entry from cache - self.trace_cache._cache.pop(self.thread_id) + self.trace_cache.pop(self.thread_id) def context_wrapper(func, trace=None, request=None, trace_cache_id=None, strict=True): diff --git a/newrelic/core/trace_cache.py b/newrelic/core/trace_cache.py index 1634d0d0b2..5f0ddcd3da 100644 --- a/newrelic/core/trace_cache.py +++ b/newrelic/core/trace_cache.py @@ -28,6 +28,11 @@ except ImportError: import _thread as thread +try: + from collections.abc import MutableMapping +except ImportError: + from collections import MutableMapping + from newrelic.core.config import global_settings from newrelic.core.loop_node import LoopNode @@ -92,7 +97,7 @@ class TraceCacheActiveTraceError(RuntimeError): pass -class TraceCache(object): +class TraceCache(MutableMapping): asyncio = cached_module("asyncio") greenlet = cached_module("greenlet") @@ -100,7 +105,7 @@ def __init__(self): self._cache = weakref.WeakValueDictionary() def __repr__(self): - return "<%s object at 0x%x %s>" % (self.__class__.__name__, id(self), str(dict(self._cache.items()))) + return "<%s object at 0x%x %s>" % (self.__class__.__name__, id(self), str(dict(self.items()))) def current_thread_id(self): """Returns the thread ID for the caller. @@ -135,10 +140,10 @@ def current_thread_id(self): def task_start(self, task): trace = self.current_trace() if trace: - self._cache[id(task)] = trace + self[id(task)] = trace def task_stop(self, task): - self._cache.pop(id(task), None) + self.pop(id(task), None) def current_transaction(self): """Return the transaction object if one exists for the currently @@ -146,11 +151,11 @@ def current_transaction(self): """ - trace = self._cache.get(self.current_thread_id()) + trace = self.get(self.current_thread_id()) return trace and trace.transaction def current_trace(self): - return self._cache.get(self.current_thread_id()) + return self.get(self.current_thread_id()) def active_threads(self): """Returns an iterator over all current stack frames for all @@ -169,7 +174,7 @@ def active_threads(self): # First yield up those for real Python threads. for thread_id, frame in sys._current_frames().items(): - trace = self._cache.get(thread_id) + trace = self.get(thread_id) transaction = trace and trace.transaction if transaction is not None: if transaction.background_task: @@ -197,7 +202,7 @@ def active_threads(self): debug = global_settings().debug if debug.enable_coroutine_profiling: - for thread_id, trace in list(self._cache.items()): + for thread_id, trace in self.items(): transaction = trace.transaction if transaction and transaction._greenlet is not None: gr = transaction._greenlet() @@ -212,7 +217,7 @@ def prepare_for_root(self): trace in the cache is from a different task (for asyncio). Returns the current trace after the cache is updated.""" thread_id = self.current_thread_id() - trace = self._cache.get(thread_id) + trace = self.get(thread_id) if not trace: return None @@ -221,11 +226,11 @@ def prepare_for_root(self): task = current_task(self.asyncio) if task is not None and id(trace._task) != id(task): - self._cache.pop(thread_id, None) + self.pop(thread_id, None) return None if trace.root and trace.root.exited: - self._cache.pop(thread_id, None) + self.pop(thread_id, None) return None return trace @@ -240,8 +245,8 @@ def save_trace(self, trace): thread_id = trace.thread_id - if thread_id in self._cache: - cache_root = self._cache[thread_id].root + if thread_id in self: + cache_root = self[thread_id].root if cache_root and cache_root is not trace.root and not cache_root.exited: # Cached trace exists and has a valid root still _logger.error( @@ -253,7 +258,7 @@ def save_trace(self, trace): raise TraceCacheActiveTraceError("transaction already active") - self._cache[thread_id] = trace + self[thread_id] = trace # We judge whether we are actually running in a coroutine by # seeing if the current thread ID is actually listed in the set @@ -284,7 +289,7 @@ def pop_current(self, trace): thread_id = trace.thread_id parent = trace.parent - self._cache[thread_id] = parent + self[thread_id] = parent def complete_root(self, root): """Completes a trace specified by the given root @@ -301,7 +306,7 @@ def complete_root(self, root): to_complete = [] for task_id in task_ids: - entry = self._cache.get(task_id) + entry = self.get(task_id) if entry and entry is not root and entry.root is root: to_complete.append(entry) @@ -316,12 +321,12 @@ def complete_root(self, root): thread_id = root.thread_id - if thread_id not in self._cache: + if thread_id not in self: thread_id = self.current_thread_id() - if thread_id not in self._cache: + if thread_id not in self: raise TraceCacheNoActiveTraceError("no active trace") - current = self._cache.get(thread_id) + current = self.get(thread_id) if root is not current: _logger.error( @@ -333,7 +338,7 @@ def complete_root(self, root): raise RuntimeError("not the current trace") - del self._cache[thread_id] + del self[thread_id] root._greenlet = None def record_event_loop_wait(self, start_time, end_time): @@ -359,7 +364,7 @@ def record_event_loop_wait(self, start_time, end_time): task = getattr(transaction.root_span, "_task", None) loop = get_event_loop(task) - for trace in list(self._cache.values()): + for trace in self.values(): if trace in seen: continue @@ -390,6 +395,62 @@ def record_event_loop_wait(self, start_time, end_time): root.increment_child_count() root.add_child(node) + # MutableMapping methods + + def items(self): + """ + Safely iterates on self._cache.items() indirectly using a list of value references + to avoid RuntimeErrors from size changes during iteration. + """ + for wr in self._cache.valuerefs(): + value = wr() # Dereferenced value is potentially no longer live. + if ( + value is not None + ): # weakref is None means weakref has been garbage collected and is no longer live. Ignore. + yield wr.key, value # wr.key is the original dict key + + def keys(self): + """ + Iterates on self._cache.keys() indirectly using a list of value references + to avoid RuntimeErrors from size changes during iteration. + + NOTE: Returned keys are keys to weak references which may at any point be garbage collected. + It is only safe to retrieve values from the trace cache using trace_cache.get(key, None). + Retrieving values using trace_cache[key] can cause a KeyError if the item has been garbage collected. + """ + for wr in self._cache.valuerefs(): + yield wr.key # wr.key is the original dict key + + def values(self): + """ + Safely iterates on self._cache.values() indirectly using a list of value references + to avoid RuntimeErrors from size changes during iteration. + """ + for wr in self._cache.valuerefs(): + value = wr() # Dereferenced value is potentially no longer live. + if ( + value is not None + ): # weakref is None means weakref has been garbage collected and is no longer live. Ignore. + yield value + + def __getitem__(self, key): + return self._cache.__getitem__(key) + + def __setitem__(self, key, value): + self._cache.__setitem__(key, value) + + def __delitem__(self, key): + self._cache.__delitem__(key) + + def __iter__(self): + return self.keys() + + def __len__(self): + return self._cache.__len__() + + def __bool__(self): + return bool(self._cache.__len__()) + _trace_cache = TraceCache() diff --git a/tests/agent_features/test_async_context_propagation.py b/tests/agent_features/test_async_context_propagation.py index 8026cbbccf..47d16cfc56 100644 --- a/tests/agent_features/test_async_context_propagation.py +++ b/tests/agent_features/test_async_context_propagation.py @@ -13,11 +13,11 @@ # limitations under the License. import pytest -from testing_support.fixtures import ( - function_not_called, - override_generic_settings, +from testing_support.fixtures import function_not_called, override_generic_settings +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, ) -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics + from newrelic.api.application import application_instance as application from newrelic.api.background_task import BackgroundTask, background_task from newrelic.api.database_trace import database_trace @@ -131,7 +131,7 @@ def handle_exception(loop, context): # The agent should have removed all traces from the cache since # run_until_complete has terminated (all callbacks scheduled inside the # task have run) - assert not trace_cache()._cache + assert not trace_cache() # Assert that no exceptions have occurred assert not exceptions, exceptions @@ -286,7 +286,7 @@ def _test(): # The agent should have removed all traces from the cache since # run_until_complete has terminated - assert not trace_cache()._cache + assert not trace_cache() # Assert that no exceptions have occurred assert not exceptions, exceptions diff --git a/tests/agent_features/test_event_loop_wait_time.py b/tests/agent_features/test_event_loop_wait_time.py index 69e6fc1024..84c65dcdc5 100644 --- a/tests/agent_features/test_event_loop_wait_time.py +++ b/tests/agent_features/test_event_loop_wait_time.py @@ -140,7 +140,7 @@ def _test(): def test_record_event_loop_wait_outside_task(): # Insert a random trace into the trace cache trace = FunctionTrace(name="testing") - trace_cache()._cache[0] = trace + trace_cache()[0] = trace @background_task(name="test_record_event_loop_wait_outside_task") def _test(): diff --git a/tests/agent_unittests/test_trace_cache.py b/tests/agent_unittests/test_trace_cache.py new file mode 100644 index 0000000000..e0f7db84fa --- /dev/null +++ b/tests/agent_unittests/test_trace_cache.py @@ -0,0 +1,129 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading + +import pytest + +from newrelic.core.trace_cache import TraceCache + +_TEST_CONCURRENT_ITERATION_TC_SIZE = 20 + + +class DummyTrace(object): + pass + + +@pytest.fixture(scope="function") +def trace_cache(): + return TraceCache() + + +def test_trace_cache_methods(trace_cache): + """Test MutableMapping methods functional for trace_cache""" + obj = DummyTrace() # weakref compatible object + + trace_cache[1] = obj + assert 1 in trace_cache + assert bool(trace_cache) + assert list(trace_cache) + + del trace_cache[1] + assert 1 not in trace_cache + assert not bool(trace_cache) + + trace_cache[1] = obj + assert trace_cache.get(1, None) + assert trace_cache.pop(1, None) + + trace_cache[1] = obj + assert len(trace_cache) == 1 + assert len(list(trace_cache.items())) == 1 + assert len(list(trace_cache.keys())) == 1 + assert len(list(trace_cache.values())) == 1 + + +@pytest.fixture(scope="function") +def iterate_trace_cache(trace_cache): + def _iterate_trace_cache(shutdown): + while True: + if shutdown.is_set(): + return + for k, v in trace_cache.items(): + pass + for v in trace_cache.values(): + pass + for v in trace_cache.keys(): + pass + + return _iterate_trace_cache + + +@pytest.fixture(scope="function") +def change_weakref_dict_size(trace_cache): + def _change_weakref_dict_size(shutdown, obj_refs): + """ + Cause RuntimeErrors when iterating on the trace_cache by: + - Repeatedly pop and add batches of keys to cause size changes. + - Randomly delete and replace some object refs so the weak references are deleted, + causing the weakref dict to delete them and forcing further size changes. + """ + + dict_size_change = _TEST_CONCURRENT_ITERATION_TC_SIZE // 2 # Remove up to half of items + while True: + if shutdown.is_set(): + return + + # Delete and re-add keys + for i in range(dict_size_change): + trace_cache._cache.pop(i, None) + for i in range(dict_size_change): + trace_cache._cache[i] = obj_refs[i] + + # Replace every 3rd obj ref causing the WeakValueDictionary to drop it. + for i, _ in enumerate(obj_refs[::3]): + obj_refs[i] = DummyTrace() + + return _change_weakref_dict_size + + +def test_concurrent_iteration(iterate_trace_cache, change_weakref_dict_size): + """ + Test for exceptions related to trace_cache changing size during iteration. + + The WeakValueDictionary used internally is particularly prone to this, as iterating + on it in any way other than indirectly through WeakValueDictionary.valuerefs() + will cause RuntimeErrors due to the unguarded iteration on a dictionary internally. + """ + obj_refs = [DummyTrace() for _ in range(_TEST_CONCURRENT_ITERATION_TC_SIZE)] + shutdown = threading.Event() + + t1 = threading.Thread(target=change_weakref_dict_size, args=(shutdown, obj_refs)) + t2 = threading.Thread(target=iterate_trace_cache, args=(shutdown,)) + t1.daemon = True + t2.daemon = True + t1.start() + t2.start() + + # Run for 1 second, then shutdown. Stop immediately for exceptions. + t2.join(timeout=1) + assert t1.is_alive(), "Thread exited with exception." + assert t2.is_alive(), "Thread exited with exception." + shutdown.set() + + # Ensure threads shutdown with a timeout to prevent hangs + t1.join(timeout=1) + t2.join(timeout=1) + assert not t1.is_alive(), "Thread failed to exit." + assert not t2.is_alive(), "Thread failed to exit." diff --git a/tests/coroutines_asyncio/test_context_propagation.py b/tests/coroutines_asyncio/test_context_propagation.py index 09fccffb2d..ef26aacc11 100644 --- a/tests/coroutines_asyncio/test_context_propagation.py +++ b/tests/coroutines_asyncio/test_context_propagation.py @@ -15,12 +15,11 @@ import sys import pytest -from testing_support.fixtures import ( - function_not_called, - override_generic_settings, +from testing_support.fixtures import function_not_called, override_generic_settings +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, ) -from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics from newrelic.api.application import application_instance as application from newrelic.api.background_task import BackgroundTask, background_task from newrelic.api.database_trace import database_trace @@ -132,7 +131,7 @@ def handle_exception(loop, context): # The agent should have removed all traces from the cache since # run_until_complete has terminated (all callbacks scheduled inside the # task have run) - assert not trace_cache()._cache + assert not trace_cache() # Assert that no exceptions have occurred assert not exceptions, exceptions @@ -290,7 +289,7 @@ def _test(): # The agent should have removed all traces from the cache since # run_until_complete has terminated - assert not trace_cache()._cache + assert not trace_cache() # Assert that no exceptions have occurred assert not exceptions, exceptions