Skip to content

Commit

Permalink
Starlette GraphQL Context Propagation (#361)
Browse files Browse the repository at this point in the history
* Add graphql sync tests to fastapi

* Add starlette context propagation

* Move starlette tests to starlette not fastapi
  • Loading branch information
TimPansino committed Sep 16, 2021
1 parent 49c95c3 commit acf43cb
Show file tree
Hide file tree
Showing 10 changed files with 193 additions and 86 deletions.
5 changes: 5 additions & 0 deletions newrelic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2453,6 +2453,11 @@ def _process_module_builtin_defaults():
"newrelic.hooks.framework_starlette",
"instrument_starlette_background_task",
)
_process_module_definition(
"starlette.concurrency",
"newrelic.hooks.framework_starlette",
"instrument_starlette_concurrency",
)

_process_module_definition(
"strawberry.asgi",
Expand Down
56 changes: 56 additions & 0 deletions newrelic/core/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2010 New Relic, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This module implements utilities for context propagation for tracing across threads.
"""

from newrelic.common.object_wrapper import function_wrapper
from newrelic.core.trace_cache import trace_cache

class ContextOf(object):
def __init__(self, trace_cache_id):
self.trace_cache = trace_cache()
self.trace = self.trace_cache._cache.get(trace_cache_id)
self.thread_id = None
self.restore = None

def __enter__(self):
if self.trace:
self.thread_id = self.trace_cache.current_thread_id()
self.restore = self.trace_cache._cache.get(self.thread_id)
self.trace_cache._cache[self.thread_id] = self.trace
return self

def __exit__(self, exc, value, tb):
if self.restore:
self.trace_cache._cache[self.thread_id] = self.restore


async def context_wrapper_async(awaitable, trace_cache_id):
with ContextOf(trace_cache_id):
return await awaitable


def context_wrapper(func, trace_cache_id):
@function_wrapper
def _context_wrapper(wrapped, instance, args, kwargs):
with ContextOf(trace_cache_id):
return wrapped(*args, **kwargs)

return _context_wrapper(func)


def current_thread_id():
return trace_cache().current_thread_id()
25 changes: 1 addition & 24 deletions newrelic/hooks/adapter_asgiref.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,12 @@
from newrelic.common.object_wrapper import wrap_function_wrapper
from newrelic.core.trace_cache import trace_cache
from newrelic.core.context import context_wrapper_async, ContextOf


def _bind_thread_handler(loop, source_task, *args, **kwargs):
return source_task


class ContextOf(object):
def __init__(self, trace_cache_id):
self.trace_cache = trace_cache()
self.trace = self.trace_cache._cache.get(trace_cache_id)
self.thread_id = None
self.restore = None

def __enter__(self):
if self.trace:
self.thread_id = self.trace_cache.current_thread_id()
self.restore = self.trace_cache._cache.get(self.thread_id)
self.trace_cache._cache[self.thread_id] = self.trace
return self

def __exit__(self, exc, value, tb):
if self.restore:
self.trace_cache._cache[self.thread_id] = self.restore


async def context_wrapper_async(awaitable, trace_cache_id):
with ContextOf(trace_cache_id):
return await awaitable


def thread_handler_wrapper(wrapped, instance, args, kwargs):
task = _bind_thread_handler(*args, **kwargs)
with ContextOf(id(task)):
Expand Down
1 change: 0 additions & 1 deletion newrelic/hooks/framework_graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def wrap_execute_operation(wrapped, instance, args, kwargs):
_logger.warning(
"Runtime instrumentation warning. GraphQL operation found without active GraphQLOperationTrace."
)
breakpoint()
return wrapped(*args, **kwargs)

try:
Expand Down
22 changes: 22 additions & 0 deletions newrelic/hooks/framework_starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
wrap_function_wrapper,
)
from newrelic.core.config import should_ignore_error
from newrelic.core.context import context_wrapper, current_thread_id
from newrelic.core.trace_cache import trace_cache


Expand Down Expand Up @@ -204,6 +205,23 @@ def error_middleware_wrapper(wrapped, instance, args, kwargs):
return FunctionTraceWrapper(wrapped)(*args, **kwargs)


def bind_run_in_threadpool(func, *args, **kwargs):
return func, args, kwargs


async def wrap_run_in_threadpool(wrapped, instance, args, kwargs):
transaction = current_transaction()
trace = current_trace()

if not transaction or not trace:
return await wrapped(*args, **kwargs)

func, args, kwargs = bind_run_in_threadpool(*args, **kwargs)
func = context_wrapper(func, current_thread_id())

return await wrapped(func, *args, **kwargs)


def instrument_starlette_applications(module):
framework = framework_details()
version_info = tuple(int(v) for v in framework[1].split(".", 3)[:3])
Expand Down Expand Up @@ -256,3 +274,7 @@ def instrument_starlette_exceptions(module):

def instrument_starlette_background_task(module):
wrap_function_wrapper(module, "BackgroundTask.__call__", wrap_background_method)


def instrument_starlette_concurrency(module):
wrap_function_wrapper(module, "run_in_threadpool", wrap_run_in_threadpool)
12 changes: 0 additions & 12 deletions tests/framework_fastapi/_target_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
# limitations under the License.

from fastapi import FastAPI
from graphene import ObjectType, String, Schema
from graphql.execution.executors.asyncio import AsyncioExecutor
from starlette.graphql import GraphQLApp

from newrelic.api.transaction import current_transaction
from testing_support.asgi_testing import AsgiTest
Expand All @@ -35,13 +32,4 @@ async def non_sync():
return {}


class Query(ObjectType):
hello = String()

def resolve_hello(self, info):
return "Hello!"


app.add_route("/graphql", GraphQLApp(executor_class=AsyncioExecutor, schema=Schema(query=Query)))

target_application = AsgiTest(app)
49 changes: 1 addition & 48 deletions tests/framework_fastapi/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
# limitations under the License.

import pytest
from testing_support.fixtures import dt_enabled, validate_transaction_metrics
from testing_support.validators.validate_span_events import validate_span_events
from testing_support.fixtures import validate_transaction_metrics


@pytest.mark.parametrize("endpoint,transaction_name", (
Expand All @@ -29,49 +28,3 @@ def _test():
assert response.status == 200

_test()


@dt_enabled
def test_graphql_endpoint(app):
from graphql import __version__ as version

FRAMEWORK_METRICS = [
("Python/Framework/GraphQL/%s" % version, 1),
]
_test_scoped_metrics = [
("GraphQL/resolve/GraphQL/hello", 1),
("GraphQL/operation/GraphQL/query/<anonymous>/hello", 1),
]
_test_unscoped_metrics = [
("GraphQL/all", 1),
("GraphQL/GraphQL/all", 1),
("GraphQL/allWeb", 1),
("GraphQL/GraphQL/allWeb", 1),
] + _test_scoped_metrics

_expected_query_operation_attributes = {
"graphql.operation.type": "query",
"graphql.operation.name": "<anonymous>",
"graphql.operation.query": "{ hello }",
}
_expected_query_resolver_attributes = {
"graphql.field.name": "hello",
"graphql.field.parentType": "Query",
"graphql.field.path": "hello",
"graphql.field.returnType": "String",
}

@validate_span_events(exact_agents=_expected_query_operation_attributes)
@validate_span_events(exact_agents=_expected_query_resolver_attributes)
@validate_transaction_metrics(
"query/<anonymous>/hello",
"GraphQL",
scoped_metrics=_test_scoped_metrics,
rollup_metrics=_test_unscoped_metrics + FRAMEWORK_METRICS,
)
def _test():
response = app.make_request("POST", "/graphql", params="query=%7B%20hello%20%7D")
assert response.status == 200
assert "Hello!" in response.body.decode("utf-8")

_test()
37 changes: 37 additions & 0 deletions tests/framework_starlette/_test_graphql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2010 New Relic, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from starlette.applications import Starlette
from starlette.routing import Route
from testing_support.asgi_testing import AsgiTest

from graphene import ObjectType, String, Schema
from graphql.execution.executors.asyncio import AsyncioExecutor
from starlette.graphql import GraphQLApp


class Query(ObjectType):
hello = String()

def resolve_hello(self, info):
return "Hello!"


routes = [
Route("/async", GraphQLApp(executor_class=AsyncioExecutor, schema=Schema(query=Query))),
Route("/sync", GraphQLApp(schema=Schema(query=Query))),
]

app = Starlette(routes=routes)
target_application = AsgiTest(app)
70 changes: 70 additions & 0 deletions tests/framework_starlette/test_graphql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2010 New Relic, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import pytest
from testing_support.fixtures import dt_enabled, validate_transaction_metrics
from testing_support.validators.validate_span_events import validate_span_events

@pytest.fixture(scope="session")
def target_application():
import _test_graphql

return _test_graphql.target_application

@dt_enabled
@pytest.mark.parametrize("endpoint", ("/async", "/sync"))
def test_graphql_metrics_and_attrs(target_application, endpoint):
from graphql import __version__ as version

FRAMEWORK_METRICS = [
("Python/Framework/GraphQL/%s" % version, 1),
]
_test_scoped_metrics = [
("GraphQL/resolve/GraphQL/hello", 1),
("GraphQL/operation/GraphQL/query/<anonymous>/hello", 1),
]
_test_unscoped_metrics = [
("GraphQL/all", 1),
("GraphQL/GraphQL/all", 1),
("GraphQL/allWeb", 1),
("GraphQL/GraphQL/allWeb", 1),
] + _test_scoped_metrics

_expected_query_operation_attributes = {
"graphql.operation.type": "query",
"graphql.operation.name": "<anonymous>",
"graphql.operation.query": "{ hello }",
}
_expected_query_resolver_attributes = {
"graphql.field.name": "hello",
"graphql.field.parentType": "Query",
"graphql.field.path": "hello",
"graphql.field.returnType": "String",
}

@validate_span_events(exact_agents=_expected_query_operation_attributes)
@validate_span_events(exact_agents=_expected_query_resolver_attributes)
@validate_transaction_metrics(
"query/<anonymous>/hello",
"GraphQL",
scoped_metrics=_test_scoped_metrics,
rollup_metrics=_test_unscoped_metrics + FRAMEWORK_METRICS,
)
def _test():
response = target_application.make_request("POST", endpoint, body=json.dumps({"query": "{ hello }"}), headers={"Content-Type": "application/json"})
assert response.status == 200
assert "Hello!" in response.body.decode("utf-8")

_test()
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,6 @@ deps =
framework_falcon-falcon0200: falcon<2.1
framework_falcon-falconmaster: https://github.com/falconry/falcon/archive/master.zip
framework_fastapi: fastapi
framework_fastapi: graphene
framework_fastapi: asyncio
framework_flask: Flask-Compress
framework_flask-flask0012: flask<0.13
Expand Down Expand Up @@ -283,6 +282,7 @@ deps =
framework_sanic-sanic210300: sanic<21.3.1
framework_sanic-saniclatest: sanic
framework_sanic-sanic{1812,190301,1906}: aiohttp
framework_starlette: graphene
framework_starlette-starlette0014: starlette<0.15
framework_starlette-starlettelatest: starlette
framework_strawberry: starlette
Expand Down

0 comments on commit acf43cb

Please sign in to comment.