Skip to content

Commit

Permalink
Merge pull request #885 from Cysharp/feature/ReuseStreamingHubContext
Browse files Browse the repository at this point in the history
Reuse SteramingHubContext
  • Loading branch information
mayuki authored Jan 6, 2025
2 parents 8d1e7d9 + ac257bf commit 2d88d28
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 97 deletions.
155 changes: 86 additions & 69 deletions src/MagicOnion.Server/Hubs/StreamingHub.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public abstract class StreamingHubBase<THubInterface, TReceiver> : ServiceBase<T
StreamingHubHeartbeatHandle heartbeatHandle = default!;
TimeProvider timeProvider = default!;
bool isReturnExceptionStackTraceInErrorDetail = false;
UniqueHashDictionary<StreamingHubHandler> handlers = default!;

protected static readonly Task<Nil> NilTask = Task.FromResult(Nil.Default);
protected static readonly ValueTask CompletedTask = new ValueTask();
Expand All @@ -45,7 +46,7 @@ public abstract class StreamingHubBase<THubInterface, TReceiver> : ServiceBase<T
AllowSynchronousContinuations = false,
FullMode = BoundedChannelFullMode.Wait,
SingleReader = true,
SingleWriter = false,
SingleWriter = true,
});

public HubGroupRepository<TReceiver> Group { get; private set; } = default!;
Expand Down Expand Up @@ -90,11 +91,13 @@ async Task<DuplexStreamingResult<StreamingHubPayload, StreamingHubPayload>> IStr
var serviceProvider = streamingContext.ServiceContext.ServiceProvider;

var features = this.Context.CallContext.GetHttpContext().Features;
streamingHubFeature = features.Get<IStreamingHubFeature>()!; // TODO: GetRequiredFeature
streamingHubFeature = features.GetRequiredFeature<IStreamingHubFeature>();
var magicOnionOptions = serviceProvider.GetRequiredService<IOptions<MagicOnionOptions>>().Value;
timeProvider = magicOnionOptions.TimeProvider ?? TimeProvider.System;
isReturnExceptionStackTraceInErrorDetail = magicOnionOptions.IsReturnExceptionStackTraceInErrorDetail;

handlers = streamingHubFeature.Handlers;

var remoteProxyFactory = serviceProvider.GetRequiredService<IRemoteProxyFactory>();
var remoteSerializer = serviceProvider.GetRequiredService<IRemoteSerializer>();
this.remoteClientResultPendingTasks = new RemoteClientResultPendingTaskRegistry(magicOnionOptions.ClientResultsDefaultTimeout, timeProvider);
Expand Down Expand Up @@ -169,32 +172,66 @@ async Task HandleMessageAsync()
// eg: Send the current game state to the client.
await OnConnected();

var handlers = streamingHubFeature.Handlers;

// Starts a loop that consumes the request queue.
var consumeRequestsTask = ConsumeRequestQueueAsync(ct);
_ = ConsumeRequestQueueAsync();

// Main loop of StreamingHub.
// Be careful to allocation and performance.
while (await reader.MoveNext(ct))
{
var payload = reader.Current;

await ProcessMessageAsync(payload, handlers, ct);
await ProcessMessageAsync(payload, ct);

// NOTE: DO NOT return the StreamingHubPayload to the pool here.
// Client requests may be pending at this point.
}
}

