Skip to content

Commit

Permalink
Run rustfmt
Browse files Browse the repository at this point in the history
  • Loading branch information
GnomedDev committed Mar 24, 2024
1 parent a679d7e commit d9e1d61
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 82 deletions.
166 changes: 92 additions & 74 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
#![warn(clippy::pedantic)]

use std::{str::FromStr, collections::HashMap};
use std::{collections::HashMap, str::FromStr, sync::OnceLock, time::Duration};

use anyhow::Result;
use std::sync::OnceLock;
use hmac::{Mac as _, digest::FixedOutput};
use axum::{extract::Path, http::HeaderValue, response::Response};
use hmac::{digest::FixedOutput, Mac as _};
use subtle::ConstantTimeEq;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

use axum::{http::HeaderValue, response::Response};


mod models;
mod macros;
mod models;

type ResponseResult<T> = Result<T, Error>;

Expand Down Expand Up @@ -44,13 +41,12 @@ impl From<PatreonTier> for PatreonTierInfo {
entitled_servers: match tier {
PatreonTier::Basic => 2,
PatreonTier::Extra => 5,
}
},
}
}
}


#[derive(serde::Deserialize, serde::Serialize, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[derive(serde::Deserialize, serde::Serialize, Clone, Copy, PartialEq, Eq, Hash)]
#[serde(transparent)]
struct DiscordUserId(u64);

Expand All @@ -61,8 +57,10 @@ struct Config {
extra_tier_id: String,
webhook_secret: String,
bind_address: Option<String>,
#[serde(default)] preset_members: Vec<DiscordUserId>,
#[serde(deserialize_with = "add_bearer")] creator_access_token: HeaderValue,
#[serde(default)]
preset_members: Vec<DiscordUserId>,
#[serde(deserialize_with = "add_bearer")]
creator_access_token: HeaderValue,
}

