diff --git a/axum-extra/src/extract/mod.rs b/axum-extra/src/extract/mod.rs index 1f9974de02..d7d188e52e 100644 --- a/axum-extra/src/extract/mod.rs +++ b/axum-extra/src/extract/mod.rs @@ -19,7 +19,12 @@ mod query; #[cfg(feature = "multipart")] pub mod multipart; -pub use self::{cached::Cached, optional_path::OptionalPath, with_rejection::WithRejection}; +pub mod user_lang; + +pub use self::{ + cached::Cached, optional_path::OptionalPath, user_lang::UserLanguage, + with_rejection::WithRejection, +}; #[cfg(feature = "cookie")] pub use self::cookie::CookieJar; diff --git a/axum-extra/src/extract/user_lang/config.rs b/axum-extra/src/extract/user_lang/config.rs new file mode 100644 index 0000000000..1c2861d1a6 --- /dev/null +++ b/axum-extra/src/extract/user_lang/config.rs @@ -0,0 +1,100 @@ +use std::sync::Arc; + +use crate::extract::user_lang::{UserLanguage, UserLanguageSource}; + +/// Configuration for the [`UserLanguage`] extractor. +/// +/// By default the [`UserLanguage`] extractor will try to read the +/// languages from the sources returned by [`UserLanguage::default_sources`]. +/// +/// You can override the default behaviour by adding a [`Config`] +/// extension to your routes. +/// +/// You can add sources and specify a fallback language. +/// +/// # Example +/// +/// ```rust +/// use axum::{routing::get, Extension, Router}; +/// use axum_extra::extract::user_lang::{PathSource, QuerySource, UserLanguage}; +/// +/// # fn main() { +/// let app = Router::new() +/// .route("/:lang", get(handler)) +/// .layer(Extension( +/// UserLanguage::config() +/// .add_source(QuerySource::new("lang")) +/// .add_source(PathSource::new("lang")) +/// .build(), +/// )); +/// # let _: Router = app; +/// # } +/// # async fn handler() {} +/// ``` +/// +#[derive(Debug, Clone)] +pub struct Config { + pub(crate) fallback_language: String, + pub(crate) sources: Vec>, +} + +/// Builder to create a [`Config`] for the [`UserLanguage`] extractor. +/// +/// Allows you to declaratively create a [`Config`]. +/// You can create a [`ConfigBuilder`] by calling +/// [`UserLanguage::config`]. +/// +/// # Example +/// +/// ```rust +/// use axum_extra::extract::user_lang::{QuerySource, UserLanguage}; +/// +/// # fn main() { +/// let config = UserLanguage::config() +/// .add_source(QuerySource::new("lang")) +/// .fallback_language("es") +/// .build(); +/// # let _ = config; +/// # } +/// ``` +#[derive(Debug, Clone)] +pub struct ConfigBuilder { + fallback_language: String, + sources: Vec>, +} + +impl ConfigBuilder { + /// Set the fallback language. + pub fn fallback_language(mut self, fallback_language: impl Into) -> Self { + self.fallback_language = fallback_language.into(); + self + } + + /// Add a [`UserLanguageSource`]. + pub fn add_source(mut self, source: impl UserLanguageSource + 'static) -> Self { + self.sources.push(Arc::new(source)); + self + } + + /// Create a [`Config`] from this builder. + pub fn build(self) -> Config { + Config { + fallback_language: self.fallback_language, + sources: if !self.sources.is_empty() { + self.sources + } else { + UserLanguage::default_sources().clone() + }, + } + } +} + +impl UserLanguage { + /// Returns a builder for [`Config`]. + pub fn config() -> ConfigBuilder { + ConfigBuilder { + fallback_language: "en".to_owned(), + sources: vec![], + } + } +} diff --git a/axum-extra/src/extract/user_lang/lang.rs b/axum-extra/src/extract/user_lang/lang.rs new file mode 100644 index 0000000000..06cf560bd7 --- /dev/null +++ b/axum-extra/src/extract/user_lang/lang.rs @@ -0,0 +1,219 @@ +use super::{ + sources::{AcceptLanguageSource, PathSource}, + Config, UserLanguageSource, +}; +use axum::{async_trait, extract::FromRequestParts, Extension, RequestPartsExt}; +use http::request::Parts; +use std::{ + convert::Infallible, + sync::{Arc, OnceLock}, +}; + +#[cfg(feature = "query")] +use super::sources::QuerySource; + +/// The users preferred languages, read from the request. +/// +/// This extractor reads the users preferred languages from a +/// configurable list of sources. +/// +/// By default it will try to read from the following sources: +/// * The query parameter `lang` +/// * The path segment `:lang` +/// * The `Accept-Language` header +/// +/// This extractor never fails. If no language could be read from the request, +/// the fallback language will be used. By default the fallback is `en`, but +/// this can be configured. +/// +/// # Configuration +/// +/// To configure the sources for the languages or the fallback language, see [`UserLanguage::config`]. +/// +/// # Custom Sources +/// +/// You can create custom user language sources. See +/// [`UserLanguageSource`] for details. +/// +/// # Example +/// +/// ```rust +/// use axum_extra::extract::UserLanguage; +/// +/// async fn handler(lang: UserLanguage) { +/// println!("Preferred languages: {:?}", lang.preferred_languages()); +/// } +/// ``` +#[derive(Debug, Clone)] +pub struct UserLanguage { + preferred_languages: Vec, + fallback_language: String, +} + +impl UserLanguage { + /// The default sources for the preferred languages. + /// + /// If you do not add a configuration for the [`UserLanguage`] extractor, + /// these sources will be used by default. They are in order: + /// * The query parameter `lang` (if feature `query` is enabled) + /// * The path segment `:lang` + /// * The `Accept-Language` header + pub fn default_sources() -> &'static Vec> { + static DEFAULT_SOURCES: OnceLock>> = OnceLock::new(); + + DEFAULT_SOURCES.get_or_init(|| { + vec![ + #[cfg(feature = "query")] + Arc::new(QuerySource::new("lang")), + Arc::new(PathSource::new("lang")), + Arc::new(AcceptLanguageSource), + ] + }) + } + + /// The users most preferred language as read from the request. + /// + /// This is the first language in the list of [`UserLanguage::preferred_languages`]. + /// If no language could be read from the request, the fallback language + /// will be returned. + pub fn preferred_language(&self) -> &str { + self.preferred_languages + .first() + .unwrap_or(&self.fallback_language) + } + + /// The users preferred languages in order of preference. + /// + /// Preference is first determined by the order of the sources. + /// Within each source the languages are ordered by the users preference, + /// if applicable for the source. For example the `Accept-Language` header + /// source will order the languages by the `q` parameter. + /// + /// This list may be empty if no language could be read from the request. + pub fn preferred_languages(&self) -> &[String] { + self.preferred_languages.as_slice() + } + + /// The language that will be used as a fallback if no language could be + /// read from the request. + pub fn fallback_language(&self) -> &str { + &self.fallback_language + } +} + +#[async_trait] +impl FromRequestParts for UserLanguage +where + S: Send + Sync, +{ + type Rejection = Infallible; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + let (sources, fallback_language) = match parts.extract::>().await { + Ok(Extension(config)) => (Some(config.sources), Some(config.fallback_language)), + Err(_) => (None, None), + }; + + let sources = sources.as_ref().unwrap_or(Self::default_sources()); + let fallback_language = fallback_language.unwrap_or_else(|| "en".to_owned()); + + let mut preferred_languages = Vec::::new(); + + for source in sources { + let languages = source.languages_from_parts(parts).await; + preferred_languages.extend(languages); + } + + Ok(UserLanguage { + preferred_languages, + fallback_language, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_helpers::*; + use axum::{routing::get, Router}; + use http::{header::ACCEPT_LANGUAGE, StatusCode}; + + #[derive(Debug)] + struct TestSource(Vec); + + #[async_trait] + impl UserLanguageSource for TestSource { + async fn languages_from_parts(&self, _parts: &mut Parts) -> Vec { + self.0.clone() + } + } + + #[tokio::test] + async fn reads_from_configured_sources_in_specified_order() { + let app = Router::new() + .route("/", get(return_all_langs)) + .layer(Extension( + UserLanguage::config() + .add_source(TestSource(vec!["s1.1".to_owned(), "s1.2".to_owned()])) + .add_source(TestSource(vec!["s2.1".to_owned(), "s2.2".to_owned()])) + .build(), + )); + + let client = TestClient::new(app); + + let res = client.get("/").send().await; + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "s1.1,s1.2,s2.1,s2.2"); + } + + #[tokio::test] + async fn reads_languages_from_default_sources() { + let app = Router::new().route("/:lang", get(return_all_langs)); + + let client = TestClient::new(app); + + let res = client + .get("/de?lang=fr") + .header(ACCEPT_LANGUAGE, "en;q=0.9,es;q=0.8") + .send() + .await; + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "fr,de,en,es"); + } + + #[tokio::test] + async fn falls_back_to_configured_language() { + let app = Router::new().route("/", get(return_lang)).layer(Extension( + UserLanguage::config().fallback_language("fallback").build(), + )); + + let client = TestClient::new(app); + + let res = client.get("/").send().await; + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "fallback"); + } + + #[tokio::test] + async fn falls_back_to_default_language() { + let app = Router::new().route("/", get(return_lang)); + + let client = TestClient::new(app); + + let res = client.get("/").send().await; + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "en"); + } + + async fn return_lang(lang: UserLanguage) -> String { + lang.preferred_language().to_owned() + } + + async fn return_all_langs(lang: UserLanguage) -> String { + lang.preferred_languages().join(",") + } +} diff --git a/axum-extra/src/extract/user_lang/mod.rs b/axum-extra/src/extract/user_lang/mod.rs new file mode 100644 index 0000000000..cd03355ca4 --- /dev/null +++ b/axum-extra/src/extract/user_lang/mod.rs @@ -0,0 +1,11 @@ +//! Extractor that retrieves the preferred languages of the user. + +mod config; +mod lang; +mod source; +mod sources; + +pub use config::*; +pub use lang::*; +pub use source::*; +pub use sources::*; diff --git a/axum-extra/src/extract/user_lang/source.rs b/axum-extra/src/extract/user_lang/source.rs new file mode 100644 index 0000000000..5071d238d3 --- /dev/null +++ b/axum-extra/src/extract/user_lang/source.rs @@ -0,0 +1,44 @@ +use axum::async_trait; +use http::request::Parts; +use std::fmt::Debug; + +/// A source for the users preferred languages. +/// +/// # Implementing a custom source +/// +/// The following is an example of how to read the language from the query. +/// +/// ```rust +/// use std::collections::HashMap; +/// use axum::{extract::Query, RequestPartsExt}; +/// use axum_extra::extract::user_lang::UserLanguageSource; +/// +/// #[derive(Debug)] +/// pub struct QuerySource; +/// +/// #[axum::async_trait] +/// impl UserLanguageSource for QuerySource { +/// async fn languages_from_parts(&self, parts: &mut http::request::Parts) -> Vec { +/// let Ok(query) = parts.extract::>>().await else { +/// return vec![]; +/// }; +/// +/// let Some(lang) = query.get("lang") else { +/// return vec![]; +/// }; +/// +/// vec![lang.to_owned()] +/// } +/// } +/// ``` +#[async_trait] +pub trait UserLanguageSource: Send + Sync + Debug { + /// Extract a list of user languages from the request parts. + /// + /// The multiple languages are returned, they should be in + /// order of preference of the user, if possible. + /// + /// If no languages could be read from the request, return + /// an empty vec. + async fn languages_from_parts(&self, parts: &mut Parts) -> Vec; +} diff --git a/axum-extra/src/extract/user_lang/sources/header.rs b/axum-extra/src/extract/user_lang/sources/header.rs new file mode 100644 index 0000000000..1dfe6df3de --- /dev/null +++ b/axum-extra/src/extract/user_lang/sources/header.rs @@ -0,0 +1,151 @@ +use axum::async_trait; +use std::cmp::Ordering; + +use crate::extract::user_lang::UserLanguageSource; + +/// A [`UserLanguageSource`] that reads languages from the `Accept-Language` header. +/// +/// This source may return multiple languages. Languages are returned in order of their +/// quality values. +/// +/// # Example +/// +/// ```rust +/// # use axum::{Router, extract::Extension, routing::get}; +/// # use axum_extra::extract::user_lang::{UserLanguage, AcceptLanguageSource}; +/// # +/// let source = AcceptLanguageSource; +/// +/// let app = Router::new() +/// .route("/home", get(handler)) +/// .layer( +/// Extension( +/// UserLanguage::config() +/// .add_source(source) +/// .build(), +/// )); +/// +/// # let _: Router = app; +/// # async fn handler() {} +/// ``` +#[derive(Debug, Clone)] +pub struct AcceptLanguageSource; + +#[async_trait] +impl UserLanguageSource for AcceptLanguageSource { + async fn languages_from_parts(&self, parts: &mut http::request::Parts) -> Vec { + let Some(accept_language) = parts.headers.get("Accept-Language") else { + return vec![]; + }; + + let Ok(accept_language) = accept_language.to_str() else { + return vec![]; + }; + + parse_quality_values(accept_language) + .into_iter() + .filter(|(lang, _)| *lang != "*") + .map(|(lang, _)| lang.to_owned()) + .collect() + } +} + +/// Parse quality values from the `Accept-Language` header. +fn parse_quality_values(values: &str) -> Vec<(&str, f32)> { + let values = values.split(','); + let mut quality_values = Vec::new(); + + for value in values { + let mut value = value.trim().split(';'); + let (value, quality) = (value.next(), value.next()); + + let Some(value) = value else { + // empty quality value entry + continue; + }; + + if value.is_empty() { + // empty quality value entry + continue; + } + + let quality = if let Some(quality) = quality.and_then(|q| q.strip_prefix("q=")) { + quality.parse::().unwrap_or(0.0) + } else { + 1.0 + }; + + quality_values.push((value, quality)); + } + + quality_values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal)); + quality_values +} + +#[cfg(test)] +mod tests { + use super::*; + use http::{header::ACCEPT_LANGUAGE, Request}; + + #[tokio::test] + async fn reads_language_from_accept_header() { + let source = AcceptLanguageSource; + + let request: Request<()> = Request::builder() + .header(ACCEPT_LANGUAGE, "fr,de;q=0.8,en;q=0.9") + .body(()) + .unwrap(); + + let (mut parts, _) = request.into_parts(); + + let languages = source.languages_from_parts(&mut parts).await; + + assert_eq!( + languages, + vec!["fr".to_owned(), "en".to_owned(), "de".to_owned()] + ); + } + + #[tokio::test] + async fn ignores_wildcard_lang() { + let source = AcceptLanguageSource; + + let request: Request<()> = Request::builder() + .header(ACCEPT_LANGUAGE, "fr,de;q=0.8,*;q=0.9") + .body(()) + .unwrap(); + + let (mut parts, _) = request.into_parts(); + + let languages = source.languages_from_parts(&mut parts).await; + + assert_eq!(languages, vec!["fr".to_owned(), "de".to_owned()]); + } + + #[test] + fn parsing_quality_values() { + let values = "fr-CH, fr;q=0.9, en;q=0.8, de;q=0.7, *;q=0.5"; + + let parsed = parse_quality_values(values); + + assert_eq!( + parsed, + vec![ + ("fr-CH", 1.0), + ("fr", 0.9), + ("en", 0.8), + ("de", 0.7), + ("*", 0.5), + ] + ); + } + + #[test] + fn empty_quality_values() { + let values = ""; + + let parsed = parse_quality_values(values); + + assert_eq!(parsed, vec![]); + } +} diff --git a/axum-extra/src/extract/user_lang/sources/mod.rs b/axum-extra/src/extract/user_lang/sources/mod.rs new file mode 100644 index 0000000000..33b9483be7 --- /dev/null +++ b/axum-extra/src/extract/user_lang/sources/mod.rs @@ -0,0 +1,11 @@ +mod header; +mod path; + +#[cfg(feature = "query")] +mod query; + +pub use header::*; +pub use path::*; + +#[cfg(feature = "query")] +pub use query::*; diff --git a/axum-extra/src/extract/user_lang/sources/path.rs b/axum-extra/src/extract/user_lang/sources/path.rs new file mode 100644 index 0000000000..996412f202 --- /dev/null +++ b/axum-extra/src/extract/user_lang/sources/path.rs @@ -0,0 +1,92 @@ +use axum::{async_trait, extract::Path, RequestPartsExt}; +use std::collections::HashMap; + +use crate::extract::user_lang::UserLanguageSource; + +/// A source that reads the user language from the request path. +/// +/// When creating this source you specify the name of the path +/// segment to read the language from. The routes you want to extract +/// the language from must include a path segment with the configured +/// name for this source to be able to read the language. +/// +/// # Example +/// +/// The following example will read the language from +/// the path segment `lang_id`. Your routes need to include +/// a `:lang_id` path segment that will contain the language. +/// +/// ```rust +/// # use axum::{Router, extract::Extension, routing::get}; +/// # use axum_extra::extract::user_lang::{UserLanguage, PathSource}; +/// # +/// // The path segment name is `lang_id`. +/// let source = PathSource::new("lang_id"); +/// +/// // The routes need to include a `:lang_id` path segment. +/// let app = Router::new() +/// .route("/home/:lang_id", get(handler)) +/// .layer( +/// Extension( +/// UserLanguage::config() +/// .add_source(source) +/// .build(), +/// )); +/// +/// # let _: Router = app; +/// # async fn handler() {} +/// ``` +#[derive(Debug, Clone)] +pub struct PathSource { + name: String, +} + +impl PathSource { + /// Create a new path source with a given path segment name. + pub fn new(name: impl Into) -> Self { + Self { name: name.into() } + } + + fn languages_from_path(&self, path: Path>) -> Vec { + let Some(lang) = path.get(self.name.as_str()) else { + return vec![]; + }; + + vec![lang.to_owned()] + } +} + +#[async_trait] +impl UserLanguageSource for PathSource { + async fn languages_from_parts(&self, parts: &mut http::request::Parts) -> Vec { + let Ok(path) = parts.extract::>>().await else { + return vec![]; + }; + + self.languages_from_path(path) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn reads_language_from_path() { + let source = PathSource::new("lang"); + + // We cannot setup the Path extractor here, as it requires + // UrlParams in the request extensions, which is private to axum. + // + // Instead we test loading from the extracted path directly. + let path = Path({ + let mut path_matches = HashMap::new(); + path_matches.insert("lang".to_owned(), "it".to_owned()); + path_matches + }); + + let languages = source.languages_from_path(path); + + assert_eq!(languages, vec!["it".to_owned()]); + } +} diff --git a/axum-extra/src/extract/user_lang/sources/query.rs b/axum-extra/src/extract/user_lang/sources/query.rs new file mode 100644 index 0000000000..31b3333256 --- /dev/null +++ b/axum-extra/src/extract/user_lang/sources/query.rs @@ -0,0 +1,84 @@ +use axum::{async_trait, RequestPartsExt}; +use std::collections::HashMap; + +use crate::extract::{Query, user_lang::UserLanguageSource}; + +/// A [`UserLanguageSource`] that reads the language from a field in the +/// query string. +/// +/// When creating this source you specify the name of the query +/// field to read the language from. You can add multiple `QuerySource` +/// instances to read from different fields. +/// +/// # Example +/// +/// The following example will read the language from +/// the query field `lang_id`. +/// +/// ```rust +/// # use axum::{Router, extract::Extension, routing::get}; +/// # use axum_extra::extract::user_lang::{UserLanguage, QuerySource}; +/// # +/// // The query field name is `lang_id`. +/// let source = QuerySource::new("lang_id"); +/// +/// let app = Router::new() +/// .route("/home", get(handler)) +/// .layer( +/// Extension( +/// UserLanguage::config() +/// .add_source(source) +/// .build(), +/// )); +/// +/// # let _: Router = app; +/// # async fn handler() {} +/// ``` +#[derive(Debug, Clone)] +pub struct QuerySource { + name: String, +} + +impl QuerySource { + /// Create a new query source with a given query field name. + pub fn new(name: impl Into) -> Self { + Self { name: name.into() } + } +} + +#[async_trait] +impl UserLanguageSource for QuerySource { + async fn languages_from_parts(&self, parts: &mut http::request::Parts) -> Vec { + let Ok(query) = parts.extract::>>().await else { + return vec![]; + }; + + let Some(lang) = query.get(self.name.as_str()) else { + return vec![]; + }; + + vec![lang.to_owned()] + } +} + +#[cfg(test)] +mod tests { + use super::*; + use http::{Request, Uri}; + + #[tokio::test] + async fn reads_language_from_query() { + let source = QuerySource::new("lang"); + + let request: Request<()> = Request::builder() + .uri(Uri::builder().path_and_query("/?lang=de").build().unwrap()) + .body(()) + .unwrap(); + + let (mut parts, _) = request.into_parts(); + + let languages = source.languages_from_parts(&mut parts).await; + + assert_eq!(languages, vec!["de".to_owned()]); + } +} diff --git a/examples/user-language/Cargo.toml b/examples/user-language/Cargo.toml new file mode 100644 index 0000000000..000e3604fd --- /dev/null +++ b/examples/user-language/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "example-user-language" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +axum = { path = "../../axum" } +axum-extra = { path = "../../axum-extra" } +tokio = { version = "1.0", features = ["full"] } diff --git a/examples/user-language/src/main.rs b/examples/user-language/src/main.rs new file mode 100644 index 0000000000..0f88a46d3b --- /dev/null +++ b/examples/user-language/src/main.rs @@ -0,0 +1,48 @@ +//! Run with +//! +//! ```not_rust +//! cargo run -p example-user-language +//! ``` + +use axum::{response::Html, routing::get, Extension, Router}; +use axum_extra::extract::user_lang::{PathSource, QuerySource, UserLanguage}; + +#[tokio::main] +async fn main() { + // build our application with some routes + let app = Router::new() + .route("/", get(handler)) + .route("/:lang", get(handler)) + // Add configuration for the `UserLanguage` extractor. + // This step is optional, if omitted the default + // configuration will be used. + .layer(Extension( + UserLanguage::config() + // read the language from the `lang` query parameter + .add_source(QuerySource::new("lang")) + // read the language from the `:lang` segment of the path + .add_source(PathSource::new("lang")) + .build(), + )); + + // run it + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .unwrap(); + println!("listening on {}", listener.local_addr().unwrap()); + axum::serve(listener, app).await.unwrap(); +} + +async fn handler(lang: UserLanguage) -> Html<&'static str> { + println!( + "User prefers content in the following languages (in order): {:?}", + lang.preferred_languages() + ); + + match lang.preferred_language() { + "de" => Html("

Hallo, Welt!

"), + "es" => Html("

Hola, Mundo!

"), + "fr" => Html("

Bonjour, le monde!

"), + _ => Html("

Hello, World!

"), + } +}