Skip to content

Commit

Permalink
Merge branch 'shush' into shush-connect
Browse files Browse the repository at this point in the history
  • Loading branch information
ctiller committed Jan 24, 2024
2 parents 5abf497 + 11145c9 commit 4720863
Showing 1 changed file with 96 additions and 15 deletions.
111 changes: 96 additions & 15 deletions src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,25 @@

import argparse
import collections
from concurrent import futures
import concurrent.futures
import datetime
import logging
import signal
import sys
import threading
import time
from typing import DefaultDict, Dict, List, Mapping, Sequence, Set, Tuple
from typing import (
DefaultDict,
Dict,
Iterable,
List,
Mapping,
Sequence,
Set,
Tuple,
)

import grpc
from grpc import _typing as grpc_typing
import grpc_admin
from grpc_channelz.v1 import channelz

Expand Down Expand Up @@ -57,6 +66,12 @@

PerMethodMetadataType = Mapping[str, Sequence[Tuple[str, str]]]


# FutureFromCall is both a grpc.Call and grpc.Future
class FutureFromCallType(grpc.Call, grpc.Future):
pass


_CONFIG_CHANGE_TIMEOUT = datetime.timedelta(milliseconds=500)


Expand All @@ -69,8 +84,13 @@ class _StatsWatcher:
_no_remote_peer: int
_lock: threading.Lock
_condition: threading.Condition
_metadata_keys: frozenset
_include_all_metadata: bool
_metadata_by_peer: DefaultDict[
str, messages_pb2.LoadBalancerStatsResponse.MetadataByPeer
]

def __init__(self, start: int, end: int):
def __init__(self, start: int, end: int, metadata_keys: Iterable[str]):
self._start = start
self._end = end
self._rpcs_needed = end - start
Expand All @@ -80,8 +100,44 @@ def __init__(self, start: int, end: int):
)
self._condition = threading.Condition()
self._no_remote_peer = 0
self._metadata_keys = frozenset(
self._sanitize_metadata_key(key) for key in metadata_keys
)
self._include_all_metadata = "*" in self._metadata_keys
self._metadata_by_peer = collections.defaultdict(
messages_pb2.LoadBalancerStatsResponse.MetadataByPeer
)

@classmethod
def _sanitize_metadata_key(cls, metadata_key: str) -> str:
return metadata_key.strip().lower()

def on_rpc_complete(self, request_id: int, peer: str, method: str) -> None:
def _add_metadata(
self,
rpc_metadata: messages_pb2.LoadBalancerStatsResponse.RpcMetadata,
metadata_to_add: grpc_typing.MetadataType,
metadata_type: messages_pb2.LoadBalancerStatsResponse.MetadataType,
) -> None:
for key, value in metadata_to_add:
if (
self._include_all_metadata
or self._sanitize_metadata_key(key) in self._metadata_keys
):
rpc_metadata.metadata.append(
messages_pb2.LoadBalancerStatsResponse.MetadataEntry(
key=key, value=value, type=metadata_type
)
)

def on_rpc_complete(
self,
request_id: int,
peer: str,
method: str,
*,
initial_metadata: grpc_typing.MetadataType,
trailing_metadata: grpc_typing.MetadataType,
) -> None:
"""Records statistics for a single RPC."""
if self._start <= request_id < self._end:
with self._condition:
Expand All @@ -90,6 +146,23 @@ def on_rpc_complete(self, request_id: int, peer: str, method: str) -> None:
else:
self._rpcs_by_peer[peer] += 1
self._rpcs_by_method[method][peer] += 1
if self._metadata_keys:
rpc_metadata = (
messages_pb2.LoadBalancerStatsResponse.RpcMetadata()
)
self._add_metadata(
rpc_metadata,
initial_metadata,
messages_pb2.LoadBalancerStatsResponse.MetadataType.INITIAL,
)
self._add_metadata(
rpc_metadata,
trailing_metadata,
messages_pb2.LoadBalancerStatsResponse.MetadataType.TRAILING,
)
self._metadata_by_peer[peer].rpc_metadata.append(
rpc_metadata
)
self._rpcs_needed -= 1
self._condition.notify()

