diff --git a/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py b/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py index 7ac32f108445d..d5e994da61156 100644 --- a/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py +++ b/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py @@ -14,16 +14,25 @@ import argparse import collections -from concurrent import futures +import concurrent.futures import datetime import logging import signal -import sys import threading import time -from typing import DefaultDict, Dict, List, Mapping, Sequence, Set, Tuple +from typing import ( + DefaultDict, + Dict, + Iterable, + List, + Mapping, + Sequence, + Set, + Tuple, +) import grpc +from grpc import _typing as grpc_typing import grpc_admin from grpc_channelz.v1 import channelz @@ -57,6 +66,12 @@ PerMethodMetadataType = Mapping[str, Sequence[Tuple[str, str]]] + +# FutureFromCall is both a grpc.Call and grpc.Future +class FutureFromCallType(grpc.Call, grpc.Future): + pass + + _CONFIG_CHANGE_TIMEOUT = datetime.timedelta(milliseconds=500) @@ -69,8 +84,13 @@ class _StatsWatcher: _no_remote_peer: int _lock: threading.Lock _condition: threading.Condition + _metadata_keys: frozenset + _include_all_metadata: bool + _metadata_by_peer: DefaultDict[ + str, messages_pb2.LoadBalancerStatsResponse.MetadataByPeer + ] - def __init__(self, start: int, end: int): + def __init__(self, start: int, end: int, metadata_keys: Iterable[str]): self._start = start self._end = end self._rpcs_needed = end - start @@ -80,8 +100,44 @@ def __init__(self, start: int, end: int): ) self._condition = threading.Condition() self._no_remote_peer = 0 + self._metadata_keys = frozenset( + self._sanitize_metadata_key(key) for key in metadata_keys + ) + self._include_all_metadata = "*" in self._metadata_keys + self._metadata_by_peer = collections.defaultdict( + messages_pb2.LoadBalancerStatsResponse.MetadataByPeer + ) + + @classmethod + def _sanitize_metadata_key(cls, metadata_key: str) -> str: + return metadata_key.strip().lower() - def on_rpc_complete(self, request_id: int, peer: str, method: str) -> None: + def _add_metadata( + self, + rpc_metadata: messages_pb2.LoadBalancerStatsResponse.RpcMetadata, + metadata_to_add: grpc_typing.MetadataType, + metadata_type: messages_pb2.LoadBalancerStatsResponse.MetadataType, + ) -> None: + for key, value in metadata_to_add: + if ( + self._include_all_metadata + or self._sanitize_metadata_key(key) in self._metadata_keys + ): + rpc_metadata.metadata.append( + messages_pb2.LoadBalancerStatsResponse.MetadataEntry( + key=key, value=value, type=metadata_type + ) + ) + + def on_rpc_complete( + self, + request_id: int, + peer: str, + method: str, + *, + initial_metadata: grpc_typing.MetadataType, + trailing_metadata: grpc_typing.MetadataType, + ) -> None: """Records statistics for a single RPC.""" if self._start <= request_id < self._end: with self._condition: @@ -90,6 +146,23 @@ def on_rpc_complete(self, request_id: int, peer: str, method: str) -> None: else: self._rpcs_by_peer[peer] += 1 self._rpcs_by_method[method][peer] += 1 + if self._metadata_keys: + rpc_metadata = ( + messages_pb2.LoadBalancerStatsResponse.RpcMetadata() + ) + self._add_metadata( + rpc_metadata, + initial_metadata, + messages_pb2.LoadBalancerStatsResponse.MetadataType.INITIAL, + ) + self._add_metadata( + rpc_metadata, + trailing_metadata, + messages_pb2.LoadBalancerStatsResponse.MetadataType.TRAILING, + ) + self._metadata_by_peer[peer].rpc_metadata.append( + rpc_metadata + ) self._rpcs_needed -= 1 self._condition.notify() @@ -107,6 +180,8 @@ def await_rpc_stats_response( for method, count_by_peer in self._rpcs_by_method.items(): for peer, count in count_by_peer.items(): response.rpcs_by_method[method].rpcs_by_peer[peer] = count + for peer, metadata_by_peer in self._metadata_by_peer.items(): + response.metadatas_by_peer[peer].CopyFrom(metadata_by_peer) response.num_failures = self._no_remote_peer + self._rpcs_needed return response @@ -150,7 +225,7 @@ def GetClientStats( with _global_lock: start = _global_rpc_id + 1 end = start + request.num_rpcs - watcher = _StatsWatcher(start, end) + watcher = _StatsWatcher(start, end, request.metadata_keys) _watchers.add(watcher) response = watcher.await_rpc_stats_response(request.timeout_sec) with _global_lock: @@ -192,7 +267,7 @@ def _start_rpc( request_id: int, stub: test_pb2_grpc.TestServiceStub, timeout: float, - futures: Mapping[int, Tuple[grpc.Future, str]], + futures: Mapping[int, Tuple[FutureFromCallType, str]], ) -> None: logger.debug(f"Sending {method} request to backend: {request_id}") if method == "UnaryCall": @@ -209,7 +284,7 @@ def _start_rpc( def _on_rpc_done( - rpc_id: int, future: grpc.Future, method: str, print_response: bool + rpc_id: int, future: FutureFromCallType, method: str, print_response: bool ) -> None: exception = future.exception() hostname = "" @@ -241,23 +316,29 @@ def _on_rpc_done( if future.code() == grpc.StatusCode.OK: logger.debug("Successful response.") else: - logger.debug(f"RPC failed: {call}") + logger.debug(f"RPC failed: {rpc_id}") with _global_lock: for watcher in _watchers: - watcher.on_rpc_complete(rpc_id, hostname, method) + watcher.on_rpc_complete( + rpc_id, + hostname, + method, + initial_metadata=future.initial_metadata(), + trailing_metadata=future.trailing_metadata(), + ) def _remove_completed_rpcs( - futures: Mapping[int, grpc.Future], print_response: bool + rpc_futures: Mapping[int, FutureFromCallType], print_response: bool ) -> None: logger.debug("Removing completed RPCs") done = [] - for future_id, (future, method) in futures.items(): + for future_id, (future, method) in rpc_futures.items(): if future.done(): _on_rpc_done(future_id, future, method, args.print_response) done.append(future_id) for rpc_id in done: - del futures[rpc_id] + del rpc_futures[rpc_id] def _cancel_all_rpcs(futures: Mapping[int, Tuple[grpc.Future, str]]) -> None: @@ -309,7 +390,7 @@ def _run_single_channel(config: _ChannelConfiguration) -> None: channel = grpc.insecure_channel(server) with channel: stub = test_pb2_grpc.TestServiceStub(channel) - futures: Dict[int, Tuple[grpc.Future, str]] = {} + futures: Dict[int, Tuple[FutureFromCallType, str]] = {} while not _stop_event.is_set(): with config.condition: if config.qps == 0: @@ -438,7 +519,7 @@ def _run( ) channel_configs[method] = channel_config method_handles.append(_MethodHandle(args.num_channels, channel_config)) - _global_server = grpc.server(futures.ThreadPoolExecutor()) + _global_server = grpc.server(concurrent.futures.ThreadPoolExecutor()) _global_server.add_insecure_port(f"0.0.0.0:{args.stats_port}") test_pb2_grpc.add_LoadBalancerStatsServiceServicer_to_server( _LoadBalancerStatsServicer(), _global_server