diff --git a/Cargo.toml b/Cargo.toml index e4c977113b..b697d170ff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ want = "0.3" # Optional libc = { version = "0.2", optional = true } +smallvec = { version = "1.12", features = ["const_generics", "const_new"], optional = true } socket2 = { version = ">=0.4.7, <0.6.0", optional = true, features = ["all"] } [dev-dependencies] @@ -87,8 +88,8 @@ http1 = [] http2 = ["h2"] # Client/Server -client = [] -server = [] +client = ["dep:smallvec"] +server = ["dep:smallvec"] # `impl Stream` for things stream = [] diff --git a/src/client/conn/http1.rs b/src/client/conn/http1.rs index 37eda04067..ff057e7b0e 100644 --- a/src/client/conn/http1.rs +++ b/src/client/conn/http1.rs @@ -115,6 +115,7 @@ pub struct Builder { h1_writev: Option, h1_title_case_headers: bool, h1_preserve_header_case: bool, + h1_max_headers: Option, #[cfg(feature = "ffi")] h1_preserve_header_order: bool, h1_read_buf_exact_size: Option, @@ -302,6 +303,7 @@ impl Builder { h1_parser_config: Default::default(), h1_title_case_headers: false, h1_preserve_header_case: false, + h1_max_headers: None, #[cfg(feature = "ffi")] h1_preserve_header_order: false, h1_max_buf_size: None, @@ -434,6 +436,24 @@ impl Builder { self } + /// Set the maximum number of headers. + /// + /// When a response is received, the parser will reserve a buffer to store headers for optimal + /// performance. + /// + /// If client receives more headers than the buffer size, the error "message header too large" + /// is returned. + /// + /// Note that headers is allocated on the stack by default, which has higher performance. After + /// setting this value, headers will be allocated in heap memory, that is, heap memory + /// allocation will occur for each response, and there will be a performance drop of about 5%. + /// + /// Default is 100. + pub fn max_headers(&mut self, val: usize) -> &mut Self { + self.h1_max_headers = Some(val); + self + } + /// Set whether to support preserving original header order. /// /// Currently, this will record the order in which headers are received, and store this @@ -514,6 +534,9 @@ impl Builder { if opts.h1_preserve_header_case { conn.set_preserve_header_case(); } + if let Some(max_headers) = opts.h1_max_headers { + conn.set_http1_max_headers(max_headers); + } #[cfg(feature = "ffi")] if opts.h1_preserve_header_order { conn.set_preserve_header_order(); diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index 5ab72f264e..be92a94886 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -53,6 +53,7 @@ where keep_alive: KA::Busy, method: None, h1_parser_config: ParserConfig::default(), + h1_max_headers: None, #[cfg(all(feature = "server", feature = "runtime"))] h1_header_read_timeout: None, #[cfg(all(feature = "server", feature = "runtime"))] @@ -125,6 +126,10 @@ where self.state.h09_responses = true; } + pub(crate) fn set_http1_max_headers(&mut self, val: usize) { + self.state.h1_max_headers = Some(val); + } + #[cfg(all(feature = "server", feature = "runtime"))] pub(crate) fn set_http1_header_read_timeout(&mut self, val: Duration) { self.state.h1_header_read_timeout = Some(val); @@ -198,6 +203,7 @@ where cached_headers: &mut self.state.cached_headers, req_method: &mut self.state.method, h1_parser_config: self.state.h1_parser_config.clone(), + h1_max_headers: self.state.h1_max_headers, #[cfg(all(feature = "server", feature = "runtime"))] h1_header_read_timeout: self.state.h1_header_read_timeout, #[cfg(all(feature = "server", feature = "runtime"))] @@ -822,6 +828,7 @@ struct State { /// a body or not. method: Option, h1_parser_config: ParserConfig, + h1_max_headers: Option, #[cfg(all(feature = "server", feature = "runtime"))] h1_header_read_timeout: Option, #[cfg(all(feature = "server", feature = "runtime"))] diff --git a/src/proto/h1/io.rs b/src/proto/h1/io.rs index ac494b9387..30939ea63f 100644 --- a/src/proto/h1/io.rs +++ b/src/proto/h1/io.rs @@ -191,6 +191,7 @@ where cached_headers: parse_ctx.cached_headers, req_method: parse_ctx.req_method, h1_parser_config: parse_ctx.h1_parser_config.clone(), + h1_max_headers: parse_ctx.h1_max_headers, #[cfg(all(feature = "server", feature = "runtime"))] h1_header_read_timeout: parse_ctx.h1_header_read_timeout, #[cfg(all(feature = "server", feature = "runtime"))] @@ -741,6 +742,7 @@ mod tests { cached_headers: &mut None, req_method: &mut None, h1_parser_config: Default::default(), + h1_max_headers: None, #[cfg(feature = "runtime")] h1_header_read_timeout: None, #[cfg(feature = "runtime")] diff --git a/src/proto/h1/mod.rs b/src/proto/h1/mod.rs index 5a2587a843..37126baf68 100644 --- a/src/proto/h1/mod.rs +++ b/src/proto/h1/mod.rs @@ -76,6 +76,7 @@ pub(crate) struct ParseContext<'a> { cached_headers: &'a mut Option, req_method: &'a mut Option, h1_parser_config: ParserConfig, + h1_max_headers: Option, #[cfg(all(feature = "server", feature = "runtime"))] h1_header_read_timeout: Option, #[cfg(all(feature = "server", feature = "runtime"))] diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index 1c00d7445d..1a87d34219 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -7,6 +7,7 @@ use bytes::BytesMut; use http::header::ValueIter; use http::header::{self, Entry, HeaderName, HeaderValue}; use http::{HeaderMap, Method, StatusCode, Version}; +use smallvec::{smallvec, smallvec_inline, SmallVec}; #[cfg(all(feature = "server", feature = "runtime"))] use tokio::time::Instant; use tracing::{debug, error, trace, trace_span, warn}; @@ -24,7 +25,7 @@ use crate::proto::h1::{ }; use crate::proto::{BodyLength, MessageHead, RequestHead, RequestLine}; -const MAX_HEADERS: usize = 100; +const DEFAULT_MAX_HEADERS: usize = 100; const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific #[cfg(feature = "server")] const MAX_URI_LEN: usize = (u16::MAX - 1) as usize; @@ -169,14 +170,17 @@ impl Http1Transaction for Server { // but we *never* read any of it until after httparse has assigned // values into it. By not zeroing out the stack memory, this saves // a good ~5% on pipeline benchmarks. - let mut headers_indices: [MaybeUninit; MAX_HEADERS] = unsafe { - // SAFETY: We can go safely from MaybeUninit array to array of MaybeUninit - MaybeUninit::uninit().assume_init() - }; + let mut headers_indices: SmallVec<[MaybeUninit; DEFAULT_MAX_HEADERS]> = + match ctx.h1_max_headers { + Some(cap) => smallvec![MaybeUninit::uninit(); cap], + None => smallvec_inline![MaybeUninit::uninit(); DEFAULT_MAX_HEADERS], + }; { - /* SAFETY: it is safe to go from MaybeUninit array to array of MaybeUninit */ - let mut headers: [MaybeUninit>; MAX_HEADERS] = - unsafe { MaybeUninit::uninit().assume_init() }; + let mut headers: SmallVec<[MaybeUninit>; DEFAULT_MAX_HEADERS]> = + match ctx.h1_max_headers { + Some(cap) => smallvec![MaybeUninit::uninit(); cap], + None => smallvec_inline![MaybeUninit::uninit(); DEFAULT_MAX_HEADERS], + }; trace!(bytes = buf.len(), "Request.parse"); let mut req = httparse::Request::new(&mut []); let bytes = buf.as_ref(); @@ -966,15 +970,18 @@ impl Http1Transaction for Client { // Loop to skip information status code headers (100 Continue, etc). loop { - // Unsafe: see comment in Server Http1Transaction, above. - let mut headers_indices: [MaybeUninit; MAX_HEADERS] = unsafe { - // SAFETY: We can go safely from MaybeUninit array to array of MaybeUninit - MaybeUninit::uninit().assume_init() - }; + let mut headers_indices: SmallVec<[MaybeUninit; DEFAULT_MAX_HEADERS]> = + match ctx.h1_max_headers { + Some(cap) => smallvec![MaybeUninit::uninit(); cap], + None => smallvec_inline![MaybeUninit::uninit(); DEFAULT_MAX_HEADERS], + }; let (len, status, reason, version, headers_len) = { - // SAFETY: We can go safely from MaybeUninit array to array of MaybeUninit - let mut headers: [MaybeUninit>; MAX_HEADERS] = - unsafe { MaybeUninit::uninit().assume_init() }; + let mut headers: SmallVec< + [MaybeUninit>; DEFAULT_MAX_HEADERS], + > = match ctx.h1_max_headers { + Some(cap) => smallvec![MaybeUninit::uninit(); cap], + None => smallvec_inline![MaybeUninit::uninit(); DEFAULT_MAX_HEADERS], + }; trace!(bytes = buf.len(), "Response.parse"); let mut res = httparse::Response::new(&mut []); let bytes = buf.as_ref(); @@ -1555,6 +1562,7 @@ mod tests { cached_headers: &mut None, req_method: &mut method, h1_parser_config: Default::default(), + h1_max_headers: None, #[cfg(feature = "runtime")] h1_header_read_timeout: None, #[cfg(feature = "runtime")] @@ -1590,6 +1598,7 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(crate::Method::GET), h1_parser_config: Default::default(), + h1_max_headers: None, #[cfg(feature = "runtime")] h1_header_read_timeout: None, #[cfg(feature = "runtime")] @@ -1620,6 +1629,7 @@ mod tests { cached_headers: &mut None, req_method: &mut None, h1_parser_config: Default::default(), + h1_max_headers: None, #[cfg(feature = "runtime")] h1_header_read_timeout: None, #[cfg(feature = "runtime")] @@ -1648,6 +1658,7 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(crate::Method::GET), h1_parser_config: Default::default(), + h1_max_headers: None, #[cfg(feature = "runtime")] h1_header_read_timeout: None, #[cfg(feature = "runtime")] @@ -1678,6 +1689,7 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(crate::Method::GET), h1_parser_config: Default::default(), + h1_max_headers: None, #[cfg(feature = "runtime")] h1_header_read_timeout: None, #[cfg(feature = "runtime")] @@ -1712,6 +1724,7 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(crate::Method::GET), h1_parser_config, + h1_max_headers: None, #[cfg(feature = "runtime")] h1_header_read_timeout: None, #[cfg(feature = "runtime")] @@ -1743,6 +1756,7 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(crate::Method::GET), h1_parser_config: Default::default(), + h1_max_headers: None, #[cfg(feature = "runtime")] h1_header_read_timeout: None, #[cfg(feature = "runtime")] @@ -1769,6 +1783,7 @@ mod tests { cached_headers: &mut None, req_method: &mut None, h1_parser_config: Default::default(), + h1_max_headers: None, #[cfg(feature = "runtime")] h1_header_read_timeout: None, #[cfg(feature = "runtime")] @@ -1816,6 +1831,7 @@ mod tests { cached_headers: &mut None, req_method: &mut None, h1_parser_config: Default::default(), + h1_max_headers: None, #[cfg(feature = "runtime")] h1_header_read_timeout: None, #[cfg(feature = "runtime")] @@ -1844,6 +1860,7 @@ mod tests { cached_headers: &mut None, req_method: &mut None, h1_parser_config: Default::default(), + h1_max_headers: None, #[cfg(feature = "runtime")] h1_header_read_timeout: None, #[cfg(feature = "runtime")] @@ -2081,6 +2098,7 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(Method::GET), h1_parser_config: Default::default(), + h1_max_headers: None, #[cfg(feature = "runtime")] h1_header_read_timeout: None, #[cfg(feature = "runtime")] @@ -2109,6 +2127,7 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(m), h1_parser_config: Default::default(), + h1_max_headers: None, #[cfg(feature = "runtime")] h1_header_read_timeout: None, #[cfg(feature = "runtime")] @@ -2137,6 +2156,7 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(Method::GET), h1_parser_config: Default::default(), + h1_max_headers: None, #[cfg(feature = "runtime")] h1_header_read_timeout: None, #[cfg(feature = "runtime")] @@ -2642,6 +2662,7 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(Method::GET), h1_parser_config: Default::default(), + h1_max_headers: None, #[cfg(feature = "runtime")] h1_header_read_timeout: None, #[cfg(feature = "runtime")] @@ -2664,6 +2685,143 @@ mod tests { assert_eq!(parsed.head.headers["server"], "hello\tworld"); } + #[test] + fn parse_too_large_headers() { + fn gen_req_with_headers(num: usize) -> String { + let mut req = String::from("GET / HTTP/1.1\r\n"); + for i in 0..num { + req.push_str(&format!("key{i}: val{i}\r\n")); + } + req.push_str("\r\n"); + req + } + fn gen_resp_with_headers(num: usize) -> String { + let mut req = String::from("HTTP/1.1 200 OK\r\n"); + for i in 0..num { + req.push_str(&format!("key{i}: val{i}\r\n")); + } + req.push_str("\r\n"); + req + } + fn parse(max_headers: Option, gen_size: usize, should_success: bool) { + { + // server side + let mut bytes = BytesMut::from(gen_req_with_headers(gen_size).as_str()); + let result = Server::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut None, + h1_parser_config: Default::default(), + h1_max_headers: max_headers, + h1_header_read_timeout: None, + h1_header_read_timeout_fut: &mut None, + h1_header_read_timeout_running: &mut false, + timer: Time::Empty, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + }, + ); + if should_success { + result.expect("parse ok").expect("parse complete"); + } else { + result.expect_err("parse should err"); + } + } + { + // client side + let mut bytes = BytesMut::from(gen_resp_with_headers(gen_size).as_str()); + let result = Client::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut None, + h1_parser_config: Default::default(), + h1_max_headers: max_headers, + h1_header_read_timeout: None, + h1_header_read_timeout_fut: &mut None, + h1_header_read_timeout_running: &mut false, + timer: Time::Empty, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + }, + ); + if should_success { + result.expect("parse ok").expect("parse complete"); + } else { + result.expect_err("parse should err"); + } + } + } + + // check generator + assert_eq!( + gen_req_with_headers(0), + String::from("GET / HTTP/1.1\r\n\r\n") + ); + assert_eq!( + gen_req_with_headers(1), + String::from("GET / HTTP/1.1\r\nkey0: val0\r\n\r\n") + ); + assert_eq!( + gen_req_with_headers(2), + String::from("GET / HTTP/1.1\r\nkey0: val0\r\nkey1: val1\r\n\r\n") + ); + assert_eq!( + gen_req_with_headers(3), + String::from("GET / HTTP/1.1\r\nkey0: val0\r\nkey1: val1\r\nkey2: val2\r\n\r\n") + ); + + // default max_headers is 100, so + // + // - less than or equal to 100, accepted + // + parse(None, 0, true); + parse(None, 1, true); + parse(None, 50, true); + parse(None, 99, true); + parse(None, 100, true); + // + // - more than 100, rejected + // + parse(None, 101, false); + parse(None, 102, false); + parse(None, 200, false); + + // max_headers is 0, parser will reject any headers + // + // - without header, accepted + // + parse(Some(0), 0, true); + // + // - with header(s), rejected + // + parse(Some(0), 1, false); + parse(Some(0), 100, false); + + // max_headers is 200 + // + // - less than or equal to 200, accepted + // + parse(Some(200), 0, true); + parse(Some(200), 1, true); + parse(Some(200), 100, true); + parse(Some(200), 200, true); + // + // - more than 200, rejected + // + parse(Some(200), 201, false); + parse(Some(200), 210, false); + } + #[test] fn test_is_complete_fast() { let s = b"GET / HTTP/1.1\r\na: b\r\n\r\n"; @@ -2756,6 +2914,7 @@ mod tests { cached_headers: &mut headers, req_method: &mut None, h1_parser_config: Default::default(), + h1_max_headers: None, #[cfg(feature = "runtime")] h1_header_read_timeout: None, #[cfg(feature = "runtime")] @@ -2804,6 +2963,7 @@ mod tests { cached_headers: &mut headers, req_method: &mut None, h1_parser_config: Default::default(), + h1_max_headers: None, #[cfg(feature = "runtime")] h1_header_read_timeout: None, #[cfg(feature = "runtime")] diff --git a/src/server/conn.rs b/src/server/conn.rs index 951c9ee5cd..caaa6d2059 100644 --- a/src/server/conn.rs +++ b/src/server/conn.rs @@ -113,6 +113,7 @@ pub struct Http { h1_keep_alive: bool, h1_title_case_headers: bool, h1_preserve_header_case: bool, + h1_max_headers: Option, #[cfg(all(feature = "http1", feature = "runtime"))] h1_header_read_timeout: Option, h1_writev: Option, @@ -260,6 +261,7 @@ impl Http { h1_title_case_headers: false, h1_preserve_header_case: false, #[cfg(all(feature = "http1", feature = "runtime"))] + h1_max_headers: None, h1_header_read_timeout: None, h1_writev: None, #[cfg(feature = "http2")] @@ -349,6 +351,26 @@ impl Http { self } + /// Set the maximum number of headers. + /// + /// When a request is received, the parser will reserve a buffer to store headers for optimal + /// performance. + /// + /// If server receives more headers than the buffer size, it responds to the client with + /// "431 Request Header Fields Too Large". + /// + /// Note that headers is allocated on the stack by default, which has higher performance. After + /// setting this value, headers will be allocated in heap memory, that is, heap memory + /// allocation will occur for each request, and there will be a performance drop of about 5%. + /// + /// Default is 100. + #[cfg(feature = "http1")] + #[cfg_attr(docsrs, doc(cfg(feature = "http1")))] + pub fn http1_max_headers(&mut self, val: usize) -> &mut Self { + self.h1_max_headers = Some(val); + self + } + /// Set a timeout for reading client request headers. If a client does not /// transmit the entire header within this time, the connection is closed. /// @@ -623,6 +645,7 @@ impl Http { h1_keep_alive: self.h1_keep_alive, h1_title_case_headers: self.h1_title_case_headers, h1_preserve_header_case: self.h1_preserve_header_case, + h1_max_headers: self.h1_max_headers, #[cfg(all(feature = "http1", feature = "runtime"))] h1_header_read_timeout: self.h1_header_read_timeout, h1_writev: self.h1_writev, @@ -687,6 +710,9 @@ impl Http { if self.h1_preserve_header_case { conn.set_preserve_header_case(); } + if let Some(max_headers) = self.h1_max_headers { + conn.set_http1_max_headers(max_headers); + } #[cfg(all(feature = "http1", feature = "runtime"))] if let Some(header_read_timeout) = self.h1_header_read_timeout { conn.set_http1_header_read_timeout(header_read_timeout); diff --git a/src/server/conn/http1.rs b/src/server/conn/http1.rs index ab833b938b..1e71e681ef 100644 --- a/src/server/conn/http1.rs +++ b/src/server/conn/http1.rs @@ -42,6 +42,7 @@ pub struct Builder { h1_keep_alive: bool, h1_title_case_headers: bool, h1_preserve_header_case: bool, + h1_max_headers: Option, h1_header_read_timeout: Option, h1_writev: Option, max_buf_size: Option, @@ -208,6 +209,7 @@ impl Builder { h1_keep_alive: true, h1_title_case_headers: false, h1_preserve_header_case: false, + h1_max_headers: None, h1_header_read_timeout: None, h1_writev: None, max_buf_size: None, @@ -260,6 +262,24 @@ impl Builder { self } + /// Set the maximum number of headers. + /// + /// When a request is received, the parser will reserve a buffer to store headers for optimal + /// performance. + /// + /// If server receives more headers than the buffer size, it responds to the client with + /// "431 Request Header Fields Too Large". + /// + /// Note that headers is allocated on the stack by default, which has higher performance. After + /// setting this value, headers will be allocated in heap memory, that is, heap memory + /// allocation will occur for each request, and there will be a performance drop of about 5%. + /// + /// Default is 100. + pub fn max_headers(&mut self, val: usize) -> &mut Self { + self.h1_max_headers = Some(val); + self + } + /// Set a timeout for reading client request headers. If a client does not /// transmit the entire header within this time, the connection is closed. /// @@ -370,6 +390,9 @@ impl Builder { if self.h1_preserve_header_case { conn.set_preserve_header_case(); } + if let Some(max_headers) = self.h1_max_headers { + conn.set_http1_max_headers(max_headers); + } if let Some(header_read_timeout) = self.h1_header_read_timeout { conn.set_http1_header_read_timeout(header_read_timeout); }