Skip to content

Commit

Permalink
Merge pull request #874 from Cysharp/feature/JsonTranscodingCallContext
Browse files Browse the repository at this point in the history
Add support for some members of JsonTranscodingServerCallContext.
  • Loading branch information
mayuki authored Dec 2, 2024
2 parents 3d6e5e3 + b2e85e7 commit bef83f6
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public void BindUnary<TRequest, TResponse, TRawRequest, TRawResponse>(IMagicOnio

context.AddMethod(grpcMethod, RoutePatternFactory.Parse(routePath), metadata, async (context) =>
{
var serverCallContext = new MagicOnionJsonTranscodingServerCallContext(method);
var serverCallContext = new MagicOnionJsonTranscodingServerCallContext(context, method);

// Grpc.AspNetCore.Server expects that UserState has the key "__HttpContext" and that HttpContext is set to it.
// https://github.com/grpc/grpc-dotnet/blob/5a58c24efc1d0b7c5ff88e7b0582ea891b90b17f/src/Grpc.AspNetCore.Server/ServerCallContextExtensions.cs#L30
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,77 @@
using System.Net.Sockets;
using Grpc.AspNetCore.Server;
using Grpc.Core;
using MagicOnion.Server.Binder;
using Microsoft.AspNetCore.Http;

namespace MagicOnion.Server.JsonTranscoding;

public class MagicOnionJsonTranscodingServerCallContext(IMagicOnionGrpcMethod method) : ServerCallContext, IServerCallContextFeature
public class MagicOnionJsonTranscodingServerCallContext(HttpContext httpContext, IMagicOnionGrpcMethod method) : ServerCallContext, IServerCallContextFeature
{
protected override Task WriteResponseHeadersAsyncCore(Metadata responseHeaders) => throw new NotImplementedException();
Metadata? requestHeaders;

public ServerCallContext ServerCallContext => this;

protected override Task WriteResponseHeadersAsyncCore(Metadata responseHeaders)
{
foreach (var header in responseHeaders)
{
var key = header.IsBinary ? header.Key + "-bin" : header.Key;
var value = header.IsBinary ? Convert.ToBase64String(header.ValueBytes) : header.Value;

httpContext.Response.Headers.TryAdd(key, value);
}

return Task.CompletedTask;
}

protected override ContextPropagationToken CreatePropagationTokenCore(ContextPropagationOptions? options) => throw new NotImplementedException();

protected override string MethodCore { get; } = $"{method.ServiceName}/{method.MethodName}";
protected override string HostCore => throw new NotImplementedException();
protected override string PeerCore => throw new NotImplementedException();
protected override DateTime DeadlineCore => throw new NotImplementedException();
protected override Metadata RequestHeadersCore => throw new NotImplementedException();
protected override CancellationToken CancellationTokenCore => throw new NotImplementedException();

protected override string HostCore { get; } = httpContext.Request.Host.Value ?? string.Empty;

protected override string PeerCore { get; } = httpContext.Connection.RemoteIpAddress switch
{
{ AddressFamily: AddressFamily.InterNetwork } => $"ipv4:{httpContext.Connection.RemoteIpAddress}:{httpContext.Connection.RemotePort}",
{ AddressFamily: AddressFamily.InterNetworkV6 } => $"ipv6:{httpContext.Connection.RemoteIpAddress}:{httpContext.Connection.RemotePort}",
{ } => $"unknown:{httpContext.Connection.RemoteIpAddress}:{httpContext.Connection.RemotePort}",
_ => "unknown"
};

protected override DateTime DeadlineCore => DateTime.MaxValue; // No deadline

protected override Metadata RequestHeadersCore
{
get
{
if (requestHeaders is null)
{
requestHeaders = new Metadata();
foreach (var header in httpContext.Request.Headers)
{
var key = header.Key;
var value = header.Value;
if (key.EndsWith("-bin"))
{
key = key.Substring(0, key.Length - 4);
requestHeaders.Add(key, Convert.FromBase64String(value.ToString()));
}
else
{
requestHeaders.Add(key, value.ToString());
}
}
}

return requestHeaders;
}
}

protected override CancellationToken CancellationTokenCore => httpContext.RequestAborted;
protected override Metadata ResponseTrailersCore => throw new NotImplementedException();
protected override Status StatusCore { get; set; }
protected override WriteOptions? WriteOptionsCore { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }
protected override AuthContext AuthContextCore => throw new NotImplementedException();
public ServerCallContext ServerCallContext => this;

}
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,42 @@ public async Task ThrowWithReturnStatusCode()
}

record ErrorResponse(int Code, string Detail);

[Fact]
public async Task CallContextInfo()
{
// Arrange
var httpClient = factory.CreateDefaultClient();

// Act
var response = await httpClient.PostAsync($"http://localhost/webapi/ITestService/CallContextInfo", new StringContent(string.Empty, new MediaTypeHeaderValue("application/json")));
var content = await response.Content.ReadAsStringAsync();

// Assert
var dict = JsonSerializer.Deserialize<Dictionary<string, string>>(content) ?? throw new InvalidOperationException("Failed to deserialize a response.");
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Equal("localhost", dict["Host"]);
Assert.Equal("unknown", dict["Peer"]);
Assert.Equal("ITestService/CallContextInfo", dict["Method"]);
}

[Fact]
public async Task CallContext_WriteResponseHeader()
{
// Arrange
var httpClient = factory.CreateDefaultClient();
var requestBody = """
{ "key": "x-test-header", "value": "12345" }
""";

// Act
var response = await httpClient.PostAsync($"http://localhost/webapi/ITestService/CallContext_WriteResponseHeader", new StringContent(requestBody, new MediaTypeHeaderValue("application/json")));
var content = await response.Content.ReadAsStringAsync();

// Assert
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Equal("12345", response.Headers.TryGetValues("x-test-header", out var values) ? string.Join(",", values) : null);
}
}

public interface ITestService : IService<ITestService>
Expand All @@ -251,6 +287,9 @@ public interface ITestService : IService<ITestService>

UnaryResult ThrowAsync();
UnaryResult ThrowWithReturnStatusCodeAsync(int statusCode, string detail);

UnaryResult<Dictionary<string, string>> CallContextInfo();
UnaryResult CallContext_WriteResponseHeader(string key, string value);
}

[MessagePackObject]
Expand Down Expand Up @@ -316,4 +355,23 @@ public UnaryResult ThrowWithReturnStatusCodeAsync(int statusCode, string detail)
{
throw new ReturnStatusException((StatusCode)statusCode, detail);
}

public UnaryResult<Dictionary<string, string>> CallContextInfo()
{
return UnaryResult.FromResult(new Dictionary<string, string>()
{
["Method"] = this.Context.CallContext.Method,
["Peer"] = this.Context.CallContext.Peer,
["Host"] = this.Context.CallContext.Host,
["Deadline.Ticks"] = this.Context.CallContext.Deadline.Ticks.ToString(),
["RequestHeaders"] = JsonSerializer.Serialize(this.Context.CallContext.RequestHeaders),
});
}

public async UnaryResult CallContext_WriteResponseHeader(string key, string value)
{
var metadata = new Metadata();
metadata.Add(key, value);
await Context.CallContext.WriteResponseHeadersAsync(metadata);
}
}

0 comments on commit bef83f6

Please sign in to comment.