Expand All @@ -107,6 +180,8 @@ def await_rpc_stats_response(
for method, count_by_peer in self._rpcs_by_method.items():
for peer, count in count_by_peer.items():
response.rpcs_by_method[method].rpcs_by_peer[peer] = count
for peer, metadata_by_peer in self._metadata_by_peer.items():
response.metadatas_by_peer[peer].CopyFrom(metadata_by_peer)
response.num_failures = self._no_remote_peer + self._rpcs_needed
return response

Expand Down Expand Up @@ -150,7 +225,7 @@ def GetClientStats(
with _global_lock:
start = _global_rpc_id + 1
end = start + request.num_rpcs
watcher = _StatsWatcher(start, end)
watcher = _StatsWatcher(start, end, request.metadata_keys)
_watchers.add(watcher)
response = watcher.await_rpc_stats_response(request.timeout_sec)
with _global_lock:
Expand Down Expand Up @@ -192,7 +267,7 @@ def _start_rpc(
request_id: int,
stub: test_pb2_grpc.TestServiceStub,
timeout: float,
futures: Mapping[int, Tuple[grpc.Future, str]],
futures: Mapping[int, Tuple[FutureFromCallType, str]],
) -> None:
logger.debug(f"Sending {method} request to backend: {request_id}")
if method == "UnaryCall":
Expand All @@ -209,7 +284,7 @@ def _start_rpc(


def _on_rpc_done(
rpc_id: int, future: grpc.Future, method: str, print_response: bool
rpc_id: int, future: FutureFromCallType, method: str, print_response: bool
) -> None:
exception = future.exception()
hostname = ""
Expand Down Expand Up @@ -241,23 +316,29 @@ def _on_rpc_done(
if future.code() == grpc.StatusCode.OK:
logger.debug("Successful response.")
else:
logger.debug(f"RPC failed: {call}")
logger.debug(f"RPC failed: {rpc_id}")
with _global_lock:
for watcher in _watchers:
watcher.on_rpc_complete(rpc_id, hostname, method)
watcher.on_rpc_complete(
rpc_id,
hostname,
method,
initial_metadata=future.initial_metadata(),
trailing_metadata=future.trailing_metadata(),
)


def _remove_completed_rpcs(
futures: Mapping[int, grpc.Future], print_response: bool
rpc_futures: Mapping[int, FutureFromCallType], print_response: bool
) -> None:
logger.debug("Removing completed RPCs")
done = []
for future_id, (future, method) in futures.items():
for future_id, (future, method) in rpc_futures.items():
if future.done():
_on_rpc_done(future_id, future, method, args.print_response)
done.append(future_id)
for rpc_id in done:
del futures[rpc_id]
del rpc_futures[rpc_id]


def _cancel_all_rpcs(futures: Mapping[int, Tuple[grpc.Future, str]]) -> None:
Expand Down Expand Up @@ -309,7 +390,7 @@ def _run_single_channel(config: _ChannelConfiguration) -> None:
channel = grpc.insecure_channel(server)
with channel:
stub = test_pb2_grpc.TestServiceStub(channel)
futures: Dict[int, Tuple[grpc.Future, str]] = {}
futures: Dict[int, Tuple[FutureFromCallType, str]] = {}
while not _stop_event.is_set():
with config.condition:
if config.qps == 0:
Expand Down Expand Up @@ -438,7 +519,7 @@ def _run(
)
channel_configs[method] = channel_config
method_handles.append(_MethodHandle(args.num_channels, channel_config))
_global_server = grpc.server(futures.ThreadPoolExecutor())
_global_server = grpc.server(concurrent.futures.ThreadPoolExecutor())
_global_server.add_insecure_port(f"0.0.0.0:{args.stats_port}")
test_pb2_grpc.add_LoadBalancerStatsServiceServicer_to_server(
_LoadBalancerStatsServicer(), _global_server
Expand Down

0 comments on commit 4720863

Please sign in to comment.