Skip to content

Commit

Permalink
add a means to interject a server trailing metadata filter into every…
Browse files Browse the repository at this point in the history
… partial interception stack
  • Loading branch information
ctiller committed Jan 30, 2025
1 parent f9b6316 commit 7ba4779
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 46 deletions.
29 changes: 26 additions & 3 deletions src/core/lib/transport/interception_chain.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,32 @@ class InterceptionChainBuilder final {

// Add a filter that just mutates client initial metadata.
template <typename F>
void AddOnClientInitialMetadata(F f) {
InterceptionChainBuilder& AddOnClientInitialMetadata(F f) {
stack_builder().AddOnClientInitialMetadata(std::move(f));
return *this;
}

// Add a filter that just mutates server trailing metadata.
template <typename F>
void AddOnServerTrailingMetadata(F f) {
InterceptionChainBuilder& AddOnServerTrailingMetadata(F f) {
stack_builder().AddOnServerTrailingMetadata(std::move(f));
return *this;
}

// Immediately: Call AddOnServerTrailingMetadata
// Then, or every interceptor added to the filter from this point on:
// Perform an AddOnServerTrailingMetadata() immediately after
// the interceptor was added - but only if other filters or interceptors
// are added below it.
template <typename F>
InterceptionChainBuilder& AddOnServerTrailingMetadataForEachInterceptor(F f) {
LOG(INFO) << "Add";
AddOnServerTrailingMetadata(f);
on_new_interception_tail_.emplace_back([f](InterceptionChainBuilder* b) {
LOG(INFO) << "AddAnother";
b->AddOnServerTrailingMetadata(f);
});
return *this;
}

void Fail(absl::Status status) {
Expand All @@ -223,7 +241,10 @@ class InterceptionChainBuilder final {

private:
CallFilters::StackBuilder& stack_builder() {
if (!stack_builder_.has_value()) stack_builder_.emplace();
if (!stack_builder_.has_value()) {
stack_builder_.emplace();
for (auto& f : on_new_interception_tail_) f(this);
}
return *stack_builder_;
}

Expand All @@ -249,6 +270,8 @@ class InterceptionChainBuilder final {
ChannelArgs args_;
std::optional<CallFilters::StackBuilder> stack_builder_;
RefCountedPtr<Interceptor> top_interceptor_;
std::vector<absl::AnyInvocable<void(InterceptionChainBuilder*)>>
on_new_interception_tail_;
absl::Status status_;
std::map<size_t, size_t> filter_type_counts_;
static std::atomic<size_t> next_filter_id_;
Expand Down
90 changes: 47 additions & 43 deletions test/core/transport/interception_chain_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ class TestFilter {
void OnClientInitialMetadata(ClientMetadata& md) {
AnnotatePassedThrough(md, I);
}
static const NoInterceptor OnServerInitialMetadata;
static const NoInterceptor OnClientToServerMessage;
static const NoInterceptor OnClientToServerHalfClose;
static const NoInterceptor OnServerToClientMessage;
static const NoInterceptor OnServerTrailingMetadata;
static const NoInterceptor OnFinalize;
static inline const NoInterceptor OnServerInitialMetadata;
static inline const NoInterceptor OnClientToServerMessage;
static inline const NoInterceptor OnClientToServerHalfClose;
static inline const NoInterceptor OnServerToClientMessage;
static inline const NoInterceptor OnServerTrailingMetadata;
static inline const NoInterceptor OnFinalize;
};

static absl::StatusOr<std::unique_ptr<TestFilter<I>>> Create(
Expand All @@ -99,19 +99,6 @@ class TestFilter {
std::unique_ptr<int> i_ = std::make_unique<int>(I);
};

template <int I>
const NoInterceptor TestFilter<I>::Call::OnServerInitialMetadata;
template <int I>
const NoInterceptor TestFilter<I>::Call::OnClientToServerMessage;
template <int I>
const NoInterceptor TestFilter<I>::Call::OnClientToServerHalfClose;
template <int I>
const NoInterceptor TestFilter<I>::Call::OnServerToClientMessage;
template <int I>
const NoInterceptor TestFilter<I>::Call::OnServerTrailingMetadata;
template <int I>
const NoInterceptor TestFilter<I>::Call::OnFinalize;

///////////////////////////////////////////////////////////////////////////////
// Test call filter that fails to instantiate

Expand All @@ -120,13 +107,13 @@ class FailsToInstantiateFilter {
public:
class Call {
public:
static const NoInterceptor OnClientInitialMetadata;
static const NoInterceptor OnServerInitialMetadata;
static const NoInterceptor OnClientToServerMessage;
static const NoInterceptor OnClientToServerHalfClose;
static const NoInterceptor OnServerToClientMessage;
static const NoInterceptor OnServerTrailingMetadata;
static const NoInterceptor OnFinalize;
static inline const NoInterceptor OnClientInitialMetadata;
static inline const NoInterceptor OnServerInitialMetadata;
static inline const NoInterceptor OnClientToServerMessage;
static inline const NoInterceptor OnClientToServerHalfClose;
static inline const NoInterceptor OnServerToClientMessage;
static inline const NoInterceptor OnServerTrailingMetadata;
static inline const NoInterceptor OnFinalize;
};

static absl::StatusOr<std::unique_ptr<FailsToInstantiateFilter<I>>> Create(
Expand All @@ -136,22 +123,6 @@ class FailsToInstantiateFilter {
}
};

template <int I>
const NoInterceptor FailsToInstantiateFilter<I>::Call::OnClientInitialMetadata;
template <int I>
const NoInterceptor FailsToInstantiateFilter<I>::Call::OnServerInitialMetadata;
template <int I>
const NoInterceptor FailsToInstantiateFilter<I>::Call::OnClientToServerMessage;
template <int I>
const NoInterceptor
FailsToInstantiateFilter<I>::Call::OnClientToServerHalfClose;
template <int I>
const NoInterceptor FailsToInstantiateFilter<I>::Call::OnServerToClientMessage;
template <int I>
const NoInterceptor FailsToInstantiateFilter<I>::Call::OnServerTrailingMetadata;
template <int I>
const NoInterceptor FailsToInstantiateFilter<I>::Call::OnFinalize;

///////////////////////////////////////////////////////////////////////////////
// Test call interceptor - consumes calls

Expand Down Expand Up @@ -279,7 +250,8 @@ class InterceptionChainTest : public ::testing::Test {
metadata_ = Arena::MakePooledForOverwrite<ClientMetadata>();
*metadata_ =
unstarted_call_handler.UnprocessedClientInitialMetadata().Copy();
unstarted_call_handler.PushServerTrailingMetadata(
auto handler = unstarted_call_handler.StartCall();
handler.PushServerTrailingMetadata(
ServerMetadataFromStatus(GRPC_STATUS_INTERNAL, "👊 cancelled"));
}

Expand Down Expand Up @@ -435,6 +407,38 @@ TEST_F(InterceptionChainTest, CreationOrderCorrect) {
CreationLogEntry{2, 1}));
}

TEST_F(InterceptionChainTest, AddOnServerTrailingMetadataForEachInterceptor) {
CreationLog log;
auto r =
InterceptionChainBuilder(ChannelArgs())
.AddOnServerTrailingMetadata([](ServerMetadata& md) {
md.Set(
GrpcMessageMetadata(),
Slice::FromCopiedString(absl::StrCat(
"0",
md.get_pointer(GrpcMessageMetadata())->as_string_view())));
})
.AddOnServerTrailingMetadataForEachInterceptor(
[](ServerMetadata& md) {
md.Set(GrpcMessageMetadata(),
Slice::FromCopiedString(absl::StrCat(
"x", md.get_pointer(GrpcMessageMetadata())
->as_string_view())));
})
.Add<TestPassThroughInterceptor<1>>()
.Add<TestPassThroughInterceptor<2>>()
.Add<TestPassThroughInterceptor<3>>()
.Build(destination());
ASSERT_TRUE(r.ok()) << r.status();
auto finished_call = RunCall(r.value().get());
EXPECT_EQ(finished_call.server_metadata->get(GrpcStatusMetadata()),
GRPC_STATUS_INTERNAL);
EXPECT_EQ(finished_call.server_metadata->get_pointer(GrpcMessageMetadata())
->as_string_view(),
"0xxx👊 cancelled");
EXPECT_NE(finished_call.client_metadata, nullptr);
}

} // namespace
} // namespace grpc_core

Expand Down

0 comments on commit 7ba4779

Please sign in to comment.