diff --git a/cc/client/session_client.cc b/cc/client/session_client.cc index 192bf1f406..b378ed5692 100644 --- a/cc/client/session_client.cc +++ b/cc/client/session_client.cc @@ -29,7 +29,7 @@ namespace oak::client { absl::StatusOr> OakSessionClient::NewChannel(std::unique_ptr transport) { absl::StatusOr> session = - session::ClientSession::Create(config_); + session::ClientSession::Create(config_provider_()); while (!(*session)->IsOpen()) { absl::StatusOr> init_request = diff --git a/cc/client/session_client.h b/cc/client/session_client.h index 0334ddba1d..55332c18f0 100644 --- a/cc/client/session_client.h +++ b/cc/client/session_client.h @@ -37,16 +37,29 @@ class OakSessionClient { session::ClientSession>; // A valid `SessionConfig` can be obtained using - // oak::session::SessionConfigBuilder. - OakSessionClient(session::SessionConfig* config) : config_(config) {} + // oak::session::SessionConfigBuilder. Each session needs its own unique + // SessionConfig instance, so a function to create a new SessionConfig should + // be provided here. + OakSessionClient( + absl::AnyInvocable config_provider) + : config_provider_(std::move(config_provider)) {} - // Use a default configuration, Unattested + NoiseNN - ABSL_DEPRECATED("Use the config-providing variant.") + // Use a default configuration provider, Unattested + NoiseNN + ABSL_DEPRECATED("Use the config-provider-providing variant.") OakSessionClient() - : OakSessionClient( - session::SessionConfigBuilder(session::AttestationType::kUnattested, - session::HandshakeType::kNoiseNN) - .Build()) {} + : OakSessionClient([] { + return session::SessionConfigBuilder( + session::AttestationType::kUnattested, + session::HandshakeType::kNoiseNN) + .Build(); + }) {} + + // Keeping this around briefly until we transition existing clients. + ABSL_DEPRECATED( + "This constructor will lead to UB. Use the config-provider-providing " + "variant.") + OakSessionClient(session::SessionConfig* config) + : OakSessionClient([config] { return config; }) {} // Create a new OakClientChannel instance with the provided session and // transport. @@ -61,7 +74,7 @@ class OakSessionClient { std::unique_ptr transport); private: - session::SessionConfig* config_; + absl::AnyInvocable config_provider_; }; } // namespace oak::client diff --git a/cc/client/session_client_test.cc b/cc/client/session_client_test.cc index b2d5e173a4..a88497c690 100644 --- a/cc/client/session_client_test.cc +++ b/cc/client/session_client_test.cc @@ -65,9 +65,25 @@ session::SessionConfig* TestSessionConfig() { TEST(OakSessionClientTest, CreateSuccessFullyHandshakes) { auto server_session = session::ServerSession::Create(TestSessionConfig()); ASSERT_THAT(server_session, IsOk()); - auto _ = OakSessionClient(TestSessionConfig()) - .NewChannel( - std::make_unique(std::move(*server_session))); + auto channel = OakSessionClient(TestSessionConfig) + .NewChannel(std::make_unique( + std::move(*server_session))); + EXPECT_THAT(channel, IsOk()); +} + +TEST(OakSessionClientTest, CreateSecondClientSuccessFullyHandshakes) { + auto server_session = session::ServerSession::Create(TestSessionConfig()); + ASSERT_THAT(server_session, IsOk()); + auto client = OakSessionClient(TestSessionConfig); + + auto channel = client.NewChannel( + std::make_unique(std::move(*server_session))); + EXPECT_THAT(channel, IsOk()); + + auto server_session2 = session::ServerSession::Create(TestSessionConfig()); + auto channel2 = client.NewChannel( + std::make_unique(std::move(*server_session2))); + EXPECT_THAT(channel2, IsOk()); } TEST(OakSessionClientTest, CreatedSessionCanSend) { @@ -75,7 +91,7 @@ TEST(OakSessionClientTest, CreatedSessionCanSend) { // Hold a pointer for testing behavior below. session::ServerSession* server_session_ptr = server_session->get(); ASSERT_THAT(server_session, IsOk()); - auto channel = OakSessionClient(TestSessionConfig()) + auto channel = OakSessionClient(TestSessionConfig) .NewChannel(std::make_unique( std::move(*server_session))); @@ -93,7 +109,7 @@ TEST(OakSessionClientTest, CreatedSessionCanReceive) { // Hold a pointer for testing behavior below. session::ServerSession* server_session_ptr = server_session->get(); ASSERT_THAT(server_session, IsOk()); - auto channel = OakSessionClient(TestSessionConfig()) + auto channel = OakSessionClient(TestSessionConfig) .NewChannel(std::make_unique( std::move(*server_session))); diff --git a/cc/server/session_server.cc b/cc/server/session_server.cc index a7e3010e56..858b2bd38e 100644 --- a/cc/server/session_server.cc +++ b/cc/server/session_server.cc @@ -32,7 +32,7 @@ namespace oak::server { absl::StatusOr> OakSessionServer::NewChannel(std::unique_ptr transport) { - auto session = session::ServerSession::Create(config_); + auto session = session::ServerSession::Create(config_provider_()); if (!session.ok()) { return util::status::Annotate(session.status(), "Failed to create server session"); diff --git a/cc/server/session_server.h b/cc/server/session_server.h index 20b6aabb3b..7d5732b3b1 100644 --- a/cc/server/session_server.h +++ b/cc/server/session_server.h @@ -37,16 +37,29 @@ class OakSessionServer { session::ServerSession>; // A valid `SessionConfig` can be obtained using - // oak::session::SessionConfigBuilder. - OakSessionServer(session::SessionConfig* config) : config_(config) {} + // oak::session::SessionConfigBuilder. Each session needs its own unique + // SessionConfig instance, so a function to create a new SessionConfig should + // be provided. + OakSessionServer( + absl::AnyInvocable config_provider) + : config_provider_(std::move(config_provider)) {} - // Use a default configuration, Unattested + NoiseNN + // Use a default configuration provider, Unattested + NoiseNN ABSL_DEPRECATED("Use the config-providing variant.") OakSessionServer() - : OakSessionServer( - session::SessionConfigBuilder(session::AttestationType::kUnattested, - session::HandshakeType::kNoiseNN) - .Build()) {} + : OakSessionServer([] { + return session::SessionConfigBuilder( + session::AttestationType::kUnattested, + session::HandshakeType::kNoiseNN) + .Build(); + }) {} + + // Keeping this around briefly until we transition existing clients. + ABSL_DEPRECATED( + "This constructor will lead to UB. Use the config-provider-providing " + "variant.") + OakSessionServer(session::SessionConfig* config) + : OakSessionServer([config] { return config; }) {} // Create a new OakServerChannel instance with the provided session and // transport. @@ -61,7 +74,7 @@ class OakSessionServer { std::unique_ptr transport); private: - session::SessionConfig* config_; + absl::AnyInvocable config_provider_; }; } // namespace oak::server diff --git a/cc/server/session_server_test.cc b/cc/server/session_server_test.cc index 56af9b57c7..de04696fb0 100644 --- a/cc/server/session_server_test.cc +++ b/cc/server/session_server_test.cc @@ -68,9 +68,26 @@ session::SessionConfig* TestSessionConfig() { TEST(OakSessionServerTest, CreateSuccessFullyHandshakes) { auto client_session = session::ClientSession::Create(TestSessionConfig()); ASSERT_THAT(client_session, IsOk()); - auto _ = OakSessionServer(TestSessionConfig()) - .NewChannel( - std::make_unique(std::move(*client_session))); + auto channel = OakSessionServer(TestSessionConfig) + .NewChannel(std::make_unique( + std::move(*client_session))); + ASSERT_THAT(channel, IsOk()); +} + +TEST(OakSessionServerTest, SecondCreateSuccessFullyHandshakes) { + auto server = OakSessionServer(TestSessionConfig); + + auto client_session = session::ClientSession::Create(TestSessionConfig()); + ASSERT_THAT(client_session, IsOk()); + auto channel = server.NewChannel( + std::make_unique(std::move(*client_session))); + ASSERT_THAT(channel, IsOk()); + + auto client_session2 = session::ClientSession::Create(TestSessionConfig()); + ASSERT_THAT(client_session2, IsOk()); + auto channel2 = server.NewChannel( + std::make_unique(std::move(*client_session2))); + ASSERT_THAT(channel2, IsOk()); } TEST(OakSessionServerTest, CreatedSessionCanSend) { @@ -78,7 +95,7 @@ TEST(OakSessionServerTest, CreatedSessionCanSend) { // Hold a pointer for testing behavior below. session::ClientSession* client_session_ptr = client_session->get(); ASSERT_THAT(client_session, IsOk()); - auto channel = OakSessionServer(TestSessionConfig()) + auto channel = OakSessionServer(TestSessionConfig) .NewChannel(std::make_unique( std::move(*client_session))); @@ -96,7 +113,7 @@ TEST(OakSessionServerTest, CreatedSessionCanReceive) { // Hold a pointer for testing behavior below. session::ClientSession* client_session_ptr = client_session->get(); ASSERT_THAT(client_session, IsOk()); - auto channel = OakSessionServer(TestSessionConfig()) + auto channel = OakSessionServer(TestSessionConfig) .NewChannel(std::make_unique( std::move(*client_session)));