Skip to content

Commit

Permalink
x
Browse files Browse the repository at this point in the history
  • Loading branch information
ctiller committed Jan 11, 2024
1 parent 21c18fb commit 060e945
Show file tree
Hide file tree
Showing 4 changed files with 341 additions and 55 deletions.
33 changes: 31 additions & 2 deletions src/core/lib/promise/status_flag.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,18 @@

namespace grpc_core {

struct Failure {};
struct Success {};
struct Failure {
template <typename Sink>
friend void AbslStringify(Sink& sink, Failure flag) {
sink.Append("failed");
}
};
struct Success {
template <typename Sink>
friend void AbslStringify(Sink& sink, Success flag) {
sink.Append("ok");
}
};

inline bool IsStatusOk(Failure) { return false; }
inline bool IsStatusOk(Success) { return true; }
Expand Down Expand Up @@ -68,10 +78,29 @@ class StatusFlag {

bool operator==(StatusFlag other) const { return value_ == other.value_; }

template <typename Sink>
friend void AbslStringify(Sink& sink, StatusFlag flag) {
if (flag.ok()) {
sink.Append("ok");
} else {
sink.Append("failed");
}
}

private:
bool value_;
};

inline bool operator==(StatusFlag flag, Failure) { return !flag.ok(); }
inline bool operator==(Failure, StatusFlag flag) { return !flag.ok(); }
inline bool operator==(StatusFlag flag, Success) { return flag.ok(); }
inline bool operator==(Success, StatusFlag flag) { return flag.ok(); }

inline bool operator!=(StatusFlag flag, Failure) { return flag.ok(); }
inline bool operator!=(Failure, StatusFlag flag) { return flag.ok(); }
inline bool operator!=(StatusFlag flag, Success) { return !flag.ok(); }
inline bool operator!=(Success, StatusFlag flag) { return !flag.ok(); }

inline bool IsStatusOk(const StatusFlag& flag) { return flag.ok(); }

template <>
Expand Down
48 changes: 40 additions & 8 deletions src/core/lib/transport/call_filters.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,12 @@ CallFilters::CallFilters() : stack_(nullptr), call_data_(nullptr) {}
CallFilters::CallFilters(RefCountedPtr<Stack> stack)
: stack_(std::move(stack)),
call_data_(gpr_malloc_aligned(stack->data_.call_data_size,
stack->data_.call_data_alignment)) {}
stack->data_.call_data_alignment)) {
client_initial_metadata_state_.Start();
client_to_server_message_state_.Start();
server_initial_metadata_state_.Start();
server_to_client_message_state_.Start();
}

