diff --git a/src/MagicOnion.Server.JsonTranscoding/MagicOnionJsonTranscodingGrpcMethodBinder.cs b/src/MagicOnion.Server.JsonTranscoding/MagicOnionJsonTranscodingGrpcMethodBinder.cs index e638274ba..3dc0a3cba 100644 --- a/src/MagicOnion.Server.JsonTranscoding/MagicOnionJsonTranscodingGrpcMethodBinder.cs +++ b/src/MagicOnion.Server.JsonTranscoding/MagicOnionJsonTranscodingGrpcMethodBinder.cs @@ -47,7 +47,7 @@ public void BindUnary(IMagicOnio context.AddMethod(grpcMethod, RoutePatternFactory.Parse(routePath), metadata, async (context) => { - var serverCallContext = new MagicOnionJsonTranscodingServerCallContext(method); + var serverCallContext = new MagicOnionJsonTranscodingServerCallContext(context, method); // Grpc.AspNetCore.Server expects that UserState has the key "__HttpContext" and that HttpContext is set to it. // https://github.com/grpc/grpc-dotnet/blob/5a58c24efc1d0b7c5ff88e7b0582ea891b90b17f/src/Grpc.AspNetCore.Server/ServerCallContextExtensions.cs#L30 diff --git a/src/MagicOnion.Server.JsonTranscoding/MagicOnionJsonTranscodingServerCallContext.cs b/src/MagicOnion.Server.JsonTranscoding/MagicOnionJsonTranscodingServerCallContext.cs index 10aca4e45..d8785e8a0 100644 --- a/src/MagicOnion.Server.JsonTranscoding/MagicOnionJsonTranscodingServerCallContext.cs +++ b/src/MagicOnion.Server.JsonTranscoding/MagicOnionJsonTranscodingServerCallContext.cs @@ -1,24 +1,77 @@ +using System.Net.Sockets; using Grpc.AspNetCore.Server; using Grpc.Core; using MagicOnion.Server.Binder; +using Microsoft.AspNetCore.Http; namespace MagicOnion.Server.JsonTranscoding; -public class MagicOnionJsonTranscodingServerCallContext(IMagicOnionGrpcMethod method) : ServerCallContext, IServerCallContextFeature +public class MagicOnionJsonTranscodingServerCallContext(HttpContext httpContext, IMagicOnionGrpcMethod method) : ServerCallContext, IServerCallContextFeature { - protected override Task WriteResponseHeadersAsyncCore(Metadata responseHeaders) => throw new NotImplementedException(); + Metadata? requestHeaders; + + public ServerCallContext ServerCallContext => this; + + protected override Task WriteResponseHeadersAsyncCore(Metadata responseHeaders) + { + foreach (var header in responseHeaders) + { + var key = header.IsBinary ? header.Key + "-bin" : header.Key; + var value = header.IsBinary ? Convert.ToBase64String(header.ValueBytes) : header.Value; + + httpContext.Response.Headers.TryAdd(key, value); + } + + return Task.CompletedTask; + } protected override ContextPropagationToken CreatePropagationTokenCore(ContextPropagationOptions? options) => throw new NotImplementedException(); protected override string MethodCore { get; } = $"{method.ServiceName}/{method.MethodName}"; - protected override string HostCore => throw new NotImplementedException(); - protected override string PeerCore => throw new NotImplementedException(); - protected override DateTime DeadlineCore => throw new NotImplementedException(); - protected override Metadata RequestHeadersCore => throw new NotImplementedException(); - protected override CancellationToken CancellationTokenCore => throw new NotImplementedException(); + + protected override string HostCore { get; } = httpContext.Request.Host.Value ?? string.Empty; + + protected override string PeerCore { get; } = httpContext.Connection.RemoteIpAddress switch + { + { AddressFamily: AddressFamily.InterNetwork } => $"ipv4:{httpContext.Connection.RemoteIpAddress}:{httpContext.Connection.RemotePort}", + { AddressFamily: AddressFamily.InterNetworkV6 } => $"ipv6:{httpContext.Connection.RemoteIpAddress}:{httpContext.Connection.RemotePort}", + { } => $"unknown:{httpContext.Connection.RemoteIpAddress}:{httpContext.Connection.RemotePort}", + _ => "unknown" + }; + + protected override DateTime DeadlineCore => DateTime.MaxValue; // No deadline + + protected override Metadata RequestHeadersCore + { + get + { + if (requestHeaders is null) + { + requestHeaders = new Metadata(); + foreach (var header in httpContext.Request.Headers) + { + var key = header.Key; + var value = header.Value; + if (key.EndsWith("-bin")) + { + key = key.Substring(0, key.Length - 4); + requestHeaders.Add(key, Convert.FromBase64String(value.ToString())); + } + else + { + requestHeaders.Add(key, value.ToString()); + } + } + } + + return requestHeaders; + } + } + + protected override CancellationToken CancellationTokenCore => httpContext.RequestAborted; protected override Metadata ResponseTrailersCore => throw new NotImplementedException(); protected override Status StatusCore { get; set; } protected override WriteOptions? WriteOptionsCore { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } protected override AuthContext AuthContextCore => throw new NotImplementedException(); - public ServerCallContext ServerCallContext => this; + } diff --git a/tests/MagicOnion.Server.JsonTranscoding.Tests/UnaryFunctionalTests.cs b/tests/MagicOnion.Server.JsonTranscoding.Tests/UnaryFunctionalTests.cs index 67213ecfd..0c3243966 100644 --- a/tests/MagicOnion.Server.JsonTranscoding.Tests/UnaryFunctionalTests.cs +++ b/tests/MagicOnion.Server.JsonTranscoding.Tests/UnaryFunctionalTests.cs @@ -237,6 +237,42 @@ public async Task ThrowWithReturnStatusCode() } record ErrorResponse(int Code, string Detail); + + [Fact] + public async Task CallContextInfo() + { + // Arrange + var httpClient = factory.CreateDefaultClient(); + + // Act + var response = await httpClient.PostAsync($"http://localhost/webapi/ITestService/CallContextInfo", new StringContent(string.Empty, new MediaTypeHeaderValue("application/json"))); + var content = await response.Content.ReadAsStringAsync(); + + // Assert + var dict = JsonSerializer.Deserialize>(content) ?? throw new InvalidOperationException("Failed to deserialize a response."); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.Equal("localhost", dict["Host"]); + Assert.Equal("unknown", dict["Peer"]); + Assert.Equal("ITestService/CallContextInfo", dict["Method"]); + } + + [Fact] + public async Task CallContext_WriteResponseHeader() + { + // Arrange + var httpClient = factory.CreateDefaultClient(); + var requestBody = """ + { "key": "x-test-header", "value": "12345" } + """; + + // Act + var response = await httpClient.PostAsync($"http://localhost/webapi/ITestService/CallContext_WriteResponseHeader", new StringContent(requestBody, new MediaTypeHeaderValue("application/json"))); + var content = await response.Content.ReadAsStringAsync(); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.Equal("12345", response.Headers.TryGetValues("x-test-header", out var values) ? string.Join(",", values) : null); + } } public interface ITestService : IService @@ -251,6 +287,9 @@ public interface ITestService : IService UnaryResult ThrowAsync(); UnaryResult ThrowWithReturnStatusCodeAsync(int statusCode, string detail); + + UnaryResult> CallContextInfo(); + UnaryResult CallContext_WriteResponseHeader(string key, string value); } [MessagePackObject] @@ -316,4 +355,23 @@ public UnaryResult ThrowWithReturnStatusCodeAsync(int statusCode, string detail) { throw new ReturnStatusException((StatusCode)statusCode, detail); } + + public UnaryResult> CallContextInfo() + { + return UnaryResult.FromResult(new Dictionary() + { + ["Method"] = this.Context.CallContext.Method, + ["Peer"] = this.Context.CallContext.Peer, + ["Host"] = this.Context.CallContext.Host, + ["Deadline.Ticks"] = this.Context.CallContext.Deadline.Ticks.ToString(), + ["RequestHeaders"] = JsonSerializer.Serialize(this.Context.CallContext.RequestHeaders), + }); + } + + public async UnaryResult CallContext_WriteResponseHeader(string key, string value) + { + var metadata = new Metadata(); + metadata.Add(key, value); + await Context.CallContext.WriteResponseHeadersAsync(metadata); + } }