diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 36348fa544..193e7b3511 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -23,6 +23,7 @@ LeonidCSIT kreuzerkrieg evanc Jesse Towner (jwtowner) +Atul Bagga (atbagga) Abinsula s.r.l. Gianfranco Costamagna (LocutusOfBorg) @@ -36,10 +37,15 @@ Gery Vessere (gery@vessere.com) Cisco Systems Gergely Lukacsy (glukacsy) Chris Deering (deeringc) +Chris O'Gorman (chogorma) Ocedo GmbH Henning Pfeiffer (megaposer) +neXenio GmbH +Patrik Fiedler (xqp) +RenĂ© Meusel (reneme) + thomasschaub Trimble diff --git a/Release/include/cpprest/certificate_info.h b/Release/include/cpprest/certificate_info.h new file mode 100644 index 0000000000..09dd21553e --- /dev/null +++ b/Release/include/cpprest/certificate_info.h @@ -0,0 +1,45 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * Certificate info + * + * For the latest on this and related APIs, please see: https://github.com/Microsoft/cpprestsdk + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ + +#pragma once + +#include +#include + +namespace web +{ +namespace http +{ +namespace client +{ +using CertificateChain = std::vector>; + +struct certificate_info +{ + std::string host_name; + CertificateChain certificate_chain; + long certificate_error {0}; + bool verified {false}; + + certificate_info(const std::string host) : host_name(host) { } + certificate_info(const std::string host, CertificateChain chain, long error = 0) + : host_name(host), certificate_chain(chain), certificate_error(error) + { + } +}; + +using CertificateChainFunction = std::function certificate_Info)>; + +} // namespace client +} // namespace http +} // namespace web diff --git a/Release/include/cpprest/http_client.h b/Release/include/cpprest/http_client.h index fb7c6067ab..3f899db9c1 100644 --- a/Release/include/cpprest/http_client.h +++ b/Release/include/cpprest/http_client.h @@ -45,6 +45,7 @@ typedef void* native_handle; #endif // __cplusplus_winrt #include "cpprest/asyncrt_utils.h" +#include "cpprest/certificate_info.h" #include "cpprest/details/basic_types.h" #include "cpprest/details/web_utilities.h" #include "cpprest/http_msg.h" @@ -101,6 +102,7 @@ class http_client_config #if !defined(__cplusplus_winrt) , m_validate_certificates(true) #endif + , m_certificate_chain_callback([](const std::shared_ptr&) -> bool { return true; }) #if !defined(_WIN32) && !defined(__cplusplus_winrt) || defined(CPPREST_FORCE_HTTP_CLIENT_ASIO) , m_tlsext_sni_enabled(true) #endif @@ -362,6 +364,27 @@ class http_client_config if (m_set_user_nativehandle_options) m_set_user_nativehandle_options(handle); } + + /// + /// Set the certificate chain callback. If set, HTTP client will call this callback in a blocking manner during HTTP + /// connection. + /// + void set_user_certificate_chain_callback(const CertificateChainFunction& callback) + { + m_certificate_chain_callback = callback; + } + + /// + /// Invokes the certificate chain callback. + /// + /// Pointer to the certificate_info struct that has the certificate + /// information. True if the consumer code allows the connection, false otherwise. False will + /// terminate the HTTP connection. + bool invoke_certificate_chain_callback(const std::shared_ptr& certificate_Info) const + { + return m_certificate_chain_callback(certificate_Info); + } + #if !defined(_WIN32) && !defined(__cplusplus_winrt) || defined(CPPREST_FORCE_HTTP_CLIENT_ASIO) /// /// Sets a callback to enable custom setting of the ssl context, at construction time. @@ -416,6 +439,7 @@ class http_client_config bool m_validate_certificates; #endif + CertificateChainFunction m_certificate_chain_callback; std::function m_set_user_nativehandle_options; std::function m_set_user_nativesessionhandle_options; diff --git a/Release/include/cpprest/oauth2.h b/Release/include/cpprest/oauth2.h index b1ec324996..1364345fc5 100644 --- a/Release/include/cpprest/oauth2.h +++ b/Release/include/cpprest/oauth2.h @@ -15,6 +15,7 @@ #ifndef CASA_OAUTH2_H #define CASA_OAUTH2_H +#include "cpprest/certificate_info.h" #include "cpprest/details/web_utilities.h" #include "cpprest/http_msg.h" @@ -219,6 +220,8 @@ class oauth2_config , m_bearer_auth(true) , m_http_basic_auth(true) , m_access_token_key(details::oauth2_strings::access_token) + , m_certificate_chain_callback([](const std::shared_ptr&) -> bool + { return true; }) { } @@ -480,6 +483,23 @@ class oauth2_config /// void set_user_agent(utility::string_t user_agent) { m_user_agent = std::move(user_agent); } + /// + /// Set the certificate chain callback to be used by the http client. + /// + void set_user_certificate_chain_callback(const web::http::client::CertificateChainFunction& callback) + { + m_certificate_chain_callback = callback; + } + + /// + /// Get the cert chain callback. + /// + /// A reference to cert chain callback user by the client. + const web::http::client::CertificateChainFunction& user_certificate_chain_callback() + { + return m_certificate_chain_callback; + } + private: friend class web::http::client::http_client_config; friend class web::http::oauth2::details::oauth2_handler; @@ -523,6 +543,8 @@ class oauth2_config oauth2_token m_token; utility::nonce_generator m_state_generator; + + web::http::client::CertificateChainFunction m_certificate_chain_callback; }; } // namespace experimental diff --git a/Release/include/cpprest/ws_client.h b/Release/include/cpprest/ws_client.h index af17bd6060..bd52cbc8ba 100644 --- a/Release/include/cpprest/ws_client.h +++ b/Release/include/cpprest/ws_client.h @@ -18,8 +18,10 @@ #if !defined(CPPREST_EXCLUDE_WEBSOCKETS) #include "cpprest/asyncrt_utils.h" +#include "cpprest/certificate_info.h" #include "cpprest/details/web_utilities.h" #include "cpprest/http_headers.h" +#include "cpprest/json.h" #include "cpprest/uri.h" #include "cpprest/ws_msg.h" #include "pplx/pplxtasks.h" @@ -79,7 +81,13 @@ class websocket_client_config /// /// Creates a websocket client configuration with default settings. /// - websocket_client_config() : m_sni_enabled(true), m_validate_certificates(true) {} + websocket_client_config() + : m_certificate_chain_callback([](const std::shared_ptr&) -> bool + { return true; }) + , m_sni_enabled(true) + , m_validate_certificates(true) + { + } /// /// Get the web proxy object @@ -199,6 +207,26 @@ class websocket_client_config } #endif + /// Set the certificate chain callback. If set, HTTP client will call this callback in a blocking manner during HTTP + /// connection. + /// + void set_user_certificate_chain_callback(const http::client::CertificateChainFunction& callback) + { + m_certificate_chain_callback = callback; + } + + /// + /// Invokes the certificate chain callback. + /// + /// Pointer to the certificate_info struct that has the certificate + /// information. True if the consumer code allows the connection, false otherwise. False will + /// terminate the HTTP connection. + bool invoke_certificate_chain_callback( + const std::shared_ptr& certificate_Info) const + { + return m_certificate_chain_callback(certificate_Info); + } + private: web::web_proxy m_proxy; web::credentials m_credentials; @@ -206,6 +234,8 @@ class websocket_client_config bool m_sni_enabled; utf8string m_sni_hostname; bool m_validate_certificates; + http::client::CertificateChainFunction m_certificate_chain_callback; + #if !defined(_WIN32) || !defined(__cplusplus_winrt) std::function m_ssl_context_callback; #endif diff --git a/Release/src/http/client/http_client_asio.cpp b/Release/src/http/client/http_client_asio.cpp index 07bb4885bf..195d7ec958 100644 --- a/Release/src/http/client/http_client_asio.cpp +++ b/Release/src/http/client/http_client_asio.cpp @@ -41,6 +41,7 @@ #include "../common/x509_cert_utilities.h" #include "cpprest/base_uri.h" +#include "cpprest/certificate_info.h" #include "cpprest/details/http_helpers.h" #include "http_client_impl.h" #include "pplx/threadpool.h" @@ -754,7 +755,7 @@ class asio_context final : public request_context, public std::enable_shared_fro } auto start_http_request_flow = [proxy_type, proxy_host, proxy_port AND_CAPTURE_MEMBER_FUNCTION_POINTERS]( - std::shared_ptr ctx) { + std::shared_ptr ctx) { if (ctx->m_request._cancellation_token().is_canceled()) { ctx->request_context::report_error(make_error_code(std::errc::operation_canceled).value(), @@ -1123,6 +1124,9 @@ class asio_context final : public request_context, public std::enable_shared_fro // If OpenSSL fails we will doing verification at the end using the whole certificate // chain so wait until the 'leaf' cert. For now return true so OpenSSL continues down // the certificate chain. + const auto& host = utility::conversions::to_utf8string(m_http_client->base_uri().host()); + using namespace web::http::client::details; + if (!preverified) { m_openssl_failed = true; @@ -1130,12 +1134,36 @@ class asio_context final : public request_context, public std::enable_shared_fro if (m_openssl_failed) { - return verify_cert_chain_platform_specific(verifyCtx, m_connection->cn_hostname()); + if (!is_end_certificate_in_chain(verifyCtx)) + { + // Continue until we get the end certificate. + return true; + } + + auto chainFunc = [this](const std::shared_ptr& cert_info) + { return m_http_client->client_config().invoke_certificate_chain_callback(cert_info); }; + + return http::client::details::verify_cert_chain_platform_specific( + verifyCtx, utility::conversions::to_utf8string(host), chainFunc); } -#endif // CPPREST_PLATFORM_ASIO_CERT_VERIFICATION_AVAILABLE +#endif - boost::asio::ssl::rfc2818_verification rfc2818(m_connection->cn_hostname()); - return rfc2818(preverified, verifyCtx); + boost::asio::ssl::rfc2818_verification rfc2818(host); + if (!rfc2818(preverified, verifyCtx)) + { + return false; + } + + auto info = std::make_shared(host, get_X509_cert_chain_encoded_data(verifyCtx)); + info->verified = true; + + if (!is_end_certificate_in_chain(verifyCtx)) + { + // Continue until we get the end certificate. + return true; + } + + return m_http_client->client_config().invoke_certificate_chain_callback(info); } void handle_write_headers(const boost::system::error_code& ec) @@ -1200,7 +1228,8 @@ class asio_context final : public request_context, public std::enable_shared_fro const size_t offset = http::details::chunked_encoding::add_chunked_delimiters( buf, chunkSize + http::details::chunked_encoding::additional_encoding_space, readSize); - this_request->m_body_buf.commit(readSize + http::details::chunked_encoding::additional_encoding_space); + this_request->m_body_buf.commit(readSize + + http::details::chunked_encoding::additional_encoding_space); this_request->m_body_buf.consume(offset); this_request->m_uploaded += static_cast(readSize); @@ -1213,9 +1242,10 @@ class asio_context final : public request_context, public std::enable_shared_fro } else { - this_request->m_connection->async_write( - this_request->m_body_buf, - boost::bind(&asio_context::handle_write_body, this_request, boost::asio::placeholders::error)); + this_request->m_connection->async_write(this_request->m_body_buf, + boost::bind(&asio_context::handle_write_body, + this_request, + boost::asio::placeholders::error)); } }); } @@ -1665,7 +1695,7 @@ class asio_context final : public request_context, public std::enable_shared_fro writeBuffer.putn_nocopy(shared_decompressed->data(), shared_decompressed->size()) .then([this_request, to_read, shared_decompressed AND_CAPTURE_MEMBER_FUNCTION_POINTERS]( - pplx::task op) { + pplx::task op) { try { op.get(); @@ -1688,23 +1718,26 @@ class asio_context final : public request_context, public std::enable_shared_fro else { writeBuffer.putn_nocopy(boost::asio::buffer_cast(m_body_buf.data()), to_read) - .then([this_request, to_read AND_CAPTURE_MEMBER_FUNCTION_POINTERS](pplx::task op) { - try + .then( + [this_request, to_read AND_CAPTURE_MEMBER_FUNCTION_POINTERS](pplx::task op) { - op.wait(); - } - catch (...) - { - this_request->report_exception(std::current_exception()); - return; - } - this_request->m_body_buf.consume(to_read + CRLF.size()); // consume crlf - this_request->m_connection->async_read_until(this_request->m_body_buf, - CRLF, - boost::bind(&asio_context::handle_chunk_header, - this_request, - boost::asio::placeholders::error)); - }); + try + { + op.wait(); + } + catch (...) + { + this_request->report_exception(std::current_exception()); + return; + } + this_request->m_body_buf.consume(to_read + CRLF.size()); // consume crlf + this_request->m_connection->async_read_until( + this_request->m_body_buf, + CRLF, + boost::bind(&asio_context::handle_chunk_header, + this_request, + boost::asio::placeholders::error)); + }); } } } @@ -1795,7 +1828,7 @@ class asio_context final : public request_context, public std::enable_shared_fro writeBuffer.putn_nocopy(shared_decompressed->data(), shared_decompressed->size()) .then([this_request, read_size, shared_decompressed AND_CAPTURE_MEMBER_FUNCTION_POINTERS]( - pplx::task op) { + pplx::task op) { size_t writtenSize = 0; (void)writtenSize; try @@ -1804,9 +1837,10 @@ class asio_context final : public request_context, public std::enable_shared_fro this_request->m_downloaded += static_cast(read_size); this_request->m_body_buf.consume(read_size); this_request->async_read_until_buffersize( - static_cast((std::min)( - static_cast(this_request->m_http_client->client_config().chunksize()), - this_request->m_content_length - this_request->m_downloaded)), + static_cast( + (std::min)(static_cast( + this_request->m_http_client->client_config().chunksize()), + this_request->m_content_length - this_request->m_downloaded)), boost::bind(&asio_context::handle_read_content, this_request, boost::asio::placeholders::error)); @@ -1830,9 +1864,10 @@ class asio_context final : public request_context, public std::enable_shared_fro this_request->m_downloaded += static_cast(writtenSize); this_request->m_body_buf.consume(writtenSize); this_request->async_read_until_buffersize( - static_cast((std::min)( - static_cast(this_request->m_http_client->client_config().chunksize()), - this_request->m_content_length - this_request->m_downloaded)), + static_cast( + (std::min)(static_cast( + this_request->m_http_client->client_config().chunksize()), + this_request->m_content_length - this_request->m_downloaded)), boost::bind(&asio_context::handle_read_content, this_request, boost::asio::placeholders::error)); @@ -1873,8 +1908,7 @@ class asio_context final : public request_context, public std::enable_shared_fro m_timer.expires_from_now(m_duration); auto ctx = m_ctx; m_timer.async_wait([ctx AND_CAPTURE_MEMBER_FUNCTION_POINTERS](const boost::system::error_code& ec) { - handle_timeout(ec, ctx); - }); + handle_timeout(ec, ctx); }); } void reset() @@ -1887,8 +1921,7 @@ class asio_context final : public request_context, public std::enable_shared_fro assert(m_state == started); auto ctx = m_ctx; m_timer.async_wait([ctx AND_CAPTURE_MEMBER_FUNCTION_POINTERS](const boost::system::error_code& ec) { - handle_timeout(ec, ctx); - }); + handle_timeout(ec, ctx); }); } } @@ -2028,12 +2061,12 @@ static bool is_recognized_redirection(status_code code) return is_retrieval_redirection(code) || is_unchanged_redirection(code); } -static bool is_retrieval_request(method method) -{ +static bool is_retrieval_request(method method) +{ return methods::GET == method || methods::HEAD == method; } -static const std::vector request_body_header_names = +static const std::vector request_body_header_names = { header_names::content_encoding, header_names::content_language, @@ -2081,28 +2114,28 @@ uri http_redirect_follower::url_to_follow(const http_response& response) const { // Return immediately if the response is not a supported redirection if (!is_recognized_redirection(response.status_code())) - return{}; + return {}; // Although not required by RFC 7231, config may limit the number of automatic redirects // (followed_urls includes the initial request URL, hence '<' here) if (config.max_redirects() < followed_urls.size()) - return{}; + return {}; // Can't very well automatically redirect if the server hasn't provided a Location const auto location = response.headers().find(header_names::location); if (response.headers().end() == location) - return{}; + return {}; uri to_follow(followed_urls.back().resolve_uri(location->second)); // Config may prohibit automatic redirects from HTTPS to HTTP if (!config.https_to_http_redirects() && followed_urls.back().scheme() == _XPLATSTR("https") && to_follow.scheme() != _XPLATSTR("https")) - return{}; + return {}; // "A client SHOULD detect and intervene in cyclical redirections." if (followed_urls.end() != std::find(followed_urls.begin(), followed_urls.end(), to_follow)) - return{}; + return {}; return to_follow; } @@ -2111,7 +2144,7 @@ pplx::task http_redirect_follower::operator()(http_response respo { // Return immediately if the response doesn't indicate a valid automatic redirect uri to_follow = url_to_follow(response); - if (to_follow.is_empty()) + if (to_follow.is_empty()) return pplx::task_from_result(response); // This implementation only supports retrieval redirects, as it cannot redirect e.g. a POST request @@ -2119,7 +2152,7 @@ pplx::task http_redirect_follower::operator()(http_response respo if (!is_retrieval_request(redirect.method()) && !is_retrieval_redirection(response.status_code())) return pplx::task_from_result(response); - if (!is_retrieval_request(redirect.method())) + if (!is_retrieval_request(redirect.method())) redirect.set_method(methods::GET); // If the reply to this request is also a redirect, we want visibility of that @@ -2152,9 +2185,9 @@ pplx::task asio_client::propagate(http_request request) // Asynchronously send the response with the HTTP client implementation. this->async_send_request(context); - return client_config().max_redirects() > 0 - ? result_task.then(http_redirect_follower(client_config(), request)) - : result_task; + return client_config().max_redirects() > 0 + ? result_task.then(http_redirect_follower(client_config(), request)) + : result_task; } } // namespace details } // namespace client diff --git a/Release/src/http/client/http_client_impl.h b/Release/src/http/client/http_client_impl.h index d9e7d4829e..074e1f301f 100644 --- a/Release/src/http/client/http_client_impl.h +++ b/Release/src/http/client/http_client_impl.h @@ -104,6 +104,7 @@ class request_context pplx::cancellation_token_registration m_cancellationRegistration; std::unique_ptr m_decompressor; + bool m_certificate_chain_verification_failed {false}; protected: request_context(const std::shared_ptr<_http_client_communicator>& client, const http_request& request); diff --git a/Release/src/http/client/http_client_winhttp.cpp b/Release/src/http/client/http_client_winhttp.cpp index d6cdb5384a..28e51ea3b2 100644 --- a/Release/src/http/client/http_client_winhttp.cpp +++ b/Release/src/http/client/http_client_winhttp.cpp @@ -25,6 +25,8 @@ #include "winhttppal.h" #endif #include +#include // for certificate pinning logic +#include // for certificate pinning logic #if _WIN32_WINNT && (_WIN32_WINNT >= _WIN32_WINNT_VISTA) #include @@ -1513,7 +1515,7 @@ class winhttp_client final : public _http_client_communicator { // An actual read always resets compression state for the next chunk _ASSERTE(p_request_context->m_compression_state.m_bytes_processed == - p_request_context->m_compression_state.m_bytes_read); + p_request_context->m_compression_state.m_bytes_read); _ASSERTE(!p_request_context->m_compression_state.m_needs_flush); p_request_context->m_compression_state.m_bytes_read = bytes_read; p_request_context->m_compression_state.m_bytes_processed = 0; @@ -1536,7 +1538,7 @@ class winhttp_client final : public _http_client_communicator hint = web::http::compression::operation_hint::is_last; } else if (p_request_context->m_compression_state.m_bytes_processed == - p_request_context->m_compression_state.m_bytes_read) + p_request_context->m_compression_state.m_bytes_read) { if (p_request_context->m_remaining_to_write && p_request_context->m_remaining_to_write != (std::numeric_limits::max)()) @@ -1554,63 +1556,65 @@ class winhttp_client final : public _http_client_communicator // else we're still compressing bytes from the previous read _ASSERTE(p_request_context->m_compression_state.m_bytes_processed <= - p_request_context->m_compression_state.m_bytes_read); + p_request_context->m_compression_state.m_bytes_read); uint8_t* in = buffer + p_request_context->m_compression_state.m_bytes_processed; size_t inbytes = p_request_context->m_compression_state.m_bytes_read - - p_request_context->m_compression_state.m_bytes_processed; + p_request_context->m_compression_state.m_bytes_processed; return compressor ->compress(in, - inbytes, - &p_request_context->m_body_data.get()[http::details::chunked_encoding::data_offset], - chunk_size, - hint) - .then([p_request_context, bytes_read, hint, chunk_size]( - pplx::task op) -> pplx::task { - http::compression::operation_result r; - - try - { - r = op.get(); - } - catch (...) + inbytes, + &p_request_context->m_body_data.get()[http::details::chunked_encoding::data_offset], + chunk_size, + hint) + .then( + [p_request_context, bytes_read, hint, chunk_size]( + pplx::task op) -> pplx::task { - return pplx::task_from_exception(std::current_exception()); - } + http::compression::operation_result r; - if (hint == web::http::compression::operation_hint::is_last) - { - // We're done reading all chunks, but the compressor may still have compressed bytes to - // drain from previous reads - _ASSERTE(r.done || r.output_bytes_produced == chunk_size); - p_request_context->m_compression_state.m_needs_flush = !r.done; - p_request_context->m_compression_state.m_done = r.done; - } + try + { + r = op.get(); + } + catch (...) + { + return pplx::task_from_exception(std::current_exception()); + } - // Update the number of bytes compressed in this read chunk; if it's been fully compressed, - // we'll reset m_bytes_processed and m_bytes_read after reading the next chunk - p_request_context->m_compression_state.m_bytes_processed += r.input_bytes_processed; - _ASSERTE(p_request_context->m_compression_state.m_bytes_processed <= - p_request_context->m_compression_state.m_bytes_read); - if (p_request_context->m_remaining_to_write != (std::numeric_limits::max)()) - { - _ASSERTE(p_request_context->m_remaining_to_write >= r.input_bytes_processed); - p_request_context->m_remaining_to_write -= r.input_bytes_processed; - } + if (hint == web::http::compression::operation_hint::is_last) + { + // We're done reading all chunks, but the compressor may still have compressed bytes to + // drain from previous reads + _ASSERTE(r.done || r.output_bytes_produced == chunk_size); + p_request_context->m_compression_state.m_needs_flush = !r.done; + p_request_context->m_compression_state.m_done = r.done; + } - if (p_request_context->m_compression_state.m_acquired != nullptr && - p_request_context->m_compression_state.m_bytes_processed == - p_request_context->m_compression_state.m_bytes_read) - { - // Release the acquired buffer back to the streambuf at the earliest possible point - p_request_context->_get_readbuffer().release( - p_request_context->m_compression_state.m_acquired, - p_request_context->m_compression_state.m_bytes_processed); - p_request_context->m_compression_state.m_acquired = nullptr; - } + // Update the number of bytes compressed in this read chunk; if it's been fully compressed, + // we'll reset m_bytes_processed and m_bytes_read after reading the next chunk + p_request_context->m_compression_state.m_bytes_processed += r.input_bytes_processed; + _ASSERTE(p_request_context->m_compression_state.m_bytes_processed <= + p_request_context->m_compression_state.m_bytes_read); + if (p_request_context->m_remaining_to_write != (std::numeric_limits::max)()) + { + _ASSERTE(p_request_context->m_remaining_to_write >= r.input_bytes_processed); + p_request_context->m_remaining_to_write -= r.input_bytes_processed; + } - return pplx::task_from_result(r.output_bytes_produced); - }); + if (p_request_context->m_compression_state.m_acquired != nullptr && + p_request_context->m_compression_state.m_bytes_processed == + p_request_context->m_compression_state.m_bytes_read) + { + // Release the acquired buffer back to the streambuf at the earliest possible point + p_request_context->_get_readbuffer().release( + p_request_context->m_compression_state.m_acquired, + p_request_context->m_compression_state.m_bytes_processed); + p_request_context->m_compression_state.m_acquired = nullptr; + } + + return pplx::task_from_result(r.output_bytes_produced); + }); }; if (p_request_context->m_compression_state.m_bytes_processed < @@ -1993,10 +1997,136 @@ class winhttp_client final : public _http_client_communicator return; } } - + // Check if connection rejected the certificate. + if (p_request_context->m_certificate_chain_verification_failed) + { + p_request_context->report_error( + ERROR_WINHTTP_SECURE_FAILURE, + build_error_msg(ERROR_WINHTTP_SECURE_FAILURE, "WinHttpVerificationFailed")); + break; + } p_request_context->report_error(errorCode, build_error_msg(error_result)); return; } + case WINHTTP_CALLBACK_STATUS_SENDING_REQUEST: + { + // Todo Check where to put this. atbagga + p_request_context->on_send_request_validate_cn(); + + // get actual URL which might be different from the original one due to redirection etc. + DWORD urlSize {0}; + WinHttpQueryOption(hRequestHandle, WINHTTP_OPTION_URL, NULL, &urlSize); + auto urlwchar = new WCHAR[urlSize / sizeof(WCHAR)]; + + WinHttpQueryOption(hRequestHandle, WINHTTP_OPTION_URL, (void*)urlwchar, &urlSize); + utility::string_t url(urlwchar); + + delete[] urlwchar; + + // obtain leaf cert based on which we will be able to build the certificate chain + PCCERT_CONTEXT pCert {nullptr}; + DWORD dwSize = sizeof(pCert); + + WinHttpQueryOption(hRequestHandle, WINHTTP_OPTION_SERVER_CERT_CONTEXT, &pCert, &dwSize); + + std::vector> cert_chain; + DWORD dwErrorStatus = 0; + + if (pCert) + { + CERT_ENHKEY_USAGE keyUsage = {}; + keyUsage.cUsageIdentifier = 0; + keyUsage.rgpszUsageIdentifier = NULL; + + CERT_USAGE_MATCH certUsage = {}; + certUsage.dwType = USAGE_MATCH_TYPE_AND; + certUsage.Usage = keyUsage; + + CERT_CHAIN_PARA chainPara = {}; + chainPara.cbSize = sizeof(CERT_CHAIN_PARA); + chainPara.RequestedUsage = certUsage; + + PCCERT_CHAIN_CONTEXT pChainContext = {}; + + // build the certificate chain relying on the actual intermediate certs returned as part of the TLS + // session disable any network operations to fetch certificates + auto validChain = CertGetCertificateChain( + NULL, + pCert, + NULL, + pCert->hCertStore, + &chainPara, + CERT_CHAIN_CACHE_ONLY_URL_RETRIEVAL | CERT_CHAIN_REVOCATION_CHECK_CACHE_ONLY | + CERT_CHAIN_REVOCATION_CHECK_CHAIN | CERT_CHAIN_CACHE_END_CERT | + CERT_CHAIN_DISABLE_AUTH_ROOT_AUTO_UPDATE, + NULL, + &pChainContext); + + if (validChain && pChainContext) + { + dwErrorStatus = pChainContext->TrustStatus.dwErrorStatus; + cert_chain.reserve((int)pChainContext->cChain); + for (size_t i = 0; i < pChainContext->cChain; ++i) + { + auto chain = pChainContext->rgpChain[i]; + for (size_t j = 0; j < chain->cElement; ++j) + { + auto chainElement = chain->rgpElement[j]; + auto cert = chainElement->pCertContext; + if (cert) + { + cert_chain.emplace_back(std::vector( + cert->pbCertEncoded, cert->pbCertEncoded + (int)cert->cbCertEncoded)); + } + } + } + CertFreeCertificateChain(pChainContext); + } + CertFreeCertificateContext(pCert); + } + + utility::string_t host; + + try + { + host = web::uri(url).host(); + } + catch (std::exception e) + { + host = url; + } + + if (host.empty()) + { + host = url; + } + + auto info = std::make_shared( + utility::conversions::to_utf8string(host), cert_chain, dwErrorStatus); + + if (dwErrorStatus == CERT_TRUST_NO_ERROR || dwErrorStatus == CERT_TRUST_REVOCATION_STATUS_UNKNOWN || + dwErrorStatus == (CERT_TRUST_IS_OFFLINE_REVOCATION | CERT_TRUST_REVOCATION_STATUS_UNKNOWN)) + { + info->verified = true; + } + + if (p_request_context->m_http_client->client_config().invoke_certificate_chain_callback(info)) + { + if (!info->verified && p_request_context->m_http_client->client_config().validate_certificates()) + { + p_request_context->m_certificate_chain_verification_failed = true; + } + else + { + p_request_context->m_certificate_chain_verification_failed = false; + } + } + else + { + p_request_context->m_certificate_chain_verification_failed = true; + } + break; + } case WINHTTP_CALLBACK_STATUS_SENDREQUEST_COMPLETE: { if (!p_request_context->m_request.body()) @@ -2036,11 +2166,6 @@ class winhttp_client final : public _http_client_communicator } return; } - case WINHTTP_CALLBACK_STATUS_SENDING_REQUEST: - { - p_request_context->on_send_request_validate_cn(); - return; - } case WINHTTP_CALLBACK_STATUS_SECURE_FAILURE: { p_request_context->report_exception(web::http::http_exception( @@ -2172,6 +2297,17 @@ class winhttp_client final : public _http_client_communicator return; } } + else + { + // The connection is allowed, but did the client allow the connection. + if (p_request_context->m_certificate_chain_verification_failed) + { + p_request_context->report_error( + ERROR_WINHTTP_SECURE_FAILURE, + build_error_msg(ERROR_WINHTTP_SECURE_FAILURE, "WinHttpVerificationFailed")); + break; + } + } // Check whether the request is compressed, and if so, whether we're handling it. if (!p_request_context->handle_compression()) @@ -2472,9 +2608,11 @@ class winhttp_client final : public _http_client_communicator ->decompress( in, inbytes, buffer, chunk_size, web::http::compression::operation_hint::has_more) .then([p_request_context, buffer, chunk_size, process_buffer]( - pplx::task op) { + pplx::task op) { auto r = op.get(); - auto keep_going = [&r, process_buffer](winhttp_request_context* c) -> pplx::task { + auto keep_going = + [&r, process_buffer](winhttp_request_context* c) -> pplx::task + { _ASSERTE(r.input_bytes_processed <= c->m_compression_state.m_chunk_bytes); c->m_compression_state.m_chunk_bytes -= r.input_bytes_processed; c->m_compression_state.m_bytes_processed += r.input_bytes_processed; @@ -2482,8 +2620,8 @@ class winhttp_client final : public _http_client_communicator try { - // See if we still have more work to do for this section and/or for the response - // in general + // See if we still have more work to do for this section and/or for the + // response in general return pplx::task_from_result( process_buffer(c, r.output_bytes_produced, false)); } @@ -2494,8 +2632,8 @@ class winhttp_client final : public _http_client_communicator }; _ASSERTE(p_request_context->m_compression_state.m_bytes_processed + - r.input_bytes_processed <= - p_request_context->m_compression_state.m_bytes_read); + r.input_bytes_processed <= + p_request_context->m_compression_state.m_bytes_read); if (p_request_context->m_compression_state.m_acquired != nullptr) { @@ -2508,15 +2646,17 @@ class winhttp_client final : public _http_client_communicator // We decompressed into our own buffer; let the stream copy the data return p_request_context->_get_writebuffer() .putn_nocopy(buffer, r.output_bytes_produced) - .then([p_request_context, r, keep_going](pplx::task op) { - if (op.get() != r.output_bytes_produced) + .then( + [p_request_context, r, keep_going](pplx::task op) { - return pplx::task_from_exception( - std::runtime_error("Response stream unexpectedly failed to write the " - "requested number of bytes")); - } - return keep_going(p_request_context.get()); - }); + if (op.get() != r.output_bytes_produced) + { + return pplx::task_from_exception(std::runtime_error( + "Response stream unexpectedly failed to write the " + "requested number of bytes")); + } + return keep_going(p_request_context.get()); + }); }); }).then([p_request_context](pplx::task op) { try @@ -2525,8 +2665,8 @@ class winhttp_client final : public _http_client_communicator } catch (...) { - // We're only here to pick up any exception that may have been thrown, and to clean up - // if needed + // We're only here to pick up any exception that may have been thrown, and to clean + // up if needed if (p_request_context->m_compression_state.m_acquired) { p_request_context->_get_writebuffer().commit(0); diff --git a/Release/src/http/client/x509_cert_utilities.cpp b/Release/src/http/client/x509_cert_utilities.cpp index 67fc5ac47b..d1af9284bd 100644 --- a/Release/src/http/client/x509_cert_utilities.cpp +++ b/Release/src/http/client/x509_cert_utilities.cpp @@ -41,17 +41,39 @@ namespace client { namespace details { -static bool verify_X509_cert_chain(const std::vector& certChain, const std::string& hostName); +static bool verify_X509_cert_chain( + const std::vector& certChain, + const std::string& hostName, + const CertificateChainFunction& certInfoFunc); -bool verify_cert_chain_platform_specific(boost::asio::ssl::verify_context& verifyCtx, const std::string& hostName) +#if defined(_WIN32) +#include +#include +#endif + +#include + +bool is_end_certificate_in_chain(boost::asio::ssl::verify_context& verifyCtx) { X509_STORE_CTX* storeContext = verifyCtx.native_handle(); int currentDepth = X509_STORE_CTX_get_error_depth(storeContext); if (currentDepth != 0) + { + return false; + } + return true; +} + +bool verify_cert_chain_platform_specific(boost::asio::ssl::verify_context& verifyCtx, + const std::string& hostName, + const CertificateChainFunction& func) +{ + if (!is_end_certificate_in_chain(verifyCtx)) { return true; } + X509_STORE_CTX* storeContext = verifyCtx.native_handle(); #if (OPENSSL_VERSION_NUMBER < 0x10100000L) STACK_OF(X509)* certStack = X509_STORE_CTX_get_chain(storeContext); #else @@ -89,7 +111,7 @@ bool verify_cert_chain_platform_specific(boost::asio::ssl::verify_context& verif certChain.push_back(std::move(certData)); } - auto verify_result = verify_X509_cert_chain(certChain, hostName); + auto verify_result = verify_X509_cert_chain(certChain, hostName, func); // The Windows Crypto APIs don't do host name checks, use Boost's implementation. #if defined(_WIN32) @@ -102,6 +124,8 @@ bool verify_cert_chain_platform_specific(boost::asio::ssl::verify_context& verif return verify_result; } +#endif + #if defined(ANDROID) || defined(__ANDROID__) using namespace crossplat; @@ -316,7 +340,42 @@ class cf_ref }; } // namespace -bool verify_X509_cert_chain(const std::vector& certChain, const std::string& hostName) +static std::shared_ptr build_certificate_info_ptr(cf_ref& trust, + const std::string& hostName, + long trustResult, + bool isVerified) +{ + auto info = std::make_shared(hostName); + info->certificate_error = trustResult; + info->verified = isVerified; + + CFIndex cnt = SecTrustGetCertificateCount(trust.get()); + if (cnt > 0) + { + info->certificate_chain.reserve(cnt); + for (int i = 0; i < cnt; i++) + { + SecCertificateRef cert = SecTrustGetCertificateAtIndex(trust.get(), i); + if (!cert) + { + break; + } + + cf_ref cdata = SecCertificateCopyData(cert); + if (cdata.get()) + { + const unsigned char* buffer = CFDataGetBytePtr(cdata.get()); + info->certificate_chain.emplace_back( + std::vector(buffer, buffer + CFDataGetLength(cdata.get()))); + } + } + } + return info; +} + +bool verify_X509_cert_chain(const std::vector& certChain, + const std::string& hostName, + const CertificateChainFunction& certInfoFunc /* = nullptr */) { // Build up CFArrayRef with all the certificates. // All this code is basically just to get into the correct structures for the Apple APIs. @@ -360,6 +419,9 @@ bool verify_X509_cert_chain(const std::vector& certChain, const std cf_ref policy = SecPolicyCreateSSL(true /* client side */, cfHostName.get()); cf_ref trust; OSStatus status = SecTrustCreateWithCertificates(certsArray.get(), policy.get(), &trust.get()); + + bool isVerified = false; + if (status == noErr) { // Perform actual certificate verification. @@ -367,27 +429,118 @@ bool verify_X509_cert_chain(const std::vector& certChain, const std status = SecTrustEvaluate(trust.get(), &trustResult); if (status == noErr && (trustResult == kSecTrustResultUnspecified || trustResult == kSecTrustResultProceed)) { - return true; + isVerified = true; } - } - return false; + if (certInfoFunc) + { + auto info = build_certificate_info_ptr(trust, hostName, (long)trustResult, isVerified); + + if (!certInfoFunc(info)) + { + isVerified = false; + } + } + } + return isVerified; } + #endif +std::vector> get_X509_cert_chain_encoded_data(boost::asio::ssl::verify_context& verifyCtx) +{ + std::vector> cert_chain; + + X509_STORE_CTX* storeContext = verifyCtx.native_handle(); + + STACK_OF(X509)* certStack = X509_STORE_CTX_get_chain(storeContext); + + const int numCerts = sk_X509_num(certStack); + if (numCerts < 0) + { + return cert_chain; + } + + cert_chain.reserve(numCerts); + for (int index = 0; index < numCerts; ++index) + { + X509* cert = sk_X509_value(certStack, index); + if (cert) + { + unsigned char* certKeyOut = nullptr; + int resCertificateLength = i2d_X509(cert, &certKeyOut); + if (resCertificateLength < 0 || !certKeyOut) + { + continue; + } + + std::vector certOut(certKeyOut, certKeyOut + resCertificateLength); + cert_chain.emplace_back(certOut); + } + } + return cert_chain; +} + #if defined(_WIN32) -bool verify_X509_cert_chain(const std::vector& certChain, const std::string& hostname) +static std::shared_ptr build_certificate_info_ptr(const chain_context& chain, + const std::string& hostName, + bool isVerified) +{ + auto info = std::make_shared(hostName); + + info->verified = isVerified; + info->certificate_error = chain->TrustStatus.dwErrorStatus; + info->certificate_chain.reserve((int)chain->cChain); + + for (size_t i = 0; i < chain->cChain; ++i) + { + auto pChain = chain->rgpChain[i]; + for (size_t j = 0; j < pChain->cElement; ++j) + { + auto chainElement = pChain->rgpElement[j]; + auto cert = chainElement->pCertContext; + if (cert) + { + info->certificate_chain.emplace_back( + std::vector(cert->pbCertEncoded, cert->pbCertEncoded + (int)cert->cbCertEncoded)); + } + } + } + + return info; +} + +bool verify_X509_cert_chain(const std::vector& certChain, + const std::string& hostName, + const CertificateChainFunction& certInfoFunc /* = nullptr */) { // Create certificate context from server certificate. - winhttp_cert_context cert; - cert.raw = CertCreateCertificateContext(X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, - reinterpret_cast(certChain[0].c_str()), - static_cast(certChain[0].size())); - if (cert.raw == nullptr) + cert_context pCert(CertCreateCertificateContext(X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, + reinterpret_cast(certChain[0].c_str()), + static_cast(certChain[0].size()))); + if (pCert == nullptr) { return false; } + // Add all SSL intermediate certs into a store to be used by the OS building the full certificate chain. + HCERTSTORE caMemStore = NULL; + caMemStore = CertOpenStore(CERT_STORE_PROV_MEMORY, (PKCS_7_ASN_ENCODING | X509_ASN_ENCODING), NULL, 0, NULL); + if (caMemStore) + { + for (const auto& certData : certChain) + { + cert_context certContext( + CertCreateCertificateContext(X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, + reinterpret_cast(certData.c_str()), + static_cast(certData.size()))); + if (certContext) + { + CertAddCertificateContextToStore(caMemStore, certContext.get(), CERT_STORE_ADD_ALWAYS, NULL); + } + } + } + // Let the OS build a certificate chain from the server certificate. char oidPkixKpServerAuth[] = szOID_PKIX_KP_SERVER_AUTH; char oidServerGatedCrypto[] = szOID_SERVER_GATED_CRYPTO; @@ -428,20 +581,46 @@ bool verify_X509_cert_chain(const std::vector& certChain, const std 0, &u16HostName[0], }; - CERT_CHAIN_POLICY_PARA policyPara = {sizeof(policyPara)}; - policyPara.pvExtraPolicyPara = &policyData; - CERT_CHAIN_POLICY_STATUS policyStatus = {sizeof(policyStatus)}; - if (!CertVerifyCertificateChainPolicy(CERT_CHAIN_POLICY_SSL, chainContext.raw, &policyPara, &policyStatus)) + params.RequestedUsage.Usage.cUsageIdentifier = std::extent::value; + params.RequestedUsage.Usage.rgpszUsageIdentifier = usages; + + PCCERT_CHAIN_CONTEXT pChainContext = {}; + chain_context chain; + + bool isVerified = false; + + auto cSuccess = CertGetCertificateChain( + nullptr, pCert.get(), nullptr, caMemStore, ¶ms, CERT_CHAIN_REVOCATION_CHECK_CHAIN, nullptr, &pChainContext); + + chain.reset(pChainContext); + + if (caMemStore) { - return false; + CertCloseStore(caMemStore, 0); } - if (policyStatus.dwError) + if (cSuccess && chain) { - return false; - } + // Only do revocation checking if it's known. + if (chain->TrustStatus.dwErrorStatus == CERT_TRUST_NO_ERROR || + chain->TrustStatus.dwErrorStatus == CERT_TRUST_REVOCATION_STATUS_UNKNOWN || + chain->TrustStatus.dwErrorStatus == + (CERT_TRUST_IS_OFFLINE_REVOCATION | CERT_TRUST_REVOCATION_STATUS_UNKNOWN)) + { + isVerified = true; + } - return true; + if (certInfoFunc) + { + auto info = build_certificate_info_ptr(chain, hostName, isVerified); + + if (!certInfoFunc(info)) + { + isVerified = false; + } + } + } + return isVerified; } #endif } // namespace details diff --git a/Release/src/http/common/x509_cert_utilities.h b/Release/src/http/common/x509_cert_utilities.h index 854e30534d..a179037a37 100644 --- a/Release/src/http/common/x509_cert_utilities.h +++ b/Release/src/http/common/x509_cert_utilities.h @@ -13,6 +13,9 @@ #pragma once +#include "cpprest/certificate_info.h" +#include + #if defined(_WIN32) #include @@ -94,6 +97,8 @@ namespace client { namespace details { +bool is_end_certificate_in_chain(boost::asio::ssl::verify_context& verifyCtx); + /// /// Using platform specific APIs verifies server certificate. /// Currently implemented to work on Windows, iOS, Android, and OS X. @@ -101,7 +106,15 @@ namespace details /// Boost.ASIO context to get certificate chain from. /// Host name from the URI. /// True if verification passed and server can be trusted, false otherwise. -bool verify_cert_chain_platform_specific(boost::asio::ssl::verify_context& verifyCtx, const std::string& hostName); +bool verify_cert_chain_platform_specific(boost::asio::ssl::verify_context& verifyCtx, + const std::string& hostName, + const CertificateChainFunction& func = nullptr); + +bool verify_X509_cert_chain(const std::vector& certChain, + const std::string& hostName, + const CertificateChainFunction& func = nullptr); + +std::vector> get_X509_cert_chain_encoded_data(boost::asio::ssl::verify_context& verifyCtx); } // namespace details } // namespace client } // namespace http diff --git a/Release/src/http/oauth/oauth2.cpp b/Release/src/http/oauth/oauth2.cpp index 3e54a6e07c..07bffa3242 100644 --- a/Release/src/http/oauth/oauth2.cpp +++ b/Release/src/http/oauth/oauth2.cpp @@ -137,6 +137,7 @@ pplx::task oauth2_config::_request_token(uri_builder& request_body_ub) // configure proxy http_client_config config; config.set_proxy(m_proxy); + config.set_user_certificate_chain_callback(m_certificate_chain_callback); http_client token_client(token_endpoint(), config); diff --git a/Release/src/websockets/client/ws_client_wspp.cpp b/Release/src/websockets/client/ws_client_wspp.cpp index d7c31c4095..e7e747adb9 100644 --- a/Release/src/websockets/client/ws_client_wspp.cpp +++ b/Release/src/websockets/client/ws_client_wspp.cpp @@ -209,25 +209,49 @@ class wspp_callback_client : public websocket_client_callback_impl, #ifdef CPPREST_PLATFORM_ASIO_CERT_VERIFICATION_AVAILABLE m_openssl_failed = false; #endif - sslContext->set_verify_callback([this](bool preverified, boost::asio::ssl::verify_context& verifyCtx) { -#ifdef CPPREST_PLATFORM_ASIO_CERT_VERIFICATION_AVAILABLE - // Attempt to use platform certificate validation when it is available: - // If OpenSSL fails we will doing verification at the end using the whole certificate chain, - // so wait until the 'leaf' cert. For now return true so OpenSSL continues down the certificate - // chain. - if (!preverified) - { - m_openssl_failed = true; - } - if (m_openssl_failed) + sslContext->set_verify_callback( + [this](bool preverified, boost::asio::ssl::verify_context& verifyCtx) { - return http::client::details::verify_cert_chain_platform_specific( - verifyCtx, utility::conversions::to_utf8string(m_uri.host())); - } +#ifdef CPPREST_PLATFORM_ASIO_CERT_VERIFICATION_AVAILABLE + // Attempt to use platform certificate validation when it is available: + // If OpenSSL fails we will doing verification at the end using the whole certificate chain, + // so wait until the 'leaf' cert. For now return true so OpenSSL continues down the + // certificate chain. + using namespace web::http::client::details; + if (!preverified) + { + m_openssl_failed = true; + } + if (m_openssl_failed) + { + if (!http::client::details::is_end_certificate_in_chain(verifyCtx)) + { + // Continue until we get the end certificate. + return true; + } + + auto chainFunc = + [this](const std::shared_ptr& cert_info) + { return m_config.invoke_certificate_chain_callback(cert_info); }; + + return http::client::details::verify_cert_chain_platform_specific( + verifyCtx, utility::conversions::to_utf8string(m_uri.host()), chainFunc); + } #endif - boost::asio::ssl::rfc2818_verification rfc2818(utility::conversions::to_utf8string(m_uri.host())); - return rfc2818(preverified, verifyCtx); - }); + boost::asio::ssl::rfc2818_verification rfc2818( + utility::conversions::to_utf8string(m_uri.host())); + if (!rfc2818(preverified, verifyCtx)) + { + return false; + } + + auto info = std::make_shared( + utility::conversions::to_utf8string(m_uri.host()), + get_X509_cert_chain_encoded_data(verifyCtx)); + info->verified = true; + + return m_config.invoke_certificate_chain_callback(info); + }); #if OPENSSL_VERSION_NUMBER < 0x10100000L || defined(LIBRESSL_VERSION_NUMBER) // OpenSSL stores some per thread state that never will be cleaned up until @@ -243,8 +267,7 @@ class wspp_callback_client : public websocket_client_callback_impl, }); // Options specific to underlying socket. - client.set_socket_init_handler([this](websocketpp::connection_hdl, - boost::asio::ssl::stream& ssl_stream) { + client.set_socket_init_handler([this](websocketpp::connection_hdl, boost::asio::ssl::stream& ssl_stream) { // Support for SNI. if (m_config.is_sni_enabled()) { @@ -254,12 +277,13 @@ class wspp_callback_client : public websocket_client_callback_impl, // OpenSSL runs the string parameter through a macro casting away const with a C style cast. // Do a C++ cast ourselves to avoid warnings. SSL_set_tlsext_host_name(ssl_stream.native_handle(), - const_cast(m_config.server_name().c_str())); + const_cast(m_config.server_name().c_str())); } else { const auto& server_name = utility::conversions::to_utf8string(m_uri.host()); - SSL_set_tlsext_host_name(ssl_stream.native_handle(), const_cast(server_name.c_str())); + SSL_set_tlsext_host_name(ssl_stream.native_handle(), + const_cast(server_name.c_str())); } } }); @@ -295,35 +319,34 @@ class wspp_callback_client : public websocket_client_callback_impl, this->shutdown_wspp_impl(con_hdl, true); }); - client.set_message_handler( - [this](websocketpp::connection_hdl, const websocketpp::config::asio_client::message_type::ptr& msg) { - if (m_external_message_handler) - { - _ASSERTE(m_state >= CONNECTED && m_state < CLOSED); - websocket_incoming_message incoming_msg; + client.set_message_handler([this](websocketpp::connection_hdl, const websocketpp::config::asio_client::message_type::ptr& msg) { + if (m_external_message_handler) + { + _ASSERTE(m_state >= CONNECTED && m_state < CLOSED); + websocket_incoming_message incoming_msg; - switch (msg->get_opcode()) - { - case websocketpp::frame::opcode::binary: - incoming_msg.m_msg_type = websocket_message_type::binary_message; - break; - case websocketpp::frame::opcode::text: - incoming_msg.m_msg_type = websocket_message_type::text_message; - break; - default: - // Unknown message type. Since both websocketpp and our code use the RFC codes, we'll just - // pass it on to the user. - incoming_msg.m_msg_type = static_cast(msg->get_opcode()); - break; - } + switch (msg->get_opcode()) + { + case websocketpp::frame::opcode::binary: + incoming_msg.m_msg_type = websocket_message_type::binary_message; + break; + case websocketpp::frame::opcode::text: + incoming_msg.m_msg_type = websocket_message_type::text_message; + break; + default: + // Unknown message type. Since both websocketpp and our code use the RFC codes, we'll just + // pass it on to the user. + incoming_msg.m_msg_type = static_cast(msg->get_opcode()); + break; + } - // 'move' the payload into a container buffer to avoid any copies. - auto& payload = msg->get_raw_payload(); - incoming_msg.m_body = concurrency::streams::container_buffer(std::move(payload)); + // 'move' the payload into a container buffer to avoid any copies. + auto& payload = msg->get_raw_payload(); + incoming_msg.m_body = concurrency::streams::container_buffer(std::move(payload)); - m_external_message_handler(incoming_msg); - } - }); + m_external_message_handler(incoming_msg); + } + }); client.set_ping_handler([this](websocketpp::connection_hdl, const std::string& msg) { if (m_external_message_handler) @@ -593,8 +616,7 @@ class wspp_callback_client : public websocket_client_callback_impl, } return ec; }) - .then([this_client, msg, is_buf, acquired, sp_allocated, length]( - pplx::task previousTask) mutable { + .then([this_client, msg, is_buf, acquired, sp_allocated, length](pplx::task previousTask) mutable { std::exception_ptr eptr; try { @@ -602,7 +624,8 @@ class wspp_callback_client : public websocket_client_callback_impl, const auto& ec = previousTask.get(); if (ec.value() != 0) { - eptr = std::make_exception_ptr(websocket_exception(ec, build_error_msg(ec, "sending message"))); + eptr = std::make_exception_ptr( + websocket_exception(ec, build_error_msg(ec, "sending message"))); } } catch (...) diff --git a/Release/tests/functional/http/client/connections_and_errors.cpp b/Release/tests/functional/http/client/connections_and_errors.cpp index 847755d80a..4e860b9e70 100644 --- a/Release/tests/functional/http/client/connections_and_errors.cpp +++ b/Release/tests/functional/http/client/connections_and_errors.cpp @@ -60,10 +60,12 @@ static void pending_requests_after_client_impl(const uri& address) // send responses. for (size_t i = 0; i < num_requests; ++i) { - completed_requests.push_back(requests[i].then([&](test_request* request) { - http_asserts::assert_test_request_equals(request, mtd, U("/")); - VERIFY_ARE_EQUAL(0u, request->reply(status_codes::OK)); - })); + completed_requests.push_back(requests[i].then( + [&](test_request* request) + { + http_asserts::assert_test_request_equals(request, mtd, U("/")); + VERIFY_ARE_EQUAL(0u, request->reply(status_codes::OK)); + })); } // verify responses. @@ -106,6 +108,69 @@ SUITE(connections_and_errors) VERIFY_THROWS(t.wait(), web::http::http_exception); } + TEST_FIXTURE(uri_address, cert_pinning_succeed) + { + test_http_server::scoped_server scoped(m_uri); + + http_client_config client_config; + web::credentials cred(U("some_user"), U("some_password")); + client_config.set_credentials(cred); + pplx::cancellation_token_source source; + + client_config.set_user_certificate_chain_callback( + [](const std::shared_ptr&) -> bool + { + // accept any certificate. + return true; + }); + + http_client client(m_uri, client_config); + + scoped.server()->next_request().then( + [&](test_request* p_request) + { + http_asserts::assert_test_request_equals(p_request, methods::GET, U("/")); + p_request->reply(200); + }); + + auto response = client.request(methods::GET, source.get_token()).get(); + + VERIFY_ARE_EQUAL(status_codes::OK, response.status_code()); + } + +#ifdef _WIN32 + + TEST_FIXTURE(uri_address, cert_pinning_failed) + { + test_http_server::scoped_server scoped(m_uri); + + http_client_config client_config; + web::credentials cred(U("some_user"), U("some_password")); + client_config.set_credentials(cred); + pplx::cancellation_token_source source; + + client_config.set_user_certificate_chain_callback( + [](const std::shared_ptr&) -> bool + { + // don't accept any certificate. + return false; + }); + + http_client client(m_uri, client_config); + + scoped.server()->next_request().then( + [&](test_request* p_request) + { + http_asserts::assert_test_request_equals(p_request, methods::GET, U("/")); + p_request->reply(200); + }); + + auto request = client.request(methods::GET, source.get_token()); + + VERIFY_THROWS_HTTP_ERROR_CODE(request.wait(), ERROR_WINHTTP_SECURE_FAILURE); + } +#endif + TEST_FIXTURE(uri_address, server_close_without_responding) { http_client_config config;