diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b2c221bcfb..e3b264a9fc 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1024,3 +1024,68 @@ jobs: name: coverage-${{ github.job }}-${{ strategy.job-index }} path: ./**/.coverage.* retention-days: 1 + + firestore: + env: + TOTAL_GROUPS: 1 + + strategy: + fail-fast: false + matrix: + group-number: [1] + + runs-on: ubuntu-20.04 + container: + image: ghcr.io/newrelic/newrelic-python-agent-ci:latest + options: >- + --add-host=host.docker.internal:host-gateway + timeout-minutes: 30 + + services: + firestore: + # Image set here MUST be repeated down below in options. See comment below. + image: gcr.io/google.com/cloudsdktool/google-cloud-cli:437.0.1-emulators + ports: + - 8080:8080 + # Set health checks to wait 5 seconds in lieu of an actual healthcheck + options: >- + --health-cmd "echo success" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + --health-start-period 5s + gcr.io/google.com/cloudsdktool/google-cloud-cli:437.0.1-emulators /bin/bash -c "gcloud emulators firestore start --host-port=0.0.0.0:8080" || + # This is a very hacky solution. GitHub Actions doesn't provide APIs for setting commands on services, but allows adding arbitrary options. + # --entrypoint won't work as it only accepts an executable and not the [] syntax. + # Instead, we specify the image again the command afterwards like a call to docker create. The result is a few environment variables + # and the original command being appended to our hijacked docker create command. We can avoid any issues by adding || to prevent that + # from every being executed as bash commands. + + steps: + - uses: actions/checkout@v3 + + - name: Fetch git tags + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" + git fetch --tags origin + + - name: Get Environments + id: get-envs + run: | + echo "envs=$(tox -l | grep '^${{ github.job }}\-' | ./.github/workflows/get-envs.py)" >> $GITHUB_OUTPUT + env: + GROUP_NUMBER: ${{ matrix.group-number }} + + - name: Test + run: | + tox -vv -e ${{ steps.get-envs.outputs.envs }} -p auto + env: + TOX_PARALLEL_NO_SPINNER: 1 + PY_COLORS: 0 + + - name: Upload Coverage Artifacts + uses: actions/upload-artifact@v3 + with: + name: coverage-${{ github.job }}-${{ strategy.job-index }} + path: ./**/.coverage.* + retention-days: 1 diff --git a/newrelic/api/database_trace.py b/newrelic/api/database_trace.py index 2bc4976887..1069be506f 100644 --- a/newrelic/api/database_trace.py +++ b/newrelic/api/database_trace.py @@ -16,7 +16,7 @@ import logging from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.database_node import DatabaseNode from newrelic.core.stack_trace import current_stack @@ -244,9 +244,9 @@ def create_node(self): ) -def DatabaseTraceWrapper(wrapped, sql, dbapi2_module=None): +def DatabaseTraceWrapper(wrapped, sql, dbapi2_module=None, async_wrapper=None): def _nr_database_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -273,9 +273,9 @@ def _nr_database_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_database_trace_wrapper_) -def database_trace(sql, dbapi2_module=None): - return functools.partial(DatabaseTraceWrapper, sql=sql, dbapi2_module=dbapi2_module) +def database_trace(sql, dbapi2_module=None, async_wrapper=None): + return functools.partial(DatabaseTraceWrapper, sql=sql, dbapi2_module=dbapi2_module, async_wrapper=async_wrapper) -def wrap_database_trace(module, object_path, sql, dbapi2_module=None): - wrap_object(module, object_path, DatabaseTraceWrapper, (sql, dbapi2_module)) +def wrap_database_trace(module, object_path, sql, dbapi2_module=None, async_wrapper=None): + wrap_object(module, object_path, DatabaseTraceWrapper, (sql, dbapi2_module, async_wrapper)) diff --git a/newrelic/api/datastore_trace.py b/newrelic/api/datastore_trace.py index fb40abcab3..0401c79ea5 100644 --- a/newrelic/api/datastore_trace.py +++ b/newrelic/api/datastore_trace.py @@ -15,7 +15,7 @@ import functools from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.datastore_node import DatastoreNode @@ -82,6 +82,9 @@ def __enter__(self): self.product = transaction._intern_string(self.product) self.target = transaction._intern_string(self.target) self.operation = transaction._intern_string(self.operation) + self.host = transaction._intern_string(self.host) + self.port_path_or_id = transaction._intern_string(self.port_path_or_id) + self.database_name = transaction._intern_string(self.database_name) datastore_tracer_settings = transaction.settings.datastore_tracer self.instance_reporting_enabled = datastore_tracer_settings.instance_reporting.enabled @@ -92,7 +95,14 @@ def __repr__(self): return "<%s object at 0x%x %s>" % ( self.__class__.__name__, id(self), - dict(product=self.product, target=self.target, operation=self.operation), + dict( + product=self.product, + target=self.target, + operation=self.operation, + host=self.host, + port_path_or_id=self.port_path_or_id, + database_name=self.database_name, + ), ) def finalize_data(self, transaction, exc=None, value=None, tb=None): @@ -125,7 +135,7 @@ def create_node(self): ) -def DatastoreTraceWrapper(wrapped, product, target, operation): +def DatastoreTraceWrapper(wrapped, product, target, operation, host=None, port_path_or_id=None, database_name=None, async_wrapper=None): """Wraps a method to time datastore queries. :param wrapped: The function to apply the trace to. @@ -140,6 +150,16 @@ def DatastoreTraceWrapper(wrapped, product, target, operation): or the name of any API function/method in the client library. :type operation: str or callable + :param host: The name of the server hosting the actual datastore. + :type host: str + :param port_path_or_id: The value passed in can represent either the port, + path, or id of the datastore being connected to. + :type port_path_or_id: str + :param database_name: The name of database where the current query is being + executed. + :type database_name: str + :param async_wrapper: An async trace wrapper from newrelic.common.async_wrapper. + :type async_wrapper: callable or None :rtype: :class:`newrelic.common.object_wrapper.FunctionWrapper` This is typically used to wrap datastore queries such as calls to Redis or @@ -155,7 +175,7 @@ def DatastoreTraceWrapper(wrapped, product, target, operation): """ def _nr_datastore_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -187,7 +207,33 @@ def _nr_datastore_trace_wrapper_(wrapped, instance, args, kwargs): else: _operation = operation - trace = DatastoreTrace(_product, _target, _operation, parent=parent, source=wrapped) + if callable(host): + if instance is not None: + _host = host(instance, *args, **kwargs) + else: + _host = host(*args, **kwargs) + else: + _host = host + + if callable(port_path_or_id): + if instance is not None: + _port_path_or_id = port_path_or_id(instance, *args, **kwargs) + else: + _port_path_or_id = port_path_or_id(*args, **kwargs) + else: + _port_path_or_id = port_path_or_id + + if callable(database_name): + if instance is not None: + _database_name = database_name(instance, *args, **kwargs) + else: + _database_name = database_name(*args, **kwargs) + else: + _database_name = database_name + + trace = DatastoreTrace( + _product, _target, _operation, _host, _port_path_or_id, _database_name, parent=parent, source=wrapped + ) if wrapper: # pylint: disable=W0125,W0126 return wrapper(wrapped, trace)(*args, **kwargs) @@ -198,7 +244,7 @@ def _nr_datastore_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_datastore_trace_wrapper_) -def datastore_trace(product, target, operation): +def datastore_trace(product, target, operation, host=None, port_path_or_id=None, database_name=None, async_wrapper=None): """Decorator allows datastore query to be timed. :param product: The name of the vendor. @@ -211,6 +257,16 @@ def datastore_trace(product, target, operation): or the name of any API function/method in the client library. :type operation: str + :param host: The name of the server hosting the actual datastore. + :type host: str + :param port_path_or_id: The value passed in can represent either the port, + path, or id of the datastore being connected to. + :type port_path_or_id: str + :param database_name: The name of database where the current query is being + executed. + :type database_name: str + :param async_wrapper: An async trace wrapper from newrelic.common.async_wrapper. + :type async_wrapper: callable or None This is typically used to decorate datastore queries such as calls to Redis or ElasticSearch. @@ -224,10 +280,21 @@ def datastore_trace(product, target, operation): ... time.sleep(*args, **kwargs) """ - return functools.partial(DatastoreTraceWrapper, product=product, target=target, operation=operation) - - -def wrap_datastore_trace(module, object_path, product, target, operation): + return functools.partial( + DatastoreTraceWrapper, + product=product, + target=target, + operation=operation, + host=host, + port_path_or_id=port_path_or_id, + database_name=database_name, + async_wrapper=async_wrapper, + ) + + +def wrap_datastore_trace( + module, object_path, product, target, operation, host=None, port_path_or_id=None, database_name=None, async_wrapper=None +): """Method applies custom timing to datastore query. :param module: Module containing the method to be instrumented. @@ -244,6 +311,16 @@ def wrap_datastore_trace(module, object_path, product, target, operation): or the name of any API function/method in the client library. :type operation: str + :param host: The name of the server hosting the actual datastore. + :type host: str + :param port_path_or_id: The value passed in can represent either the port, + path, or id of the datastore being connected to. + :type port_path_or_id: str + :param database_name: The name of database where the current query is being + executed. + :type database_name: str + :param async_wrapper: An async trace wrapper from newrelic.common.async_wrapper. + :type async_wrapper: callable or None This is typically used to time database query method calls such as Redis GET. @@ -256,4 +333,6 @@ def wrap_datastore_trace(module, object_path, product, target, operation): ... 'sleep') """ - wrap_object(module, object_path, DatastoreTraceWrapper, (product, target, operation)) + wrap_object( + module, object_path, DatastoreTraceWrapper, (product, target, operation, host, port_path_or_id, database_name, async_wrapper) + ) diff --git a/newrelic/api/external_trace.py b/newrelic/api/external_trace.py index c43c560c6c..2e147df450 100644 --- a/newrelic/api/external_trace.py +++ b/newrelic/api/external_trace.py @@ -16,7 +16,7 @@ from newrelic.api.cat_header_mixin import CatHeaderMixin from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.external_node import ExternalNode @@ -66,9 +66,9 @@ def create_node(self): ) -def ExternalTraceWrapper(wrapped, library, url, method=None): +def ExternalTraceWrapper(wrapped, library, url, method=None, async_wrapper=None): def dynamic_wrapper(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -103,7 +103,7 @@ def dynamic_wrapper(wrapped, instance, args, kwargs): return wrapped(*args, **kwargs) def literal_wrapper(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -125,9 +125,9 @@ def literal_wrapper(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, literal_wrapper) -def external_trace(library, url, method=None): - return functools.partial(ExternalTraceWrapper, library=library, url=url, method=method) +def external_trace(library, url, method=None, async_wrapper=None): + return functools.partial(ExternalTraceWrapper, library=library, url=url, method=method, async_wrapper=async_wrapper) -def wrap_external_trace(module, object_path, library, url, method=None): - wrap_object(module, object_path, ExternalTraceWrapper, (library, url, method)) +def wrap_external_trace(module, object_path, library, url, method=None, async_wrapper=None): + wrap_object(module, object_path, ExternalTraceWrapper, (library, url, method, async_wrapper)) diff --git a/newrelic/api/function_trace.py b/newrelic/api/function_trace.py index 474c1b2266..85d7617b68 100644 --- a/newrelic/api/function_trace.py +++ b/newrelic/api/function_trace.py @@ -15,7 +15,7 @@ import functools from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_names import callable_name from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.function_node import FunctionNode @@ -89,9 +89,9 @@ def create_node(self): ) -def FunctionTraceWrapper(wrapped, name=None, group=None, label=None, params=None, terminal=False, rollup=None): +def FunctionTraceWrapper(wrapped, name=None, group=None, label=None, params=None, terminal=False, rollup=None, async_wrapper=None): def dynamic_wrapper(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -147,7 +147,7 @@ def dynamic_wrapper(wrapped, instance, args, kwargs): return wrapped(*args, **kwargs) def literal_wrapper(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -171,13 +171,13 @@ def literal_wrapper(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, literal_wrapper) -def function_trace(name=None, group=None, label=None, params=None, terminal=False, rollup=None): +def function_trace(name=None, group=None, label=None, params=None, terminal=False, rollup=None, async_wrapper=None): return functools.partial( - FunctionTraceWrapper, name=name, group=group, label=label, params=params, terminal=terminal, rollup=rollup + FunctionTraceWrapper, name=name, group=group, label=label, params=params, terminal=terminal, rollup=rollup, async_wrapper=async_wrapper ) def wrap_function_trace( - module, object_path, name=None, group=None, label=None, params=None, terminal=False, rollup=None + module, object_path, name=None, group=None, label=None, params=None, terminal=False, rollup=None, async_wrapper=None ): - return wrap_object(module, object_path, FunctionTraceWrapper, (name, group, label, params, terminal, rollup)) + return wrap_object(module, object_path, FunctionTraceWrapper, (name, group, label, params, terminal, rollup, async_wrapper)) diff --git a/newrelic/api/graphql_trace.py b/newrelic/api/graphql_trace.py index 7a2c9ec02f..3d6ae6b09f 100644 --- a/newrelic/api/graphql_trace.py +++ b/newrelic/api/graphql_trace.py @@ -16,7 +16,7 @@ from newrelic.api.time_trace import TimeTrace, current_trace from newrelic.api.transaction import current_transaction -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.graphql_node import GraphQLOperationNode, GraphQLResolverNode @@ -109,9 +109,9 @@ def set_transaction_name(self, priority=None): transaction.set_transaction_name(name, "GraphQL", priority=priority) -def GraphQLOperationTraceWrapper(wrapped): +def GraphQLOperationTraceWrapper(wrapped, async_wrapper=None): def _nr_graphql_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -130,12 +130,12 @@ def _nr_graphql_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_graphql_trace_wrapper_) -def graphql_operation_trace(): - return functools.partial(GraphQLOperationTraceWrapper) +def graphql_operation_trace(async_wrapper=None): + return functools.partial(GraphQLOperationTraceWrapper, async_wrapper=async_wrapper) -def wrap_graphql_operation_trace(module, object_path): - wrap_object(module, object_path, GraphQLOperationTraceWrapper) +def wrap_graphql_operation_trace(module, object_path, async_wrapper=None): + wrap_object(module, object_path, GraphQLOperationTraceWrapper, (async_wrapper,)) class GraphQLResolverTrace(TimeTrace): @@ -193,9 +193,9 @@ def create_node(self): ) -def GraphQLResolverTraceWrapper(wrapped): +def GraphQLResolverTraceWrapper(wrapped, async_wrapper=None): def _nr_graphql_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -214,9 +214,9 @@ def _nr_graphql_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_graphql_trace_wrapper_) -def graphql_resolver_trace(): - return functools.partial(GraphQLResolverTraceWrapper) +def graphql_resolver_trace(async_wrapper=None): + return functools.partial(GraphQLResolverTraceWrapper, async_wrapper=async_wrapper) -def wrap_graphql_resolver_trace(module, object_path): - wrap_object(module, object_path, GraphQLResolverTraceWrapper) +def wrap_graphql_resolver_trace(module, object_path, async_wrapper=None): + wrap_object(module, object_path, GraphQLResolverTraceWrapper, (async_wrapper,)) diff --git a/newrelic/api/memcache_trace.py b/newrelic/api/memcache_trace.py index 6657a9ce27..87f12f9fc7 100644 --- a/newrelic/api/memcache_trace.py +++ b/newrelic/api/memcache_trace.py @@ -15,7 +15,7 @@ import functools from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.memcache_node import MemcacheNode @@ -51,9 +51,9 @@ def create_node(self): ) -def MemcacheTraceWrapper(wrapped, command): +def MemcacheTraceWrapper(wrapped, command, async_wrapper=None): def _nr_wrapper_memcache_trace_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -80,9 +80,9 @@ def _nr_wrapper_memcache_trace_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_wrapper_memcache_trace_) -def memcache_trace(command): - return functools.partial(MemcacheTraceWrapper, command=command) +def memcache_trace(command, async_wrapper=None): + return functools.partial(MemcacheTraceWrapper, command=command, async_wrapper=async_wrapper) -def wrap_memcache_trace(module, object_path, command): - wrap_object(module, object_path, MemcacheTraceWrapper, (command,)) +def wrap_memcache_trace(module, object_path, command, async_wrapper=None): + wrap_object(module, object_path, MemcacheTraceWrapper, (command, async_wrapper)) diff --git a/newrelic/api/message_trace.py b/newrelic/api/message_trace.py index be819d7044..f564c41cb4 100644 --- a/newrelic/api/message_trace.py +++ b/newrelic/api/message_trace.py @@ -16,7 +16,7 @@ from newrelic.api.cat_header_mixin import CatHeaderMixin from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.message_node import MessageNode @@ -91,9 +91,9 @@ def create_node(self): ) -def MessageTraceWrapper(wrapped, library, operation, destination_type, destination_name, params={}, terminal=True): +def MessageTraceWrapper(wrapped, library, operation, destination_type, destination_name, params={}, terminal=True, async_wrapper=None): def _nr_message_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -144,7 +144,7 @@ def _nr_message_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_message_trace_wrapper_) -def message_trace(library, operation, destination_type, destination_name, params={}, terminal=True): +def message_trace(library, operation, destination_type, destination_name, params={}, terminal=True, async_wrapper=None): return functools.partial( MessageTraceWrapper, library=library, @@ -153,10 +153,11 @@ def message_trace(library, operation, destination_type, destination_name, params destination_name=destination_name, params=params, terminal=terminal, + async_wrapper=async_wrapper, ) -def wrap_message_trace(module, object_path, library, operation, destination_type, destination_name, params={}, terminal=True): +def wrap_message_trace(module, object_path, library, operation, destination_type, destination_name, params={}, terminal=True, async_wrapper=None): wrap_object( - module, object_path, MessageTraceWrapper, (library, operation, destination_type, destination_name, params, terminal) + module, object_path, MessageTraceWrapper, (library, operation, destination_type, destination_name, params, terminal, async_wrapper) ) diff --git a/newrelic/common/async_wrapper.py b/newrelic/common/async_wrapper.py index c5f95308da..2d3db2b4be 100644 --- a/newrelic/common/async_wrapper.py +++ b/newrelic/common/async_wrapper.py @@ -18,7 +18,9 @@ is_coroutine_callable, is_asyncio_coroutine, is_generator_function, + is_async_generator_function, ) +from newrelic.packages import six def evaluate_wrapper(wrapper_string, wrapped, trace): @@ -29,7 +31,6 @@ def evaluate_wrapper(wrapper_string, wrapped, trace): def coroutine_wrapper(wrapped, trace): - WRAPPER = textwrap.dedent(""" @functools.wraps(wrapped) async def wrapper(*args, **kwargs): @@ -61,29 +62,76 @@ def wrapper(*args, **kwargs): return wrapped -def generator_wrapper(wrapped, trace): - @functools.wraps(wrapped) - def wrapper(*args, **kwargs): - g = wrapped(*args, **kwargs) - value = None - with trace: - while True: +if six.PY3: + def generator_wrapper(wrapped, trace): + WRAPPER = textwrap.dedent(""" + @functools.wraps(wrapped) + def wrapper(*args, **kwargs): + with trace: + result = yield from wrapped(*args, **kwargs) + return result + """) + + try: + return evaluate_wrapper(WRAPPER, wrapped, trace) + except: + return wrapped +else: + def generator_wrapper(wrapped, trace): + @functools.wraps(wrapped) + def wrapper(*args, **kwargs): + g = wrapped(*args, **kwargs) + with trace: try: - yielded = g.send(value) + yielded = g.send(None) + while True: + try: + sent = yield yielded + except GeneratorExit as e: + g.close() + raise + except BaseException as e: + yielded = g.throw(e) + else: + yielded = g.send(sent) except StopIteration: - break + return + return wrapper - try: - value = yield yielded - except BaseException as e: - value = yield g.throw(type(e), e) - return wrapper +def async_generator_wrapper(wrapped, trace): + WRAPPER = textwrap.dedent(""" + @functools.wraps(wrapped) + async def wrapper(*args, **kwargs): + g = wrapped(*args, **kwargs) + with trace: + try: + yielded = await g.asend(None) + while True: + try: + sent = yield yielded + except GeneratorExit as e: + await g.aclose() + raise + except BaseException as e: + yielded = await g.athrow(e) + else: + yielded = await g.asend(sent) + except StopAsyncIteration: + return + """) + + try: + return evaluate_wrapper(WRAPPER, wrapped, trace) + except: + return wrapped def async_wrapper(wrapped): if is_coroutine_callable(wrapped): return coroutine_wrapper + elif is_async_generator_function(wrapped): + return async_generator_wrapper elif is_generator_function(wrapped): if is_asyncio_coroutine(wrapped): return awaitable_generator_wrapper diff --git a/newrelic/common/coroutine.py b/newrelic/common/coroutine.py index cf4c91f85c..33a4922f56 100644 --- a/newrelic/common/coroutine.py +++ b/newrelic/common/coroutine.py @@ -43,3 +43,11 @@ def _iscoroutinefunction_tornado(fn): def is_coroutine_callable(wrapped): return is_coroutine_function(wrapped) or is_coroutine_function(getattr(wrapped, "__call__", None)) + + +if hasattr(inspect, 'isasyncgenfunction'): + def is_async_generator_function(wrapped): + return inspect.isasyncgenfunction(wrapped) +else: + def is_async_generator_function(wrapped): + return False diff --git a/newrelic/config.py b/newrelic/config.py index 8a041ad344..efeeaaec21 100644 --- a/newrelic/config.py +++ b/newrelic/config.py @@ -2269,6 +2269,87 @@ def _process_module_builtin_defaults(): "instrument_graphql_validate", ) + _process_module_definition( + "google.cloud.firestore_v1.base_client", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_base_client", + ) + _process_module_definition( + "google.cloud.firestore_v1.client", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_client", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_client", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_client", + ) + _process_module_definition( + "google.cloud.firestore_v1.document", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_document", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_document", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_document", + ) + _process_module_definition( + "google.cloud.firestore_v1.collection", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_collection", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_collection", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_collection", + ) + _process_module_definition( + "google.cloud.firestore_v1.query", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_query", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_query", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_query", + ) + _process_module_definition( + "google.cloud.firestore_v1.aggregation", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_aggregation", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_aggregation", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_aggregation", + ) + _process_module_definition( + "google.cloud.firestore_v1.batch", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_batch", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_batch", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_batch", + ) + _process_module_definition( + "google.cloud.firestore_v1.bulk_batch", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_bulk_batch", + ) + _process_module_definition( + "google.cloud.firestore_v1.transaction", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_transaction", + ) + _process_module_definition( + "google.cloud.firestore_v1.async_transaction", + "newrelic.hooks.datastore_firestore", + "instrument_google_cloud_firestore_v1_async_transaction", + ) + _process_module_definition( "ariadne.asgi", "newrelic.hooks.framework_ariadne", diff --git a/newrelic/hooks/datastore_firestore.py b/newrelic/hooks/datastore_firestore.py new file mode 100644 index 0000000000..6d3196a7c3 --- /dev/null +++ b/newrelic/hooks/datastore_firestore.py @@ -0,0 +1,473 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from newrelic.api.datastore_trace import wrap_datastore_trace +from newrelic.api.function_trace import wrap_function_trace +from newrelic.common.async_wrapper import generator_wrapper, async_generator_wrapper + + +def _conn_str_to_host(getter): + """Safely transform a getter that can retrieve a connection string into the resulting host.""" + + def closure(obj, *args, **kwargs): + try: + return getter(obj, *args, **kwargs).split(":")[0] + except Exception: + return None + + return closure + + +def _conn_str_to_port(getter): + """Safely transform a getter that can retrieve a connection string into the resulting port.""" + + def closure(obj, *args, **kwargs): + try: + return getter(obj, *args, **kwargs).split(":")[1] + except Exception: + return None + + return closure + + +# Default Target ID and Instance Info +_get_object_id = lambda obj, *args, **kwargs: getattr(obj, "id", None) +_get_client_database_string = lambda obj, *args, **kwargs: getattr( + getattr(obj, "_client", None), "_database_string", None +) +_get_client_target = lambda obj, *args, **kwargs: obj._client._target +_get_client_target_host = _conn_str_to_host(_get_client_target) +_get_client_target_port = _conn_str_to_port(_get_client_target) + +# Client Instance Info +_get_database_string = lambda obj, *args, **kwargs: getattr(obj, "_database_string", None) +_get_target = lambda obj, *args, **kwargs: obj._target +_get_target_host = _conn_str_to_host(_get_target) +_get_target_port = _conn_str_to_port(_get_target) + +# Query Target ID +_get_parent_id = lambda obj, *args, **kwargs: getattr(getattr(obj, "_parent", None), "id", None) + +# AggregationQuery Target ID +_get_collection_ref_id = lambda obj, *args, **kwargs: getattr(getattr(obj, "_collection_ref", None), "id", None) + + +def instrument_google_cloud_firestore_v1_base_client(module): + rollup = ("Datastore/all", "Datastore/Firestore/all") + wrap_function_trace( + module, "BaseClient.__init__", name="%s:BaseClient.__init__" % module.__name__, terminal=True, rollup=rollup + ) + + +def instrument_google_cloud_firestore_v1_client(module): + if hasattr(module, "Client"): + class_ = module.Client + for method in ("collections", "get_all"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "Client.%s" % method, + operation=method, + product="Firestore", + target=None, + host=_get_target_host, + port_path_or_id=_get_target_port, + database_name=_get_database_string, + async_wrapper=generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_async_client(module): + if hasattr(module, "AsyncClient"): + class_ = module.AsyncClient + for method in ("collections", "get_all"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncClient.%s" % method, + operation=method, + product="Firestore", + target=None, + host=_get_target_host, + port_path_or_id=_get_target_port, + database_name=_get_database_string, + async_wrapper=async_generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_collection(module): + if hasattr(module, "CollectionReference"): + class_ = module.CollectionReference + for method in ("add", "get"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "CollectionReference.%s" % method, + product="Firestore", + target=_get_object_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("stream", "list_documents"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "CollectionReference.%s" % method, + operation=method, + product="Firestore", + target=_get_object_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_async_collection(module): + if hasattr(module, "AsyncCollectionReference"): + class_ = module.AsyncCollectionReference + for method in ("add", "get"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncCollectionReference.%s" % method, + product="Firestore", + target=_get_object_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + operation=method, + ) + + for method in ("stream", "list_documents"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncCollectionReference.%s" % method, + operation=method, + product="Firestore", + target=_get_object_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=async_generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_document(module): + if hasattr(module, "DocumentReference"): + class_ = module.DocumentReference + for method in ("create", "delete", "get", "set", "update"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "DocumentReference.%s" % method, + product="Firestore", + target=_get_object_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("collections",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "DocumentReference.%s" % method, + operation=method, + product="Firestore", + target=_get_object_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_async_document(module): + if hasattr(module, "AsyncDocumentReference"): + class_ = module.AsyncDocumentReference + for method in ("create", "delete", "get", "set", "update"): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncDocumentReference.%s" % method, + product="Firestore", + target=_get_object_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("collections",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncDocumentReference.%s" % method, + operation=method, + product="Firestore", + target=_get_object_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=async_generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_query(module): + if hasattr(module, "Query"): + class_ = module.Query + for method in ("get",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "Query.%s" % method, + product="Firestore", + target=_get_parent_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("stream",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "Query.%s" % method, + operation=method, + product="Firestore", + target=_get_parent_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=generator_wrapper, + ) + + if hasattr(module, "CollectionGroup"): + class_ = module.CollectionGroup + for method in ("get_partitions",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "CollectionGroup.%s" % method, + operation=method, + product="Firestore", + target=_get_parent_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_async_query(module): + if hasattr(module, "AsyncQuery"): + class_ = module.AsyncQuery + for method in ("get",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncQuery.%s" % method, + product="Firestore", + target=_get_parent_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("stream",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncQuery.%s" % method, + operation=method, + product="Firestore", + target=_get_parent_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=async_generator_wrapper, + ) + + if hasattr(module, "AsyncCollectionGroup"): + class_ = module.AsyncCollectionGroup + for method in ("get_partitions",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncCollectionGroup.%s" % method, + operation=method, + product="Firestore", + target=_get_parent_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=async_generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_aggregation(module): + if hasattr(module, "AggregationQuery"): + class_ = module.AggregationQuery + for method in ("get",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AggregationQuery.%s" % method, + product="Firestore", + target=_get_collection_ref_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("stream",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AggregationQuery.%s" % method, + operation=method, + product="Firestore", + target=_get_collection_ref_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_async_aggregation(module): + if hasattr(module, "AsyncAggregationQuery"): + class_ = module.AsyncAggregationQuery + for method in ("get",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncAggregationQuery.%s" % method, + product="Firestore", + target=_get_collection_ref_id, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + for method in ("stream",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncAggregationQuery.%s" % method, + operation=method, + product="Firestore", + target=_get_collection_ref_id, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + async_wrapper=async_generator_wrapper, + ) + + +def instrument_google_cloud_firestore_v1_batch(module): + if hasattr(module, "WriteBatch"): + class_ = module.WriteBatch + for method in ("commit",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "WriteBatch.%s" % method, + product="Firestore", + target=None, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + +def instrument_google_cloud_firestore_v1_async_batch(module): + if hasattr(module, "AsyncWriteBatch"): + class_ = module.AsyncWriteBatch + for method in ("commit",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "AsyncWriteBatch.%s" % method, + product="Firestore", + target=None, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + +def instrument_google_cloud_firestore_v1_bulk_batch(module): + if hasattr(module, "BulkWriteBatch"): + class_ = module.BulkWriteBatch + for method in ("commit",): + if hasattr(class_, method): + wrap_datastore_trace( + module, + "BulkWriteBatch.%s" % method, + product="Firestore", + target=None, + operation=method, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + +def instrument_google_cloud_firestore_v1_transaction(module): + if hasattr(module, "Transaction"): + class_ = module.Transaction + for method in ("_commit", "_rollback"): + if hasattr(class_, method): + operation = method[1:] # Trim leading underscore + wrap_datastore_trace( + module, + "Transaction.%s" % method, + product="Firestore", + target=None, + operation=operation, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) + + +def instrument_google_cloud_firestore_v1_async_transaction(module): + if hasattr(module, "AsyncTransaction"): + class_ = module.AsyncTransaction + for method in ("_commit", "_rollback"): + if hasattr(class_, method): + operation = method[1:] # Trim leading underscore + wrap_datastore_trace( + module, + "AsyncTransaction.%s" % method, + product="Firestore", + target=None, + operation=operation, + host=_get_client_target_host, + port_path_or_id=_get_client_target_port, + database_name=_get_client_database_string, + ) diff --git a/tests/agent_features/_test_async_coroutine_trace.py b/tests/agent_features/_test_async_coroutine_trace.py index 51b81f5f64..1250b8c254 100644 --- a/tests/agent_features/_test_async_coroutine_trace.py +++ b/tests/agent_features/_test_async_coroutine_trace.py @@ -28,6 +28,7 @@ from newrelic.api.datastore_trace import datastore_trace from newrelic.api.external_trace import external_trace from newrelic.api.function_trace import function_trace +from newrelic.api.graphql_trace import graphql_operation_trace, graphql_resolver_trace from newrelic.api.memcache_trace import memcache_trace from newrelic.api.message_trace import message_trace @@ -41,6 +42,8 @@ (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), ], ) def test_awaitable_timing(event_loop, trace, metric): @@ -79,6 +82,8 @@ def _test(): (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), ], ) @pytest.mark.parametrize("yield_from", [True, False]) diff --git a/tests/agent_features/_test_async_generator_trace.py b/tests/agent_features/_test_async_generator_trace.py new file mode 100644 index 0000000000..30b970c372 --- /dev/null +++ b/tests/agent_features/_test_async_generator_trace.py @@ -0,0 +1,548 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import sys +import time + +import pytest +from testing_support.fixtures import capture_transaction_metrics, validate_tt_parenting +from testing_support.validators.validate_transaction_errors import ( + validate_transaction_errors, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +from newrelic.api.background_task import background_task +from newrelic.api.database_trace import database_trace +from newrelic.api.datastore_trace import datastore_trace +from newrelic.api.external_trace import external_trace +from newrelic.api.function_trace import function_trace +from newrelic.api.graphql_trace import graphql_operation_trace, graphql_resolver_trace +from newrelic.api.memcache_trace import memcache_trace +from newrelic.api.message_trace import message_trace + +asyncio = pytest.importorskip("asyncio") + + +@pytest.mark.parametrize( + "trace,metric", + [ + (functools.partial(function_trace, name="simple_gen"), "Function/simple_gen"), + (functools.partial(external_trace, library="lib", url="http://foo.com"), "External/foo.com/lib/"), + (functools.partial(database_trace, "select * from foo"), "Datastore/statement/None/foo/select"), + (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), + (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), + (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), + ], +) +def test_async_generator_timing(event_loop, trace, metric): + @trace() + async def simple_gen(): + time.sleep(0.1) + yield + time.sleep(0.1) + + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @validate_transaction_metrics( + "test_async_generator_timing", background_task=True, scoped_metrics=[(metric, 1)], rollup_metrics=[(metric, 1)] + ) + @background_task(name="test_async_generator_timing") + def _test_async_generator_timing(): + async def _test(): + async for _ in simple_gen(): + pass + + event_loop.run_until_complete(_test()) + _test_async_generator_timing() + + # Check that coroutines time the total call time (including pauses) + metric_key = (metric, "") + assert full_metrics[metric_key].total_call_time >= 0.2 + + +class MyException(Exception): + pass + + +@validate_transaction_metrics( + "test_async_generator_error", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +@validate_transaction_errors(errors=["_test_async_generator_trace:MyException"]) +def test_async_generator_error(event_loop): + @function_trace(name="agen") + async def agen(): + yield + + @background_task(name="test_async_generator_error") + async def _test(): + gen = agen() + await gen.asend(None) + await gen.athrow(MyException) + + with pytest.raises(MyException): + event_loop.run_until_complete(_test()) + + +@validate_transaction_metrics( + "test_async_generator_caught_exception", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +@validate_transaction_errors(errors=[]) +def test_async_generator_caught_exception(event_loop): + @function_trace(name="agen") + async def agen(): + for _ in range(2): + time.sleep(0.1) + try: + yield + except ValueError: + pass + + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @background_task(name="test_async_generator_caught_exception") + def _test_async_generator_caught_exception(): + async def _test(): + gen = agen() + # kickstart the generator (the try/except logic is inside the + # generator) + await gen.asend(None) + await gen.athrow(ValueError) + + # consume the generator + async for _ in gen: + pass + + # The ValueError should not be reraised + event_loop.run_until_complete(_test()) + _test_async_generator_caught_exception() + + assert full_metrics[("Function/agen", "")].total_call_time >= 0.2 + + +@validate_transaction_metrics( + "test_async_generator_handles_terminal_nodes", + background_task=True, + scoped_metrics=[("Function/parent", 1), ("Function/agen", None)], + rollup_metrics=[("Function/parent", 1), ("Function/agen", None)], +) +def test_async_generator_handles_terminal_nodes(event_loop): + # sometimes coroutines can be called underneath terminal nodes + # In this case, the trace shouldn't actually be created and we also + # shouldn't get any errors + + @function_trace(name="agen") + async def agen(): + yield + time.sleep(0.1) + + @function_trace(name="parent", terminal=True) + async def parent(): + # parent calls child + async for _ in agen(): + pass + + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @background_task(name="test_async_generator_handles_terminal_nodes") + def _test_async_generator_handles_terminal_nodes(): + async def _test(): + await parent() + + event_loop.run_until_complete(_test()) + _test_async_generator_handles_terminal_nodes() + + metric_key = ("Function/parent", "") + assert full_metrics[metric_key].total_exclusive_call_time >= 0.1 + + +@validate_transaction_metrics( + "test_async_generator_close_ends_trace", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_async_generator_close_ends_trace(event_loop): + @function_trace(name="agen") + async def agen(): + yield + + @background_task(name="test_async_generator_close_ends_trace") + async def _test(): + gen = agen() + + # kickstart the coroutine + await gen.asend(None) + + # trace should be ended/recorded by close + await gen.aclose() + + # We may call gen.close as many times as we want + await gen.aclose() + + event_loop.run_until_complete(_test()) + +@validate_tt_parenting( + ( + "TransactionNode", + [ + ( + "FunctionNode", + [ + ("FunctionNode", []), + ], + ), + ], + ) +) +@validate_transaction_metrics( + "test_async_generator_parents", + background_task=True, + scoped_metrics=[("Function/child", 1), ("Function/parent", 1)], + rollup_metrics=[("Function/child", 1), ("Function/parent", 1)], +) +def test_async_generator_parents(event_loop): + @function_trace(name="child") + async def child(): + yield + time.sleep(0.1) + yield + + @function_trace(name="parent") + async def parent(): + time.sleep(0.1) + yield + async for _ in child(): + pass + + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @background_task(name="test_async_generator_parents") + def _test_async_generator_parents(): + async def _test(): + async for _ in parent(): + pass + + event_loop.run_until_complete(_test()) + _test_async_generator_parents() + + # Check that the child time is subtracted from the parent time (parenting + # relationship is correctly established) + key = ("Function/parent", "") + assert full_metrics[key].total_exclusive_call_time < 0.2 + + +@validate_transaction_metrics( + "test_asend_receives_a_value", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_asend_receives_a_value(event_loop): + _received = [] + @function_trace(name="agen") + async def agen(): + value = yield + _received.append(value) + yield value + + @background_task(name="test_asend_receives_a_value") + async def _test(): + gen = agen() + + # kickstart the coroutine + await gen.asend(None) + + assert await gen.asend("foobar") == "foobar" + assert _received and _received[0] == "foobar" + + # finish consumption of the coroutine if necessary + async for _ in gen: + pass + + event_loop.run_until_complete(_test()) + + +@validate_transaction_metrics( + "test_athrow_yields_a_value", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_athrow_yields_a_value(event_loop): + @function_trace(name="agen") + async def agen(): + for _ in range(2): + try: + yield + except MyException: + yield "foobar" + + @background_task(name="test_athrow_yields_a_value") + async def _test(): + gen = agen() + + # kickstart the coroutine + await gen.asend(None) + + assert await gen.athrow(MyException) == "foobar" + + # finish consumption of the coroutine if necessary + async for _ in gen: + pass + + event_loop.run_until_complete(_test()) + + +@validate_transaction_metrics( + "test_multiple_throws_yield_a_value", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_multiple_throws_yield_a_value(event_loop): + @function_trace(name="agen") + async def agen(): + value = None + for _ in range(4): + try: + yield value + value = "bar" + except MyException: + value = "foo" + + + @background_task(name="test_multiple_throws_yield_a_value") + async def _test(): + gen = agen() + + # kickstart the coroutine + assert await gen.asend(None) is None + assert await gen.athrow(MyException) == "foo" + assert await gen.athrow(MyException) == "foo" + assert await gen.asend(None) == "bar" + + # finish consumption of the coroutine if necessary + async for _ in gen: + pass + + event_loop.run_until_complete(_test()) + + +@validate_transaction_metrics( + "test_athrow_does_not_yield_a_value", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_athrow_does_not_yield_a_value(event_loop): + @function_trace(name="agen") + async def agen(): + for _ in range(2): + try: + yield + except MyException: + return + + @background_task(name="test_athrow_does_not_yield_a_value") + async def _test(): + gen = agen() + + # kickstart the coroutine + await gen.asend(None) + + # async generator will raise StopAsyncIteration + with pytest.raises(StopAsyncIteration): + await gen.athrow(MyException) + + + event_loop.run_until_complete(_test()) + + +@pytest.mark.parametrize( + "trace", + [ + function_trace(name="simple_gen"), + external_trace(library="lib", url="http://foo.com"), + database_trace("select * from foo"), + datastore_trace("lib", "foo", "bar"), + message_trace("lib", "op", "typ", "name"), + memcache_trace("cmd"), + ], +) +def test_async_generator_functions_outside_of_transaction(event_loop, trace): + @trace + async def agen(): + for _ in range(2): + yield "foo" + + async def _test(): + assert [_ async for _ in agen()] == ["foo", "foo"] + + event_loop.run_until_complete(_test()) + + +@validate_transaction_metrics( + "test_catching_generator_exit_causes_runtime_error", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_catching_generator_exit_causes_runtime_error(event_loop): + @function_trace(name="agen") + async def agen(): + try: + yield + except GeneratorExit: + yield + + @background_task(name="test_catching_generator_exit_causes_runtime_error") + async def _test(): + gen = agen() + + # kickstart the coroutine (we're inside the try now) + await gen.asend(None) + + # Generators cannot catch generator exit exceptions (which are injected by + # close). This will result in a runtime error. + with pytest.raises(RuntimeError): + await gen.aclose() + + event_loop.run_until_complete(_test()) + + +@validate_transaction_metrics( + "test_async_generator_time_excludes_creation_time", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +def test_async_generator_time_excludes_creation_time(event_loop): + @function_trace(name="agen") + async def agen(): + yield + + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @background_task(name="test_async_generator_time_excludes_creation_time") + def _test_async_generator_time_excludes_creation_time(): + async def _test(): + gen = agen() + time.sleep(0.1) + async for _ in gen: + pass + + event_loop.run_until_complete(_test()) + _test_async_generator_time_excludes_creation_time() + + # check that the trace does not include the time between creation and + # consumption + assert full_metrics[("Function/agen", "")].total_call_time < 0.1 + + +@validate_transaction_metrics( + "test_complete_async_generator", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +@background_task(name="test_complete_async_generator") +def test_complete_async_generator(event_loop): + @function_trace(name="agen") + async def agen(): + for i in range(5): + yield i + + async def _test(): + gen = agen() + assert [x async for x in gen] == [x for x in range(5)] + + event_loop.run_until_complete(_test()) + + +@pytest.mark.parametrize("nr_transaction", [True, False]) +def test_incomplete_async_generator(event_loop, nr_transaction): + @function_trace(name="agen") + async def agen(): + for _ in range(5): + yield + + def _test_incomplete_async_generator(): + async def _test(): + c = agen() + + async for _ in c: + break + + if nr_transaction: + _test = background_task(name="test_incomplete_async_generator")(_test) + + event_loop.run_until_complete(_test()) + + if nr_transaction: + _test_incomplete_async_generator = validate_transaction_metrics( + "test_incomplete_async_generator", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], + )(_test_incomplete_async_generator) + + _test_incomplete_async_generator() + + +def test_incomplete_async_generator_transaction_exited(event_loop): + @function_trace(name="agen") + async def agen(): + for _ in range(5): + yield + + @validate_transaction_metrics( + "test_incomplete_async_generator", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], + ) + def _test_incomplete_async_generator(): + c = agen() + @background_task(name="test_incomplete_async_generator") + async def _test(): + async for _ in c: + break + + event_loop.run_until_complete(_test()) + + # Remove generator after transaction completes + del c + + _test_incomplete_async_generator() diff --git a/tests/agent_features/test_async_generator_trace.py b/tests/agent_features/test_async_generator_trace.py new file mode 100644 index 0000000000..208cf1588a --- /dev/null +++ b/tests/agent_features/test_async_generator_trace.py @@ -0,0 +1,19 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +# Async Generators were introduced in Python 3.6, but some APIs weren't completely stable until Python 3.7. +if sys.version_info >= (3, 7): + from _test_async_generator_trace import * # NOQA diff --git a/tests/agent_features/test_async_wrapper_detection.py b/tests/agent_features/test_async_wrapper_detection.py new file mode 100644 index 0000000000..bb1fd3f1e3 --- /dev/null +++ b/tests/agent_features/test_async_wrapper_detection.py @@ -0,0 +1,102 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import functools +import time + +from newrelic.api.background_task import background_task +from newrelic.api.database_trace import database_trace +from newrelic.api.datastore_trace import datastore_trace +from newrelic.api.external_trace import external_trace +from newrelic.api.function_trace import function_trace +from newrelic.api.graphql_trace import graphql_operation_trace, graphql_resolver_trace +from newrelic.api.memcache_trace import memcache_trace +from newrelic.api.message_trace import message_trace + +from newrelic.common.async_wrapper import generator_wrapper + +from testing_support.fixtures import capture_transaction_metrics +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +trace_metric_cases = [ + (functools.partial(function_trace, name="simple_gen"), "Function/simple_gen"), + (functools.partial(external_trace, library="lib", url="http://foo.com"), "External/foo.com/lib/"), + (functools.partial(database_trace, "select * from foo"), "Datastore/statement/None/foo/select"), + (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), + (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), + (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), +] + + +@pytest.mark.parametrize("trace,metric", trace_metric_cases) +def test_automatic_generator_trace_wrapper(trace, metric): + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @validate_transaction_metrics( + "test_automatic_generator_trace_wrapper", background_task=True, scoped_metrics=[(metric, 1)], rollup_metrics=[(metric, 1)] + ) + @background_task(name="test_automatic_generator_trace_wrapper") + def _test(): + @trace() + def gen(): + time.sleep(0.1) + yield + time.sleep(0.1) + + for _ in gen(): + pass + + _test() + + # Check that generators time the total call time (including pauses) + metric_key = (metric, "") + assert full_metrics[metric_key].total_call_time >= 0.2 + + +@pytest.mark.parametrize("trace,metric", trace_metric_cases) +def test_manual_generator_trace_wrapper(trace, metric): + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @validate_transaction_metrics( + "test_automatic_generator_trace_wrapper", background_task=True, scoped_metrics=[(metric, 1)], rollup_metrics=[(metric, 1)] + ) + @background_task(name="test_automatic_generator_trace_wrapper") + def _test(): + @trace(async_wrapper=generator_wrapper) + def wrapper_func(): + """Function that returns a generator object, obscuring the automatic introspection of async_wrapper()""" + def gen(): + time.sleep(0.1) + yield + time.sleep(0.1) + return gen() + + for _ in wrapper_func(): + pass + + _test() + + # Check that generators time the total call time (including pauses) + metric_key = (metric, "") + assert full_metrics[metric_key].total_call_time >= 0.2 diff --git a/tests/agent_features/test_coroutine_trace.py b/tests/agent_features/test_coroutine_trace.py index 36e365bc46..2043f13268 100644 --- a/tests/agent_features/test_coroutine_trace.py +++ b/tests/agent_features/test_coroutine_trace.py @@ -31,6 +31,7 @@ from newrelic.api.datastore_trace import datastore_trace from newrelic.api.external_trace import external_trace from newrelic.api.function_trace import function_trace +from newrelic.api.graphql_trace import graphql_operation_trace, graphql_resolver_trace from newrelic.api.memcache_trace import memcache_trace from newrelic.api.message_trace import message_trace @@ -47,6 +48,8 @@ (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), ], ) def test_coroutine_timing(trace, metric): @@ -337,6 +340,37 @@ def coro(): pass +@validate_transaction_metrics( + "test_multiple_throws_yield_a_value", + background_task=True, + scoped_metrics=[("Function/coro", 1)], + rollup_metrics=[("Function/coro", 1)], +) +@background_task(name="test_multiple_throws_yield_a_value") +def test_multiple_throws_yield_a_value(): + @function_trace(name="coro") + def coro(): + value = None + for _ in range(4): + try: + yield value + value = "bar" + except MyException: + value = "foo" + + c = coro() + + # kickstart the coroutine + assert next(c) is None + assert c.throw(MyException) == "foo" + assert c.throw(MyException) == "foo" + assert next(c) == "bar" + + # finish consumption of the coroutine if necessary + for _ in c: + pass + + @pytest.mark.parametrize( "trace", [ diff --git a/tests/agent_features/test_datastore_trace.py b/tests/agent_features/test_datastore_trace.py new file mode 100644 index 0000000000..08067e0402 --- /dev/null +++ b/tests/agent_features/test_datastore_trace.py @@ -0,0 +1,89 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from testing_support.validators.validate_datastore_trace_inputs import ( + validate_datastore_trace_inputs, +) + +from newrelic.api.background_task import background_task +from newrelic.api.datastore_trace import DatastoreTrace, DatastoreTraceWrapper + + +@validate_datastore_trace_inputs( + operation="test_operation", + target="test_target", + host="test_host", + port_path_or_id="test_port", + database_name="test_db_name", +) +@background_task() +def test_dt_trace_all_args(): + with DatastoreTrace( + product="Agent Features", + target="test_target", + operation="test_operation", + host="test_host", + port_path_or_id="test_port", + database_name="test_db_name", + ): + pass + + +@validate_datastore_trace_inputs(operation=None, target=None, host=None, port_path_or_id=None, database_name=None) +@background_task() +def test_dt_trace_empty(): + with DatastoreTrace(product=None, target=None, operation=None): + pass + + +@background_task() +def test_dt_trace_callable_args(): + def product_callable(): + return "Agent Features" + + def target_callable(): + return "test_target" + + def operation_callable(): + return "test_operation" + + def host_callable(): + return "test_host" + + def port_path_id_callable(): + return "test_port" + + def db_name_callable(): + return "test_db_name" + + @validate_datastore_trace_inputs( + operation="test_operation", + target="test_target", + host="test_host", + port_path_or_id="test_port", + database_name="test_db_name", + ) + def _test(): + pass + + wrapped_fn = DatastoreTraceWrapper( + _test, + product=product_callable, + target=target_callable, + operation=operation_callable, + host=host_callable, + port_path_or_id=port_path_id_callable, + database_name=db_name_callable, + ) + wrapped_fn() diff --git a/tests/datastore_firestore/conftest.py b/tests/datastore_firestore/conftest.py new file mode 100644 index 0000000000..28e138fa28 --- /dev/null +++ b/tests/datastore_firestore/conftest.py @@ -0,0 +1,124 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import uuid + +import pytest + +from google.cloud.firestore import Client +from google.cloud.firestore import Client, AsyncClient + +from testing_support.db_settings import firestore_settings +from testing_support.fixture.event_loop import event_loop as loop # noqa: F401; pylint: disable=W0611 +from testing_support.fixtures import ( # noqa: F401; pylint: disable=W0611 + collector_agent_registration_fixture, + collector_available_fixture, +) + +from newrelic.api.datastore_trace import DatastoreTrace +from newrelic.api.time_trace import current_trace +from newrelic.common.system_info import LOCALHOST_EQUIVALENTS, gethostname + +DB_SETTINGS = firestore_settings()[0] +FIRESTORE_HOST = DB_SETTINGS["host"] +FIRESTORE_PORT = DB_SETTINGS["port"] + +_default_settings = { + "transaction_tracer.explain_threshold": 0.0, + "transaction_tracer.transaction_threshold": 0.0, + "transaction_tracer.stack_trace_threshold": 0.0, + "debug.log_data_collector_payloads": True, + "debug.record_transaction_failure": True, + "debug.log_explain_plan_queries": True, +} + +collector_agent_registration = collector_agent_registration_fixture( + app_name="Python Agent Test (datastore_firestore)", + default_settings=_default_settings, + linked_applications=["Python Agent Test (datastore)"], +) + + +@pytest.fixture() +def instance_info(): + host = gethostname() if FIRESTORE_HOST in LOCALHOST_EQUIVALENTS else FIRESTORE_HOST + return {"host": host, "port_path_or_id": str(FIRESTORE_PORT), "db.instance": "projects/google-cloud-firestore-emulator/databases/(default)"} + + +@pytest.fixture(scope="session") +def client(): + os.environ["FIRESTORE_EMULATOR_HOST"] = "%s:%d" % (FIRESTORE_HOST, FIRESTORE_PORT) + client = Client() + # Ensure connection is available + client.collection("healthcheck").document("healthcheck").set( + {}, retry=None, timeout=5 + ) + return client + + +@pytest.fixture(scope="function") +def collection(client): + collection_ = client.collection("firestore_collection_" + str(uuid.uuid4())) + yield collection_ + client.recursive_delete(collection_) + + +@pytest.fixture(scope="session") +def async_client(loop): + os.environ["FIRESTORE_EMULATOR_HOST"] = "%s:%d" % (FIRESTORE_HOST, FIRESTORE_PORT) + client = AsyncClient() + loop.run_until_complete(client.collection("healthcheck").document("healthcheck").set({}, retry=None, timeout=5)) # Ensure connection is available + return client + + +@pytest.fixture(scope="function") +def async_collection(async_client, collection): + # Use the same collection name as the collection fixture + yield async_client.collection(collection.id) + + +@pytest.fixture(scope="session") +def assert_trace_for_generator(): + def _assert_trace_for_generator(generator_func, *args, **kwargs): + txn = current_trace() + assert not isinstance(txn, DatastoreTrace) + + # Check for generator trace on collections + _trace_check = [] + for _ in generator_func(*args, **kwargs): + _trace_check.append(isinstance(current_trace(), DatastoreTrace)) + assert _trace_check and all(_trace_check) # All checks are True, and at least 1 is present. + assert current_trace() is txn # Generator trace has exited. + + return _assert_trace_for_generator + + +@pytest.fixture(scope="session") +def assert_trace_for_async_generator(loop): + def _assert_trace_for_async_generator(generator_func, *args, **kwargs): + _trace_check = [] + txn = current_trace() + assert not isinstance(txn, DatastoreTrace) + + async def coro(): + # Check for generator trace on collections + async for _ in generator_func(*args, **kwargs): + _trace_check.append(isinstance(current_trace(), DatastoreTrace)) + + loop.run_until_complete(coro()) + + assert _trace_check and all(_trace_check) # All checks are True, and at least 1 is present. + assert current_trace() is txn # Generator trace has exited. + + return _assert_trace_for_async_generator diff --git a/tests/datastore_firestore/test_async_batching.py b/tests/datastore_firestore/test_async_batching.py new file mode 100644 index 0000000000..08890c39af --- /dev/null +++ b/tests/datastore_firestore/test_async_batching.py @@ -0,0 +1,68 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics +from newrelic.api.background_task import background_task +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + + +@pytest.fixture() +def exercise_async_write_batch(async_client, async_collection): + async def _exercise_async_write_batch(): + docs = [async_collection.document(str(x)) for x in range(1, 4)] + async_batch = async_client.batch() + for doc in docs: + async_batch.set(doc, {}) + + await async_batch.commit() + return _exercise_async_write_batch + + +def test_firestore_async_write_batch(loop, exercise_async_write_batch): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/commit", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/all", 1), + ("Datastore/allOther", 1), + ] + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_write_batch", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_write_batch") + def _test(): + loop.run_until_complete(exercise_async_write_batch()) + + _test() + + +def test_firestore_async_write_batch_trace_node_datastore_params(loop, exercise_async_write_batch, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_write_batch()) + + _test() diff --git a/tests/datastore_firestore/test_async_client.py b/tests/datastore_firestore/test_async_client.py new file mode 100644 index 0000000000..1a17181d59 --- /dev/null +++ b/tests/datastore_firestore/test_async_client.py @@ -0,0 +1,83 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics +from newrelic.api.background_task import background_task +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + + +@pytest.fixture() +def existing_document(collection): + doc = collection.document("document") + doc.set({"x": 1}) + return doc + + +@pytest.fixture() +def exercise_async_client(async_client, existing_document): + async def _exercise_async_client(): + assert len([_ async for _ in async_client.collections()]) >= 1 + doc = [_ async for _ in async_client.get_all([existing_document])][0] + assert doc.to_dict()["x"] == 1 + return _exercise_async_client + + +def test_firestore_async_client(loop, exercise_async_client): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/collections", 1), + ("Datastore/operation/Firestore/get_all", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_client", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_client") + def _test(): + loop.run_until_complete(exercise_async_client()) + + _test() + + +@background_task() +def test_firestore_async_client_generators(async_client, collection, assert_trace_for_async_generator): + doc = collection.document("test") + doc.set({}) + + assert_trace_for_async_generator(async_client.collections) + assert_trace_for_async_generator(async_client.get_all, [doc]) + + +def test_firestore_async_client_trace_node_datastore_params(loop, exercise_async_client, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_client()) + + _test() diff --git a/tests/datastore_firestore/test_async_collections.py b/tests/datastore_firestore/test_async_collections.py new file mode 100644 index 0000000000..a1004a7205 --- /dev/null +++ b/tests/datastore_firestore/test_async_collections.py @@ -0,0 +1,89 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics +from newrelic.api.background_task import background_task +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + + +@pytest.fixture() +def exercise_async_collections(async_collection): + async def _exercise_async_collections(): + async_collection.document("DoesNotExist") + await async_collection.add({"capital": "Rome", "currency": "Euro", "language": "Italian"}, "Italy") + await async_collection.add({"capital": "Mexico City", "currency": "Peso", "language": "Spanish"}, "Mexico") + + documents_get = await async_collection.get() + assert len(documents_get) == 2 + documents_stream = [_ async for _ in async_collection.stream()] + assert len(documents_stream) == 2 + documents_list = [_ async for _ in async_collection.list_documents()] + assert len(documents_list) == 2 + return _exercise_async_collections + + +def test_firestore_async_collections(loop, exercise_async_collections, async_collection): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/list_documents" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/add" % async_collection.id, 2), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/add", 2), + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ] + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_collections", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_collections") + def _test(): + loop.run_until_complete(exercise_async_collections()) + + _test() + + +@background_task() +def test_firestore_async_collections_generators(collection, async_collection, assert_trace_for_async_generator): + collection.add({}) + collection.add({}) + assert len([_ for _ in collection.list_documents()]) == 2 + + assert_trace_for_async_generator(async_collection.stream) + assert_trace_for_async_generator(async_collection.list_documents) + + +def test_firestore_async_collections_trace_node_datastore_params(loop, exercise_async_collections, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_collections()) + + _test() diff --git a/tests/datastore_firestore/test_async_documents.py b/tests/datastore_firestore/test_async_documents.py new file mode 100644 index 0000000000..9c0a30479d --- /dev/null +++ b/tests/datastore_firestore/test_async_documents.py @@ -0,0 +1,101 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics +from newrelic.api.background_task import background_task +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + + +@pytest.fixture() +def exercise_async_documents(async_collection): + async def _exercise_async_documents(): + italy_doc = async_collection.document("Italy") + await italy_doc.set({"capital": "Rome", "currency": "Euro", "language": "Italian"}) + await italy_doc.get() + italian_cities = italy_doc.collection("cities") + await italian_cities.add({"capital": "Rome"}) + retrieved_coll = [_ async for _ in italy_doc.collections()] + assert len(retrieved_coll) == 1 + + usa_doc = async_collection.document("USA") + await usa_doc.create({"capital": "Washington D.C.", "currency": "Dollar", "language": "English"}) + await usa_doc.update({"president": "Joe Biden"}) + + await async_collection.document("USA").delete() + return _exercise_async_documents + + +def test_firestore_async_documents(loop, exercise_async_documents): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/Italy/set", 1), + ("Datastore/statement/Firestore/Italy/get", 1), + ("Datastore/statement/Firestore/Italy/collections", 1), + ("Datastore/statement/Firestore/cities/add", 1), + ("Datastore/statement/Firestore/USA/create", 1), + ("Datastore/statement/Firestore/USA/update", 1), + ("Datastore/statement/Firestore/USA/delete", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/set", 1), + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/add", 1), + ("Datastore/operation/Firestore/collections", 1), + ("Datastore/operation/Firestore/create", 1), + ("Datastore/operation/Firestore/update", 1), + ("Datastore/operation/Firestore/delete", 1), + ("Datastore/all", 7), + ("Datastore/allOther", 7), + ] + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_documents", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_documents") + def _test(): + loop.run_until_complete(exercise_async_documents()) + + _test() + + +@background_task() +def test_firestore_async_documents_generators(collection, async_collection, assert_trace_for_async_generator): + subcollection_doc = collection.document("SubCollections") + subcollection_doc.set({}) + subcollection_doc.collection("collection1").add({}) + subcollection_doc.collection("collection2").add({}) + assert len([_ for _ in subcollection_doc.collections()]) == 2 + + async_subcollection = async_collection.document(subcollection_doc.id) + + assert_trace_for_async_generator(async_subcollection.collections) + + +def test_firestore_async_documents_trace_node_datastore_params(loop, exercise_async_documents, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_documents()) + + _test() diff --git a/tests/datastore_firestore/test_async_query.py b/tests/datastore_firestore/test_async_query.py new file mode 100644 index 0000000000..c3e43d0e4c --- /dev/null +++ b/tests/datastore_firestore/test_async_query.py @@ -0,0 +1,225 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics +from newrelic.api.background_task import background_task +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + + +@pytest.fixture(autouse=True) +def sample_data(collection): + for x in range(1, 6): + collection.add({"x": x}) + + subcollection_doc = collection.document("subcollection") + subcollection_doc.set({}) + subcollection_doc.collection("subcollection1").add({}) + + +# ===== AsyncQuery ===== + +@pytest.fixture() +def exercise_async_query(async_collection): + async def _exercise_async_query(): + async_query = async_collection.select("x").limit(10).order_by("x").where(field_path="x", op_string="<=", value=3) + assert len(await async_query.get()) == 3 + assert len([_ async for _ in async_query.stream()]) == 3 + return _exercise_async_query + + +def test_firestore_async_query(loop, exercise_async_query, async_collection): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % async_collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ] + # @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_query", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_query") + def _test(): + loop.run_until_complete(exercise_async_query()) + + _test() + + +@background_task() +def test_firestore_async_query_generators(async_collection, assert_trace_for_async_generator): + async_query = async_collection.select("x").where(field_path="x", op_string="<=", value=3) + assert_trace_for_async_generator(async_query.stream) + + +def test_firestore_async_query_trace_node_datastore_params(loop, exercise_async_query, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_query()) + + _test() + +# ===== AsyncAggregationQuery ===== + +@pytest.fixture() +def exercise_async_aggregation_query(async_collection): + async def _exercise_async_aggregation_query(): + async_aggregation_query = async_collection.select("x").where(field_path="x", op_string="<=", value=3).count() + assert (await async_aggregation_query.get())[0][0].value == 3 + assert [_ async for _ in async_aggregation_query.stream()][0][0].value == 3 + return _exercise_async_aggregation_query + + +def test_firestore_async_aggregation_query(loop, exercise_async_aggregation_query, async_collection): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % async_collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ] + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_aggregation_query", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_aggregation_query") + def _test(): + loop.run_until_complete(exercise_async_aggregation_query()) + + _test() + + +@background_task() +def test_firestore_async_aggregation_query_generators(async_collection, assert_trace_for_async_generator): + async_aggregation_query = async_collection.select("x").where(field_path="x", op_string="<=", value=3).count() + assert_trace_for_async_generator(async_aggregation_query.stream) + + +def test_firestore_async_aggregation_query_trace_node_datastore_params(loop, exercise_async_aggregation_query, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_aggregation_query()) + + _test() + + +# ===== CollectionGroup ===== + + +@pytest.fixture() +def patch_partition_queries(monkeypatch, async_client, collection, sample_data): + """ + Partitioning is not implemented in the Firestore emulator. + + Ordinarily this method would return a coroutine that returns an async_generator of Cursor objects. + Each Cursor must point at a valid document path. To test this, we can patch the RPC to return 1 Cursor + which is pointed at any document available. The get_partitions will take that and make 2 QueryPartition + objects out of it, which should be enough to ensure we can exercise the generator's tracing. + """ + from google.cloud.firestore_v1.types.document import Value + from google.cloud.firestore_v1.types.query import Cursor + + subcollection = collection.document("subcollection").collection("subcollection1") + documents = [d for d in subcollection.list_documents()] + + async def mock_partition_query(*args, **kwargs): + async def _mock_partition_query(): + yield Cursor(before=False, values=[Value(reference_value=documents[0].path)]) + return _mock_partition_query() + + monkeypatch.setattr(async_client._firestore_api, "partition_query", mock_partition_query) + yield + + +@pytest.fixture() +def exercise_async_collection_group(async_client, async_collection): + async def _exercise_async_collection_group(): + async_collection_group = async_client.collection_group(async_collection.id) + assert len(await async_collection_group.get()) + assert len([d async for d in async_collection_group.stream()]) + + partitions = [p async for p in async_collection_group.get_partitions(1)] + assert len(partitions) == 2 + documents = [] + while partitions: + documents.extend(await partitions.pop().query().get()) + assert len(documents) == 6 + return _exercise_async_collection_group + + +def test_firestore_async_collection_group(loop, exercise_async_collection_group, async_collection, patch_partition_queries): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/get" % async_collection.id, 3), + ("Datastore/statement/Firestore/%s/stream" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/get_partitions" % async_collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 3), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/get_partitions", 1), + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_collection_group", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_collection_group") + def _test(): + loop.run_until_complete(exercise_async_collection_group()) + + _test() + + +@background_task() +def test_firestore_async_collection_group_generators(async_client, async_collection, assert_trace_for_async_generator, patch_partition_queries): + async_collection_group = async_client.collection_group(async_collection.id) + assert_trace_for_async_generator(async_collection_group.get_partitions, 1) + + +def test_firestore_async_collection_group_trace_node_datastore_params(loop, exercise_async_collection_group, instance_info, patch_partition_queries): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_collection_group()) + + _test() diff --git a/tests/datastore_firestore/test_async_transaction.py b/tests/datastore_firestore/test_async_transaction.py new file mode 100644 index 0000000000..134c080bdd --- /dev/null +++ b/tests/datastore_firestore/test_async_transaction.py @@ -0,0 +1,149 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics +from newrelic.api.background_task import background_task +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + + +@pytest.fixture(autouse=True) +def sample_data(collection): + for x in range(1, 4): + collection.add({"x": x}, "doc%d" % x) + + +@pytest.fixture() +def exercise_async_transaction_commit(async_client, async_collection): + async def _exercise_async_transaction_commit(): + from google.cloud.firestore import async_transactional + + @async_transactional + async def _exercise(async_transaction): + # get a DocumentReference + with pytest.raises(TypeError): # get is currently broken. It attempts to await an async_generator instead of consuming it. + [_ async for _ in async_transaction.get(async_collection.document("doc1"))] + + # get a Query + with pytest.raises(TypeError): # get is currently broken. It attempts to await an async_generator instead of consuming it. + async_query = async_collection.select("x").where(field_path="x", op_string=">", value=2) + assert len([_ async for _ in async_transaction.get(async_query)]) == 1 + + # get_all on a list of DocumentReferences + with pytest.raises(TypeError): # get_all is currently broken. It attempts to await an async_generator instead of consuming it. + all_docs = async_transaction.get_all([async_collection.document("doc%d" % x) for x in range(1, 4)]) + assert len([_ async for _ in all_docs]) == 3 + + # set and delete methods + async_transaction.set(async_collection.document("doc2"), {"x": 0}) + async_transaction.delete(async_collection.document("doc3")) + + await _exercise(async_client.transaction()) + assert len([_ async for _ in async_collection.list_documents()]) == 2 + return _exercise_async_transaction_commit + + +@pytest.fixture() +def exercise_async_transaction_rollback(async_client, async_collection): + async def _exercise_async_transaction_rollback(): + from google.cloud.firestore import async_transactional + + @async_transactional + async def _exercise(async_transaction): + # set and delete methods + async_transaction.set(async_collection.document("doc2"), {"x": 99}) + async_transaction.delete(async_collection.document("doc1")) + raise RuntimeError() + + with pytest.raises(RuntimeError): + await _exercise(async_client.transaction()) + assert len([_ async for _ in async_collection.list_documents()]) == 3 + return _exercise_async_transaction_rollback + + +def test_firestore_async_transaction_commit(loop, exercise_async_transaction_commit, async_collection): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/commit", 1), + # ("Datastore/operation/Firestore/get_all", 2), + # ("Datastore/statement/Firestore/%s/stream" % async_collection.id, 1), + ("Datastore/statement/Firestore/%s/list_documents" % async_collection.id, 1), + ] + + _test_rollup_metrics = [ + # ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 2), # Should be 5 if not for broken APIs + ("Datastore/allOther", 2), + ] + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_transaction", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_transaction") + def _test(): + loop.run_until_complete(exercise_async_transaction_commit()) + + _test() + + +def test_firestore_async_transaction_rollback(loop, exercise_async_transaction_rollback, async_collection): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/rollback", 1), + ("Datastore/statement/Firestore/%s/list_documents" % async_collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ] + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_async_transaction", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_async_transaction") + def _test(): + loop.run_until_complete(exercise_async_transaction_rollback()) + + _test() + + +def test_firestore_async_transaction_commit_trace_node_datastore_params(loop, exercise_async_transaction_commit, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_transaction_commit()) + + _test() + + +def test_firestore_async_transaction_rollback_trace_node_datastore_params(loop, exercise_async_transaction_rollback, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + loop.run_until_complete(exercise_async_transaction_rollback()) + + _test() diff --git a/tests/datastore_firestore/test_batching.py b/tests/datastore_firestore/test_batching.py new file mode 100644 index 0000000000..5dcdd7b396 --- /dev/null +++ b/tests/datastore_firestore/test_batching.py @@ -0,0 +1,124 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + +# ===== WriteBatch ===== + + +@pytest.fixture() +def exercise_write_batch(client, collection): + def _exercise_write_batch(): + docs = [collection.document(str(x)) for x in range(1, 4)] + batch = client.batch() + for doc in docs: + batch.set(doc, {}) + + batch.commit() + return _exercise_write_batch + + +def test_firestore_write_batch(exercise_write_batch): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/commit", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/all", 1), + ("Datastore/allOther", 1), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_write_batch", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_write_batch") + def _test(): + exercise_write_batch() + + _test() + + +def test_firestore_write_batch_trace_node_datastore_params(exercise_write_batch, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_write_batch() + + _test() + + +# ===== BulkWriteBatch ===== + + +@pytest.fixture() +def exercise_bulk_write_batch(client, collection): + def _exercise_bulk_write_batch(): + from google.cloud.firestore_v1.bulk_batch import BulkWriteBatch + + docs = [collection.document(str(x)) for x in range(1, 4)] + batch = BulkWriteBatch(client) + for doc in docs: + batch.set(doc, {}) + + batch.commit() + return _exercise_bulk_write_batch + + +def test_firestore_bulk_write_batch(exercise_bulk_write_batch): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/commit", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/all", 1), + ("Datastore/allOther", 1), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_bulk_write_batch", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_bulk_write_batch") + def _test(): + exercise_bulk_write_batch() + + _test() + + +def test_firestore_bulk_write_batch_trace_node_datastore_params(exercise_bulk_write_batch, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_bulk_write_batch() + + _test() diff --git a/tests/datastore_firestore/test_client.py b/tests/datastore_firestore/test_client.py new file mode 100644 index 0000000000..06580356ad --- /dev/null +++ b/tests/datastore_firestore/test_client.py @@ -0,0 +1,81 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture() +def sample_data(collection): + doc = collection.document("document") + doc.set({"x": 1}) + return doc + + +@pytest.fixture() +def exercise_client(client, sample_data): + def _exercise_client(): + assert len([_ for _ in client.collections()]) + doc = [_ for _ in client.get_all([sample_data])][0] + assert doc.to_dict()["x"] == 1 + return _exercise_client + + +def test_firestore_client(exercise_client): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/collections", 1), + ("Datastore/operation/Firestore/get_all", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_client", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_client") + def _test(): + exercise_client() + + _test() + + +@background_task() +def test_firestore_client_generators(client, sample_data, assert_trace_for_generator): + assert_trace_for_generator(client.collections) + assert_trace_for_generator(client.get_all, [sample_data]) + + +def test_firestore_client_trace_node_datastore_params(exercise_client, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_client() + + _test() diff --git a/tests/datastore_firestore/test_collections.py b/tests/datastore_firestore/test_collections.py new file mode 100644 index 0000000000..c5c443dcef --- /dev/null +++ b/tests/datastore_firestore/test_collections.py @@ -0,0 +1,92 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture() +def exercise_collections(collection): + def _exercise_collections(): + collection.document("DoesNotExist") + collection.add({"capital": "Rome", "currency": "Euro", "language": "Italian"}, "Italy") + collection.add({"capital": "Mexico City", "currency": "Peso", "language": "Spanish"}, "Mexico") + + documents_get = collection.get() + assert len(documents_get) == 2 + documents_stream = [_ for _ in collection.stream()] + assert len(documents_stream) == 2 + documents_list = [_ for _ in collection.list_documents()] + assert len(documents_list) == 2 + return _exercise_collections + + +def test_firestore_collections(exercise_collections, collection): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % collection.id, 1), + ("Datastore/statement/Firestore/%s/list_documents" % collection.id, 1), + ("Datastore/statement/Firestore/%s/add" % collection.id, 2), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/add", 2), + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ] + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_collections", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_collections") + def _test(): + exercise_collections() + + _test() + + +@background_task() +def test_firestore_collections_generators(collection, assert_trace_for_generator): + collection.add({}) + collection.add({}) + assert len([_ for _ in collection.list_documents()]) == 2 + + assert_trace_for_generator(collection.stream) + assert_trace_for_generator(collection.list_documents) + + +def test_firestore_collections_trace_node_datastore_params(exercise_collections, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_collections() + + _test() diff --git a/tests/datastore_firestore/test_documents.py b/tests/datastore_firestore/test_documents.py new file mode 100644 index 0000000000..2006899601 --- /dev/null +++ b/tests/datastore_firestore/test_documents.py @@ -0,0 +1,102 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture() +def exercise_documents(collection): + def _exercise_documents(): + italy_doc = collection.document("Italy") + italy_doc.set({"capital": "Rome", "currency": "Euro", "language": "Italian"}) + italy_doc.get() + italian_cities = italy_doc.collection("cities") + italian_cities.add({"capital": "Rome"}) + retrieved_coll = [_ for _ in italy_doc.collections()] + assert len(retrieved_coll) == 1 + + usa_doc = collection.document("USA") + usa_doc.create({"capital": "Washington D.C.", "currency": "Dollar", "language": "English"}) + usa_doc.update({"president": "Joe Biden"}) + + collection.document("USA").delete() + return _exercise_documents + + +def test_firestore_documents(exercise_documents): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/Italy/set", 1), + ("Datastore/statement/Firestore/Italy/get", 1), + ("Datastore/statement/Firestore/Italy/collections", 1), + ("Datastore/statement/Firestore/cities/add", 1), + ("Datastore/statement/Firestore/USA/create", 1), + ("Datastore/statement/Firestore/USA/update", 1), + ("Datastore/statement/Firestore/USA/delete", 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/set", 1), + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/add", 1), + ("Datastore/operation/Firestore/collections", 1), + ("Datastore/operation/Firestore/create", 1), + ("Datastore/operation/Firestore/update", 1), + ("Datastore/operation/Firestore/delete", 1), + ("Datastore/all", 7), + ("Datastore/allOther", 7), + ] + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_documents", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_documents") + def _test(): + exercise_documents() + + _test() + + +@background_task() +def test_firestore_documents_generators(collection, assert_trace_for_generator): + subcollection_doc = collection.document("SubCollections") + subcollection_doc.set({}) + subcollection_doc.collection("collection1").add({}) + subcollection_doc.collection("collection2").add({}) + assert len([_ for _ in subcollection_doc.collections()]) == 2 + + assert_trace_for_generator(subcollection_doc.collections) + + +def test_firestore_documents_trace_node_datastore_params(exercise_documents, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_documents() + + _test() diff --git a/tests/datastore_firestore/test_query.py b/tests/datastore_firestore/test_query.py new file mode 100644 index 0000000000..5e681f53e5 --- /dev/null +++ b/tests/datastore_firestore/test_query.py @@ -0,0 +1,229 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture(autouse=True) +def sample_data(collection): + for x in range(1, 6): + collection.add({"x": x}) + + subcollection_doc = collection.document("subcollection") + subcollection_doc.set({}) + subcollection_doc.collection("subcollection1").add({}) + + +# ===== Query ===== + + +@pytest.fixture() +def exercise_query(collection): + def _exercise_query(): + query = collection.select("x").limit(10).order_by("x").where(field_path="x", op_string="<=", value=3) + assert len(query.get()) == 3 + assert len([_ for _ in query.stream()]) == 3 + return _exercise_query + + +def test_firestore_query(exercise_query, collection): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_query", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_query") + def _test(): + exercise_query() + + _test() + + +@background_task() +def test_firestore_query_generators(collection, assert_trace_for_generator): + query = collection.select("x").where(field_path="x", op_string="<=", value=3) + assert_trace_for_generator(query.stream) + + +def test_firestore_query_trace_node_datastore_params(exercise_query, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_query() + + _test() + +# ===== AggregationQuery ===== + + +@pytest.fixture() +def exercise_aggregation_query(collection): + def _exercise_aggregation_query(): + aggregation_query = collection.select("x").where(field_path="x", op_string="<=", value=3).count() + assert aggregation_query.get()[0][0].value == 3 + assert [_ for _ in aggregation_query.stream()][0][0].value == 3 + return _exercise_aggregation_query + + +def test_firestore_aggregation_query(exercise_aggregation_query, collection): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/get" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 1), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_aggregation_query", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_aggregation_query") + def _test(): + exercise_aggregation_query() + + _test() + + +@background_task() +def test_firestore_aggregation_query_generators(collection, assert_trace_for_generator): + aggregation_query = collection.select("x").where(field_path="x", op_string="<=", value=3).count() + assert_trace_for_generator(aggregation_query.stream) + + +def test_firestore_aggregation_query_trace_node_datastore_params(exercise_aggregation_query, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_aggregation_query() + + _test() + + +# ===== CollectionGroup ===== + + +@pytest.fixture() +def patch_partition_queries(monkeypatch, client, collection, sample_data): + """ + Partitioning is not implemented in the Firestore emulator. + + Ordinarily this method would return a generator of Cursor objects. Each Cursor must point at a valid document path. + To test this, we can patch the RPC to return 1 Cursor which is pointed at any document available. + The get_partitions will take that and make 2 QueryPartition objects out of it, which should be enough to ensure + we can exercise the generator's tracing. + """ + from google.cloud.firestore_v1.types.document import Value + from google.cloud.firestore_v1.types.query import Cursor + + subcollection = collection.document("subcollection").collection("subcollection1") + documents = [d for d in subcollection.list_documents()] + + def mock_partition_query(*args, **kwargs): + yield Cursor(before=False, values=[Value(reference_value=documents[0].path)]) + + monkeypatch.setattr(client._firestore_api, "partition_query", mock_partition_query) + yield + + +@pytest.fixture() +def exercise_collection_group(client, collection, patch_partition_queries): + def _exercise_collection_group(): + collection_group = client.collection_group(collection.id) + assert len(collection_group.get()) + assert len([d for d in collection_group.stream()]) + + partitions = [p for p in collection_group.get_partitions(1)] + assert len(partitions) == 2 + documents = [] + while partitions: + documents.extend(partitions.pop().query().get()) + assert len(documents) == 6 + return _exercise_collection_group + + +def test_firestore_collection_group(exercise_collection_group, client, collection): + _test_scoped_metrics = [ + ("Datastore/statement/Firestore/%s/get" % collection.id, 3), + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/get_partitions" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/get", 3), + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/get_partitions", 1), + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_collection_group", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_collection_group") + def _test(): + exercise_collection_group() + + _test() + + +@background_task() +def test_firestore_collection_group_generators(client, collection, assert_trace_for_generator, patch_partition_queries): + collection_group = client.collection_group(collection.id) + assert_trace_for_generator(collection_group.get_partitions, 1) + + +def test_firestore_collection_group_trace_node_datastore_params(exercise_collection_group, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_collection_group() + + _test() diff --git a/tests/datastore_firestore/test_transaction.py b/tests/datastore_firestore/test_transaction.py new file mode 100644 index 0000000000..c322a797ea --- /dev/null +++ b/tests/datastore_firestore/test_transaction.py @@ -0,0 +1,149 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from testing_support.validators.validate_database_duration import ( + validate_database_duration, +) +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) +from testing_support.validators.validate_tt_collector_json import ( + validate_tt_collector_json, +) + +from newrelic.api.background_task import background_task + + +@pytest.fixture(autouse=True) +def sample_data(collection): + for x in range(1, 4): + collection.add({"x": x}, "doc%d" % x) + + +@pytest.fixture() +def exercise_transaction_commit(client, collection): + def _exercise_transaction_commit(): + from google.cloud.firestore_v1.transaction import transactional + + @transactional + def _exercise(transaction): + # get a DocumentReference + [_ for _ in transaction.get(collection.document("doc1"))] + + # get a Query + query = collection.select("x").where(field_path="x", op_string=">", value=2) + assert len([_ for _ in transaction.get(query)]) == 1 + + # get_all on a list of DocumentReferences + all_docs = transaction.get_all([collection.document("doc%d" % x) for x in range(1, 4)]) + assert len([_ for _ in all_docs]) == 3 + + # set and delete methods + transaction.set(collection.document("doc2"), {"x": 0}) + transaction.delete(collection.document("doc3")) + + _exercise(client.transaction()) + assert len([_ for _ in collection.list_documents()]) == 2 + return _exercise_transaction_commit + + +@pytest.fixture() +def exercise_transaction_rollback(client, collection): + def _exercise_transaction_rollback(): + from google.cloud.firestore_v1.transaction import transactional + + @transactional + def _exercise(transaction): + # set and delete methods + transaction.set(collection.document("doc2"), {"x": 99}) + transaction.delete(collection.document("doc1")) + raise RuntimeError() + + with pytest.raises(RuntimeError): + _exercise(client.transaction()) + assert len([_ for _ in collection.list_documents()]) == 3 + return _exercise_transaction_rollback + + +def test_firestore_transaction_commit(exercise_transaction_commit, collection): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/commit", 1), + ("Datastore/operation/Firestore/get_all", 2), + ("Datastore/statement/Firestore/%s/stream" % collection.id, 1), + ("Datastore/statement/Firestore/%s/list_documents" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/stream", 1), + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 5), + ("Datastore/allOther", 5), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_transaction", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_transaction") + def _test(): + exercise_transaction_commit() + + _test() + + +def test_firestore_transaction_rollback(exercise_transaction_rollback, collection): + _test_scoped_metrics = [ + ("Datastore/operation/Firestore/rollback", 1), + ("Datastore/statement/Firestore/%s/list_documents" % collection.id, 1), + ] + + _test_rollup_metrics = [ + ("Datastore/operation/Firestore/list_documents", 1), + ("Datastore/all", 2), + ("Datastore/allOther", 2), + ] + + @validate_database_duration() + @validate_transaction_metrics( + "test_firestore_transaction", + scoped_metrics=_test_scoped_metrics, + rollup_metrics=_test_rollup_metrics, + background_task=True, + ) + @background_task(name="test_firestore_transaction") + def _test(): + exercise_transaction_rollback() + + _test() + + +def test_firestore_transaction_commit_trace_node_datastore_params(exercise_transaction_commit, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_transaction_commit() + + _test() + + +def test_firestore_transaction_rollback_trace_node_datastore_params(exercise_transaction_rollback, instance_info): + @validate_tt_collector_json(datastore_params=instance_info) + @background_task() + def _test(): + exercise_transaction_rollback() + + _test() diff --git a/tests/testing_support/db_settings.py b/tests/testing_support/db_settings.py index e32e2ecfa4..f7bda3d7a8 100644 --- a/tests/testing_support/db_settings.py +++ b/tests/testing_support/db_settings.py @@ -190,6 +190,28 @@ def mongodb_settings(): return settings +def firestore_settings(): + """Return a list of dict of settings for connecting to firestore. + + This only includes the host and port as the collection name is defined in + the firestore conftest file. + Will return the correct settings, depending on which of the environments it + is running in. It attempts to set variables in the following order, where + later environments override earlier ones. + + 1. Local + 2. Github Actions + """ + + host = "host.docker.internal" if "GITHUB_ACTIONS" in os.environ else "127.0.0.1" + instances = 2 + settings = [ + {"host": host, "port": 8080 + instance_num} + for instance_num in range(instances) + ] + return settings + + def elasticsearch_settings(): """Return a list of dict of settings for connecting to elasticsearch. diff --git a/tests/testing_support/validators/validate_datastore_trace_inputs.py b/tests/testing_support/validators/validate_datastore_trace_inputs.py index ade4ebea6f..365a14ebda 100644 --- a/tests/testing_support/validators/validate_datastore_trace_inputs.py +++ b/tests/testing_support/validators/validate_datastore_trace_inputs.py @@ -23,7 +23,7 @@ """ -def validate_datastore_trace_inputs(operation=None, target=None): +def validate_datastore_trace_inputs(operation=None, target=None, host=None, port_path_or_id=None, database_name=None): @transient_function_wrapper("newrelic.api.datastore_trace", "DatastoreTrace.__init__") @catch_background_exceptions def _validate_datastore_trace_inputs(wrapped, instance, args, kwargs): @@ -44,6 +44,18 @@ def _bind_params(product, target, operation, host=None, port_path_or_id=None, da assert captured_target == target, "%s didn't match expected %s" % (captured_target, target) if operation is not None: assert captured_operation == operation, "%s didn't match expected %s" % (captured_operation, operation) + if host is not None: + assert captured_host == host, "%s didn't match expected %s" % (captured_host, host) + if port_path_or_id is not None: + assert captured_port_path_or_id == port_path_or_id, "%s didn't match expected %s" % ( + captured_port_path_or_id, + port_path_or_id, + ) + if database_name is not None: + assert captured_database_name == database_name, "%s didn't match expected %s" % ( + captured_database_name, + database_name, + ) return wrapped(*args, **kwargs) diff --git a/tox.ini b/tox.ini index d36edbe736..d46a0ea55d 100644 --- a/tox.ini +++ b/tox.ini @@ -83,6 +83,7 @@ envlist = memcached-datastore_memcache-{py27,py37,py38,py39,py310,py311,pypy27,pypy38}-memcached01, mysql-datastore_mysql-mysql080023-py27, mysql-datastore_mysql-mysqllatest-{py37,py38,py39,py310,py311}, + firestore-datastore_firestore-{py37,py38,py39,py310,py311}, postgres-datastore_postgresql-{py37,py38,py39}, postgres-datastore_psycopg2-{py27,py37,py38,py39,py310,py311}-psycopg2latest postgres-datastore_psycopg2cffi-{py27,pypy27,py37,py38,py39,py310,py311}-psycopg2cffilatest, @@ -235,6 +236,7 @@ deps = datastore_elasticsearch: requests datastore_elasticsearch-elasticsearch07: elasticsearch<8.0 datastore_elasticsearch-elasticsearch08: elasticsearch<9.0 + datastore_firestore: google-cloud-firestore datastore_memcache-memcached01: python-memcached<2 datastore_mysql-mysqllatest: mysql-connector-python datastore_mysql-mysql080023: mysql-connector-python<8.0.24 @@ -437,6 +439,7 @@ changedir = datastore_asyncpg: tests/datastore_asyncpg datastore_bmemcached: tests/datastore_bmemcached datastore_elasticsearch: tests/datastore_elasticsearch + datastore_firestore: tests/datastore_firestore datastore_memcache: tests/datastore_memcache datastore_mysql: tests/datastore_mysql datastore_postgresql: tests/datastore_postgresql