CallFilters::~CallFilters() {
if (call_data_ != nullptr) gpr_free_aligned(call_data_);
Expand All @@ -179,7 +184,10 @@ void CallFilters::SetStack(RefCountedPtr<Stack> stack) {
stack_ = std::move(stack);
call_data_ = gpr_malloc_aligned(stack->data_.call_data_size,
stack->data_.call_data_alignment);
stack_waiter_.Wake();
client_initial_metadata_state_.Start();
client_to_server_message_state_.Start();
server_initial_metadata_state_.Start();
server_to_client_message_state_.Start();
}

void CallFilters::Finalize(const grpc_call_final_info* final_info) {
Expand Down Expand Up @@ -217,7 +225,13 @@ RefCountedPtr<CallFilters::Stack> CallFilters::StackBuilder::Build() {
///////////////////////////////////////////////////////////////////////////////
// CallFilters::PipeState

void CallFilters::PipeState::BeginPush() {
void filters_detail::PipeState::Start() {
GPR_DEBUG_ASSERT(!started_);
started_ = true;
wait_recv_.Wake();
}

void filters_detail::PipeState::BeginPush() {
switch (state_) {
case ValueState::kIdle:
state_ = ValueState::kQueued;
Expand All @@ -237,7 +251,7 @@ void CallFilters::PipeState::BeginPush() {
}
}

void CallFilters::PipeState::AbandonPush() {
void filters_detail::PipeState::DropPush() {
switch (state_) {
case ValueState::kQueued:
case ValueState::kReady:
Expand All @@ -253,7 +267,23 @@ void CallFilters::PipeState::AbandonPush() {
}
}

Poll<StatusFlag> CallFilters::PipeState::PollPush() {
void filters_detail::PipeState::DropPull() {
switch (state_) {
case ValueState::kQueued:
case ValueState::kReady:
case ValueState::kProcessing:
case ValueState::kWaiting:
state_ = ValueState::kError;
wait_send_.Wake();
break;
case ValueState::kIdle:
case ValueState::kClosed:
case ValueState::kError:
break;
}
}

Poll<StatusFlag> filters_detail::PipeState::PollPush() {
switch (state_) {
case ValueState::kIdle:
// Read completed and new read started => we see waiting here
Expand All @@ -269,25 +299,27 @@ Poll<StatusFlag> CallFilters::PipeState::PollPush() {
}
}

Poll<StatusFlag> CallFilters::PipeState::PollPullValue() {
Poll<StatusFlag> filters_detail::PipeState::PollPull() {
switch (state_) {
case ValueState::kWaiting:
return wait_recv_.pending();
case ValueState::kIdle:
state_ = ValueState::kWaiting;
return wait_recv_.pending();
case ValueState::kReady:
case ValueState::kQueued:
if (!started_) return wait_recv_.pending();
state_ = ValueState::kProcessing;
return Success{};
case ValueState::kProcessing:
case ValueState::kWaiting:
Crash("Only one pull allowed to be outstanding");
case ValueState::kClosed:
case ValueState::kError:
return Failure{};
}
}

void CallFilters::PipeState::AckPullValue() {
void filters_detail::PipeState::AckPull() {
switch (state_) {
case ValueState::kProcessing:
state_ = ValueState::kIdle;
Expand Down
130 changes: 85 additions & 45 deletions src/core/lib/transport/call_filters.h
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,7 @@ class PipeTransformer {
const FallibleOperator<T>* end_ops_;
};

// Per PipeTransformer, but for infallible operation sequences.
template <typename T>
class InfalliblePipeTransformer {
public:
Expand Down Expand Up @@ -836,6 +837,62 @@ class InfalliblePipeTransformer {
const InfallibleOperator<T>* end_ops_;
};

// The current state of a pipe.
// CallFilters expose a set of pipe like objects for client & server initial
// metadata and for messages.
// This class tracks the state of one of those pipes.
// Size matters here: this state is kept for the lifetime of a call, and we keep
// multiple of them.
// This class encapsulates the untyped work of the state machine; there are
// typed wrappers around this class as private members of CallFilters that
// augment it to provide all the functionality that we must.
class PipeState {
public:
// Start the pipe: allows pulls to proceed
void Start();
// A push operation is beginning
void BeginPush();
// A previously started push operation has completed
void DropPush();
// Poll for push completion: occurs after the corresponding Pull()
Poll<StatusFlag> PollPush();
Poll<StatusFlag> PollPull();
// A pulled value has been consumed: we can unblock the push
void AckPull();
// A previously started pull operation has completed
void DropPull();

bool holds_error() const { return state_ == ValueState::kError; }

private:
enum class ValueState : uint8_t {
// Nothing sending nor receiving
kIdle,
// Sent, but not yet received
kQueued,
// Trying to receive, but not yet sent
kWaiting,
// Ready to start processing, but not yet started
// (we have the value to send through the pipe, the reader is waiting,
// but it's not yet been polled)
kReady,
// Processing through filters
kProcessing,
// Closed sending
kClosed,
// Closed due to failure
kError
};
// Waiter for a promise blocked waiting to send.
IntraActivityWaiter wait_send_;
// Waiter for a promise blocked waiting to receive.
IntraActivityWaiter wait_recv_;
// Current state.
ValueState state_ = ValueState::kIdle;
// Has the pipe been started?
bool started_ = false;
};

} // namespace filters_detail

// Execution environment for a stack of filters.
Expand All @@ -844,6 +901,11 @@ class CallFilters {
public:
class StackBuilder;

// A stack is an opaque, immutable type that contains the data necessary to
// execute a call through a given set of filters.
// It's reference counted so that it can be shared between many calls.
// It contains pointers to the individual filters, yet it does not own those
// pointers: it's expected that some other object will track that ownership.
class Stack : public RefCounted<Stack> {
private:
friend class CallFilters;
Expand All @@ -852,6 +914,8 @@ class CallFilters {
const filters_detail::StackData data_;
};

// Build stacks... repeatedly call Add with each filter that contributes to
// the stack, then call Build() to generate a ref counted Stack object.
class StackBuilder {
public:
template <typename FilterType>
Expand Down Expand Up @@ -905,40 +969,8 @@ class CallFilters {
void Finalize(const grpc_call_final_info* final_info);

private:
class PipeState {
public:
void BeginPush();
void AbandonPush();
Poll<StatusFlag> PollPush();
Poll<StatusFlag> PollPullValue();
void AckPullValue();

private:
enum class ValueState : uint8_t {
// Nothing sending nor receiving
kIdle,
// Sent, but not yet received
kQueued,
// Trying to receive, but not yet sent
kWaiting,
// Ready to start processing, but not yet started
// (we have the value to send through the pipe, the reader is waiting,
// but it's not yet been polled)
kReady,
// Processing through filters
kProcessing,
// Closed sending
kClosed,
// Closed due to failure
kError
};
IntraActivityWaiter wait_send_;
IntraActivityWaiter wait_recv_;
ValueState state_ = ValueState::kIdle;
};

template <PipeState(CallFilters::*state_ptr), void*(CallFilters::*push_ptr),
typename T,
template <filters_detail::PipeState(CallFilters::*state_ptr),
void*(CallFilters::*push_ptr), typename T,
filters_detail::Layout<filters_detail::FallibleOperator<T>>(
filters_detail::StackData::*layout_ptr)>
class PipePromise {
Expand All @@ -952,7 +984,7 @@ class CallFilters {
}
~Push() {
if (filters_ != nullptr) {
state().AbandonPush();
state().DropPush();
push_slot() = nullptr;
}
}
Expand All @@ -975,7 +1007,7 @@ class CallFilters {
T TakeValue() { return std::move(value_); }

private:
PipeState& state() { return filters_->*state_ptr; }
filters_detail::PipeState& state() { return filters_->*state_ptr; }
void*& push_slot() { return filters_->*push_ptr; }

CallFilters* filters_;
Expand All @@ -985,11 +1017,20 @@ class CallFilters {
class Pull {
public:
explicit Pull(CallFilters* filters) : filters_(filters) {}
~Pull() {
if (filters_ != nullptr) {
state().DropPull();
}
}

Pull(const Pull&) = delete;
Pull& operator=(const Pull&) = delete;
Pull(Pull&& other)
: filters_(std::exchange(other.filters_, nullptr)),
transformer_(std::move(other.transformer_)) {}
Pull& operator=(Pull&&) = delete;

Poll<ValueOrFailure<T>> operator()() {
if (filters_->stack_ == nullptr) {
return filters_->stack_waiter_.pending();
}
if (transformer_.IsRunning()) {
return FinishPipeTransformer(transformer_.Step(filters_->call_data_));
}
Expand All @@ -1005,7 +1046,7 @@ class CallFilters {
}

private:
PipeState& state() { return filters_->*state_ptr; }
filters_detail::PipeState& state() { return filters_->*state_ptr; }
Push* push() { return static_cast<Push*>(filters_->*push_ptr); }

Poll<ValueOrFailure<T>> FinishPipeTransformer(
Expand All @@ -1030,12 +1071,11 @@ class CallFilters {

RefCountedPtr<Stack> stack_;

PipeState client_initial_metadata_state_;
PipeState server_initial_metadata_state_;
PipeState client_to_server_message_state_;
PipeState server_to_client_message_state_;
filters_detail::PipeState client_initial_metadata_state_;
filters_detail::PipeState server_initial_metadata_state_;
filters_detail::PipeState client_to_server_message_state_;
filters_detail::PipeState server_to_client_message_state_;
IntraActivityWaiter server_trailing_metadata_waiter_;
IntraActivityWaiter stack_waiter_;

void* call_data_;
void* client_initial_metadata_ = nullptr;
Expand Down
Loading

0 comments on commit 060e945

Please sign in to comment.