diff --git a/axum-extra/src/extract/mod.rs b/axum-extra/src/extract/mod.rs index 7d2a5b2433..bc78d8b0dc 100644 --- a/axum-extra/src/extract/mod.rs +++ b/axum-extra/src/extract/mod.rs @@ -6,6 +6,38 @@ mod optional_path; pub mod rejection; mod with_rejection; +/// Private mod, public trait trick +mod spoof { + pub trait FromSpoofableRequestParts: Sized { + type Rejection: axum::response::IntoResponse; + + fn from_request_parts( + parts: &mut axum::http::request::Parts, + state: &S, + ) -> impl std::future::Future> + Send; + } +} + +/// Wrap spoofable extractor +#[derive(Debug)] +pub struct Spoofable(pub E); + +/// Allow `Spoofable` to be used with spoofable extractors in handlers +impl FromRequestParts for Spoofable +where + E: spoof::FromSpoofableRequestParts, + S: Sync, +{ + type Rejection = E::Rejection; + + async fn from_request_parts( + parts: &mut axum::http::request::Parts, + state: &S, + ) -> Result { + E::from_request_parts(parts, state).await.map(Spoofable) + } +} + #[cfg(feature = "form")] mod form; @@ -24,6 +56,8 @@ pub mod multipart; #[cfg(feature = "scheme")] mod scheme; +use axum::extract::FromRequestParts; + pub use self::{ cached::Cached, host::Host, optional_path::OptionalPath, with_rejection::WithRejection, }; diff --git a/axum-extra/src/extract/scheme.rs b/axum-extra/src/extract/scheme.rs index 891d5c0bdd..6c16b76c3a 100644 --- a/axum-extra/src/extract/scheme.rs +++ b/axum-extra/src/extract/scheme.rs @@ -1,10 +1,7 @@ //! Extractor that parses the scheme of a request. //! See [`Scheme`] for more details. -use axum::{ - extract::FromRequestParts, - response::{IntoResponse, Response}, -}; +use axum::response::{IntoResponse, Response}; use http::{ header::{HeaderMap, FORWARDED}, request::Parts, @@ -34,7 +31,7 @@ impl IntoResponse for SchemeMissing { } } -impl FromRequestParts for Scheme +impl super::spoof::FromSpoofableRequestParts for Scheme where S: Send + Sync, { @@ -83,12 +80,12 @@ fn parse_forwarded(headers: &HeaderMap) -> Option<&str> { #[cfg(test)] mod tests { use super::*; - use crate::test_helpers::TestClient; + use crate::{extract::Spoofable, test_helpers::TestClient}; use axum::{routing::get, Router}; use http::header::HeaderName; fn test_client() -> TestClient { - async fn scheme_as_body(Scheme(scheme): Scheme) -> String { + async fn scheme_as_body(Spoofable(Scheme(scheme)): Spoofable) -> String { scheme } diff --git a/examples/spoofable-scheme/Cargo.toml b/examples/spoofable-scheme/Cargo.toml new file mode 100644 index 0000000000..b346fd2072 --- /dev/null +++ b/examples/spoofable-scheme/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "spoofable-scheme" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +axum = { path = "../../axum" } +axum-extra = { path = "../../axum-extra", features = ["scheme"] } +tokio = { version = "1.0", features = ["full"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } + diff --git a/examples/spoofable-scheme/src/main.rs b/examples/spoofable-scheme/src/main.rs new file mode 100644 index 0000000000..b8614779f8 --- /dev/null +++ b/examples/spoofable-scheme/src/main.rs @@ -0,0 +1,40 @@ +//! Example of application using spoofable extractors +//! +//! Run with +//! +//! ```not_rust +//! cargo run -p spoofable-scheme +//! ``` +//! +//! Test with curl: +//! +//! ```not_rust +//! curl -i http://localhost:3000/ -H "X-Forwarded-Proto: http" +//! ``` + +use axum::{routing::get, Router}; +use axum_extra::extract::{Scheme, Spoofable}; +use tokio::net::TcpListener; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +#[tokio::main] +async fn main() { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + // build our application with some routes + let app = Router::new().route("/", get(f)); + + let listener = TcpListener::bind("127.0.0.1:3000").await.unwrap(); + tracing::debug!("listening on {}", listener.local_addr().unwrap()); + axum::serve(listener, app).await.unwrap(); +} + +async fn f(Spoofable(Scheme(scheme)): Spoofable) -> String { + scheme +}