diff --git a/newrelic/common/streaming_utils.py b/newrelic/common/streaming_utils.py index 8dda947eb8..ccd0b44efb 100644 --- a/newrelic/common/streaming_utils.py +++ b/newrelic/common/streaming_utils.py @@ -13,6 +13,7 @@ # limitations under the License. import collections +import logging import threading try: @@ -20,9 +21,10 @@ except: AttributeValue = None +_logger = logging.getLogger(__name__) -class StreamBuffer(object): +class StreamBuffer(object): def __init__(self, maxlen): self._queue = collections.deque(maxlen=maxlen) self._notify = self.condition() @@ -64,18 +66,46 @@ def stats(self): return seen, dropped - def __next__(self): - while True: - if self._shutdown: - raise StopIteration + def __iter__(self): + return StreamBufferIterator(self) + - try: - return self._queue.popleft() - except IndexError: - pass +class StreamBufferIterator(object): + def __init__(self, stream_buffer): + self.stream_buffer = stream_buffer + self._notify = self.stream_buffer._notify + self._shutdown = False + self._stream = None - with self._notify: - if not self._shutdown and not self._queue: + def shutdown(self): + with self._notify: + self._shutdown = True + self._notify.notify_all() + + def stream_closed(self): + return self._shutdown or self.stream_buffer._shutdown or (self._stream and self._stream.done()) + + def __next__(self): + with self._notify: + while True: + # When a gRPC stream receives a server side disconnect (usually in the form of an OK code) + # the item it is waiting to consume from the iterator will not be sent, and will inevitably + # be lost. To prevent this, StopIteration is raised by shutting down the iterator and + # notifying to allow the thread to exit. Iterators cannot be reused or race conditions may + # occur between iterator shutdown and restart, so a new iterator must be created from the + # streaming buffer. + if self.stream_closed(): + _logger.debug("gRPC stream is closed. Shutting down and refusing to iterate.") + if not self._shutdown: + self.shutdown() + raise StopIteration + + try: + return self.stream_buffer._queue.popleft() + except IndexError: + pass + + if not self.stream_closed() and not self.stream_buffer._queue: self._notify.wait() next = __next__ @@ -90,10 +120,8 @@ def __init__(self, *args, **kwargs): if args: arg = args[0] if len(args) > 1: - raise TypeError( - "SpanProtoAttrs expected at most 1 argument, got %d", - len(args)) - elif hasattr(arg, 'keys'): + raise TypeError("SpanProtoAttrs expected at most 1 argument, got %d", len(args)) + elif hasattr(arg, "keys"): for k in arg: self[k] = arg[k] else: @@ -104,8 +132,7 @@ def __init__(self, *args, **kwargs): self[k] = kwargs[k] def __setitem__(self, key, value): - super(SpanProtoAttrs, self).__setitem__(key, - SpanProtoAttrs.get_attribute_value(value)) + super(SpanProtoAttrs, self).__setitem__(key, SpanProtoAttrs.get_attribute_value(value)) def copy(self): copy = SpanProtoAttrs() diff --git a/newrelic/core/agent_streaming.py b/newrelic/core/agent_streaming.py index b25bad0f3f..9ba88799eb 100644 --- a/newrelic/core/agent_streaming.py +++ b/newrelic/core/agent_streaming.py @@ -48,7 +48,8 @@ def __init__(self, endpoint, stream_buffer, metadata, record_metric, ssl=True): self._endpoint = endpoint self._ssl = ssl self.metadata = metadata - self.request_iterator = stream_buffer + self.stream_buffer = stream_buffer + self.request_iterator = iter(stream_buffer) self.response_processing_thread = threading.Thread( target=self.process_responses, name="NR-StreamingRpc-process-responses" ) @@ -68,6 +69,12 @@ def create_channel(self): self.rpc = self.channel.stream_stream(self.PATH, Span.SerializeToString, RecordStatus.FromString) + def create_response_iterator(self): + with self.stream_buffer._notify: + self.request_iterator = iter(self.stream_buffer) + self.request_iterator._stream = reponse_iterator = self.rpc(self.request_iterator, metadata=self.metadata) + return reponse_iterator + @staticmethod def condition(*args, **kwargs): return threading.Condition(*args, **kwargs) @@ -114,6 +121,12 @@ def process_responses(self): "response code. The agent will attempt " "to reestablish the stream immediately." ) + + # Reconnect channel for load balancing + self.request_iterator.shutdown() + self.channel.close() + self.create_channel() + else: self.record_metric( "Supportability/InfiniteTracing/Span/Response/Error", @@ -153,6 +166,7 @@ def process_responses(self): ) # Reconnect channel with backoff + self.request_iterator.shutdown() self.channel.close() self.notify.wait(retry_time) if self.closed: @@ -164,7 +178,8 @@ def process_responses(self): if self.closed: break - response_iterator = self.rpc(self.request_iterator, metadata=self.metadata) + response_iterator = self.create_response_iterator() + _logger.info("Streaming RPC connect completed.") try: