From d9e1d6136a4a62def75e4d26c4d919b5411f3655 Mon Sep 17 00:00:00 2001 From: David Thomas Date: Sun, 24 Mar 2024 13:06:55 +0000 Subject: [PATCH] Run rustfmt --- src/main.rs | 166 ++++++++++++++++++++++++++++---------------------- src/models.rs | 16 ++--- 2 files changed, 100 insertions(+), 82 deletions(-) diff --git a/src/main.rs b/src/main.rs index 67ba9d2..3ad1962 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,18 +1,15 @@ #![warn(clippy::pedantic)] -use std::{str::FromStr, collections::HashMap}; +use std::{collections::HashMap, str::FromStr, sync::OnceLock, time::Duration}; use anyhow::Result; -use std::sync::OnceLock; -use hmac::{Mac as _, digest::FixedOutput}; +use axum::{extract::Path, http::HeaderValue, response::Response}; +use hmac::{digest::FixedOutput, Mac as _}; use subtle::ConstantTimeEq; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; -use axum::{http::HeaderValue, response::Response}; - - -mod models; mod macros; +mod models; type ResponseResult = Result; @@ -44,13 +41,12 @@ impl From for PatreonTierInfo { entitled_servers: match tier { PatreonTier::Basic => 2, PatreonTier::Extra => 5, - } + }, } } } - -#[derive(serde::Deserialize, serde::Serialize, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(serde::Deserialize, serde::Serialize, Clone, Copy, PartialEq, Eq, Hash)] #[serde(transparent)] struct DiscordUserId(u64); @@ -61,8 +57,10 @@ struct Config { extra_tier_id: String, webhook_secret: String, bind_address: Option, - #[serde(default)] preset_members: Vec, - #[serde(deserialize_with = "add_bearer")] creator_access_token: HeaderValue, + #[serde(default)] + preset_members: Vec, + #[serde(deserialize_with = "add_bearer")] + creator_access_token: HeaderValue, } fn add_bearer<'de, D: serde::Deserializer<'de>>(deserializer: D) -> Result { @@ -81,7 +79,6 @@ struct State { static STATE: OnceLock = OnceLock::new(); - fn check_md5(key: &[u8], untrusted_signature: &[u8], untrusted_data: &[u8]) -> Result { let mut mac = hmac::Hmac::::new_from_slice(key)?; mac.update(untrusted_data); @@ -90,15 +87,12 @@ fn check_md5(key: &[u8], untrusted_signature: &[u8], untrusted_data: &[u8]) -> R Ok(correct_sig.ct_eq(untrusted_signature).into()) } - - - #[derive(serde::Deserialize)] struct FetchMember { - member_id: DiscordUserId + member_id: DiscordUserId, } -async fn fetch_member(axum::extract::Path(payload): axum::extract::Path) -> impl axum::response::IntoResponse { +async fn fetch_member(Path(payload): Path) -> impl axum::response::IntoResponse { let state = STATE.get().unwrap(); let members = state.members.read().expect("poison"); @@ -106,7 +100,7 @@ async fn fetch_member(axum::extract::Path(payload): axum::extract::Path impl axum::response::IntoResponse { - let state = STATE.get().unwrap(); + let state = STATE.get().unwrap(); let members = state.members.read().expect("poison"); axum::Json(members.clone()) @@ -117,20 +111,23 @@ async fn refresh_members() { state.refresh_task.send(()).await.unwrap(); } -async fn webhook_recv( - headers: axum::http::HeaderMap, - payload: String, -) -> ResponseResult<()> { +async fn webhook_recv(headers: axum::http::HeaderMap, payload: String) -> ResponseResult<()> { if check_md5( STATE.get().unwrap().config.webhook_secret.as_bytes(), require!(headers.get("X-Patreon-Signature"), Ok(())).as_bytes(), - payload.as_bytes() + payload.as_bytes(), )? { return Err(Error::SignatureMismatch); }; let event = require!(headers.get("X-Patreon-Event"), Ok(())).to_str()?; - if matches!(event, "members:pledge:create" | "members:pledge:delete" | "members:pledge:update" | "members:create") { + if matches!( + event, + "members:pledge:create" + | "members:pledge:delete" + | "members:pledge:update" + | "members:create" + ) { fill_members().await?; // Just refresh all the members } else { tracing::info!("Unknown event: {event}"); @@ -139,45 +136,47 @@ async fn webhook_recv( Ok(()) } - #[tokio::main] async fn main() -> Result<()> { let fmt_layer = tracing_subscriber::fmt::layer(); let filter = tracing_subscriber::filter::LevelFilter::from_str( - &std::env::var("LOG_LEVEL") - .unwrap_or_else(|_| String::from("INFO")) + &std::env::var("LOG_LEVEL").unwrap_or_else(|_| String::from("INFO")), )?; - tracing_subscriber::registry().with(fmt_layer).with(filter).init(); + tracing_subscriber::registry() + .with(fmt_layer) + .with(filter) + .init(); let mut config: Config = toml::from_str(&std::fs::read_to_string("config.toml")?)?; let bind_address: std::net::SocketAddr = config.bind_address.take().unwrap().parse()?; - STATE.set(State { + let state = State { config, reqwest: reqwest::Client::new(), members: std::sync::RwLock::new(HashMap::new()), refresh_task: { let (tx, mut rx) = tokio::sync::mpsc::channel(1); - tokio::spawn(async move {loop { - let res = tokio::time::timeout( - std::time::Duration::from_secs(60 * 60), - rx.recv() - ).await; + tokio::spawn(async move { + loop { + let res = tokio::time::timeout(Duration::from_secs(60 * 60), rx.recv()).await; - if res.as_ref().map(Option::is_none).unwrap_or(false) { - break - } + if res.as_ref().map(Option::is_none).unwrap_or(false) { + break; + } - match fill_members().await { - Ok(len) => tracing::info!("Refreshed {len} members"), - Err(err) => tracing::error!("{err:?}"), + match fill_members().await { + Ok(len) => tracing::info!("Refreshed {len} members"), + Err(err) => tracing::error!("{err:?}"), + } } - }}); + }); tx - } - }).is_err().then(|| unreachable!()); + }, + }; + + STATE.set(state).is_err().then(|| unreachable!()); fill_members().await?; @@ -191,46 +190,62 @@ async fn main() -> Result<()> { let listener = tokio::net::TcpListener::bind(bind_address).await?; axum::serve(listener, app.into_make_service()) - .with_graceful_shutdown(async {drop(tokio::signal::ctrl_c().await)}) + .with_graceful_shutdown(async { drop(tokio::signal::ctrl_c().await) }) .await?; Ok(()) } +const BASE_URL: &str = "https://www.patreon.com/api/oauth2/v2"; +fn check_tier(member: &models::RawPatreonMember, tier_id: &str) -> bool { + member + .relationships + .currently_entitled_tiers + .data + .iter() + .any(|tier| tier.id == tier_id) +} -const BASE_URL: &str = "https://www.patreon.com/api/oauth2/v2"; +fn get_member_tier( + config: &Config, + member: &models::RawPatreonMember, + user: &models::RawPatreonUser, +) -> Option<(DiscordUserId, Option)> { + let socials = user.attributes.social_connections.as_ref()?; + let user_id = socials.discord.as_ref()?.user_id.as_ref()?; + + let id = DiscordUserId(user_id.parse().unwrap()); + let tier = if check_tier(member, &config.extra_tier_id) { + Some(PatreonTier::Extra) + } else if check_tier(member, &config.basic_tier_id) { + Some(PatreonTier::Basic) + } else { + None + }; -fn get_member_tier(config: &Config, member: &models::RawPatreonMember, user: &models::RawPatreonUser) -> Option<(DiscordUserId, Option)> { - user.attributes.social_connections.as_ref().and_then(|socials| socials.discord.as_ref()).and_then(|discord_info| { - let check_tier = |tier_id| member.relationships.currently_entitled_tiers.data.iter().any(|tier| tier_id == &tier.id); - - discord_info.user_id.as_ref().map(|user_id| ( - DiscordUserId(user_id.parse().unwrap()), - if check_tier(&config.extra_tier_id) { - Some(PatreonTier::Extra) - } else if check_tier(&config.basic_tier_id) { - Some(PatreonTier::Basic) - } else { - None - } - )) - }) + Some((id, tier)) } async fn fill_members() -> Result { let state = STATE.get().unwrap(); - let mut url = reqwest::Url::parse(&format!("{BASE_URL}/campaigns/{}/members", state.config.campaign_id))?; + let reqwest = &state.reqwest; + let mut url = reqwest::Url::parse(&format!( + "{BASE_URL}/campaigns/{}/members", + state.config.campaign_id + ))?; + url.query_pairs_mut() .append_pair("fields[user]", "social_connections") .append_pair("include", "user,currently_entitled_tiers") .finish(); let mut next_cursor = Some(String::new()); - let headers = reqwest::header::HeaderMap::from_iter([ - (reqwest::header::AUTHORIZATION, state.config.creator_access_token.clone()) - ]); + let headers = reqwest::header::HeaderMap::from_iter([( + reqwest::header::AUTHORIZATION, + state.config.creator_access_token.clone(), + )]); let mut members = { let members = state.members.read().expect("poison"); @@ -241,10 +256,8 @@ async fn fill_members() -> Result { let mut url = url.clone(); url.query_pairs_mut().append_pair("page[cursor]", &cursor); - let resp: models::RawPatreonResponse = state.reqwest - .get(url).headers(headers.clone()) - .send().await?.error_for_status()? - .json().await?; + let resp = reqwest.get(url).headers(headers.clone()).send().await?; + let resp: models::RawPatreonResponse = resp.error_for_status()?.json().await?; members.extend(resp.data.into_iter().filter_map(|member| { let user_id = &member.relationships.user.data.id; @@ -255,10 +268,15 @@ async fn fill_members() -> Result { }) })); - next_cursor = resp.meta.pagination.cursors.and_then(|cursors| cursors.next); + next_cursor = resp + .meta + .pagination + .cursors + .and_then(|cursors| cursors.next); } - members.extend(state.config.preset_members.iter().map(|id| (*id, PatreonTierInfo::fake()))); + let preset_members = state.config.preset_members.iter(); + members.extend(preset_members.map(|id| (*id, PatreonTierInfo::fake()))); members.shrink_to_fit(); let len = members.len(); @@ -266,7 +284,6 @@ async fn fill_members() -> Result { Ok(len) } - #[derive(Debug)] enum Error { SignatureMismatch, @@ -293,7 +310,8 @@ impl axum::response::IntoResponse for Error { tracing::error!("{self:?}"); ( axum::http::StatusCode::INTERNAL_SERVER_ERROR, - self.to_string() - ).into_response() + self.to_string(), + ) + .into_response() } } diff --git a/src/models.rs b/src/models.rs index d099d7a..154e7f3 100644 --- a/src/models.rs +++ b/src/models.rs @@ -18,12 +18,12 @@ pub struct RawPatreonRelationships { #[derive(serde::Deserialize)] pub struct RawPatreonIdData { - pub data: RawPatreonId + pub data: RawPatreonId, } #[derive(serde::Deserialize)] pub struct RawPatreonId { - pub id: String + pub id: String, } #[derive(serde::Deserialize)] @@ -34,27 +34,27 @@ pub struct RawPatreonTierRelationship { #[derive(serde::Deserialize)] pub struct RawPatreonUser { pub id: String, - pub attributes: RawPatreonUserAttributes + pub attributes: RawPatreonUserAttributes, } #[derive(serde::Deserialize)] pub struct RawPatreonUserAttributes { - pub social_connections: Option + pub social_connections: Option, } #[derive(serde::Deserialize)] pub struct RawPatreonSocialConnections { - pub discord: Option + pub discord: Option, } #[derive(serde::Deserialize)] pub struct RawPatreonDiscordConnection { - pub user_id: Option + pub user_id: Option, } #[derive(serde::Deserialize)] pub struct RawPatreonMeta { - pub pagination: RawPatreonPagination + pub pagination: RawPatreonPagination, } #[derive(serde::Deserialize)] @@ -64,5 +64,5 @@ pub struct RawPatreonPagination { #[derive(serde::Deserialize)] pub struct RawPatreonCursors { - pub next: Option + pub next: Option, }