diff --git a/axum-core/src/response/into_response.rs b/axum-core/src/response/into_response.rs index 915b55eff8..9bc0c2fec2 100644 --- a/axum-core/src/response/into_response.rs +++ b/axum-core/src/response/into_response.rs @@ -1,4 +1,6 @@ -use super::{IntoResponseParts, Response, ResponseParts}; +use super::{ + IntoResponseFailed, IntoResponseParts, OverrideAllStatusCodes, Response, ResponseParts, +}; use crate::{body::Body, BoxError}; use bytes::{buf::Chain, Buf, Bytes, BytesMut}; use http::{ @@ -329,7 +331,9 @@ where { fn into_response(self) -> Response { let mut res = self.1.into_response(); - *res.status_mut() = self.0; + if res.extensions().get::().is_none() { + *res.status_mut() = self.0; + } res } } @@ -405,18 +409,16 @@ macro_rules! impl_into_response { let ($($ty),*, res) = self; let res = res.into_response(); - let parts = ResponseParts { res }; - - $( - let parts = match $ty.into_response_parts(parts) { + if res.extensions().get::().is_none() { + let parts = ResponseParts { res }; + let parts = match ($($ty,)*).into_response_parts(parts) { Ok(parts) => parts, - Err(err) => { - return err.into_response(); - } + Err(err) => return err.into_response(), }; - )* - - parts.res + parts.res + } else { + res + } } } @@ -430,16 +432,40 @@ macro_rules! impl_into_response { let (status, $($ty),*, res) = self; let res = res.into_response(); - let parts = ResponseParts { res }; - - $( - let parts = match $ty.into_response_parts(parts) { + if res.extensions().get::().is_none() { + let parts = ResponseParts { res }; + let mut parts = match ($($ty,)*).into_response_parts(parts) { Ok(parts) => parts, - Err(err) => { - return err.into_response(); - } + Err(err) => return err.into_response(), }; - )* + + // Don't call `(status, parts.res).into_response()` since that checks for + // `IntoResponseFailed` and skips setting the status. We've already done that + // check here so overriding the status is required if returning + // `(IntoResponseFailed, StatusCode::INTERNAL_SERVER_ERROR)` + *parts.res.status_mut() = status; + parts.res + } else { + res + } + } + } + + #[allow(non_snake_case)] + impl IntoResponse for (OverrideAllStatusCodes, $($ty),*, R) + where + $( $ty: IntoResponseParts, )* + R: IntoResponse, + { + fn into_response(self) -> Response { + let (status, $($ty),*, res) = self; + + let res = res.into_response(); + let parts = ResponseParts { res }; + let parts = match ($($ty,)*).into_response_parts(parts) { + Ok(parts) => parts, + Err(err) => return err.into_response(), + }; (status, parts.res).into_response() } @@ -455,17 +481,22 @@ macro_rules! impl_into_response { let (outer_parts, $($ty),*, res) = self; let res = res.into_response(); - let parts = ResponseParts { res }; - $( - let parts = match $ty.into_response_parts(parts) { + if res.extensions().get::().is_none() { + let parts = ResponseParts { res }; + let mut parts = match ($($ty,)*).into_response_parts(parts) { Ok(parts) => parts, - Err(err) => { - return err.into_response(); - } + Err(err) => return err.into_response(), }; - )* - (outer_parts, parts.res).into_response() + // Don't call `(outer_parts, parts.res).into_response()` for the same reason we + // don't call `(status, parts.res).into_response()` in the above impl. + *parts.res.status_mut() = outer_parts.status; + parts.res.headers_mut().extend(outer_parts.headers); + parts.res.extensions_mut().extend(outer_parts.extensions); + parts.res + } else { + res + } } } diff --git a/axum-core/src/response/into_response_parts.rs b/axum-core/src/response/into_response_parts.rs index 955648238d..2deabb6448 100644 --- a/axum-core/src/response/into_response_parts.rs +++ b/axum-core/src/response/into_response_parts.rs @@ -241,7 +241,9 @@ macro_rules! impl_into_response_parts { let res = match $ty.into_response_parts(res) { Ok(res) => res, Err(err) => { - return Err(err.into_response()); + let mut err_res = err.into_response(); + err_res.extensions_mut().insert(super::IntoResponseFailed); + return Err(err_res); } }; )* diff --git a/axum-core/src/response/mod.rs b/axum-core/src/response/mod.rs index dd6728b1c2..0a9a478219 100644 --- a/axum-core/src/response/mod.rs +++ b/axum-core/src/response/mod.rs @@ -4,6 +4,10 @@ //! //! [`axum::response`]: https://docs.rs/axum/0.7/axum/response/index.html +use std::convert::Infallible; + +use http::StatusCode; + use crate::body::Body; mod append_headers; @@ -128,3 +132,56 @@ where Self(value.into_response()) } } + +/// ``` +/// todo!(); +/// ``` +#[derive(Copy, Clone, Debug)] +pub struct IntoResponseFailed; + +impl IntoResponseParts for IntoResponseFailed { + type Error = Infallible; + + fn into_response_parts(self, mut res: ResponseParts) -> Result { + res.extensions_mut().insert(self); + Ok(res) + } +} + +/// Not sure it makes sense to return `IntoResponseFailed` as the whole response. You should +/// probably at least combine it with a status code. +/// +/// ```compile_fail +/// fn foo() +/// where +/// axum_core::response::IntoResponseFailed: axum_core::response::IntoResponse, +/// {} +/// ``` +#[allow(dead_code)] +fn into_response_failed_doesnt_impl_into_response() {} + +/// Override all status codes regardless if [`IntoResponseFailed`] is used or not. +/// +/// See the docs for [`IntoResponseFailed`] for more details. +#[derive(Debug, Copy, Clone, Default)] +pub struct OverrideAllStatusCodes(pub StatusCode); + +impl IntoResponse for OverrideAllStatusCodes { + fn into_response(self) -> Response { + let mut res = ().into_response(); + *res.status_mut() = self.0; + res + } +} + +impl IntoResponse for (OverrideAllStatusCodes, R) +where + R: IntoResponse, +{ + fn into_response(self) -> Response { + let (OverrideAllStatusCodes(status), res) = self; + let mut res = res.into_response(); + *res.status_mut() = status; + res + } +} diff --git a/axum-extra/src/protobuf.rs b/axum-extra/src/protobuf.rs index d563807403..faaca151ce 100644 --- a/axum-extra/src/protobuf.rs +++ b/axum-extra/src/protobuf.rs @@ -2,7 +2,7 @@ use axum::{ extract::{rejection::BytesRejection, FromRequest, Request}, - response::{IntoResponse, Response}, + response::{IntoResponse, IntoResponseFailed, Response}, }; use bytes::{Bytes, BytesMut}; use http::StatusCode; @@ -122,7 +122,12 @@ where let mut buf = BytesMut::with_capacity(128); match &self.0.encode(&mut buf) { Ok(()) => buf.into_response(), - Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(), + Err(err) => ( + StatusCode::INTERNAL_SERVER_ERROR, + IntoResponseFailed, + err.to_string(), + ) + .into_response(), } } } diff --git a/axum-extra/src/response/erased_json.rs b/axum-extra/src/response/erased_json.rs index 5088ff35fb..781f3aeb1e 100644 --- a/axum-extra/src/response/erased_json.rs +++ b/axum-extra/src/response/erased_json.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use axum::{ http::{header, HeaderValue, StatusCode}, - response::{IntoResponse, Response}, + response::{IntoResponse, IntoResponseFailed, Response}, }; use bytes::{BufMut, Bytes, BytesMut}; use serde::Serialize; @@ -77,7 +77,12 @@ impl IntoResponse for ErasedJson { bytes, ) .into_response(), - Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(), + Err(err) => ( + StatusCode::INTERNAL_SERVER_ERROR, + IntoResponseFailed, + err.to_string(), + ) + .into_response(), } } } diff --git a/axum/src/form.rs b/axum/src/form.rs index f754c4c1b8..40b7585b56 100644 --- a/axum/src/form.rs +++ b/axum/src/form.rs @@ -1,6 +1,6 @@ use crate::extract::Request; use crate::extract::{rejection::*, FromRequest, RawForm}; -use axum_core::response::{IntoResponse, Response}; +use axum_core::response::{IntoResponse, IntoResponseFailed, Response}; use axum_core::RequestExt; use http::header::CONTENT_TYPE; use http::StatusCode; @@ -113,7 +113,12 @@ where body, ) .into_response(), - Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(), + Err(err) => ( + StatusCode::INTERNAL_SERVER_ERROR, + IntoResponseFailed, + err.to_string(), + ) + .into_response(), } } } diff --git a/axum/src/json.rs b/axum/src/json.rs index 7082ac8bc3..8f66e70469 100644 --- a/axum/src/json.rs +++ b/axum/src/json.rs @@ -1,6 +1,6 @@ use crate::extract::Request; use crate::extract::{rejection::*, FromRequest}; -use axum_core::response::{IntoResponse, Response}; +use axum_core::response::{IntoResponse, IntoResponseFailed, Response}; use bytes::{BufMut, Bytes, BytesMut}; use http::{ header::{self, HeaderMap, HeaderValue}, @@ -204,6 +204,7 @@ where header::CONTENT_TYPE, HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()), )], + IntoResponseFailed, err.to_string(), ) .into_response(), diff --git a/axum/src/response/mod.rs b/axum/src/response/mod.rs index dd616dff57..47c71d0ef6 100644 --- a/axum/src/response/mod.rs +++ b/axum/src/response/mod.rs @@ -20,7 +20,8 @@ pub use crate::Extension; #[doc(inline)] pub use axum_core::response::{ - AppendHeaders, ErrorResponse, IntoResponse, IntoResponseParts, Response, ResponseParts, Result, + AppendHeaders, ErrorResponse, IntoResponse, IntoResponseFailed, IntoResponseParts, Response, + ResponseParts, Result, }; #[doc(inline)] @@ -87,10 +88,16 @@ impl IntoResponse for NoContent { #[cfg(test)] mod tests { use crate::extract::Extension; + use crate::test_helpers::*; + use crate::Json; use crate::{routing::get, Router}; - use axum_core::response::IntoResponse; + use axum_core::response::OverrideAllStatusCodes; + use axum_core::response::{ + IntoResponse, IntoResponseFailed, IntoResponseParts, Response, ResponseParts, + }; use http::HeaderMap; use http::{StatusCode, Uri}; + use std::collections::HashMap; // just needs to compile #[allow(dead_code)] @@ -249,6 +256,282 @@ mod tests { .route("/", get(header_array_extension_mixed_body)); } + #[test] + fn status_code_tuple_doesnt_override_error() { + // sanity check where there is just one status code + assert_eq!( + StatusCode::INTERNAL_SERVER_ERROR.into_response().status(), + StatusCode::INTERNAL_SERVER_ERROR + ); + assert_eq!( + (StatusCode::INTERNAL_SERVER_ERROR,) + .into_response() + .status(), + StatusCode::INTERNAL_SERVER_ERROR + ); + + // non-5xx status should be changed + assert_eq!( + (StatusCode::SEE_OTHER, StatusCode::NO_CONTENT) + .into_response() + .status(), + StatusCode::SEE_OTHER + ); + let res = ( + StatusCode::SEE_OTHER, + [("location", "foo")], + StatusCode::NO_CONTENT, + ) + .into_response(); + assert_eq!(res.status(), StatusCode::SEE_OTHER); + assert_eq!(res.headers()["location"], "foo"); + + // 5xx status codes are also changed + assert_eq!( + (StatusCode::SEE_OTHER, StatusCode::INTERNAL_SERVER_ERROR) + .into_response() + .status(), + StatusCode::SEE_OTHER + ); + let res = ( + StatusCode::SEE_OTHER, + [("location", "foo")], + StatusCode::INTERNAL_SERVER_ERROR, + ) + .into_response(); + assert_eq!(res.status(), StatusCode::SEE_OTHER); + assert_eq!(res.headers()["location"], "foo"); + + // the status is not changed if `IntoResponseFailed` is used + assert_eq!( + ( + StatusCode::SEE_OTHER, + (IntoResponseFailed, StatusCode::INTERNAL_SERVER_ERROR) + ) + .into_response() + .status(), + StatusCode::INTERNAL_SERVER_ERROR + ); + let res = ( + StatusCode::SEE_OTHER, + [("location", "foo")], + (IntoResponseFailed, StatusCode::INTERNAL_SERVER_ERROR), + ) + .into_response(); + assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert!(res.headers().get("location").is_none()); + + // response parts from the inner response do run + let res = ( + // with status override + StatusCode::SEE_OTHER, + [("location", "foo")], + ( + [("x-bar", "bar")], + IntoResponseFailed, + [("x-foo", "foo")], + StatusCode::INTERNAL_SERVER_ERROR, + ), + ) + .into_response(); + assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert!(res.headers().get("location").is_none()); + assert_eq!(res.headers()["x-foo"], "foo"); + assert_eq!(res.headers()["x-bar"], "bar"); + + let res = ( + // without status override + [("location", "foo")], + ( + [("x-bar", "bar")], + IntoResponseFailed, + [("x-foo", "foo")], + StatusCode::INTERNAL_SERVER_ERROR, + ), + ) + .into_response(); + assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert!(res.headers().get("location").is_none()); + assert_eq!(res.headers()["x-foo"], "foo"); + assert_eq!(res.headers()["x-bar"], "bar"); + + // (Parts, ...) + let res = ( + Response::new(()).into_parts().0, + [("location", "foo")], + ( + [("x-bar", "bar")], + IntoResponseFailed, + [("x-foo", "foo")], + StatusCode::INTERNAL_SERVER_ERROR, + ), + ) + .into_response(); + assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert!(res.headers().get("location").is_none()); + assert_eq!(res.headers()["x-foo"], "foo"); + assert_eq!(res.headers()["x-bar"], "bar"); + + // (Response<()>, ...) + let res = ( + Response::new(()), + [("location", "foo")], + ( + [("x-bar", "bar")], + IntoResponseFailed, + [("x-foo", "foo")], + StatusCode::INTERNAL_SERVER_ERROR, + ), + ) + .into_response(); + assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert!(res.headers().get("location").is_none()); + assert_eq!(res.headers()["x-foo"], "foo"); + assert_eq!(res.headers()["x-bar"], "bar"); + } + #[test] + fn into_response_parts_failing_sets_extension() { + struct Fail; + + impl IntoResponseParts for Fail { + type Error = (); + + fn into_response_parts( + self, + _res: ResponseParts, + ) -> Result { + Err(()) + } + } + + impl IntoResponse for Fail { + fn into_response(self) -> Response { + (self, ()).into_response() + } + } + + assert!(Fail + .into_response() + .extensions() + .get::() + .is_some()); + + assert!((StatusCode::INTERNAL_SERVER_ERROR, Fail, ()) + .into_response() + .extensions() + .get::() + .is_some()); + + assert!((Response::new(()).into_parts().0, Fail, ()) + .into_response() + .extensions() + .get::() + .is_some()); + + assert!((Response::new(()), Fail, ()) + .into_response() + .extensions() + .get::() + .is_some()); + } + #[test] + fn doenst_override_status_code_when_using_into_response_failed_at_same_level() { + assert_eq!( + (StatusCode::INTERNAL_SERVER_ERROR, IntoResponseFailed, ()) + .into_response() + .status(), + StatusCode::INTERNAL_SERVER_ERROR, + ); + + #[derive(Clone)] + struct Thing; + + let res = ( + Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .header("x-foo", "foo") + .extension(Thing) + .body(()) + .unwrap() + .into_parts() + .0, + IntoResponseFailed, + (), + ) + .into_response(); + assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR,); + assert_eq!(res.headers()["x-foo"], "foo"); + assert!(res.extensions().get::().is_some()); + + // just a sanity check + assert_eq!( + (IntoResponseFailed, ()).into_response().status(), + StatusCode::OK, + ); + } + #[test] + fn force_overriding_status_code() { + assert_eq!( + OverrideAllStatusCodes(StatusCode::IM_A_TEAPOT) + .into_response() + .status(), + StatusCode::IM_A_TEAPOT + ); + + assert_eq!( + (OverrideAllStatusCodes(StatusCode::IM_A_TEAPOT),) + .into_response() + .status(), + StatusCode::IM_A_TEAPOT + ); + + assert_eq!( + (OverrideAllStatusCodes(StatusCode::IM_A_TEAPOT), ()) + .into_response() + .status(), + StatusCode::IM_A_TEAPOT + ); + + assert_eq!( + ( + OverrideAllStatusCodes(StatusCode::IM_A_TEAPOT), + IntoResponseFailed, + StatusCode::INTERNAL_SERVER_ERROR, + ) + .into_response() + .status(), + StatusCode::IM_A_TEAPOT + ); + } + #[crate::test] + async fn status_code_tuple_doesnt_override_error_json() { + let app = Router::new() + .route( + "/", + get(|| async { + let not_json_compatible = HashMap::from([(Vec::from([1, 2, 3]), 123)]); + (StatusCode::IM_A_TEAPOT, Json(not_json_compatible)) + }), + ) + .route( + "/two", + get(|| async { + let not_json_compatible = HashMap::from([(Vec::from([1, 2, 3]), 123)]); + ( + OverrideAllStatusCodes(StatusCode::IM_A_TEAPOT), + Json(not_json_compatible), + ) + }), + ); + + let client = TestClient::new(app); + + let res = client.get("/").await; + assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); + + let res = client.get("/two").await; + assert_eq!(res.status(), StatusCode::IM_A_TEAPOT); + } #[test] fn no_content() { assert_eq!(