async ValueTask ConsumeRequestQueueAsync(CancellationToken cancellationToken)
async ValueTask ConsumeRequestQueueAsync()
{
// Create and reuse a single StreamingHubContext for each hub connection.
var hubContext = new StreamingHubContext();

// We need to process client requests sequentially.
await foreach (var request in requests.Reader.ReadAllAsync(cancellationToken))
// NOTE: Do not pass a CancellationToken to avoid allocation. We call Writer.Complete when we want to stop the consumption loop.
await foreach (var request in requests.Reader.ReadAllAsync(default))
{
try
{
await ProcessRequestAsync(request.Handlers, request.MethodId, request.MessageId, request.Body, request.HasResponse);
if (handlers.TryGetValue(request.MethodId, out var handler))
{
hubContext.Initialize(
handler: handler,
streamingServiceContext: (IStreamingServiceContext<StreamingHubPayload, StreamingHubPayload>)Context,
hubInstance: this,
request: request.Body,
messageId: request.MessageId,
timestamp: timeProvider.GetUtcNow().UtcDateTime
);

var isErrorOrInterrupted = false;
var methodStartingTimestamp = timeProvider.GetTimestamp();
MagicOnionServerLog.BeginInvokeHubMethod(Context.Logger, hubContext, hubContext.Request, handler.RequestType);

try
{
await handler.MethodBody.Invoke(hubContext);
}
catch (Exception ex)
{
isErrorOrInterrupted = true;
HandleException(hubContext, ex, request.HasResponse);
}
finally
{
CleanupRequest(hubContext, methodStartingTimestamp, isErrorOrInterrupted);
}
}
else
{
RespondMethodNotFound(request.MethodId, request.MessageId);
}
}
finally
{
Expand All @@ -203,7 +240,46 @@ async ValueTask ConsumeRequestQueueAsync(CancellationToken cancellationToken)
}
}

ValueTask ProcessMessageAsync(StreamingHubPayload payload, UniqueHashDictionary<StreamingHubHandler> handlers, CancellationToken cancellationToken)
void HandleException(StreamingHubContext hubContext, Exception ex, bool hasResponse)
{
if (ex is ReturnStatusException rse)
{
if (hasResponse)
{
hubContext.WriteErrorMessage((int)rse.StatusCode, rse.Detail, null, false);
}
}
else
{
MagicOnionServerLog.Error(Context.Logger, ex, hubContext);
Metrics.StreamingHubException(Context.Metrics, hubContext.Handler, ex);

if (hasResponse)
{
hubContext.WriteErrorMessage((int)StatusCode.Internal, $"An error occurred while processing handler '{hubContext.Handler}'.", ex, isReturnExceptionStackTraceInErrorDetail);
}
}
}

void CleanupRequest(StreamingHubContext hubContext, long methodStartingTimestamp, bool isErrorOrInterrupted)
{
var methodEndingTimestamp = timeProvider.GetTimestamp();
var elapsed = timeProvider.GetElapsedTime(methodStartingTimestamp, methodEndingTimestamp);
MagicOnionServerLog.EndInvokeHubMethod(Context.Logger, hubContext, hubContext.ResponseSize, hubContext.ResponseType, elapsed.TotalMilliseconds, isErrorOrInterrupted);
Metrics.StreamingHubMethodCompleted(Context.Metrics, hubContext.Handler, methodStartingTimestamp, methodEndingTimestamp, isErrorOrInterrupted);

hubContext.Uninitialize();
}

void RespondMethodNotFound(int methodId, int messageId)
{
MagicOnionServerLog.HubMethodNotFound(Context.Logger, Context.ServiceName, methodId);
var payload = StreamingHubPayloadBuilder.BuildError(messageId, (int)StatusCode.Unimplemented, $"StreamingHub method '{methodId}' is not found in StreamingHub.", null, isReturnExceptionStackTraceInErrorDetail);
StreamingServiceContext.QueueResponseStreamWrite(payload);
}


ValueTask ProcessMessageAsync(StreamingHubPayload payload, CancellationToken cancellationToken)
{
var reader = new StreamingHubServerMessageReader(payload.Memory);
var messageType = reader.ReadMessageType();
Expand Down Expand Up @@ -260,65 +336,6 @@ ValueTask ProcessMessageAsync(StreamingHubPayload payload, UniqueHashDictionary<
}
}

[AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder))]
async ValueTask ProcessRequestAsync(UniqueHashDictionary<StreamingHubHandler> handlers, int methodId, int messageId, ReadOnlyMemory<byte> body, bool hasResponse)
{
if (handlers.TryGetValue(methodId, out var handler))
{
// Create a context for each call to the hub method.
var context = StreamingHubContextPool.Shared.Get();
context.Initialize(
handler: handler,
streamingServiceContext: (IStreamingServiceContext<StreamingHubPayload, StreamingHubPayload>)Context,
hubInstance: this,
request: body,
messageId: messageId,
timestamp: timeProvider.GetUtcNow().UtcDateTime
);

var methodStartingTimestamp = timeProvider.GetTimestamp();
var isErrorOrInterrupted = false;
MagicOnionServerLog.BeginInvokeHubMethod(Context.Logger, context, context.Request, handler.RequestType);
try
{
await handler.MethodBody.Invoke(context);
}
catch (ReturnStatusException ex)
{
if (hasResponse)
{
await context.WriteErrorMessage((int)ex.StatusCode, ex.Detail, null, false);
}
}
catch (Exception ex)
{
isErrorOrInterrupted = true;
MagicOnionServerLog.Error(Context.Logger, ex, context);
Metrics.StreamingHubException(Context.Metrics, handler, ex);

if (hasResponse)
{
await context.WriteErrorMessage((int)StatusCode.Internal, $"An error occurred while processing handler '{handler.ToString()}'.", ex, isReturnExceptionStackTraceInErrorDetail);
}
}
finally
{
var methodEndingTimestamp = timeProvider.GetTimestamp();
MagicOnionServerLog.EndInvokeHubMethod(Context.Logger, context, context.ResponseSize, context.ResponseType, timeProvider.GetElapsedTime(methodStartingTimestamp, methodEndingTimestamp).TotalMilliseconds, isErrorOrInterrupted);
Metrics.StreamingHubMethodCompleted(Context.Metrics, handler, methodStartingTimestamp, methodEndingTimestamp, isErrorOrInterrupted);

StreamingHubContextPool.Shared.Return(context);
}
}
else
{
MagicOnionServerLog.HubMethodNotFound(Context.Logger, Context.ServiceName, methodId);
var payload = StreamingHubPayloadBuilder.BuildError(messageId, (int)StatusCode.Unimplemented, $"StreamingHub method '{methodId}' is not found in StreamingHub.", null, isReturnExceptionStackTraceInErrorDetail);
StreamingServiceContext.QueueResponseStreamWrite(payload);
}
}


