diff --git a/CMakeLists.txt b/CMakeLists.txt index f58d0bf8f4..1dc9a36c1d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -688,6 +688,7 @@ add_library (seastar include/seastar/util/closeable.hh include/seastar/util/source_location-compat.hh include/seastar/util/short_streams.hh + include/seastar/websocket/common.hh include/seastar/websocket/server.hh src/core/alien.cc src/core/file.cc @@ -777,6 +778,8 @@ add_library (seastar src/util/read_first_line.cc src/util/tmp_file.cc src/util/short_streams.cc + src/websocket/parser.cc + src/websocket/common.cc src/websocket/server.cc ) diff --git a/include/seastar/websocket/common.hh b/include/seastar/websocket/common.hh new file mode 100644 index 0000000000..7050ff54cc --- /dev/null +++ b/include/seastar/websocket/common.hh @@ -0,0 +1,173 @@ +/* + * This file is open source software, licensed to you under the terms + * of the Apache License, Version 2.0 (the "License"). See the NOTICE file + * distributed with this work for additional information regarding copyright + * ownership. You may not use this file except in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * Copyright 2024 ScyllaDB + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace seastar::experimental::websocket { + +extern sstring magic_key_suffix; + +using handler_t = std::function(input_stream&, output_stream&)>; + +class server; + +/// \defgroup websocket WebSocket +/// \addtogroup websocket +/// @{ + +/*! + * \brief an error in handling a WebSocket connection + */ +class exception : public std::exception { + std::string _msg; +public: + exception(std::string_view msg) : _msg(msg) {} + virtual const char* what() const noexcept { + return _msg.c_str(); + } +}; + +/*! + * \brief a server WebSocket connection + */ +class connection : public boost::intrusive::list_base_hook<> { +protected: + using buff_t = temporary_buffer; + + /*! + * \brief Implementation of connection's data source. + */ + class connection_source_impl final : public data_source_impl { + queue* data; + + public: + connection_source_impl(queue* data) : data(data) {} + + virtual future get() override { + return data->pop_eventually().then_wrapped([](future f){ + try { + return make_ready_future(std::move(f.get())); + } catch(...) { + return current_exception_as_future(); + } + }); + } + + virtual future<> close() override { + data->push(buff_t(0)); + return make_ready_future<>(); + } + }; + + /*! + * \brief Implementation of connection's data sink. + */ + class connection_sink_impl final : public data_sink_impl { + queue* data; + public: + connection_sink_impl(queue* data) : data(data) {} + + virtual future<> put(net::packet d) override { + net::fragment f = d.frag(0); + return data->push_eventually(temporary_buffer{std::move(f.base), f.size}); + } + + size_t buffer_size() const noexcept override { + return data->max_size(); + } + + virtual future<> close() override { + data->push(buff_t(0)); + return make_ready_future<>(); + } + }; + + /*! + * \brief This function processess received PING frame. + * https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 + */ + future<> handle_ping(); + /*! + * \brief This function processess received PONG frame. + * https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 + */ + future<> handle_pong(); + + static const size_t PIPE_SIZE = 512; + connected_socket _fd; + input_stream _read_buf; + output_stream _write_buf; + bool _done = false; + + websocket_parser _websocket_parser; + queue > _input_buffer; + input_stream _input; + queue > _output_buffer; + output_stream _output; + + sstring _subprotocol; + handler_t _handler; +public: + /*! + * \param fd established socket used for communication + */ + connection(connected_socket&& fd) + : _fd(std::move(fd)) + , _read_buf(_fd.input()) + , _write_buf(_fd.output()) + , _input_buffer{PIPE_SIZE} + , _output_buffer{PIPE_SIZE} + { + _input = input_stream{data_source{ + std::make_unique(&_input_buffer)}}; + _output = output_stream{data_sink{ + std::make_unique(&_output_buffer)}}; + } + + /*! + * \brief close the socket + */ + void shutdown_input(); + future<> close(bool send_close = true); + +protected: + future<> read_one(); + future<> response_loop(); + /*! + * \brief Packs buff in websocket frame and sends it to the client. + */ + future<> send_data(opcodes opcode, temporary_buffer&& buff); +}; + +std::string sha1_base64(std::string_view source); +std::string encode_base64(std::string_view source); + +extern logger websocket_logger; + +/// @} +} diff --git a/include/seastar/websocket/parser.hh b/include/seastar/websocket/parser.hh new file mode 100644 index 0000000000..836d54d254 --- /dev/null +++ b/include/seastar/websocket/parser.hh @@ -0,0 +1,143 @@ +/* + * This file is open source software, licensed to you under the terms + * of the Apache License, Version 2.0 (the "License"). See the NOTICE file + * distributed with this work for additional information regarding copyright + * ownership. You may not use this file except in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include +#include + +namespace seastar::experimental::websocket { + +/// \addtogroup websocket +/// @{ + +/*! + * \brief Possible type of a websocket frame. + */ +enum opcodes { + CONTINUATION = 0x0, + TEXT = 0x1, + BINARY = 0x2, + CLOSE = 0x8, + PING = 0x9, + PONG = 0xA, + INVALID = 0xFF, +}; + +struct frame_header { + static constexpr uint8_t FIN = 7; + static constexpr uint8_t RSV1 = 6; + static constexpr uint8_t RSV2 = 5; + static constexpr uint8_t RSV3 = 4; + static constexpr uint8_t MASKED = 7; + + uint8_t fin : 1; + uint8_t rsv1 : 1; + uint8_t rsv2 : 1; + uint8_t rsv3 : 1; + uint8_t opcode : 4; + uint8_t masked : 1; + uint8_t length : 7; + frame_header(const char* input) { + this->fin = (input[0] >> FIN) & 1; + this->rsv1 = (input[0] >> RSV1) & 1; + this->rsv2 = (input[0] >> RSV2) & 1; + this->rsv3 = (input[0] >> RSV3) & 1; + this->opcode = input[0] & 0b1111; + this->masked = (input[1] >> MASKED) & 1; + this->length = (input[1] & 0b1111111); + } + // Returns length of the rest of the header. + uint64_t get_rest_of_header_length() { + size_t next_read_length = sizeof(uint32_t); // Masking key + if (length == 126) { + next_read_length += sizeof(uint16_t); + } else if (length == 127) { + next_read_length += sizeof(uint64_t); + } + return next_read_length; + } + uint8_t get_fin() {return fin;} + uint8_t get_rsv1() {return rsv1;} + uint8_t get_rsv2() {return rsv2;} + uint8_t get_rsv3() {return rsv3;} + uint8_t get_opcode() {return opcode;} + uint8_t get_masked() {return masked;} + uint8_t get_length() {return length;} + + bool is_opcode_known() { + //https://datatracker.ietf.org/doc/html/rfc6455#section-5.1 + return opcode < 0xA && !(opcode < 0x8 && opcode > 0x2); + } +}; + +class websocket_parser { + enum class parsing_state : uint8_t { + flags_and_payload_data, + payload_length_and_mask, + payload + }; + enum class connection_state : uint8_t { + valid, + closed, + error + }; + using consumption_result_t = consumption_result; + using buff_t = temporary_buffer; + // What parser is currently doing. + parsing_state _state; + // State of connection - can be valid, closed or should be closed + // due to error. + connection_state _cstate; + sstring _buffer; + std::unique_ptr _header; + uint64_t _payload_length = 0; + uint64_t _consumed_payload_length = 0; + uint32_t _masking_key; + buff_t _result; + + static future dont_stop() { + return make_ready_future(continue_consuming{}); + } + static future stop(buff_t data) { + return make_ready_future(stop_consuming(std::move(data))); + } + uint64_t remaining_payload_length() const { + return _payload_length - _consumed_payload_length; + } + + // Removes mask from payload given in p. + void remove_mask(buff_t& p, size_t n) { + char *payload = p.get_write(); + for (uint64_t i = 0, j = 0; i < n; ++i, j = (j + 1) % 4) { + payload[i] ^= static_cast(((_masking_key << (j * 8)) >> 24)); + } + } +public: + websocket_parser() : _state(parsing_state::flags_and_payload_data), + _cstate(connection_state::valid), + _masking_key(0) {} + future operator()(temporary_buffer data); + bool is_valid() { return _cstate == connection_state::valid; } + bool eof() { return _cstate == connection_state::closed; } + opcodes opcode() const; + buff_t result(); +}; + +/// @} +} diff --git a/include/seastar/websocket/server.hh b/include/seastar/websocket/server.hh index aeed48c4e3..5200a238d7 100644 --- a/include/seastar/websocket/server.hh +++ b/include/seastar/websocket/server.hh @@ -22,275 +22,49 @@ #pragma once #include -#include #include #include #include -#include #include #include #include +#include namespace seastar::experimental::websocket { -using handler_t = std::function(input_stream&, output_stream&)>; - -class server; - -/// \defgroup websocket WebSocket /// \addtogroup websocket /// @{ /*! - * \brief an error in handling a WebSocket connection + * \brief a server WebSocket connection */ -class exception : public std::exception { - std::string _msg; -public: - exception(std::string_view msg) : _msg(msg) {} - virtual const char* what() const noexcept { - return _msg.c_str(); - } -}; - -/*! - * \brief Possible type of a websocket frame. - */ -enum opcodes { - CONTINUATION = 0x0, - TEXT = 0x1, - BINARY = 0x2, - CLOSE = 0x8, - PING = 0x9, - PONG = 0xA, - INVALID = 0xFF, -}; - -struct frame_header { - static constexpr uint8_t FIN = 7; - static constexpr uint8_t RSV1 = 6; - static constexpr uint8_t RSV2 = 5; - static constexpr uint8_t RSV3 = 4; - static constexpr uint8_t MASKED = 7; - - uint8_t fin : 1; - uint8_t rsv1 : 1; - uint8_t rsv2 : 1; - uint8_t rsv3 : 1; - uint8_t opcode : 4; - uint8_t masked : 1; - uint8_t length : 7; - frame_header(const char* input) { - this->fin = (input[0] >> FIN) & 1; - this->rsv1 = (input[0] >> RSV1) & 1; - this->rsv2 = (input[0] >> RSV2) & 1; - this->rsv3 = (input[0] >> RSV3) & 1; - this->opcode = input[0] & 0b1111; - this->masked = (input[1] >> MASKED) & 1; - this->length = (input[1] & 0b1111111); - } - // Returns length of the rest of the header. - uint64_t get_rest_of_header_length() { - size_t next_read_length = sizeof(uint32_t); // Masking key - if (length == 126) { - next_read_length += sizeof(uint16_t); - } else if (length == 127) { - next_read_length += sizeof(uint64_t); - } - return next_read_length; - } - uint8_t get_fin() {return fin;} - uint8_t get_rsv1() {return rsv1;} - uint8_t get_rsv2() {return rsv2;} - uint8_t get_rsv3() {return rsv3;} - uint8_t get_opcode() {return opcode;} - uint8_t get_masked() {return masked;} - uint8_t get_length() {return length;} +class server_connection : public connection { - bool is_opcode_known() { - //https://datatracker.ietf.org/doc/html/rfc6455#section-5.1 - return opcode < 0xA && !(opcode < 0x8 && opcode > 0x2); - } -}; - -class websocket_parser { - enum class parsing_state : uint8_t { - flags_and_payload_data, - payload_length_and_mask, - payload - }; - enum class connection_state : uint8_t { - valid, - closed, - error - }; - using consumption_result_t = consumption_result; - using buff_t = temporary_buffer; - // What parser is currently doing. - parsing_state _state; - // State of connection - can be valid, closed or should be closed - // due to error. - connection_state _cstate; - sstring _buffer; - std::unique_ptr _header; - uint64_t _payload_length; - uint64_t _consumed_payload_length = 0; - uint32_t _masking_key; - buff_t _result; - - static future dont_stop() { - return make_ready_future(continue_consuming{}); - } - static future stop(buff_t data) { - return make_ready_future(stop_consuming(std::move(data))); - } - uint64_t remaining_payload_length() const { - return _payload_length - _consumed_payload_length; - } - - // Removes mask from payload given in p. - void remove_mask(buff_t& p, size_t n) { - char *payload = p.get_write(); - for (uint64_t i = 0, j = 0; i < n; ++i, j = (j + 1) % 4) { - payload[i] ^= static_cast(((_masking_key << (j * 8)) >> 24)); - } - } -public: - websocket_parser() : _state(parsing_state::flags_and_payload_data), - _cstate(connection_state::valid), - _payload_length(0), - _masking_key(0) {} - future operator()(temporary_buffer data); - bool is_valid() { return _cstate == connection_state::valid; } - bool eof() { return _cstate == connection_state::closed; } - opcodes opcode() const; - buff_t result(); -}; - -/*! - * \brief a WebSocket connection - */ -class connection : public boost::intrusive::list_base_hook<> { - using buff_t = temporary_buffer; - - /*! - * \brief Implementation of connection's data source. - */ - class connection_source_impl final : public data_source_impl { - queue* data; - - public: - connection_source_impl(queue* data) : data(data) {} - - virtual future get() override { - return data->pop_eventually().then_wrapped([](future f){ - try { - return make_ready_future(std::move(f.get())); - } catch(...) { - return current_exception_as_future(); - } - }); - } - - virtual future<> close() override { - data->push(buff_t(0)); - return make_ready_future<>(); - } - }; - - /*! - * \brief Implementation of connection's data sink. - */ - class connection_sink_impl final : public data_sink_impl { - queue* data; - public: - connection_sink_impl(queue* data) : data(data) {} - - virtual future<> put(net::packet d) override { - net::fragment f = d.frag(0); - return data->push_eventually(temporary_buffer{std::move(f.base), f.size}); - } - - size_t buffer_size() const noexcept override { - return data->max_size(); - } - - virtual future<> close() override { - data->push(buff_t(0)); - return make_ready_future<>(); - } - }; - - /*! - * \brief This function processess received PING frame. - * https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 - */ - future<> handle_ping(); - /*! - * \brief This function processess received PONG frame. - * https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 - */ - future<> handle_pong(); - - static const size_t PIPE_SIZE = 512; server& _server; - connected_socket _fd; - input_stream _read_buf; - output_stream _write_buf; http_request_parser _http_parser; - bool _done = false; - websocket_parser _websocket_parser; - queue > _input_buffer; - input_stream _input; - queue > _output_buffer; - output_stream _output; - - sstring _subprotocol; - handler_t _handler; public: /*! * \param server owning \ref server * \param fd established socket used for communication */ - connection(server& server, connected_socket&& fd) - : _server(server) - , _fd(std::move(fd)) - , _read_buf(_fd.input()) - , _write_buf(_fd.output()) - , _input_buffer{PIPE_SIZE} - , _output_buffer{PIPE_SIZE} - { - _input = input_stream{data_source{ - std::make_unique(&_input_buffer)}}; - _output = output_stream{data_sink{ - std::make_unique(&_output_buffer)}}; + server_connection(server& server, connected_socket&& fd) + : connection(std::move(fd)) + , _server(server) { on_new_connection(); } - ~connection(); + ~server_connection(); /*! - * \brief serve WebSocket protocol on a connection + * \brief serve WebSocket protocol on a server_connection */ future<> process(); - /*! - * \brief close the socket - */ - void shutdown_input(); - future<> close(bool send_close = true); protected: future<> read_loop(); - future<> read_one(); future<> read_http_upgrade_request(); - future<> response_loop(); void on_new_connection(); - /*! - * \brief Packs buff in websocket frame and sends it to the client. - */ - future<> send_data(opcodes opcode, temporary_buffer&& buff); - }; /*! @@ -301,7 +75,7 @@ protected: */ class server { std::vector _listeners; - boost::intrusive::list _connections; + boost::intrusive::list _connections; std::map _handlers; gate _task_gate; public: @@ -332,7 +106,7 @@ public: */ void register_handler(const std::string& name, handler_t handler); - friend class connection; + friend class server_connection; protected: void accept(server_socket &listener); future accept_one(server_socket &listener); diff --git a/src/websocket/common.cc b/src/websocket/common.cc new file mode 100644 index 0000000000..040529b603 --- /dev/null +++ b/src/websocket/common.cc @@ -0,0 +1,157 @@ +/* + * This file is open source software, licensed to you under the terms + * of the Apache License, Version 2.0 (the "License"). See the NOTICE file + * distributed with this work for additional information regarding copyright + * ownership. You may not use this file except in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * Copyright 2024 ScyllaDB + */ + +#include +#include +#include +#include +#include +#include + +namespace seastar::experimental::websocket { + +sstring magic_key_suffix = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; +logger websocket_logger("websocket"); + +future<> connection::handle_ping() { + // TODO + return make_ready_future<>(); +} + +future<> connection::handle_pong() { + // TODO + return make_ready_future<>(); +} + +future<> connection::send_data(opcodes opcode, temporary_buffer&& buff) { + char header[10] = {'\x80', 0}; + size_t header_size = 2; + + header[0] += opcode; + + if ((126 <= buff.size()) && (buff.size() <= std::numeric_limits::max())) { + header[1] = 0x7E; + write_be(header + 2, buff.size()); + header_size += sizeof(uint16_t); + } else if (std::numeric_limits::max() < buff.size()) { + header[1] = 0x7F; + write_be(header + 2, buff.size()); + header_size += sizeof(uint64_t); + } else { + header[1] = uint8_t(buff.size()); + } + + scattered_message msg; + msg.append(sstring(header, header_size)); + msg.append(std::move(buff)); + return _write_buf.write(std::move(msg)).then([this] { + return _write_buf.flush(); + }); +} + +future<> connection::response_loop() { + return do_until([this] {return _done;}, [this] { + // FIXME: implement error handling + return _output_buffer.pop_eventually().then([this] ( + temporary_buffer buf) { + return send_data(opcodes::BINARY, std::move(buf)); + }); + }).finally([this]() { + return _write_buf.close(); + }); +} + +void connection::shutdown_input() { + _fd.shutdown_input(); +} + +future<> connection::close(bool send_close) { + return [this, send_close]() { + if (send_close) { + return send_data(opcodes::CLOSE, temporary_buffer(0)); + } else { + return make_ready_future<>(); + } + }().finally([this] { + _done = true; + return when_all_succeed(_input.close(), _output.close()).discard_result().finally([this] { + _fd.shutdown_output(); + }); + }); +} + +future<> connection::read_one() { + return _read_buf.consume(_websocket_parser).then([this] () mutable { + if (_websocket_parser.is_valid()) { + // FIXME: implement error handling + switch(_websocket_parser.opcode()) { + // We do not distinguish between these 3 types. + case opcodes::CONTINUATION: + case opcodes::TEXT: + case opcodes::BINARY: + return _input_buffer.push_eventually(_websocket_parser.result()); + case opcodes::CLOSE: + websocket_logger.debug("Received close frame."); + // datatracker.ietf.org/doc/html/rfc6455#section-5.5.1 + return close(true); + case opcodes::PING: + websocket_logger.debug("Received ping frame."); + return handle_ping(); + case opcodes::PONG: + websocket_logger.debug("Received pong frame."); + return handle_pong(); + default: + // Invalid - do nothing. + ; + } + } else if (_websocket_parser.eof()) { + return close(false); + } + websocket_logger.debug("Reading from socket has failed."); + return close(true); + }); +} + +std::string sha1_base64(std::string_view source) { + unsigned char hash[20]; + assert(sizeof(hash) == gnutls_hash_get_len(GNUTLS_DIG_SHA1)); + if (int ret = gnutls_hash_fast(GNUTLS_DIG_SHA1, source.data(), source.size(), hash); + ret != GNUTLS_E_SUCCESS) { + throw websocket::exception(fmt::format("gnutls_hash_fast: {}", gnutls_strerror(ret))); + } + return encode_base64(std::string_view(reinterpret_cast(hash), sizeof(hash))); +} + +std::string encode_base64(std::string_view source) { + gnutls_datum_t src_data{ + .data = reinterpret_cast(const_cast(source.data())), + .size = static_cast(source.size()) + }; + gnutls_datum_t encoded_data; + if (int ret = gnutls_base64_encode2(&src_data, &encoded_data); ret != GNUTLS_E_SUCCESS) { + throw websocket::exception(fmt::format("gnutls_base64_encode2: {}", gnutls_strerror(ret))); + } + auto free_encoded_data = defer([&] () noexcept { gnutls_free(encoded_data.data); }); + // base64_encoded.data is "unsigned char *" + return std::string(reinterpret_cast(encoded_data.data), encoded_data.size); +} + +} diff --git a/src/websocket/parser.cc b/src/websocket/parser.cc new file mode 100644 index 0000000000..8de5fa753c --- /dev/null +++ b/src/websocket/parser.cc @@ -0,0 +1,136 @@ +/* + * This file is open source software, licensed to you under the terms + * of the Apache License, Version 2.0 (the "License"). See the NOTICE file + * distributed with this work for additional information regarding copyright + * ownership. You may not use this file except in compliance with the License. + * + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +namespace seastar::experimental::websocket { + +opcodes websocket_parser::opcode() const { + if (_header) { + return opcodes(_header->opcode); + } else { + return opcodes::INVALID; + } +} + +websocket_parser::buff_t websocket_parser::result() { + return std::move(_result); +} + +future websocket_parser::operator()( + temporary_buffer data) { + if (data.size() == 0) { + // EOF + _cstate = connection_state::closed; + return websocket_parser::stop(std::move(data)); + } + if (_state == parsing_state::flags_and_payload_data) { + if (_buffer.length() + data.size() >= 2) { + // _buffer.length() is less than 2 when entering this if body due to how + // the rest of code is structured. The else branch will never increase + // _buffer.length() to >=2 and other paths to this condition will always + // have buffer cleared. + assert(_buffer.length() < 2); + + size_t hlen = _buffer.length(); + _buffer.append(data.get(), 2 - hlen); + data.trim_front(2 - hlen); + _header = std::make_unique(_buffer.data()); + _buffer = {}; + + // https://datatracker.ietf.org/doc/html/rfc6455#section-5.1 + // We must close the connection if data isn't masked. + if ((!_header->masked) || + // RSVX must be 0 + (_header->rsv1 | _header->rsv2 | _header->rsv3) || + // Opcode must be known. + (!_header->is_opcode_known())) { + _cstate = connection_state::error; + return websocket_parser::stop(std::move(data)); + } + _state = parsing_state::payload_length_and_mask; + } else { + _buffer.append(data.get(), data.size()); + return websocket_parser::dont_stop(); + } + } + if (_state == parsing_state::payload_length_and_mask) { + size_t const required_bytes = _header->get_rest_of_header_length(); + if (_buffer.length() + data.size() >= required_bytes) { + if (_buffer.length() < required_bytes) { + size_t hlen = _buffer.length(); + _buffer.append(data.get(), required_bytes - hlen); + data.trim_front(required_bytes - hlen); + } + _payload_length = _header->length; + char const *input = _buffer.data(); + if (_header->length == 126) { + _payload_length = consume_be(input); + } else if (_header->length == 127) { + _payload_length = consume_be(input); + } + + _masking_key = consume_be(input); + _buffer = {}; + _state = parsing_state::payload; + } else { + _buffer.append(data.get(), data.size()); + return websocket_parser::dont_stop(); + } + } + if (_state == parsing_state::payload) { + if (data.size() < remaining_payload_length()) { + // data has insufficient data to complete the frame - consume data.size() bytes + if (_result.empty()) { + _result = temporary_buffer(remaining_payload_length()); + _consumed_payload_length = 0; + } + std::copy(data.begin(), data.end(), _result.get_write() + _consumed_payload_length); + _consumed_payload_length += data.size(); + return websocket_parser::dont_stop(); + } else { + // data has sufficient data to complete the frame - consume remaining_payload_length() + auto consumed_bytes = remaining_payload_length(); + if (_result.empty()) { + // Try to avoid memory copies in case when network packets contain one or more full + // websocket frames. + if (consumed_bytes == data.size()) { + _result = std::move(data); + data = temporary_buffer(0); + } else { + _result = data.share(); + _result.trim(consumed_bytes); + data.trim_front(consumed_bytes); + } + } else { + std::copy(data.begin(), data.begin() + consumed_bytes, + _result.get_write() + _consumed_payload_length); + data.trim_front(consumed_bytes); + } + remove_mask(_result, _payload_length); + _consumed_payload_length = 0; + _state = parsing_state::flags_and_payload_data; + return websocket_parser::stop(std::move(data)); + } + } + _cstate = connection_state::error; + return websocket_parser::stop(std::move(data)); +} + +} diff --git a/src/websocket/server.cc b/src/websocket/server.cc index 6d52837518..12c36e02a5 100644 --- a/src/websocket/server.cc +++ b/src/websocket/server.cc @@ -25,14 +25,10 @@ #include #include #include -#include #include -#include -#include namespace seastar::experimental::websocket { -static sstring magic_key_suffix = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; static sstring http_upgrade_reply_template = "HTTP/1.1 101 Switching Protocols\r\n" "Upgrade: websocket\r\n" @@ -40,20 +36,6 @@ static sstring http_upgrade_reply_template = "Sec-WebSocket-Version: 13\r\n" "Sec-WebSocket-Accept: "; -static logger wlogger("websocket"); - -opcodes websocket_parser::opcode() const { - if (_header) { - return opcodes(_header->opcode); - } else { - return opcodes::INVALID; - } -} - -websocket_parser::buff_t websocket_parser::result() { - return std::move(_result); -} - void server::listen(socket_address addr, listen_options lo) { _listeners.push_back(seastar::listen(addr, lo)); accept(_listeners.back()); @@ -74,10 +56,10 @@ void server::accept(server_socket &listener) { future server::accept_one(server_socket &listener) { return listener.accept().then([this](accept_result ar) { - auto conn = std::make_unique(*this, std::move(ar.connection)); + auto conn = std::make_unique(*this, std::move(ar.connection)); (void)try_with_gate(_task_gate, [conn = std::move(conn)]() mutable { return conn->process().finally([conn = std::move(conn)] { - wlogger.debug("Connection is finished"); + websocket_logger.debug("Connection is finished"); }); }).handle_exception_type([](const gate_closed_exception &e) {}); return make_ready_future(stop_iteration::no); @@ -85,11 +67,11 @@ future server::accept_one(server_socket &listener) { // We expect a ECONNABORTED when server::stop is called, // no point in warning about that. if (e.code().value() != ECONNABORTED) { - wlogger.error("accept failed: {}", e); + websocket_logger.error("accept failed: {}", e); } return make_ready_future(stop_iteration::yes); }).handle_exception([](std::exception_ptr ex) { - wlogger.info("accept failed: {}", ex); + websocket_logger.info("accept failed: {}", ex); return make_ready_future(stop_iteration::yes); }); } @@ -104,48 +86,27 @@ future<> server::stop() { } return _task_gate.close().finally([this] { - return parallel_for_each(_connections, [] (connection& conn) { + return parallel_for_each(_connections, [] (server_connection& conn) { return conn.close(true).handle_exception([] (auto ignored) {}); }); }); } -connection::~connection() { +server_connection::~server_connection() { _server._connections.erase(_server._connections.iterator_to(*this)); } -void connection::on_new_connection() { +void server_connection::on_new_connection() { _server._connections.push_back(*this); } -future<> connection::process() { +future<> server_connection::process() { return when_all_succeed(read_loop(), response_loop()).discard_result().handle_exception([] (const std::exception_ptr& e) { - wlogger.debug("Processing failed: {}", e); + websocket_logger.debug("Processing failed: {}", e); }); } -static std::string sha1_base64(std::string_view source) { - unsigned char hash[20]; - assert(sizeof(hash) == gnutls_hash_get_len(GNUTLS_DIG_SHA1)); - if (int ret = gnutls_hash_fast(GNUTLS_DIG_SHA1, source.data(), source.size(), hash); - ret != GNUTLS_E_SUCCESS) { - throw websocket::exception(fmt::format("gnutls_hash_fast: {}", gnutls_strerror(ret))); - } - gnutls_datum_t hash_data{ - .data = hash, - .size = sizeof(hash), - }; - gnutls_datum_t base64_encoded; - if (int ret = gnutls_base64_encode2(&hash_data, &base64_encoded); - ret != GNUTLS_E_SUCCESS) { - throw websocket::exception(fmt::format("gnutls_base64_encode2: {}", gnutls_strerror(ret))); - } - auto free_base64_encoded = defer([&] () noexcept { gnutls_free(base64_encoded.data); }); - // base64_encoded.data is "unsigned char *" - return std::string(reinterpret_cast(base64_encoded.data), base64_encoded.size); -} - -future<> connection::read_http_upgrade_request() { +future<> server_connection::read_http_upgrade_request() { _http_parser.init(); co_await _read_buf.consume(_http_parser); @@ -170,17 +131,17 @@ future<> connection::read_http_upgrade_request() { } this->_handler = this->_server._handlers[subprotocol]; this->_subprotocol = subprotocol; - wlogger.debug("Sec-WebSocket-Protocol: {}", subprotocol); + websocket_logger.debug("Sec-WebSocket-Protocol: {}", subprotocol); sstring sec_key = req->get_header("Sec-Websocket-Key"); sstring sec_version = req->get_header("Sec-Websocket-Version"); sstring sha1_input = sec_key + magic_key_suffix; - wlogger.debug("Sec-Websocket-Key: {}, Sec-Websocket-Version: {}", sec_key, sec_version); + websocket_logger.debug("Sec-Websocket-Key: {}, Sec-Websocket-Version: {}", sec_key, sec_version); std::string sha1_output = sha1_base64(sha1_input); - wlogger.debug("SHA1 output: {} of size {}", sha1_output, sha1_output.size()); + websocket_logger.debug("SHA1 output: {} of size {}", sha1_output, sha1_output.size()); co_await _write_buf.write(http_upgrade_reply_template); co_await _write_buf.write(sha1_output); @@ -192,152 +153,7 @@ future<> connection::read_http_upgrade_request() { co_await _write_buf.flush(); } -future websocket_parser::operator()( - temporary_buffer data) { - if (data.size() == 0) { - // EOF - _cstate = connection_state::closed; - return websocket_parser::stop(std::move(data)); - } - if (_state == parsing_state::flags_and_payload_data) { - if (_buffer.length() + data.size() >= 2) { - // _buffer.length() is less than 2 when entering this if body due to how - // the rest of code is structured. The else branch will never increase - // _buffer.length() to >=2 and other paths to this condition will always - // have buffer cleared. - assert(_buffer.length() < 2); - - size_t hlen = _buffer.length(); - _buffer.append(data.get(), 2 - hlen); - data.trim_front(2 - hlen); - _header = std::make_unique(_buffer.data()); - _buffer = {}; - - // https://datatracker.ietf.org/doc/html/rfc6455#section-5.1 - // We must close the connection if data isn't masked. - if ((!_header->masked) || - // RSVX must be 0 - (_header->rsv1 | _header->rsv2 | _header->rsv3) || - // Opcode must be known. - (!_header->is_opcode_known())) { - _cstate = connection_state::error; - return websocket_parser::stop(std::move(data)); - } - _state = parsing_state::payload_length_and_mask; - } else { - _buffer.append(data.get(), data.size()); - return websocket_parser::dont_stop(); - } - } - if (_state == parsing_state::payload_length_and_mask) { - size_t const required_bytes = _header->get_rest_of_header_length(); - if (_buffer.length() + data.size() >= required_bytes) { - if (_buffer.length() < required_bytes) { - size_t hlen = _buffer.length(); - _buffer.append(data.get(), required_bytes - hlen); - data.trim_front(required_bytes - hlen); - } - _payload_length = _header->length; - char const *input = _buffer.data(); - if (_header->length == 126) { - _payload_length = consume_be(input); - } else if (_header->length == 127) { - _payload_length = consume_be(input); - } - - _masking_key = consume_be(input); - _buffer = {}; - _state = parsing_state::payload; - } else { - _buffer.append(data.get(), data.size()); - return websocket_parser::dont_stop(); - } - } - if (_state == parsing_state::payload) { - if (data.size() < remaining_payload_length()) { - // data has insufficient data to complete the frame - consume data.size() bytes - if (_result.empty()) { - _result = temporary_buffer(remaining_payload_length()); - _consumed_payload_length = 0; - } - std::copy(data.begin(), data.end(), _result.get_write() + _consumed_payload_length); - _consumed_payload_length += data.size(); - return websocket_parser::dont_stop(); - } else { - // data has sufficient data to complete the frame - consume remaining_payload_length() - auto consumed_bytes = remaining_payload_length(); - if (_result.empty()) { - // Try to avoid memory copies in case when network packets contain one or more full - // websocket frames. - if (consumed_bytes == data.size()) { - _result = std::move(data); - data = temporary_buffer(0); - } else { - _result = data.share(); - _result.trim(consumed_bytes); - data.trim_front(consumed_bytes); - } - } else { - std::copy(data.begin(), data.begin() + consumed_bytes, - _result.get_write() + _consumed_payload_length); - data.trim_front(consumed_bytes); - } - remove_mask(_result, _payload_length); - _consumed_payload_length = 0; - _state = parsing_state::flags_and_payload_data; - return websocket_parser::stop(std::move(data)); - } - } - _cstate = connection_state::error; - return websocket_parser::stop(std::move(data)); -} - -future<> connection::handle_ping() { - // TODO - return make_ready_future<>(); -} - -future<> connection::handle_pong() { - // TODO - return make_ready_future<>(); -} - - -future<> connection::read_one() { - return _read_buf.consume(_websocket_parser).then([this] () mutable { - if (_websocket_parser.is_valid()) { - // FIXME: implement error handling - switch(_websocket_parser.opcode()) { - // We do not distinguish between these 3 types. - case opcodes::CONTINUATION: - case opcodes::TEXT: - case opcodes::BINARY: - return _input_buffer.push_eventually(_websocket_parser.result()); - case opcodes::CLOSE: - wlogger.debug("Received close frame."); - /* - * datatracker.ietf.org/doc/html/rfc6455#section-5.5.1 - */ - return close(true); - case opcodes::PING: - wlogger.debug("Received ping frame."); - return handle_ping(); - case opcodes::PONG: - wlogger.debug("Received pong frame."); - return handle_pong(); - default: - // Invalid - do nothing. - ; - } - } else if (_websocket_parser.eof()) { - return close(false); - } - wlogger.debug("Reading from socket has failed."); - return close(true); - }); -} - -future<> connection::read_loop() { +future<> server_connection::read_loop() { return read_http_upgrade_request().then([this] { return when_all_succeed( _handler(_input, _output).handle_exception([this] (std::exception_ptr e) mutable { @@ -352,63 +168,6 @@ future<> connection::read_loop() { }); } -void connection::shutdown_input() { - _fd.shutdown_input(); -} - -future<> connection::close(bool send_close) { - return [this, send_close]() { - if (send_close) { - return send_data(opcodes::CLOSE, temporary_buffer(0)); - } else { - return make_ready_future<>(); - } - }().finally([this] { - _done = true; - return when_all_succeed(_input.close(), _output.close()).discard_result().finally([this] { - _fd.shutdown_output(); - }); - }); -} - -future<> connection::send_data(opcodes opcode, temporary_buffer&& buff) { - char header[10] = {'\x80', 0}; - size_t header_size = 2; - - header[0] += opcode; - - if ((126 <= buff.size()) && (buff.size() <= std::numeric_limits::max())) { - header[1] = 0x7E; - write_be(header + 2, buff.size()); - header_size += sizeof(uint16_t); - } else if (std::numeric_limits::max() < buff.size()) { - header[1] = 0x7F; - write_be(header + 2, buff.size()); - header_size += sizeof(uint64_t); - } else { - header[1] = uint8_t(buff.size()); - } - - scattered_message msg; - msg.append(sstring(header, header_size)); - msg.append(std::move(buff)); - return _write_buf.write(std::move(msg)).then([this] { - return _write_buf.flush(); - }); -} - -future<> connection::response_loop() { - return do_until([this] {return _done;}, [this] { - // FIXME: implement error handling - return _output_buffer.pop_eventually().then([this] ( - temporary_buffer buf) { - return send_data(opcodes::BINARY, std::move(buf)); - }); - }).finally([this]() { - return _write_buf.close(); - }); -} - bool server::is_handler_registered(std::string const& name) { return _handlers.find(name) != _handlers.end(); } diff --git a/tests/unit/websocket_test.cc b/tests/unit/websocket_test.cc index bc65e21cc5..46a0bb540d 100644 --- a/tests/unit/websocket_test.cc +++ b/tests/unit/websocket_test.cc @@ -62,7 +62,7 @@ future<> test_websocket_handshake_common(std::string subprotocol) { }); }); }); - websocket::connection conn(dummy, acceptor.get().connection); + websocket::server_connection conn(dummy, acceptor.get().connection); future<> serve = conn.process(); auto close = defer([&conn, &input, &output, &serve] () noexcept { conn.close().get(); @@ -132,7 +132,7 @@ future<> test_websocket_handler_registration_common(std::string subprotocol) { }); }); }); - websocket::connection conn(ws, acceptor.get().connection); + websocket::server_connection conn(ws, acceptor.get().connection); future<> serve = conn.process(); auto close = defer([&conn, &input, &output, &serve] () noexcept {