Skip to content

Commit

Permalink
feat: allow the restriction of CORS origins
Browse files Browse the repository at this point in the history
Signed-off-by: ljedrz <[email protected]>
  • Loading branch information
ljedrz committed Nov 7, 2023
1 parent 3f845c1 commit a8afed5
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 12 deletions.
8 changes: 6 additions & 2 deletions cli/src/commands/start.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ pub struct Start {
/// If the flag is set, the node will not initialize the REST server
#[clap(long)]
pub norest: bool,
/// If present, this will restrict the CORS origins for the REST server to the provided list.
#[clap(long)]
pub allowed_origins: Vec<String>,

/// If the flag is set, the node will not render the display
#[clap(long)]
Expand Down Expand Up @@ -394,6 +397,7 @@ impl Start {
true => None,
false => Some(self.rest),
};
let allowed_origins = std::mem::take(&mut self.allowed_origins);

// If the display is not enabled, render the welcome message.
if self.nodisplay {
Expand Down Expand Up @@ -431,9 +435,9 @@ impl Start {
// Initialize the node.
let bft_ip = if self.dev.is_some() { self.bft } else { None };
match node_type {
NodeType::Validator => Node::new_validator(self.node, rest_ip, bft_ip, account, &trusted_peers, &trusted_validators, genesis, cdn, self.dev).await,
NodeType::Validator => Node::new_validator(self.node, rest_ip, bft_ip, account, &trusted_peers, &trusted_validators, genesis, cdn, self.dev, allowed_origins).await,
NodeType::Prover => Node::new_prover(self.node, account, &trusted_peers, genesis, self.dev).await,
NodeType::Client => Node::new_client(self.node, rest_ip, account, &trusted_peers, genesis, cdn, self.dev).await,
NodeType::Client => Node::new_client(self.node, rest_ip, account, &trusted_peers, genesis, cdn, self.dev, allowed_origins).await,
}
}

Expand Down
17 changes: 12 additions & 5 deletions node/rest/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use snarkvm::{
use anyhow::Result;
use axum::{
extract::{ConnectInfo, DefaultBodyLimit, Path, Query, State},
http::{header::CONTENT_TYPE, Method, Request, StatusCode},
http::{header::CONTENT_TYPE, HeaderValue, Method, Request, StatusCode},
middleware,
middleware::Next,
response::Response,
Expand All @@ -48,7 +48,7 @@ use parking_lot::Mutex;
use std::{net::SocketAddr, sync::Arc};
use tokio::task::JoinHandle;
use tower_http::{
cors::{Any, CorsLayer},
cors::{AllowOrigin, CorsLayer},
trace::TraceLayer,
};

Expand All @@ -72,11 +72,12 @@ impl<N: Network, C: 'static + ConsensusStorage<N>, R: Routing<N>> Rest<N, C, R>
consensus: Option<Consensus<N>>,
ledger: Ledger<N, C>,
routing: Arc<R>,
allowed_origins: Vec<String>,
) -> Result<Self> {
// Initialize the server.
let mut server = Self { consensus, ledger, routing, handles: Default::default() };
// Spawn the server.
server.spawn_server(rest_ip);
server.spawn_server(rest_ip, allowed_origins);
// Return the server.
Ok(server)
}
Expand All @@ -95,9 +96,15 @@ impl<N: Network, C: ConsensusStorage<N>, R: Routing<N>> Rest<N, C, R> {
}

impl<N: Network, C: ConsensusStorage<N>, R: Routing<N>> Rest<N, C, R> {
fn spawn_server(&mut self, rest_ip: SocketAddr) {
fn spawn_server(&mut self, rest_ip: SocketAddr, allowed_origins: Vec<String>) {
let allowed_origins = if allowed_origins.is_empty() {
AllowOrigin::any()
} else {
let origins = allowed_origins.iter().map(|s| HeaderValue::from_str(&s).expect("Invalid CORS origin"));
AllowOrigin::list(origins)
};
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_origin(allowed_origins)
.allow_methods([Method::GET, Method::POST, Method::OPTIONS])
.allow_headers([CONTENT_TYPE]);

Expand Down
3 changes: 2 additions & 1 deletion node/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ impl<N: Network, C: ConsensusStorage<N>> Client<N, C> {
genesis: Block<N>,
cdn: Option<String>,
dev: Option<u16>,
allowed_origins: Vec<String>,
) -> Result<Self> {
// Initialize the signal handler.
let signal_node = Self::handle_signals();
Expand Down Expand Up @@ -129,7 +130,7 @@ impl<N: Network, C: ConsensusStorage<N>> Client<N, C> {

// Initialize the REST server.
if let Some(rest_ip) = rest_ip {
node.rest = Some(Rest::start(rest_ip, None, ledger.clone(), Arc::new(node.clone()))?);
node.rest = Some(Rest::start(rest_ip, None, ledger.clone(), Arc::new(node.clone()), allowed_origins)?);
}
// Initialize the routing.
node.initialize_routing().await;
Expand Down
21 changes: 18 additions & 3 deletions node/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,22 @@ impl<N: Network> Node<N> {
genesis: Block<N>,
cdn: Option<String>,
dev: Option<u16>,
allowed_origins: Vec<String>,
) -> Result<Self> {
Ok(Self::Validator(Arc::new(
Validator::new(node_ip, rest_ip, bft_ip, account, trusted_peers, trusted_validators, genesis, cdn, dev)
.await?,
Validator::new(
node_ip,
rest_ip,
bft_ip,
account,
trusted_peers,
trusted_validators,
genesis,
cdn,
dev,
allowed_origins,
)
.await?,
)))
}

Expand All @@ -75,8 +87,11 @@ impl<N: Network> Node<N> {
genesis: Block<N>,
cdn: Option<String>,
dev: Option<u16>,
allowed_origins: Vec<String>,
) -> Result<Self> {
Ok(Self::Client(Arc::new(Client::new(node_ip, rest_ip, account, trusted_peers, genesis, cdn, dev).await?)))
Ok(Self::Client(Arc::new(
Client::new(node_ip, rest_ip, account, trusted_peers, genesis, cdn, dev, allowed_origins).await?,
)))
}

/// Returns the node type.
Expand Down
4 changes: 3 additions & 1 deletion node/src/validator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ impl<N: Network, C: ConsensusStorage<N>> Validator<N, C> {
genesis: Block<N>,
cdn: Option<String>,
dev: Option<u16>,
allowed_origins: Vec<String>,
) -> Result<Self> {
// Initialize the signal handler.
let signal_node = Self::handle_signals();
Expand Down Expand Up @@ -136,7 +137,8 @@ impl<N: Network, C: ConsensusStorage<N>> Validator<N, C> {

// Initialize the REST server.
if let Some(rest_ip) = rest_ip {
node.rest = Some(Rest::start(rest_ip, Some(consensus), ledger.clone(), Arc::new(node.clone()))?);
node.rest =
Some(Rest::start(rest_ip, Some(consensus), ledger.clone(), Arc::new(node.clone()), allowed_origins)?);
}
// Initialize the routing.
node.initialize_routing().await;
Expand Down

0 comments on commit a8afed5

Please sign in to comment.