Skip to content

Commit

Permalink
x
Browse files Browse the repository at this point in the history
  • Loading branch information
ctiller committed Jan 11, 2024
1 parent 060e945 commit 484541a
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 21 deletions.
8 changes: 4 additions & 4 deletions src/core/lib/transport/call_filters.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ CallFilters::CallFilters() : stack_(nullptr), call_data_(nullptr) {}

CallFilters::CallFilters(RefCountedPtr<Stack> stack)
: stack_(std::move(stack)),
call_data_(gpr_malloc_aligned(stack->data_.call_data_size,
stack->data_.call_data_alignment)) {
call_data_(gpr_malloc_aligned(stack_->data_.call_data_size,
stack_->data_.call_data_alignment)) {
client_initial_metadata_state_.Start();
client_to_server_message_state_.Start();
server_initial_metadata_state_.Start();
Expand All @@ -182,8 +182,8 @@ CallFilters::~CallFilters() {
void CallFilters::SetStack(RefCountedPtr<Stack> stack) {
if (call_data_ != nullptr) gpr_free_aligned(call_data_);
stack_ = std::move(stack);
call_data_ = gpr_malloc_aligned(stack->data_.call_data_size,
stack->data_.call_data_alignment);
call_data_ = gpr_malloc_aligned(stack_->data_.call_data_size,
stack_->data_.call_data_alignment);
client_initial_metadata_state_.Start();
client_to_server_message_state_.Start();
server_initial_metadata_state_.Start();
Expand Down
33 changes: 16 additions & 17 deletions src/core/lib/transport/call_filters.h
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ struct StackData {
template <typename FilterType>
void AddFinalizer(FilterType* channel_data, size_t call_offset,
void (FilterType::Call::*p)(const grpc_call_final_info*)) {
GPR_DEBUG_ASSERT(p == &FilterType::OnFinalize);
GPR_DEBUG_ASSERT(p == &FilterType::Call::OnFinalize);
finalizers.push_back(Finalizer{
channel_data,
call_offset,
Expand All @@ -711,7 +711,7 @@ struct StackData {
void AddFinalizer(FilterType* channel_data, size_t call_offset,
void (FilterType::Call::*p)(const grpc_call_final_info*,
FilterType*)) {
GPR_DEBUG_ASSERT(p == &FilterType::OnFinalize);
GPR_DEBUG_ASSERT(p == &FilterType::Call::OnFinalize);
finalizers.push_back(Finalizer{
channel_data,
call_offset,
Expand Down Expand Up @@ -921,17 +921,12 @@ class CallFilters {
template <typename FilterType>
void Add(FilterType* filter) {
const size_t call_offset = data_.AddFilter<FilterType>(filter);
data_.AddClientInitialMetadataOp(filter, call_offset,
&FilterType::OnClientInitialMetadata);
data_.AddServerInitialMetadataOp(filter, call_offset,
&FilterType::OnServerInitialMetadata);
data_.AddClientToServerMessageOp(filter, call_offset,
&FilterType::OnClientToServerMessage);
data_.AddServerToClientMessageOp(filter, call_offset,
&FilterType::OnServerToClientMessage);
data_.AddServerTrailingMetadataOp(filter, call_offset,
&FilterType::OnServerTrailingMetadata);
data_.AddFinalizer(filter, call_offset, &FilterType::OnFinalize);
data_.AddClientInitialMetadataOp(filter, call_offset);
data_.AddServerInitialMetadataOp(filter, call_offset);
data_.AddClientToServerMessageOp(filter, call_offset);
data_.AddServerToClientMessageOp(filter, call_offset);
data_.AddServerTrailingMetadataOp(filter, call_offset);
data_.AddFinalizer(filter, call_offset, &FilterType::Call::OnFinalize);
}

RefCountedPtr<Stack> Build();
Expand Down Expand Up @@ -1034,27 +1029,31 @@ class CallFilters {
if (transformer_.IsRunning()) {
return FinishPipeTransformer(transformer_.Step(filters_->call_data_));
}
auto p = state().PollPullValue();
auto p = state().PollPull();
auto* r = p.value_if_ready();
if (r == nullptr) return Pending{};
if (!r->ok()) {
filters_->CancelDueToFailedPipeOperation();
return Failure{};
}
return FinishPipeTransformer(
transformer_.Start(push()->TakeValue(), filters_->call_data_));
return FinishPipeTransformer(transformer_.Start(
layout(), push()->TakeValue(), filters_->call_data_));
}

private:
filters_detail::PipeState& state() { return filters_->*state_ptr; }
Push* push() { return static_cast<Push*>(filters_->*push_ptr); }
const filters_detail::Layout<filters_detail::FallibleOperator<T>>*
layout() {
return &(filters_->stack_->data_.*layout_ptr);
}

Poll<ValueOrFailure<T>> FinishPipeTransformer(
Poll<filters_detail::ResultOr<T>> p) {
auto* r = p.value_if_ready();
if (r == nullptr) return Pending{};
GPR_DEBUG_ASSERT(!transformer_.IsRunning());
state().AckPullValue();
state().AckPull();
if (r->ok != nullptr) return std::move(r->ok);
filters_->PushServerTrailingMetadata(std::move(r->error));
return Failure{};
Expand Down
130 changes: 130 additions & 0 deletions test/core/transport/call_filters_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ MATCHER(IsPending, "") {
return true;
}

MATCHER(IsReady, "") {
if (arg.pending()) {
*result_listener << "is pending";
return false;
}
return true;
}

MATCHER_P(IsReady, value, "") {
if (arg.pending()) {
*result_listener << "is pending";
Expand Down Expand Up @@ -1328,6 +1336,128 @@ TEST(PipeStateTest, DropProcessing) {

} // namespace filters_detail

///////////////////////////////////////////////////////////////////////////////
// CallFilters

TEST(CallFiltersTest, CanBuildStack) {
struct Filter {
struct Call {
void OnClientInitialMetadata(ClientMetadata& md) {}
void OnServerInitialMetadata(ServerMetadata& md) {}
void OnClientToServerMessage(Message& message) {}
void OnServerToClientMessage(Message& message) {}
void OnServerTrailingMetadata(ServerMetadata& md) {}
void OnFinalize(const grpc_call_final_info*) {}
};
};
CallFilters::StackBuilder builder;
Filter f;
builder.Add(&f);
auto stack = builder.Build();
EXPECT_NE(stack, nullptr);
}

TEST(CallFiltersTest, UnaryCall) {
struct Filter {
struct Call {
void OnClientInitialMetadata(ClientMetadata& md, Filter* f) {
f->steps.push_back(absl::StrCat(f->label, ":OnClientInitialMetadata"));
}
void OnServerInitialMetadata(ServerMetadata& md, Filter* f) {
f->steps.push_back(absl::StrCat(f->label, ":OnServerInitialMetadata"));
}
void OnClientToServerMessage(Message& message, Filter* f) {
f->steps.push_back(absl::StrCat(f->label, ":OnClientToServerMessage"));
}
void OnServerToClientMessage(Message& message, Filter* f) {
f->steps.push_back(absl::StrCat(f->label, ":OnServerToClientMessage"));
}
void OnServerTrailingMetadata(ServerMetadata& md, Filter* f) {
f->steps.push_back(absl::StrCat(f->label, ":OnServerTrailingMetadata"));
}
void OnFinalize(const grpc_call_final_info*, Filter* f) {
f->steps.push_back(absl::StrCat(f->label, ":OnFinalize"));
}
};

const std::string label;
std::vector<std::string>& steps;
};
std::vector<std::string> steps;
Filter f1{"f1", steps};
Filter f2{"f2", steps};
CallFilters::StackBuilder builder;
builder.Add(&f1);
builder.Add(&f2);
CallFilters filters(builder.Build());
auto memory_allocator =
MakeMemoryQuota("test-quota")->CreateMemoryAllocator("foo");
auto arena = MakeScopedArena(1024, &memory_allocator);
promise_detail::Context<Arena> ctx(arena.get());
StrictMock<MockActivity> activity;
activity.Activate();
// Push client initial metadata
auto push_client_initial_metadata = filters.PushClientInitialMetadata(
Arena::MakePooled<ClientMetadata>(arena.get()));
EXPECT_THAT(push_client_initial_metadata(), IsPending());
auto pull_client_initial_metadata = filters.PullClientInitialMetadata();
// Pull client initial metadata, expect a wakeup
EXPECT_CALL(activity, WakeupRequested());
EXPECT_THAT(pull_client_initial_metadata(), IsReady());
Mock::VerifyAndClearExpectations(&activity);
// Push should be done
EXPECT_THAT(push_client_initial_metadata(), IsReady(Success{}));
// Push client to server message
auto push_client_to_server_message = filters.PushClientToServerMessage(
Arena::MakePooled<Message>(SliceBuffer(), 0));
EXPECT_THAT(push_client_to_server_message(), IsPending());
auto pull_client_to_server_message = filters.PullClientToServerMessage();
// Pull client to server message, expect a wakeup
EXPECT_CALL(activity, WakeupRequested());
EXPECT_THAT(pull_client_to_server_message(), IsReady());
Mock::VerifyAndClearExpectations(&activity);
// Push should be done
EXPECT_THAT(push_client_to_server_message(), IsReady(Success{}));
// Push server initial metadata
auto push_server_initial_metadata = filters.PushServerInitialMetadata(
Arena::MakePooled<ServerMetadata>(arena.get()));
EXPECT_THAT(push_server_initial_metadata(), IsPending());
auto pull_server_initial_metadata = filters.PullServerInitialMetadata();
// Pull server initial metadata, expect a wakeup
EXPECT_CALL(activity, WakeupRequested());
EXPECT_THAT(pull_server_initial_metadata(), IsReady());
Mock::VerifyAndClearExpectations(&activity);
// Push should be done
EXPECT_THAT(push_server_initial_metadata(), IsReady(Success{}));
// Push server to client message
auto push_server_to_client_message = filters.PushServerToClientMessage(
Arena::MakePooled<Message>(SliceBuffer(), 0));
EXPECT_THAT(push_server_to_client_message(), IsPending());
auto pull_server_to_client_message = filters.PullServerToClientMessage();
// Pull server to client message, expect a wakeup
EXPECT_CALL(activity, WakeupRequested());
EXPECT_THAT(pull_server_to_client_message(), IsReady());
Mock::VerifyAndClearExpectations(&activity);
// Push should be done
EXPECT_THAT(push_server_to_client_message(), IsReady(Success{}));
// Push server trailing metadata
filters.PushServerTrailingMetadata(
Arena::MakePooled<ServerMetadata>(arena.get()));
// Pull server trailing metadata
auto pull_server_trailing_metadata = filters.PullServerTrailingMetadata();
// Should be done
EXPECT_THAT(pull_server_trailing_metadata(), IsReady());
filters.Finalize(nullptr);
EXPECT_THAT(steps,
::testing::ElementsAre(
"f1:OnClientInitialMetadata", "f2:OnClientInitialMetadata",
"f1:OnClientToServerMessage", "f2:OnClientToServerMessage",
"f2:OnServerInitialMetadata", "f1:OnServerInitialMetadata",
"f2:OnServerToClientMessage", "f1:OnServerToClientMessage",
"f2:OnServerTrailingMetadata", "f1:OnServerTrailingMetadata",
"f1:OnFinalize", "f2:OnFinalize"));
}

} // namespace grpc_core

int main(int argc, char** argv) {
Expand Down

0 comments on commit 484541a

Please sign in to comment.