fn add_bearer<'de, D: serde::Deserializer<'de>>(deserializer: D) -> Result<HeaderValue, D::Error> {
Expand All @@ -81,7 +79,6 @@ struct State {

static STATE: OnceLock<State> = OnceLock::new();


fn check_md5(key: &[u8], untrusted_signature: &[u8], untrusted_data: &[u8]) -> Result<bool> {
let mut mac = hmac::Hmac::<md5::Md5>::new_from_slice(key)?;
mac.update(untrusted_data);
Expand All @@ -90,23 +87,20 @@ fn check_md5(key: &[u8], untrusted_signature: &[u8], untrusted_data: &[u8]) -> R
Ok(correct_sig.ct_eq(untrusted_signature).into())
}




#[derive(serde::Deserialize)]
struct FetchMember {
member_id: DiscordUserId
member_id: DiscordUserId,
}

async fn fetch_member(axum::extract::Path(payload): axum::extract::Path<FetchMember>) -> impl axum::response::IntoResponse {
async fn fetch_member(Path(payload): Path<FetchMember>) -> impl axum::response::IntoResponse {
let state = STATE.get().unwrap();
let members = state.members.read().expect("poison");

axum::Json(members.get(&payload.member_id).copied())
}

async fn fetch_members() -> impl axum::response::IntoResponse {
let state = STATE.get().unwrap();
let state = STATE.get().unwrap();
let members = state.members.read().expect("poison");

axum::Json(members.clone())
Expand All @@ -117,20 +111,23 @@ async fn refresh_members() {
state.refresh_task.send(()).await.unwrap();
}

async fn webhook_recv(
headers: axum::http::HeaderMap,
payload: String,
) -> ResponseResult<()> {
async fn webhook_recv(headers: axum::http::HeaderMap, payload: String) -> ResponseResult<()> {
if check_md5(
STATE.get().unwrap().config.webhook_secret.as_bytes(),
require!(headers.get("X-Patreon-Signature"), Ok(())).as_bytes(),
payload.as_bytes()
payload.as_bytes(),
)? {
return Err(Error::SignatureMismatch);
};

let event = require!(headers.get("X-Patreon-Event"), Ok(())).to_str()?;
if matches!(event, "members:pledge:create" | "members:pledge:delete" | "members:pledge:update" | "members:create") {
if matches!(
event,
"members:pledge:create"
| "members:pledge:delete"
| "members:pledge:update"
| "members:create"
) {
fill_members().await?; // Just refresh all the members
} else {
tracing::info!("Unknown event: {event}");
Expand All @@ -139,45 +136,47 @@ async fn webhook_recv(
Ok(())
}


#[tokio::main]
async fn main() -> Result<()> {
let fmt_layer = tracing_subscriber::fmt::layer();
let filter = tracing_subscriber::filter::LevelFilter::from_str(
&std::env::var("LOG_LEVEL")
.unwrap_or_else(|_| String::from("INFO"))
&std::env::var("LOG_LEVEL").unwrap_or_else(|_| String::from("INFO")),
)?;

tracing_subscriber::registry().with(fmt_layer).with(filter).init();
tracing_subscriber::registry()
.with(fmt_layer)
.with(filter)
.init();

let mut config: Config = toml::from_str(&std::fs::read_to_string("config.toml")?)?;
let bind_address: std::net::SocketAddr = config.bind_address.take().unwrap().parse()?;

STATE.set(State {
let state = State {
config,
reqwest: reqwest::Client::new(),
members: std::sync::RwLock::new(HashMap::new()),
refresh_task: {
let (tx, mut rx) = tokio::sync::mpsc::channel(1);

tokio::spawn(async move {loop {
let res = tokio::time::timeout(
std::time::Duration::from_secs(60 * 60),
rx.recv()
).await;
tokio::spawn(async move {
loop {
let res = tokio::time::timeout(Duration::from_secs(60 * 60), rx.recv()).await;

if res.as_ref().map(Option::is_none).unwrap_or(false) {
break
}
if res.as_ref().map(Option::is_none).unwrap_or(false) {
break;
}

match fill_members().await {
Ok(len) => tracing::info!("Refreshed {len} members"),
Err(err) => tracing::error!("{err:?}"),
match fill_members().await {
Ok(len) => tracing::info!("Refreshed {len} members"),
Err(err) => tracing::error!("{err:?}"),
}
}
}});
});
tx
}
}).is_err().then(|| unreachable!());
},
};

STATE.set(state).is_err().then(|| unreachable!());

fill_members().await?;

Expand All @@ -191,46 +190,62 @@ async fn main() -> Result<()> {

let listener = tokio::net::TcpListener::bind(bind_address).await?;
axum::serve(listener, app.into_make_service())
.with_graceful_shutdown(async {drop(tokio::signal::ctrl_c().await)})
.with_graceful_shutdown(async { drop(tokio::signal::ctrl_c().await) })
.await?;

Ok(())
}

const BASE_URL: &str = "https://www.patreon.com/api/oauth2/v2";

fn check_tier(member: &models::RawPatreonMember, tier_id: &str) -> bool {
member
.relationships
.currently_entitled_tiers
.data
.iter()
.any(|tier| tier.id == tier_id)
}

const BASE_URL: &str = "https://www.patreon.com/api/oauth2/v2";
fn get_member_tier(
config: &Config,
member: &models::RawPatreonMember,
user: &models::RawPatreonUser,
) -> Option<(DiscordUserId, Option<PatreonTier>)> {
let socials = user.attributes.social_connections.as_ref()?;
let user_id = socials.discord.as_ref()?.user_id.as_ref()?;

let id = DiscordUserId(user_id.parse().unwrap());
let tier = if check_tier(member, &config.extra_tier_id) {
Some(PatreonTier::Extra)
} else if check_tier(member, &config.basic_tier_id) {
Some(PatreonTier::Basic)
} else {
None
};

fn get_member_tier(config: &Config, member: &models::RawPatreonMember, user: &models::RawPatreonUser) -> Option<(DiscordUserId, Option<PatreonTier>)> {
user.attributes.social_connections.as_ref().and_then(|socials| socials.discord.as_ref()).and_then(|discord_info| {
let check_tier = |tier_id| member.relationships.currently_entitled_tiers.data.iter().any(|tier| tier_id == &tier.id);

discord_info.user_id.as_ref().map(|user_id| (
DiscordUserId(user_id.parse().unwrap()),
if check_tier(&config.extra_tier_id) {
Some(PatreonTier::Extra)
} else if check_tier(&config.basic_tier_id) {
Some(PatreonTier::Basic)
} else {
None
}
))
})
Some((id, tier))
}

async fn fill_members() -> Result<usize> {
let state = STATE.get().unwrap();

let mut url = reqwest::Url::parse(&format!("{BASE_URL}/campaigns/{}/members", state.config.campaign_id))?;
let reqwest = &state.reqwest;
let mut url = reqwest::Url::parse(&format!(
"{BASE_URL}/campaigns/{}/members",
state.config.campaign_id
))?;

url.query_pairs_mut()
.append_pair("fields[user]", "social_connections")
.append_pair("include", "user,currently_entitled_tiers")
.finish();

let mut next_cursor = Some(String::new());
let headers = reqwest::header::HeaderMap::from_iter([
(reqwest::header::AUTHORIZATION, state.config.creator_access_token.clone())
]);
let headers = reqwest::header::HeaderMap::from_iter([(
reqwest::header::AUTHORIZATION,
state.config.creator_access_token.clone(),
)]);

let mut members = {
let members = state.members.read().expect("poison");
Expand All @@ -241,10 +256,8 @@ async fn fill_members() -> Result<usize> {
let mut url = url.clone();
url.query_pairs_mut().append_pair("page[cursor]", &cursor);

let resp: models::RawPatreonResponse = state.reqwest
.get(url).headers(headers.clone())
.send().await?.error_for_status()?
.json().await?;
let resp = reqwest.get(url).headers(headers.clone()).send().await?;
let resp: models::RawPatreonResponse = resp.error_for_status()?.json().await?;

members.extend(resp.data.into_iter().filter_map(|member| {
let user_id = &member.relationships.user.data.id;
Expand All @@ -255,18 +268,22 @@ async fn fill_members() -> Result<usize> {
})
}));

next_cursor = resp.meta.pagination.cursors.and_then(|cursors| cursors.next);
next_cursor = resp
.meta
.pagination
.cursors
.and_then(|cursors| cursors.next);
}

members.extend(state.config.preset_members.iter().map(|id| (*id, PatreonTierInfo::fake())));
let preset_members = state.config.preset_members.iter();
members.extend(preset_members.map(|id| (*id, PatreonTierInfo::fake())));
members.shrink_to_fit();

let len = members.len();
*state.members.write().expect("poison") = members;
Ok(len)
}


#[derive(Debug)]
enum Error {
SignatureMismatch,
Expand All @@ -293,7 +310,8 @@ impl axum::response::IntoResponse for Error {
tracing::error!("{self:?}");
(
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
self.to_string()
).into_response()
self.to_string(),
)
.into_response()
}
}
16 changes: 8 additions & 8 deletions src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ pub struct RawPatreonRelationships {

#[derive(serde::Deserialize)]
pub struct RawPatreonIdData {
pub data: RawPatreonId
pub data: RawPatreonId,
}

#[derive(serde::Deserialize)]
pub struct RawPatreonId {
pub id: String
pub id: String,
}

#[derive(serde::Deserialize)]
Expand All @@ -34,27 +34,27 @@ pub struct RawPatreonTierRelationship {
#[derive(serde::Deserialize)]
pub struct RawPatreonUser {
pub id: String,
pub attributes: RawPatreonUserAttributes
pub attributes: RawPatreonUserAttributes,
}

#[derive(serde::Deserialize)]
pub struct RawPatreonUserAttributes {
pub social_connections: Option<RawPatreonSocialConnections>
pub social_connections: Option<RawPatreonSocialConnections>,
}

#[derive(serde::Deserialize)]
pub struct RawPatreonSocialConnections {
pub discord: Option<RawPatreonDiscordConnection>
pub discord: Option<RawPatreonDiscordConnection>,
}

#[derive(serde::Deserialize)]
pub struct RawPatreonDiscordConnection {
pub user_id: Option<String>
pub user_id: Option<String>,
}

#[derive(serde::Deserialize)]
pub struct RawPatreonMeta {
pub pagination: RawPatreonPagination
pub pagination: RawPatreonPagination,
}

#[derive(serde::Deserialize)]
Expand All @@ -64,5 +64,5 @@ pub struct RawPatreonPagination {

#[derive(serde::Deserialize)]
pub struct RawPatreonCursors {
pub next: Option<String>
pub next: Option<String>,
}

0 comments on commit d9e1d61

Please sign in to comment.