// Interface methods for Client

THubInterface IStreamingHub<THubInterface, TReceiver>.FireAndForget()
Expand Down
42 changes: 14 additions & 28 deletions src/MagicOnion.Server/Hubs/StreamingHubContext.cs
Original file line number Diff line number Diff line change
@@ -1,36 +1,11 @@
using MessagePack;
using System.Collections.Concurrent;
using System.Diagnostics;
using MagicOnion.Internal;
using Microsoft.Extensions.ObjectPool;
using MagicOnion.Server.Hubs.Internal;

namespace MagicOnion.Server.Hubs;

internal class StreamingHubContextPool
{
const int MaxRetainedCount = 16;
readonly ObjectPool<StreamingHubContext> pool = new DefaultObjectPool<StreamingHubContext>(new Policy(), MaxRetainedCount);

public static StreamingHubContextPool Shared { get; } = new();

public StreamingHubContext Get() => pool.Get();
public void Return(StreamingHubContext ctx) => pool.Return(ctx);

class Policy : IPooledObjectPolicy<StreamingHubContext>
{
public StreamingHubContext Create()
{
return new StreamingHubContext();
}

public bool Return(StreamingHubContext obj)
{
obj.Uninitialize();
return true;
}
}
}

public class StreamingHubContext
{
IStreamingServiceContext<StreamingHubPayload, StreamingHubPayload> streamingServiceContext = default!;
Expand Down Expand Up @@ -62,6 +37,7 @@ public ConcurrentDictionary<string, object> Items

public IServiceContext ServiceContext => streamingServiceContext;

internal StreamingHubHandler Handler => handler;
internal int MessageId { get; private set; }
internal int MethodId => handler.MethodId;

Expand All @@ -70,6 +46,11 @@ public ConcurrentDictionary<string, object> Items

internal void Initialize(StreamingHubHandler handler, IStreamingServiceContext<StreamingHubPayload, StreamingHubPayload> streamingServiceContext, object hubInstance, ReadOnlyMemory<byte> request, DateTime timestamp, int messageId)
{
#if DEBUG
Debug.Assert(this.handler is null);
Debug.Assert(this.streamingServiceContext is null);
Debug.Assert(this.HubInstance is null);
#endif
this.handler = handler;
this.streamingServiceContext = streamingServiceContext;
HubInstance = hubInstance;
Expand All @@ -80,6 +61,12 @@ internal void Initialize(StreamingHubHandler handler, IStreamingServiceContext<S

internal void Uninitialize()
{
#if DEBUG
Debug.Assert(this.handler is not null);
Debug.Assert(this.streamingServiceContext is not null);
Debug.Assert(this.HubInstance is not null);
#endif

handler = default!;
streamingServiceContext = default!;
HubInstance = default!;
Expand Down Expand Up @@ -135,10 +122,9 @@ static async ValueTask Await(StreamingHubContext ctx, ValueTask<T> value)
}
}

internal ValueTask WriteErrorMessage(int statusCode, string detail, Exception? ex, bool isReturnExceptionStackTraceInErrorDetail)
internal void WriteErrorMessage(int statusCode, string detail, Exception? ex, bool isReturnExceptionStackTraceInErrorDetail)
{
WriteMessageCore(StreamingHubPayloadBuilder.BuildError(MessageId, statusCode, detail, ex, isReturnExceptionStackTraceInErrorDetail));
return default;
}

void WriteMessageCore(StreamingHubPayload payload)
Expand Down

0 comments on commit 2d88d28

Please sign in to comment.