Skip to content

Commit

Permalink
Merge pull request #2942 from ljedrz/feat/rest_rate_limiting
Browse files Browse the repository at this point in the history
Per-IP REST server rate limiting
  • Loading branch information
howardwu authored Dec 27, 2023
2 parents d3df68e + c1d8d5f commit 4dd3c09
Show file tree
Hide file tree
Showing 9 changed files with 378 additions and 63 deletions.
319 changes: 289 additions & 30 deletions Cargo.lock

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions cli/src/commands/start.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ pub struct Start {
/// Specify the IP address and port for the REST server
#[clap(default_value = "0.0.0.0:3033", long = "rest")]
pub rest: SocketAddr,
/// Specify the requests per second (RPS) rate limit per IP for the REST server
#[clap(default_value = "10", long = "rest-rps")]
pub rest_rps: u32,
/// If the flag is set, the node will not initialize the REST server
#[clap(long)]
pub norest: bool,
Expand Down Expand Up @@ -439,9 +442,9 @@ impl Start {
// Initialize the node.
let bft_ip = if self.dev.is_some() { self.bft } else { None };
match node_type {
NodeType::Validator => Node::new_validator(self.node, rest_ip, bft_ip, account, &trusted_peers, &trusted_validators, genesis, cdn, self.dev).await,
NodeType::Validator => Node::new_validator(self.node, bft_ip, rest_ip, self.rest_rps, account, &trusted_peers, &trusted_validators, genesis, cdn, self.dev).await,
NodeType::Prover => Node::new_prover(self.node, account, &trusted_peers, genesis, self.dev).await,
NodeType::Client => Node::new_client(self.node, rest_ip, account, &trusted_peers, genesis, cdn, self.dev).await,
NodeType::Client => Node::new_client(self.node, rest_ip, self.rest_rps, account, &trusted_peers, genesis, cdn, self.dev).await,
}
}

Expand Down
15 changes: 10 additions & 5 deletions node/rest/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@ parallel = [ "rayon" ]
version = "1.0.76"

[dependencies.axum]
version = "0.6"
features = [ "headers" ]
version = "0.7"

[dependencies.axum-extra]
version = "0.8.0"
features = [ "erased-json" ]
version = "0.9.0"
features = [ "erased-json", "typed-header" ]

[dependencies.http]
version = "1.0"
Expand Down Expand Up @@ -80,8 +79,14 @@ version = "0.3"
[dependencies.tokio]
version = "1"

[dependencies.tower-http]
[dependencies.tower]
version = "0.4"

[dependencies.tower_governor]
version = "0.2"

[dependencies.tower-http]
version = "0.5"
features = [ "cors", "trace" ]

[dependencies.tracing]
Expand Down
10 changes: 5 additions & 5 deletions node/rest/src/helpers/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@ use snarkvm::prelude::*;
use ::time::OffsetDateTime;
use anyhow::{anyhow, Result};
use axum::{
headers::authorization::{Authorization, Bearer},
body::Body,
http::{Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
RequestPartsExt,
};
use axum_extra::{
headers::authorization::{Authorization, Bearer},
TypedHeader,
};
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
Expand Down Expand Up @@ -70,10 +73,7 @@ impl Claims {
}
}

pub async fn auth_middleware<B>(request: Request<B>, next: Next<B>) -> Result<Response, Response>
where
B: Send,
{
pub async fn auth_middleware(request: Request<Body>, next: Next) -> Result<Response, Response> {
// Deconstruct the request to extract the auth token.
let (mut parts, body) = request.into_parts();
let auth: TypedHeader<Authorization<Bearer>> =
Expand Down
53 changes: 40 additions & 13 deletions node/rest/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,23 @@ use snarkvm::{

use anyhow::Result;
use axum::{
body::Body,
error_handling::HandleErrorLayer,
extract::{ConnectInfo, DefaultBodyLimit, Path, Query, State},
http::{header::CONTENT_TYPE, Method, Request, StatusCode},
middleware,
middleware::Next,
response::Response,
routing::{get, post},
BoxError,
Json,
};
use axum_extra::response::ErasedJson;
use parking_lot::Mutex;
use std::{net::SocketAddr, sync::Arc};
use tokio::task::JoinHandle;
use tokio::{net::TcpListener, task::JoinHandle};
use tower::ServiceBuilder;
use tower_governor::{errors::display_error, governor::GovernorConfigBuilder, GovernorLayer};
use tower_http::{
cors::{Any, CorsLayer},
trace::TraceLayer,
Expand All @@ -67,16 +72,17 @@ pub struct Rest<N: Network, C: ConsensusStorage<N>, R: Routing<N>> {

impl<N: Network, C: 'static + ConsensusStorage<N>, R: Routing<N>> Rest<N, C, R> {
/// Initializes a new instance of the server.
pub fn start(
pub async fn start(
rest_ip: SocketAddr,
rest_rps: u32,
consensus: Option<Consensus<N>>,
ledger: Ledger<N, C>,
routing: Arc<R>,
) -> Result<Self> {
// Initialize the server.
let mut server = Self { consensus, ledger, routing, handles: Default::default() };
// Spawn the server.
server.spawn_server(rest_ip);
server.spawn_server(rest_ip, rest_rps).await;
// Return the server.
Ok(server)
}
Expand All @@ -95,12 +101,24 @@ impl<N: Network, C: ConsensusStorage<N>, R: Routing<N>> Rest<N, C, R> {
}

impl<N: Network, C: ConsensusStorage<N>, R: Routing<N>> Rest<N, C, R> {
fn spawn_server(&mut self, rest_ip: SocketAddr) {
async fn spawn_server(&mut self, rest_ip: SocketAddr, rest_rps: u32) {
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods([Method::GET, Method::POST, Method::OPTIONS])
.allow_headers([CONTENT_TYPE]);

// Log the REST rate limit per IP.
debug!("REST rate limit per IP - {rest_rps} RPS");

// Prepare the rate limiting setup.
let governor_config = Box::new(
GovernorConfigBuilder::default()
.per_second(1)
.burst_size(rest_rps)
.finish()
.expect("Couldn't set up rate limiting for the REST server!"),
);

let router = {
axum::Router::new()

Expand Down Expand Up @@ -174,25 +192,34 @@ impl<N: Network, C: ConsensusStorage<N>, R: Routing<N>> Rest<N, C, R> {
.layer(cors)
// Cap body size at 10MB.
.layer(DefaultBodyLimit::max(10 * 1024 * 1024))
.layer(
ServiceBuilder::new()
// this middleware goes above `GovernorLayer` because it will receive
// errors returned by `GovernorLayer`
.layer(HandleErrorLayer::new(|e: BoxError| async move {
display_error(e)
}))
.layer(GovernorLayer {
// We can leak this because it is created only once and it persists.
config: Box::leak(governor_config),
}),
)
};

let rest_listener = TcpListener::bind(rest_ip).await.unwrap();
self.handles.lock().push(tokio::spawn(async move {
axum::Server::bind(&rest_ip)
.serve(router.into_make_service_with_connect_info::<SocketAddr>())
axum::serve(rest_listener, router.into_make_service_with_connect_info::<SocketAddr>())
.await
.expect("couldn't start rest server");
}))
}
}

async fn log_middleware<B>(
async fn log_middleware(
ConnectInfo(addr): ConnectInfo<SocketAddr>,
request: Request<B>,
next: Next<B>,
) -> Result<Response, StatusCode>
where
B: Send,
{
request: Request<Body>,
next: Next,
) -> Result<Response, StatusCode> {
info!("Received '{} {}' from '{addr}'", request.method(), request.uri());

Ok(next.run(request).await)
Expand Down
3 changes: 2 additions & 1 deletion node/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ impl<N: Network, C: ConsensusStorage<N>> Client<N, C> {
pub async fn new(
node_ip: SocketAddr,
rest_ip: Option<SocketAddr>,
rest_rps: u32,
account: Account<N>,
trusted_peers: &[SocketAddr],
genesis: Block<N>,
Expand Down Expand Up @@ -134,7 +135,7 @@ impl<N: Network, C: ConsensusStorage<N>> Client<N, C> {

// Initialize the REST server.
if let Some(rest_ip) = rest_ip {
node.rest = Some(Rest::start(rest_ip, None, ledger.clone(), Arc::new(node.clone()))?);
node.rest = Some(Rest::start(rest_ip, rest_rps, None, ledger.clone(), Arc::new(node.clone())).await?);
}
// Initialize the routing.
node.initialize_routing().await;
Expand Down
23 changes: 19 additions & 4 deletions node/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ impl<N: Network> Node<N> {
/// Initializes a new validator node.
pub async fn new_validator(
node_ip: SocketAddr,
rest_ip: Option<SocketAddr>,
bft_ip: Option<SocketAddr>,
rest_ip: Option<SocketAddr>,
rest_rps: u32,
account: Account<N>,
trusted_peers: &[SocketAddr],
trusted_validators: &[SocketAddr],
Expand All @@ -50,8 +51,19 @@ impl<N: Network> Node<N> {
dev: Option<u16>,
) -> Result<Self> {
Ok(Self::Validator(Arc::new(
Validator::new(node_ip, rest_ip, bft_ip, account, trusted_peers, trusted_validators, genesis, cdn, dev)
.await?,
Validator::new(
node_ip,
bft_ip,
rest_ip,
rest_rps,
account,
trusted_peers,
trusted_validators,
genesis,
cdn,
dev,
)
.await?,
)))
}

Expand All @@ -70,13 +82,16 @@ impl<N: Network> Node<N> {
pub async fn new_client(
node_ip: SocketAddr,
rest_ip: Option<SocketAddr>,
rest_rps: u32,
account: Account<N>,
trusted_peers: &[SocketAddr],
genesis: Block<N>,
cdn: Option<String>,
dev: Option<u16>,
) -> Result<Self> {
Ok(Self::Client(Arc::new(Client::new(node_ip, rest_ip, account, trusted_peers, genesis, cdn, dev).await?)))
Ok(Self::Client(Arc::new(
Client::new(node_ip, rest_ip, rest_rps, account, trusted_peers, genesis, cdn, dev).await?,
)))
}

/// Returns the node type.
Expand Down
9 changes: 6 additions & 3 deletions node/src/validator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ impl<N: Network, C: ConsensusStorage<N>> Validator<N, C> {
/// Initializes a new validator node.
pub async fn new(
node_ip: SocketAddr,
rest_ip: Option<SocketAddr>,
bft_ip: Option<SocketAddr>,
rest_ip: Option<SocketAddr>,
rest_rps: u32,
account: Account<N>,
trusted_peers: &[SocketAddr],
trusted_validators: &[SocketAddr],
Expand Down Expand Up @@ -141,7 +142,8 @@ impl<N: Network, C: ConsensusStorage<N>> Validator<N, C> {

// Initialize the REST server.
if let Some(rest_ip) = rest_ip {
node.rest = Some(Rest::start(rest_ip, Some(consensus), ledger.clone(), Arc::new(node.clone()))?);
node.rest =
Some(Rest::start(rest_ip, rest_rps, Some(consensus), ledger.clone(), Arc::new(node.clone())).await?);
}
// Initialize the routing.
node.initialize_routing().await;
Expand Down Expand Up @@ -483,8 +485,9 @@ mod tests {

let validator = Validator::<CurrentNetwork, ConsensusMemory<CurrentNetwork>>::new(
node,
Some(rest),
None,
Some(rest),
10,
account,
&[],
&[],
Expand Down
2 changes: 2 additions & 0 deletions node/tests/common/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub async fn client() -> Client<CurrentNetwork, ConsensusMemory<CurrentNetwork>>
Client::new(
"127.0.0.1:0".parse().unwrap(),
None,
10,
Account::<CurrentNetwork>::from_str("APrivateKey1zkp2oVPTci9kKcUprnbzMwq95Di1MQERpYBhEeqvkrDirK1").unwrap(),
&[],
sample_genesis_block(),
Expand Down Expand Up @@ -50,6 +51,7 @@ pub async fn validator() -> Validator<CurrentNetwork, ConsensusMemory<CurrentNet
"127.0.0.1:0".parse().unwrap(),
None,
None,
10,
Account::<CurrentNetwork>::from_str("APrivateKey1zkp2oVPTci9kKcUprnbzMwq95Di1MQERpYBhEeqvkrDirK1").unwrap(),
&[],
&[],
Expand Down

0 comments on commit 4dd3c09

Please sign in to comment.