From 484541afd2abc285215334d62c07a1cd2f76ae47 Mon Sep 17 00:00:00 2001 From: Craig Tiller Date: Thu, 11 Jan 2024 23:59:18 +0000 Subject: [PATCH] x --- src/core/lib/transport/call_filters.cc | 8 +- src/core/lib/transport/call_filters.h | 33 +++--- test/core/transport/call_filters_test.cc | 130 +++++++++++++++++++++++ 3 files changed, 150 insertions(+), 21 deletions(-) diff --git a/src/core/lib/transport/call_filters.cc b/src/core/lib/transport/call_filters.cc index bada7d22afb7a..7379f6d98240b 100644 --- a/src/core/lib/transport/call_filters.cc +++ b/src/core/lib/transport/call_filters.cc @@ -167,8 +167,8 @@ CallFilters::CallFilters() : stack_(nullptr), call_data_(nullptr) {} CallFilters::CallFilters(RefCountedPtr stack) : stack_(std::move(stack)), - call_data_(gpr_malloc_aligned(stack->data_.call_data_size, - stack->data_.call_data_alignment)) { + call_data_(gpr_malloc_aligned(stack_->data_.call_data_size, + stack_->data_.call_data_alignment)) { client_initial_metadata_state_.Start(); client_to_server_message_state_.Start(); server_initial_metadata_state_.Start(); @@ -182,8 +182,8 @@ CallFilters::~CallFilters() { void CallFilters::SetStack(RefCountedPtr stack) { if (call_data_ != nullptr) gpr_free_aligned(call_data_); stack_ = std::move(stack); - call_data_ = gpr_malloc_aligned(stack->data_.call_data_size, - stack->data_.call_data_alignment); + call_data_ = gpr_malloc_aligned(stack_->data_.call_data_size, + stack_->data_.call_data_alignment); client_initial_metadata_state_.Start(); client_to_server_message_state_.Start(); server_initial_metadata_state_.Start(); diff --git a/src/core/lib/transport/call_filters.h b/src/core/lib/transport/call_filters.h index e913349690bdf..ae8d84d7af508 100644 --- a/src/core/lib/transport/call_filters.h +++ b/src/core/lib/transport/call_filters.h @@ -696,7 +696,7 @@ struct StackData { template void AddFinalizer(FilterType* channel_data, size_t call_offset, void (FilterType::Call::*p)(const grpc_call_final_info*)) { - GPR_DEBUG_ASSERT(p == &FilterType::OnFinalize); + GPR_DEBUG_ASSERT(p == &FilterType::Call::OnFinalize); finalizers.push_back(Finalizer{ channel_data, call_offset, @@ -711,7 +711,7 @@ struct StackData { void AddFinalizer(FilterType* channel_data, size_t call_offset, void (FilterType::Call::*p)(const grpc_call_final_info*, FilterType*)) { - GPR_DEBUG_ASSERT(p == &FilterType::OnFinalize); + GPR_DEBUG_ASSERT(p == &FilterType::Call::OnFinalize); finalizers.push_back(Finalizer{ channel_data, call_offset, @@ -921,17 +921,12 @@ class CallFilters { template void Add(FilterType* filter) { const size_t call_offset = data_.AddFilter(filter); - data_.AddClientInitialMetadataOp(filter, call_offset, - &FilterType::OnClientInitialMetadata); - data_.AddServerInitialMetadataOp(filter, call_offset, - &FilterType::OnServerInitialMetadata); - data_.AddClientToServerMessageOp(filter, call_offset, - &FilterType::OnClientToServerMessage); - data_.AddServerToClientMessageOp(filter, call_offset, - &FilterType::OnServerToClientMessage); - data_.AddServerTrailingMetadataOp(filter, call_offset, - &FilterType::OnServerTrailingMetadata); - data_.AddFinalizer(filter, call_offset, &FilterType::OnFinalize); + data_.AddClientInitialMetadataOp(filter, call_offset); + data_.AddServerInitialMetadataOp(filter, call_offset); + data_.AddClientToServerMessageOp(filter, call_offset); + data_.AddServerToClientMessageOp(filter, call_offset); + data_.AddServerTrailingMetadataOp(filter, call_offset); + data_.AddFinalizer(filter, call_offset, &FilterType::Call::OnFinalize); } RefCountedPtr Build(); @@ -1034,27 +1029,31 @@ class CallFilters { if (transformer_.IsRunning()) { return FinishPipeTransformer(transformer_.Step(filters_->call_data_)); } - auto p = state().PollPullValue(); + auto p = state().PollPull(); auto* r = p.value_if_ready(); if (r == nullptr) return Pending{}; if (!r->ok()) { filters_->CancelDueToFailedPipeOperation(); return Failure{}; } - return FinishPipeTransformer( - transformer_.Start(push()->TakeValue(), filters_->call_data_)); + return FinishPipeTransformer(transformer_.Start( + layout(), push()->TakeValue(), filters_->call_data_)); } private: filters_detail::PipeState& state() { return filters_->*state_ptr; } Push* push() { return static_cast(filters_->*push_ptr); } + const filters_detail::Layout>* + layout() { + return &(filters_->stack_->data_.*layout_ptr); + } Poll> FinishPipeTransformer( Poll> p) { auto* r = p.value_if_ready(); if (r == nullptr) return Pending{}; GPR_DEBUG_ASSERT(!transformer_.IsRunning()); - state().AckPullValue(); + state().AckPull(); if (r->ok != nullptr) return std::move(r->ok); filters_->PushServerTrailingMetadata(std::move(r->error)); return Failure{}; diff --git a/test/core/transport/call_filters_test.cc b/test/core/transport/call_filters_test.cc index 1cc7bb0757353..cdb1c61bc4fcb 100644 --- a/test/core/transport/call_filters_test.cc +++ b/test/core/transport/call_filters_test.cc @@ -65,6 +65,14 @@ MATCHER(IsPending, "") { return true; } +MATCHER(IsReady, "") { + if (arg.pending()) { + *result_listener << "is pending"; + return false; + } + return true; +} + MATCHER_P(IsReady, value, "") { if (arg.pending()) { *result_listener << "is pending"; @@ -1328,6 +1336,128 @@ TEST(PipeStateTest, DropProcessing) { } // namespace filters_detail +/////////////////////////////////////////////////////////////////////////////// +// CallFilters + +TEST(CallFiltersTest, CanBuildStack) { + struct Filter { + struct Call { + void OnClientInitialMetadata(ClientMetadata& md) {} + void OnServerInitialMetadata(ServerMetadata& md) {} + void OnClientToServerMessage(Message& message) {} + void OnServerToClientMessage(Message& message) {} + void OnServerTrailingMetadata(ServerMetadata& md) {} + void OnFinalize(const grpc_call_final_info*) {} + }; + }; + CallFilters::StackBuilder builder; + Filter f; + builder.Add(&f); + auto stack = builder.Build(); + EXPECT_NE(stack, nullptr); +} + +TEST(CallFiltersTest, UnaryCall) { + struct Filter { + struct Call { + void OnClientInitialMetadata(ClientMetadata& md, Filter* f) { + f->steps.push_back(absl::StrCat(f->label, ":OnClientInitialMetadata")); + } + void OnServerInitialMetadata(ServerMetadata& md, Filter* f) { + f->steps.push_back(absl::StrCat(f->label, ":OnServerInitialMetadata")); + } + void OnClientToServerMessage(Message& message, Filter* f) { + f->steps.push_back(absl::StrCat(f->label, ":OnClientToServerMessage")); + } + void OnServerToClientMessage(Message& message, Filter* f) { + f->steps.push_back(absl::StrCat(f->label, ":OnServerToClientMessage")); + } + void OnServerTrailingMetadata(ServerMetadata& md, Filter* f) { + f->steps.push_back(absl::StrCat(f->label, ":OnServerTrailingMetadata")); + } + void OnFinalize(const grpc_call_final_info*, Filter* f) { + f->steps.push_back(absl::StrCat(f->label, ":OnFinalize")); + } + }; + + const std::string label; + std::vector& steps; + }; + std::vector steps; + Filter f1{"f1", steps}; + Filter f2{"f2", steps}; + CallFilters::StackBuilder builder; + builder.Add(&f1); + builder.Add(&f2); + CallFilters filters(builder.Build()); + auto memory_allocator = + MakeMemoryQuota("test-quota")->CreateMemoryAllocator("foo"); + auto arena = MakeScopedArena(1024, &memory_allocator); + promise_detail::Context ctx(arena.get()); + StrictMock activity; + activity.Activate(); + // Push client initial metadata + auto push_client_initial_metadata = filters.PushClientInitialMetadata( + Arena::MakePooled(arena.get())); + EXPECT_THAT(push_client_initial_metadata(), IsPending()); + auto pull_client_initial_metadata = filters.PullClientInitialMetadata(); + // Pull client initial metadata, expect a wakeup + EXPECT_CALL(activity, WakeupRequested()); + EXPECT_THAT(pull_client_initial_metadata(), IsReady()); + Mock::VerifyAndClearExpectations(&activity); + // Push should be done + EXPECT_THAT(push_client_initial_metadata(), IsReady(Success{})); + // Push client to server message + auto push_client_to_server_message = filters.PushClientToServerMessage( + Arena::MakePooled(SliceBuffer(), 0)); + EXPECT_THAT(push_client_to_server_message(), IsPending()); + auto pull_client_to_server_message = filters.PullClientToServerMessage(); + // Pull client to server message, expect a wakeup + EXPECT_CALL(activity, WakeupRequested()); + EXPECT_THAT(pull_client_to_server_message(), IsReady()); + Mock::VerifyAndClearExpectations(&activity); + // Push should be done + EXPECT_THAT(push_client_to_server_message(), IsReady(Success{})); + // Push server initial metadata + auto push_server_initial_metadata = filters.PushServerInitialMetadata( + Arena::MakePooled(arena.get())); + EXPECT_THAT(push_server_initial_metadata(), IsPending()); + auto pull_server_initial_metadata = filters.PullServerInitialMetadata(); + // Pull server initial metadata, expect a wakeup + EXPECT_CALL(activity, WakeupRequested()); + EXPECT_THAT(pull_server_initial_metadata(), IsReady()); + Mock::VerifyAndClearExpectations(&activity); + // Push should be done + EXPECT_THAT(push_server_initial_metadata(), IsReady(Success{})); + // Push server to client message + auto push_server_to_client_message = filters.PushServerToClientMessage( + Arena::MakePooled(SliceBuffer(), 0)); + EXPECT_THAT(push_server_to_client_message(), IsPending()); + auto pull_server_to_client_message = filters.PullServerToClientMessage(); + // Pull server to client message, expect a wakeup + EXPECT_CALL(activity, WakeupRequested()); + EXPECT_THAT(pull_server_to_client_message(), IsReady()); + Mock::VerifyAndClearExpectations(&activity); + // Push should be done + EXPECT_THAT(push_server_to_client_message(), IsReady(Success{})); + // Push server trailing metadata + filters.PushServerTrailingMetadata( + Arena::MakePooled(arena.get())); + // Pull server trailing metadata + auto pull_server_trailing_metadata = filters.PullServerTrailingMetadata(); + // Should be done + EXPECT_THAT(pull_server_trailing_metadata(), IsReady()); + filters.Finalize(nullptr); + EXPECT_THAT(steps, + ::testing::ElementsAre( + "f1:OnClientInitialMetadata", "f2:OnClientInitialMetadata", + "f1:OnClientToServerMessage", "f2:OnClientToServerMessage", + "f2:OnServerInitialMetadata", "f1:OnServerInitialMetadata", + "f2:OnServerToClientMessage", "f1:OnServerToClientMessage", + "f2:OnServerTrailingMetadata", "f1:OnServerTrailingMetadata", + "f1:OnFinalize", "f2:OnFinalize")); +} + } // namespace grpc_core int main(int argc, char** argv) {