diff --git a/node/rest/Cargo.toml b/node/rest/Cargo.toml index 35e6c7a419..f9e853b2a6 100644 --- a/node/rest/Cargo.toml +++ b/node/rest/Cargo.toml @@ -81,6 +81,7 @@ version = "1" [dependencies.tower] version = "0.4" +features = ["buffer", "limit"] [dependencies.tower_governor] version = "0.3" diff --git a/node/rest/src/lib.rs b/node/rest/src/lib.rs index 360dc5fd4e..c31b9d5462 100644 --- a/node/rest/src/lib.rs +++ b/node/rest/src/lib.rs @@ -36,18 +36,21 @@ use snarkvm::{ use anyhow::Result; use axum::{ body::Body, + error_handling::HandleErrorLayer, extract::{ConnectInfo, DefaultBodyLimit, Path, Query, State}, http::{header::CONTENT_TYPE, Method, Request, StatusCode}, middleware, middleware::Next, response::Response, routing::{get, post}, + BoxError, Json, }; use axum_extra::response::ErasedJson; use parking_lot::Mutex; -use std::{net::SocketAddr, sync::Arc}; +use std::{net::SocketAddr, sync::Arc, time::Duration}; use tokio::{net::TcpListener, task::JoinHandle}; +use tower::{buffer::BufferLayer, limit::RateLimitLayer, ServiceBuilder}; use tower_governor::{governor::GovernorConfigBuilder, GovernorLayer}; use tower_http::{ cors::{Any, CorsLayer}, @@ -117,6 +120,18 @@ impl, R: Routing> Rest { .expect("Couldn't set up rate limiting for the REST server!"), ); + let global_rate_limiter = ServiceBuilder::new() + .layer(HandleErrorLayer::new(|err: BoxError| async move { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Unhandled error: {}", err), + ) + })) + // Use a buffer layer to allow rate limiting. + .layer(BufferLayer::new(1024)) + // Apply a global rate limit of 100 requests/s. + .layer(RateLimitLayer::new(100, Duration::from_secs(1))); + let router = { axum::Router::new() @@ -197,6 +212,7 @@ impl, R: Routing> Rest { // We can leak this because it is created only once and it persists. config: Box::leak(governor_config), }) + .layer(global_rate_limiter) }; let rest_listener = TcpListener::bind(rest_ip).await.unwrap();