diff --git a/newrelic/config.py b/newrelic/config.py index a1ba9a9eee..bf1095555a 100644 --- a/newrelic/config.py +++ b/newrelic/config.py @@ -2453,6 +2453,11 @@ def _process_module_builtin_defaults(): "newrelic.hooks.framework_starlette", "instrument_starlette_background_task", ) + _process_module_definition( + "starlette.concurrency", + "newrelic.hooks.framework_starlette", + "instrument_starlette_concurrency", + ) _process_module_definition( "strawberry.asgi", diff --git a/newrelic/core/context.py b/newrelic/core/context.py new file mode 100644 index 0000000000..4e438de047 --- /dev/null +++ b/newrelic/core/context.py @@ -0,0 +1,56 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module implements utilities for context propagation for tracing across threads. +""" + +from newrelic.common.object_wrapper import function_wrapper +from newrelic.core.trace_cache import trace_cache + +class ContextOf(object): + def __init__(self, trace_cache_id): + self.trace_cache = trace_cache() + self.trace = self.trace_cache._cache.get(trace_cache_id) + self.thread_id = None + self.restore = None + + def __enter__(self): + if self.trace: + self.thread_id = self.trace_cache.current_thread_id() + self.restore = self.trace_cache._cache.get(self.thread_id) + self.trace_cache._cache[self.thread_id] = self.trace + return self + + def __exit__(self, exc, value, tb): + if self.restore: + self.trace_cache._cache[self.thread_id] = self.restore + + +async def context_wrapper_async(awaitable, trace_cache_id): + with ContextOf(trace_cache_id): + return await awaitable + + +def context_wrapper(func, trace_cache_id): + @function_wrapper + def _context_wrapper(wrapped, instance, args, kwargs): + with ContextOf(trace_cache_id): + return wrapped(*args, **kwargs) + + return _context_wrapper(func) + + +def current_thread_id(): + return trace_cache().current_thread_id() diff --git a/newrelic/hooks/adapter_asgiref.py b/newrelic/hooks/adapter_asgiref.py index 29c83d849a..d6dea1e049 100644 --- a/newrelic/hooks/adapter_asgiref.py +++ b/newrelic/hooks/adapter_asgiref.py @@ -1,35 +1,12 @@ from newrelic.common.object_wrapper import wrap_function_wrapper from newrelic.core.trace_cache import trace_cache +from newrelic.core.context import context_wrapper_async, ContextOf def _bind_thread_handler(loop, source_task, *args, **kwargs): return source_task -class ContextOf(object): - def __init__(self, trace_cache_id): - self.trace_cache = trace_cache() - self.trace = self.trace_cache._cache.get(trace_cache_id) - self.thread_id = None - self.restore = None - - def __enter__(self): - if self.trace: - self.thread_id = self.trace_cache.current_thread_id() - self.restore = self.trace_cache._cache.get(self.thread_id) - self.trace_cache._cache[self.thread_id] = self.trace - return self - - def __exit__(self, exc, value, tb): - if self.restore: - self.trace_cache._cache[self.thread_id] = self.restore - - -async def context_wrapper_async(awaitable, trace_cache_id): - with ContextOf(trace_cache_id): - return await awaitable - - def thread_handler_wrapper(wrapped, instance, args, kwargs): task = _bind_thread_handler(*args, **kwargs) with ContextOf(id(task)): diff --git a/newrelic/hooks/framework_graphql.py b/newrelic/hooks/framework_graphql.py index c71604a0c3..95971c1a7f 100644 --- a/newrelic/hooks/framework_graphql.py +++ b/newrelic/hooks/framework_graphql.py @@ -113,7 +113,6 @@ def wrap_execute_operation(wrapped, instance, args, kwargs): _logger.warning( "Runtime instrumentation warning. GraphQL operation found without active GraphQLOperationTrace." ) - breakpoint() return wrapped(*args, **kwargs) try: diff --git a/newrelic/hooks/framework_starlette.py b/newrelic/hooks/framework_starlette.py index e38f280c5f..47a2128dce 100644 --- a/newrelic/hooks/framework_starlette.py +++ b/newrelic/hooks/framework_starlette.py @@ -25,6 +25,7 @@ wrap_function_wrapper, ) from newrelic.core.config import should_ignore_error +from newrelic.core.context import context_wrapper, current_thread_id from newrelic.core.trace_cache import trace_cache @@ -204,6 +205,23 @@ def error_middleware_wrapper(wrapped, instance, args, kwargs): return FunctionTraceWrapper(wrapped)(*args, **kwargs) +def bind_run_in_threadpool(func, *args, **kwargs): + return func, args, kwargs + + +async def wrap_run_in_threadpool(wrapped, instance, args, kwargs): + transaction = current_transaction() + trace = current_trace() + + if not transaction or not trace: + return await wrapped(*args, **kwargs) + + func, args, kwargs = bind_run_in_threadpool(*args, **kwargs) + func = context_wrapper(func, current_thread_id()) + + return await wrapped(func, *args, **kwargs) + + def instrument_starlette_applications(module): framework = framework_details() version_info = tuple(int(v) for v in framework[1].split(".", 3)[:3]) @@ -256,3 +274,7 @@ def instrument_starlette_exceptions(module): def instrument_starlette_background_task(module): wrap_function_wrapper(module, "BackgroundTask.__call__", wrap_background_method) + + +def instrument_starlette_concurrency(module): + wrap_function_wrapper(module, "run_in_threadpool", wrap_run_in_threadpool) diff --git a/tests/framework_fastapi/_target_application.py b/tests/framework_fastapi/_target_application.py index ca45b056f6..cef1b0af16 100644 --- a/tests/framework_fastapi/_target_application.py +++ b/tests/framework_fastapi/_target_application.py @@ -13,9 +13,6 @@ # limitations under the License. from fastapi import FastAPI -from graphene import ObjectType, String, Schema -from graphql.execution.executors.asyncio import AsyncioExecutor -from starlette.graphql import GraphQLApp from newrelic.api.transaction import current_transaction from testing_support.asgi_testing import AsgiTest @@ -35,13 +32,4 @@ async def non_sync(): return {} -class Query(ObjectType): - hello = String() - - def resolve_hello(self, info): - return "Hello!" - - -app.add_route("/graphql", GraphQLApp(executor_class=AsyncioExecutor, schema=Schema(query=Query))) - target_application = AsgiTest(app) diff --git a/tests/framework_fastapi/test_application.py b/tests/framework_fastapi/test_application.py index 657ba127be..41860409ad 100644 --- a/tests/framework_fastapi/test_application.py +++ b/tests/framework_fastapi/test_application.py @@ -13,8 +13,7 @@ # limitations under the License. import pytest -from testing_support.fixtures import dt_enabled, validate_transaction_metrics -from testing_support.validators.validate_span_events import validate_span_events +from testing_support.fixtures import validate_transaction_metrics @pytest.mark.parametrize("endpoint,transaction_name", ( @@ -29,49 +28,3 @@ def _test(): assert response.status == 200 _test() - - -@dt_enabled -def test_graphql_endpoint(app): - from graphql import __version__ as version - - FRAMEWORK_METRICS = [ - ("Python/Framework/GraphQL/%s" % version, 1), - ] - _test_scoped_metrics = [ - ("GraphQL/resolve/GraphQL/hello", 1), - ("GraphQL/operation/GraphQL/query//hello", 1), - ] - _test_unscoped_metrics = [ - ("GraphQL/all", 1), - ("GraphQL/GraphQL/all", 1), - ("GraphQL/allWeb", 1), - ("GraphQL/GraphQL/allWeb", 1), - ] + _test_scoped_metrics - - _expected_query_operation_attributes = { - "graphql.operation.type": "query", - "graphql.operation.name": "", - "graphql.operation.query": "{ hello }", - } - _expected_query_resolver_attributes = { - "graphql.field.name": "hello", - "graphql.field.parentType": "Query", - "graphql.field.path": "hello", - "graphql.field.returnType": "String", - } - - @validate_span_events(exact_agents=_expected_query_operation_attributes) - @validate_span_events(exact_agents=_expected_query_resolver_attributes) - @validate_transaction_metrics( - "query//hello", - "GraphQL", - scoped_metrics=_test_scoped_metrics, - rollup_metrics=_test_unscoped_metrics + FRAMEWORK_METRICS, - ) - def _test(): - response = app.make_request("POST", "/graphql", params="query=%7B%20hello%20%7D") - assert response.status == 200 - assert "Hello!" in response.body.decode("utf-8") - - _test() diff --git a/tests/framework_starlette/_test_graphql.py b/tests/framework_starlette/_test_graphql.py new file mode 100644 index 0000000000..0b7feac62a --- /dev/null +++ b/tests/framework_starlette/_test_graphql.py @@ -0,0 +1,37 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from starlette.applications import Starlette +from starlette.routing import Route +from testing_support.asgi_testing import AsgiTest + +from graphene import ObjectType, String, Schema +from graphql.execution.executors.asyncio import AsyncioExecutor +from starlette.graphql import GraphQLApp + + +class Query(ObjectType): + hello = String() + + def resolve_hello(self, info): + return "Hello!" + + +routes = [ + Route("/async", GraphQLApp(executor_class=AsyncioExecutor, schema=Schema(query=Query))), + Route("/sync", GraphQLApp(schema=Schema(query=Query))), +] + +app = Starlette(routes=routes) +target_application = AsgiTest(app) diff --git a/tests/framework_starlette/test_graphql.py b/tests/framework_starlette/test_graphql.py new file mode 100644 index 0000000000..fd7d2ffcb3 --- /dev/null +++ b/tests/framework_starlette/test_graphql.py @@ -0,0 +1,70 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import pytest +from testing_support.fixtures import dt_enabled, validate_transaction_metrics +from testing_support.validators.validate_span_events import validate_span_events + +@pytest.fixture(scope="session") +def target_application(): + import _test_graphql + + return _test_graphql.target_application + +@dt_enabled +@pytest.mark.parametrize("endpoint", ("/async", "/sync")) +def test_graphql_metrics_and_attrs(target_application, endpoint): + from graphql import __version__ as version + + FRAMEWORK_METRICS = [ + ("Python/Framework/GraphQL/%s" % version, 1), + ] + _test_scoped_metrics = [ + ("GraphQL/resolve/GraphQL/hello", 1), + ("GraphQL/operation/GraphQL/query//hello", 1), + ] + _test_unscoped_metrics = [ + ("GraphQL/all", 1), + ("GraphQL/GraphQL/all", 1), + ("GraphQL/allWeb", 1), + ("GraphQL/GraphQL/allWeb", 1), + ] + _test_scoped_metrics + + _expected_query_operation_attributes = { + "graphql.operation.type": "query", + "graphql.operation.name": "", + "graphql.operation.query": "{ hello }", + } + _expected_query_resolver_attributes = { + "graphql.field.name": "hello", + "graphql.field.parentType": "Query", + "graphql.field.path": "hello", + "graphql.field.returnType": "String", + } + + @validate_span_events(exact_agents=_expected_query_operation_attributes) + @validate_span_events(exact_agents=_expected_query_resolver_attributes) + @validate_transaction_metrics( + "query//hello", + "GraphQL", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_unscoped_metrics + FRAMEWORK_METRICS, + ) + def _test(): + response = target_application.make_request("POST", endpoint, body=json.dumps({"query": "{ hello }"}), headers={"Content-Type": "application/json"}) + assert response.status == 200 + assert "Hello!" in response.body.decode("utf-8") + + _test() diff --git a/tox.ini b/tox.ini index 1cf1f6b3f7..853f3b4ca7 100644 --- a/tox.ini +++ b/tox.ini @@ -247,7 +247,6 @@ deps = framework_falcon-falcon0200: falcon<2.1 framework_falcon-falconmaster: https://github.com/falconry/falcon/archive/master.zip framework_fastapi: fastapi - framework_fastapi: graphene framework_fastapi: asyncio framework_flask: Flask-Compress framework_flask-flask0012: flask<0.13 @@ -283,6 +282,7 @@ deps = framework_sanic-sanic210300: sanic<21.3.1 framework_sanic-saniclatest: sanic framework_sanic-sanic{1812,190301,1906}: aiohttp + framework_starlette: graphene framework_starlette-starlette0014: starlette<0.15 framework_starlette-starlettelatest: starlette framework_strawberry: starlette