diff --git a/cli/src/commands/start.rs b/cli/src/commands/start.rs index eb35fe87db..93ac6c8f34 100644 --- a/cli/src/commands/start.rs +++ b/cli/src/commands/start.rs @@ -93,6 +93,9 @@ pub struct Start { /// If the flag is set, the node will not initialize the REST server #[clap(long)] pub norest: bool, + /// If present, this will restrict the CORS origins for the REST server to the provided list. + #[clap(long)] + pub allowed_origins: Vec, /// If the flag is set, the node will not render the display #[clap(long)] @@ -394,6 +397,7 @@ impl Start { true => None, false => Some(self.rest), }; + let allowed_origins = std::mem::take(&mut self.allowed_origins); // If the display is not enabled, render the welcome message. if self.nodisplay { @@ -431,9 +435,9 @@ impl Start { // Initialize the node. let bft_ip = if self.dev.is_some() { self.bft } else { None }; match node_type { - NodeType::Validator => Node::new_validator(self.node, rest_ip, bft_ip, account, &trusted_peers, &trusted_validators, genesis, cdn, self.dev).await, + NodeType::Validator => Node::new_validator(self.node, rest_ip, bft_ip, account, &trusted_peers, &trusted_validators, genesis, cdn, self.dev, allowed_origins).await, NodeType::Prover => Node::new_prover(self.node, account, &trusted_peers, genesis, self.dev).await, - NodeType::Client => Node::new_client(self.node, rest_ip, account, &trusted_peers, genesis, cdn, self.dev).await, + NodeType::Client => Node::new_client(self.node, rest_ip, account, &trusted_peers, genesis, cdn, self.dev, allowed_origins).await, } } diff --git a/node/rest/src/lib.rs b/node/rest/src/lib.rs index fa0d0c9991..2235398786 100644 --- a/node/rest/src/lib.rs +++ b/node/rest/src/lib.rs @@ -36,7 +36,7 @@ use snarkvm::{ use anyhow::Result; use axum::{ extract::{ConnectInfo, DefaultBodyLimit, Path, Query, State}, - http::{header::CONTENT_TYPE, Method, Request, StatusCode}, + http::{header::CONTENT_TYPE, HeaderValue, Method, Request, StatusCode}, middleware, middleware::Next, response::Response, @@ -48,7 +48,7 @@ use parking_lot::Mutex; use std::{net::SocketAddr, sync::Arc}; use tokio::task::JoinHandle; use tower_http::{ - cors::{Any, CorsLayer}, + cors::{AllowOrigin, CorsLayer}, trace::TraceLayer, }; @@ -72,11 +72,12 @@ impl, R: Routing> Rest consensus: Option>, ledger: Ledger, routing: Arc, + allowed_origins: Vec, ) -> Result { // Initialize the server. let mut server = Self { consensus, ledger, routing, handles: Default::default() }; // Spawn the server. - server.spawn_server(rest_ip); + server.spawn_server(rest_ip, allowed_origins); // Return the server. Ok(server) } @@ -95,9 +96,15 @@ impl, R: Routing> Rest { } impl, R: Routing> Rest { - fn spawn_server(&mut self, rest_ip: SocketAddr) { + fn spawn_server(&mut self, rest_ip: SocketAddr, allowed_origins: Vec) { + let allowed_origins = if allowed_origins.is_empty() { + AllowOrigin::any() + } else { + let origins = allowed_origins.iter().map(|s| HeaderValue::from_str(&s).expect("Invalid CORS origin")); + AllowOrigin::list(origins) + }; let cors = CorsLayer::new() - .allow_origin(Any) + .allow_origin(allowed_origins) .allow_methods([Method::GET, Method::POST, Method::OPTIONS]) .allow_headers([CONTENT_TYPE]); diff --git a/node/src/client/mod.rs b/node/src/client/mod.rs index dd244c74d8..ff5465f027 100644 --- a/node/src/client/mod.rs +++ b/node/src/client/mod.rs @@ -81,6 +81,7 @@ impl> Client { genesis: Block, cdn: Option, dev: Option, + allowed_origins: Vec, ) -> Result { // Initialize the signal handler. let signal_node = Self::handle_signals(); @@ -129,7 +130,7 @@ impl> Client { // Initialize the REST server. if let Some(rest_ip) = rest_ip { - node.rest = Some(Rest::start(rest_ip, None, ledger.clone(), Arc::new(node.clone()))?); + node.rest = Some(Rest::start(rest_ip, None, ledger.clone(), Arc::new(node.clone()), allowed_origins)?); } // Initialize the routing. node.initialize_routing().await; diff --git a/node/src/node.rs b/node/src/node.rs index 6049b248f4..d11d8f6aa9 100644 --- a/node/src/node.rs +++ b/node/src/node.rs @@ -48,10 +48,22 @@ impl Node { genesis: Block, cdn: Option, dev: Option, + allowed_origins: Vec, ) -> Result { Ok(Self::Validator(Arc::new( - Validator::new(node_ip, rest_ip, bft_ip, account, trusted_peers, trusted_validators, genesis, cdn, dev) - .await?, + Validator::new( + node_ip, + rest_ip, + bft_ip, + account, + trusted_peers, + trusted_validators, + genesis, + cdn, + dev, + allowed_origins, + ) + .await?, ))) } @@ -75,8 +87,11 @@ impl Node { genesis: Block, cdn: Option, dev: Option, + allowed_origins: Vec, ) -> Result { - Ok(Self::Client(Arc::new(Client::new(node_ip, rest_ip, account, trusted_peers, genesis, cdn, dev).await?))) + Ok(Self::Client(Arc::new( + Client::new(node_ip, rest_ip, account, trusted_peers, genesis, cdn, dev, allowed_origins).await?, + ))) } /// Returns the node type. diff --git a/node/src/validator/mod.rs b/node/src/validator/mod.rs index 0c6d74293f..c92256805f 100644 --- a/node/src/validator/mod.rs +++ b/node/src/validator/mod.rs @@ -81,6 +81,7 @@ impl> Validator { genesis: Block, cdn: Option, dev: Option, + allowed_origins: Vec, ) -> Result { // Initialize the signal handler. let signal_node = Self::handle_signals(); @@ -136,7 +137,8 @@ impl> Validator { // Initialize the REST server. if let Some(rest_ip) = rest_ip { - node.rest = Some(Rest::start(rest_ip, Some(consensus), ledger.clone(), Arc::new(node.clone()))?); + node.rest = + Some(Rest::start(rest_ip, Some(consensus), ledger.clone(), Arc::new(node.clone()), allowed_origins)?); } // Initialize the routing. node.initialize_routing().await;