Skip to content

Commit

Permalink
[promises] Implement CallInitiator, CallHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
ctiller committed Dec 1, 2023
1 parent a954afa commit d312f49
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 19 deletions.
1 change: 1 addition & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1605,6 +1605,7 @@ grpc_cc_library(
"//src/core:socket_mutator",
"//src/core:stats_data",
"//src/core:status_helper",
"//src/core:status_flag",
"//src/core:strerror",
"//src/core:thread_quota",
"//src/core:time",
Expand Down
2 changes: 1 addition & 1 deletion src/core/lib/channel/connected_channel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ grpc_channel_filter MakeConnectedFilter() {

ArenaPromise<ServerMetadataHandle> MakeTransportCallPromise(
Transport* transport, CallArgs call_args, NextPromiseFactory) {
return transport->client_transport()->MakeCallPromise(std::move(call_args));
Crash("unimplemented");
}

const grpc_channel_filter kPromiseBasedTransportFilter =
Expand Down
177 changes: 159 additions & 18 deletions src/core/lib/transport/transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
#include "src/core/lib/promise/latch.h"
#include "src/core/lib/promise/party.h"
#include "src/core/lib/promise/pipe.h"
#include "src/core/lib/promise/race.h"
#include "src/core/lib/promise/status_flag.h"
#include "src/core/lib/resource_quota/arena.h"
#include "src/core/lib/slice/slice_buffer.h"
#include "src/core/lib/transport/connectivity_state.h"
Expand Down Expand Up @@ -226,9 +228,6 @@ struct CallArgs {
PipeSender<MessageHandle>* server_to_client_messages;
};

using NextPromiseFactory =
std::function<ArenaPromise<ServerMetadataHandle>(CallArgs)>;

// TODO(ctiller): eventually drop this when we don't need to reference into
// legacy promise calls anymore
class CallSpineInterface {
Expand All @@ -239,9 +238,23 @@ class CallSpineInterface {
virtual Pipe<MessageHandle>& client_to_server_messages() = 0;
virtual Pipe<MessageHandle>& server_to_client_messages() = 0;
virtual Pipe<ServerMetadataHandle>& server_trailing_metadata() = 0;
GRPC_MUST_USE_RESULT virtual absl::nullopt_t Cancel(
ServerMetadataHandle metadata) = 0;
virtual Latch<ServerMetadataHandle>& cancel_latch() = 0;
virtual Party& party() = 0;
virtual void IncrementRefCount() = 0;
virtual void Unref() = 0;

GRPC_MUST_USE_RESULT absl::nullopt_t Cancel(ServerMetadataHandle metadata) {
GPR_DEBUG_ASSERT(Activity::current() == &party());
auto& c = cancel_latch();
if (c.is_set()) return absl::nullopt;
c.Set(std::move(metadata));
return absl::nullopt;
}

auto WaitForCancel() {
GPR_DEBUG_ASSERT(Activity::current() == &party());
return cancel_latch().Wait();
}

// Wrap a promise so that if it returns failure it automatically cancels
// the rest of the call.
Expand Down Expand Up @@ -303,13 +316,10 @@ class CallSpine final : public CallSpineInterface {
Pipe<ServerMetadataHandle>& server_trailing_metadata() {
return server_trailing_metadata_;
}
absl::nullopt_t Cancel(ServerMetadataHandle metadata) {
GPR_DEBUG_ASSERT(Activity::current() == &party());
if (cancel_latch_.is_set()) return absl::nullopt;
cancel_latch_.Set(std::move(metadata));
return absl::nullopt;
}
Latch<ServerMetadataHandle>& cancel_latch() { return cancel_latch_; }
Party& party() { Crash("unimplemented"); }
void IncrementRefCount() { Crash("unimplemented"); }
void Unref() { Crash("unimplemented"); }

private:
// Initial metadata from client to server
Expand All @@ -326,6 +336,135 @@ class CallSpine final : public CallSpineInterface {
Latch<ServerMetadataHandle> cancel_latch_;
};

class CallInitiator {
public:
explicit CallInitiator(RefCountedPtr<CallSpine> spine)
: spine_(std::move(spine)) {}

auto PushClientInitialMetadata(ClientMetadataHandle md) {
GPR_DEBUG_ASSERT(Activity::current() == &spine_->party());
return Map(spine_->client_initial_metadata().sender.Push(std::move(md)),
[](bool ok) { return StatusFlag(ok); });
}

auto PullServerInitialMetadata() {
GPR_DEBUG_ASSERT(Activity::current() == &spine_->party());
return Map(spine_->server_initial_metadata().receiver.Next(),
[](NextResult<ClientMetadataHandle> md)
-> ValueOrFailure<ClientMetadataHandle> {
if (!md.has_value()) return Failure{};
return std::move(*md);
});
}

auto PullServerTrailingMetadata() {
GPR_DEBUG_ASSERT(Activity::current() == &spine_->party());
return Race(spine_->WaitForCancel(),
Map(spine_->server_trailing_metadata().receiver.Next(),
[spine = spine_](NextResult<ServerMetadataHandle> md)
-> ServerMetadataHandle {
GPR_ASSERT(md.has_value());
return std::move(*md);
}));
}

auto PullMessage() {
GPR_DEBUG_ASSERT(Activity::current() == &spine_->party());
return spine_->server_to_client_messages().receiver.Next();
}

auto PushMessage(MessageHandle message) {
GPR_DEBUG_ASSERT(Activity::current() == &spine_->party());
return spine_->client_to_server_messages().sender.Push(std::move(message));
}

template <typename Promise>
auto CancelIfFails(Promise promise) {
return spine_->CancelIfFails(std::move(promise));
}

template <typename PromiseFactory>
void SpawnGuarded(absl::string_view name, PromiseFactory promise_factory) {
spine_->SpawnGuarded(name, std::move(promise_factory));
}

template <typename PromiseFactory>
void SpawnInfallible(absl::string_view name, PromiseFactory promise_factory) {
spine_->SpawnInfallible(name, std::move(promise_factory));
}

private:
const RefCountedPtr<CallSpine> spine_;
};

class CallHandler {
public:
explicit CallHandler(RefCountedPtr<CallSpine> spine)
: spine_(std::move(spine)) {}

auto PullClientInitialMetadata() {
GPR_DEBUG_ASSERT(Activity::current() == &spine_->party());
return Map(spine_->client_initial_metadata().receiver.Next(),
[](NextResult<ClientMetadataHandle> md)
-> ValueOrFailure<ClientMetadataHandle> {
if (!md.has_value()) return Failure{};
return std::move(*md);
});
}

auto PushServerInitialMetadata(ClientMetadataHandle md) {
GPR_DEBUG_ASSERT(Activity::current() == &spine_->party());
return Map(spine_->server_initial_metadata().sender.Push(std::move(md)),
[](bool ok) { return StatusFlag(ok); });
}

auto PushServerTrailingMetadata(ClientMetadataHandle md) {
GPR_DEBUG_ASSERT(Activity::current() == &spine_->party());
return Map(spine_->server_initial_metadata().sender.Push(std::move(md)),
[](bool ok) { return StatusFlag(ok); });
}

auto PullMessage() {
GPR_DEBUG_ASSERT(Activity::current() == &spine_->party());
return spine_->client_to_server_messages().receiver.Next();
}

auto PushMessage(MessageHandle message) {
GPR_DEBUG_ASSERT(Activity::current() == &spine_->party());
return spine_->server_to_client_messages().sender.Push(std::move(message));
}

template <typename Promise>
auto CancelIfFails(Promise promise) {
return spine_->CancelIfFails(std::move(promise));
}

template <typename PromiseFactory>
void SpawnGuarded(absl::string_view name, PromiseFactory promise_factory) {
spine_->SpawnGuarded(name, std::move(promise_factory));
}

template <typename PromiseFactory>
void SpawnInfallible(absl::string_view name, PromiseFactory promise_factory) {
spine_->SpawnInfallible(name, std::move(promise_factory));
}

private:
const RefCountedPtr<CallSpine> spine_;
};

template <typename CallHalf>
auto OutgoingMessages(CallHalf& h) {
struct Wrapper {
CallHalf& h;
auto Next() { return h.PullMessage(); }
};
return Wrapper{h};
}

using NextPromiseFactory =
std::function<ArenaPromise<ServerMetadataHandle>(CallArgs)>;

} // namespace grpc_core

// forward declarations
Expand Down Expand Up @@ -731,20 +870,22 @@ class FilterStackTransport {

class ClientTransport {
public:
// Create a promise to execute one client call.
virtual ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs call_args) = 0;
virtual void StartCall(CallHandler call_handler) = 0;

protected:
~ClientTransport() = default;
};

class ServerTransport {
public:
// Register the factory function for the filter stack part of a call
// promise.
void SetCallPromiseFactory(
absl::AnyInvocable<ArenaPromise<ServerMetadataHandle>(CallArgs) const>);
// AcceptFunction takes initial metadata for a new call and returns a
// CallInitiator object for it, for the transport to use to communicate with
// the CallHandler object passed to the application.
using AcceptFunction =
absl::AnyInvocable<absl::StatusOr<CallInitiator>(ClientMetadata&) const>;

// Called once slightly after transport setup to register the accept function.
virtual void SetAcceptFunction(AcceptFunction accept_function) = 0;

protected:
~ServerTransport() = default;
Expand Down

0 comments on commit d312f49

Please sign in to comment.