diff --git a/include/seastar/websocket/server.hh b/include/seastar/websocket/server.hh index e8bfef99ce..aeed48c4e3 100644 --- a/include/seastar/websocket/server.hh +++ b/include/seastar/websocket/server.hh @@ -324,7 +324,13 @@ public: bool is_handler_registered(std::string const& name); - void register_handler(std::string&& name, handler_t handler); + /*! + * \brief Register a handler for specific subprotocol + * \param name The name of the subprotocol. If it is empty string, then the handler is used + * when the protocol is not specified + * \param handler Handler for incoming WebSocket messages. + */ + void register_handler(const std::string& name, handler_t handler); friend class connection; protected: diff --git a/src/websocket/server.cc b/src/websocket/server.cc index 7dbaa6a696..6d52837518 100644 --- a/src/websocket/server.cc +++ b/src/websocket/server.cc @@ -164,9 +164,6 @@ future<> connection::read_http_upgrade_request() { } sstring subprotocol = req->get_header("Sec-WebSocket-Protocol"); - if (subprotocol.empty()) { - throw websocket::exception("Subprotocol header missing."); - } if (!_server.is_handler_registered(subprotocol)) { throw websocket::exception("Subprotocol not supported."); @@ -187,8 +184,10 @@ future<> connection::read_http_upgrade_request() { co_await _write_buf.write(http_upgrade_reply_template); co_await _write_buf.write(sha1_output); - co_await _write_buf.write("\r\nSec-WebSocket-Protocol: ", 26); - co_await _write_buf.write(_subprotocol); + if (!_subprotocol.empty()) { + co_await _write_buf.write("\r\nSec-WebSocket-Protocol: ", 26); + co_await _write_buf.write(_subprotocol); + } co_await _write_buf.write("\r\n\r\n", 4); co_await _write_buf.flush(); } @@ -414,7 +413,7 @@ bool server::is_handler_registered(std::string const& name) { return _handlers.find(name) != _handlers.end(); } -void server::register_handler(std::string&& name, handler_t handler) { +void server::register_handler(const std::string& name, handler_t handler) { _handlers[name] = handler; } diff --git a/tests/unit/websocket_test.cc b/tests/unit/websocket_test.cc index 1d71a673da..bc65e21cc5 100644 --- a/tests/unit/websocket_test.cc +++ b/tests/unit/websocket_test.cc @@ -11,17 +11,30 @@ using namespace seastar; using namespace seastar::experimental; +using namespace std::literals::string_view_literals; + +std::string build_request(std::string_view key_base64, std::string_view subprotocol) { + std::string subprotocol_line; + if (!subprotocol.empty()) { + subprotocol_line = fmt::format("Sec-WebSocket-Protocol: {}\r\n", subprotocol); + } + + return fmt::format( + "GET / HTTP/1.1\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key: {}\r\n" + "Sec-WebSocket-Version: 13\r\n" + "{}" + "\r\n", + key_base64, + subprotocol_line); +} + +future<> test_websocket_handshake_common(std::string subprotocol) { + return seastar::async([=] { + const std::string request = build_request("dGhlIHNhbXBsZSBub25jZQ==", subprotocol); -SEASTAR_TEST_CASE(test_websocket_handshake) { - return seastar::async([] { - const std::string request = - "GET / HTTP/1.1\r\n" - "Upgrade: websocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" - "Sec-WebSocket-Version: 13\r\n" - "Sec-WebSocket-Protocol: echo\r\n" - "\r\n"; loopback_connection_factory factory; loopback_socket_impl lsi(factory); @@ -32,7 +45,7 @@ SEASTAR_TEST_CASE(test_websocket_handshake) { auto output = sock.output(); websocket::server dummy; - dummy.register_handler("echo", [] (input_stream& in, + dummy.register_handler(subprotocol, [] (input_stream& in, output_stream& out) { return repeat([&in, &out]() { return in.read().then([&out](temporary_buffer f) { @@ -81,10 +94,16 @@ SEASTAR_TEST_CASE(test_websocket_handshake) { }); } +SEASTAR_TEST_CASE(test_websocket_handshake) { + return test_websocket_handshake_common("echo"); +} +SEASTAR_TEST_CASE(test_websocket_handshake_no_subprotocol) { + return test_websocket_handshake_common(""); +} -SEASTAR_TEST_CASE(test_websocket_handler_registration) { - return seastar::async([] { +future<> test_websocket_handler_registration_common(std::string subprotocol) { + return seastar::async([=] { loopback_connection_factory factory; loopback_socket_impl lsi(factory); @@ -96,7 +115,7 @@ SEASTAR_TEST_CASE(test_websocket_handler_registration) { // Setup server websocket::server ws; - ws.register_handler("echo", [] (input_stream& in, + ws.register_handler(subprotocol, [] (input_stream& in, output_stream& out) { return repeat([&in, &out]() { return in.read().then([&out](temporary_buffer f) { @@ -124,17 +143,15 @@ SEASTAR_TEST_CASE(test_websocket_handler_registration) { }); // handshake - const std::string request = - "GET / HTTP/1.1\r\n" - "Upgrade: websocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" - "Sec-WebSocket-Version: 13\r\n" - "Sec-WebSocket-Protocol: echo\r\n" - "\r\n"; + const std::string request = build_request("dGhlIHNhbXBsZSBub25jZQ==", subprotocol); output.write(request).get(); output.flush().get(); - input.read_exactly(186).get(); + + unsigned reply_size = 156; + if (!subprotocol.empty()) { + reply_size += ("\r\nSec-WebSocket-Protocol: "sv).size() + subprotocol.size(); + } + input.read_exactly(reply_size).get(); unsigned ws_frame_len = 10; @@ -156,6 +173,14 @@ SEASTAR_TEST_CASE(test_websocket_handler_registration) { }); } +SEASTAR_TEST_CASE(test_websocket_handler_registration) { + return test_websocket_handler_registration_common("echo"); +} + +SEASTAR_TEST_CASE(test_websocket_handler_registration_no_subprotocol) { + return test_websocket_handler_registration_common(""); +} + // Simple wrapper to help create a testable input_stream. class test_source_impl : public data_source_impl { std::vector> _bufs{};