Skip to content

Commit

Permalink
GraphQL Fragment Spread Logic (#293)
Browse files Browse the repository at this point in the history
* Implement fragment spread deepest path logic.

* Reformat files.
  • Loading branch information
umaannamalai authored and TimPansino committed Jul 29, 2021
1 parent 3d88d90 commit f1c6af2
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 97 deletions.
101 changes: 81 additions & 20 deletions newrelic/hooks/framework_graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

def graphql_version():
from graphql import __version__ as version

return tuple(int(v) for v in version.split("."))


Expand Down Expand Up @@ -69,13 +70,13 @@ def wrap_executor_context_init(wrapped, instance, args, kwargs):
instance.field_resolver = wrap_resolver(instance.field_resolver)
instance.field_resolver._nr_wrapped = True


return result


def bind_operation_v3(operation, root_value):
return operation


def bind_operation_v2(exe_context, operation, root_value):
return operation

Expand Down Expand Up @@ -105,12 +106,22 @@ def wrap_execute_operation(wrapped, instance, args, kwargs):
if get_node_value(field, "name") in GRAPHQL_INTROSPECTION_FIELDS:
ignore_transaction()

deepest_path = traverse_deepest_unique_path(fields)
if graphql_version() <= (3, 0, 0):
fragments = args[
0
].fragments # In v2, args[0] is the ExecutionContext object
else:
fragments = instance.fragments # instance is the ExecutionContext object
deepest_path = traverse_deepest_unique_path(fields, fragments)
trace.deepest_path = deepest_path = ".".join(deepest_path) or ""

transaction.set_transaction_name(callable_name(wrapped), "GraphQL", priority=11)
result = wrapped(*args, **kwargs)
transaction_name = "%s/%s/%s" % (operation_type, operation_name, deepest_path) if deepest_path else "%s/%s" % (operation_type, operation_name)
transaction_name = (
"%s/%s/%s" % (operation_type, operation_name, deepest_path)
if deepest_path
else "%s/%s" % (operation_type, operation_name)
)
transaction.set_transaction_name(transaction_name, "GraphQL", priority=14)

return result
Expand All @@ -123,74 +134,121 @@ def get_node_value(field, attr, subattr="value"):
return field_name


def is_fragment_spread_node(field):
# Resolve version specific imports
try:
from graphql.language.ast import FragmentSpread
except ImportError:
from graphql import FragmentSpreadNode as FragmentSpread

return isinstance(field, FragmentSpread)


def is_fragment(field):
# Resolve version specific imports
try:
from graphql.language.ast import FragmentSpread, InlineFragment
except ImportError:
from graphql import FragmentSpreadNode as FragmentSpread, InlineFragmentNode as InlineFragment
from graphql import (
FragmentSpreadNode as FragmentSpread,
InlineFragmentNode as InlineFragment,
)

_fragment_types = (InlineFragment, FragmentSpread)

return isinstance(field, _fragment_types)


def is_named_fragment(field):
# Resolve version specific imports
try:
from graphql.language.ast import NamedType
except ImportError:
from graphql import NamedTypeNode as NamedType

return is_fragment(field) and getattr(field, "type_condition", None) is not None and isinstance(field.type_condition, NamedType)
return (
is_fragment(field)
and getattr(field, "type_condition", None) is not None
and isinstance(field.type_condition, NamedType)
)


def traverse_deepest_unique_path(fields):
deepest_path = deque()
def filter_ignored_fields(fields):
filtered_fields = [
f for f in fields if get_node_value(f, "name") not in GRAPHQL_IGNORED_FIELDS
]
return filtered_fields


def traverse_deepest_unique_path(fields, fragments):
deepest_path = deque()
while fields is not None and len(fields) > 0:
fields = [f for f in fields if get_node_value(f, "name") not in GRAPHQL_IGNORED_FIELDS]
fields = filter_ignored_fields(fields)
if len(fields) != 1: # Either selections is empty, or non-unique
return deepest_path
field = fields[0]

field_name = get_node_value(field, "name")
fragment_selection_set = []

if is_named_fragment(field):
name = get_node_value(field.type_condition, "name")
if name:
deepest_path.append("%s<%s>" % (deepest_path.pop(), name))

elif is_fragment(field):
break
if len(list(fragments.values())) != 1:
return deepest_path

# list(fragments.values())[0] 's index is OK because the previous line
# ensures that there is only one field in the list
full_fragment_selection_set = list(fragments.values())[
0
].selection_set.selections
fragment_selection_set = filter_ignored_fields(full_fragment_selection_set)

if len(fragment_selection_set) != 1:
return deepest_path
else:
fragment_field_name = get_node_value(fragment_selection_set[0], "name")
deepest_path.append(fragment_field_name)

else:
if field_name:
deepest_path.append(field_name)

if is_fragment_spread_node(field):
field = fragment_selection_set[0]
if field.selection_set is None:
break
else:
fields = field.selection_set.selections

return deepest_path


def bind_get_middleware_resolvers(middlewares):
return middlewares


def wrap_get_middleware_resolvers(wrapped, instance, args, kwargs):
middlewares = bind_get_middleware_resolvers(*args, **kwargs)
middlewares = [wrap_middleware(m) if not hasattr(m, "_nr_wrapped") else m for m in middlewares]
middlewares = [
wrap_middleware(m) if not hasattr(m, "_nr_wrapped") else m for m in middlewares
]
for m in middlewares:
m._nr_wrapped = True

return wrapped(middlewares)


@function_wrapper
def wrap_middleware(wrapped, instance, args, kwargs):
transaction = current_transaction()
if transaction is None:
return wrapped(*args, **kwargs)

name = callable_name(wrapped)
transaction.set_transaction_name(name, 'GraphQL', priority=12)
transaction.set_transaction_name(name, "GraphQL", priority=12)
with FunctionTrace(name):
with ErrorTrace(ignore=ignore_graphql_duplicate_exception):
return wrapped(*args, **kwargs)
Expand Down Expand Up @@ -253,7 +311,7 @@ def wrap_validate(wrapped, instance, args, kwargs):
if transaction is None:
return wrapped(*args, **kwargs)

transaction.set_transaction_name(callable_name(wrapped),"GraphQL", priority=10)
transaction.set_transaction_name(callable_name(wrapped), "GraphQL", priority=10)

# Run and collect errors
errors = wrapped(*args, **kwargs)
Expand All @@ -267,6 +325,7 @@ def wrap_validate(wrapped, instance, args, kwargs):

return errors


def wrap_parse(wrapped, instance, args, kwargs):
transaction = current_transaction()
if transaction is None:
Expand All @@ -281,7 +340,9 @@ def bind_resolve_field_v3(parent_type, source, field_nodes, path):
return parent_type, field_nodes, path


def bind_resolve_field_v2(exe_context, parent_type, source, field_asts, parent_info, field_path):
def bind_resolve_field_v2(
exe_context, parent_type, source, field_asts, parent_info, field_path
):
return parent_type, field_asts, field_path


Expand Down Expand Up @@ -323,7 +384,8 @@ def bind_execute_graphql_query(
operation_name=None,
middleware=None,
backend=None,
**execute_options):
**execute_options
):

return request_string

Expand All @@ -335,9 +397,9 @@ def wrap_graphql_impl(wrapped, instance, args, kwargs):
return wrapped(*args, **kwargs)

version = graphql_version()
framework_version = '.'.join(map(str, version))
framework_version = ".".join(map(str, version))

transaction.add_framework_info(name='GraphQL', version=framework_version)
transaction.add_framework_info(name="GraphQL", version=framework_version)

if graphql_version() <= (3, 0, 0):
bind_query = bind_execute_graphql_query
Expand Down Expand Up @@ -378,9 +440,8 @@ def instrument_graphql_execute(module):
wrap_function_wrapper(module, "resolve_field", wrap_resolve_field)

if hasattr(module, "execute_operation"):
wrap_function_wrapper(
module, "execute_operation", wrap_execute_operation
)
wrap_function_wrapper(module, "execute_operation", wrap_execute_operation)


def instrument_graphql_execution_utils(module):
if hasattr(module, "ExecutionContext"):
Expand Down
Loading

0 comments on commit f1c6af2

Please sign in to comment.