Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOB] allow the restriction of CORS origins #2825

Open
wants to merge 1 commit into
base: testnet3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions cli/src/commands/start.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ pub struct Start {
/// If the flag is set, the node will not initialize the REST server
#[clap(long)]
pub norest: bool,
/// If present, this will restrict the CORS origins for the REST server to the provided list.
#[clap(long)]
pub allowed_origins: Vec<String>,

/// If the flag is set, the node will not render the display
#[clap(long)]
Expand Down Expand Up @@ -394,6 +397,7 @@ impl Start {
true => None,
false => Some(self.rest),
};
let allowed_origins = std::mem::take(&mut self.allowed_origins);

// If the display is not enabled, render the welcome message.
if self.nodisplay {
Expand Down Expand Up @@ -431,9 +435,9 @@ impl Start {
// Initialize the node.
let bft_ip = if self.dev.is_some() { self.bft } else { None };
match node_type {
NodeType::Validator => Node::new_validator(self.node, rest_ip, bft_ip, account, &trusted_peers, &trusted_validators, genesis, cdn, self.dev).await,
NodeType::Validator => Node::new_validator(self.node, rest_ip, bft_ip, account, &trusted_peers, &trusted_validators, genesis, cdn, self.dev, allowed_origins).await,
NodeType::Prover => Node::new_prover(self.node, account, &trusted_peers, genesis, self.dev).await,
NodeType::Client => Node::new_client(self.node, rest_ip, account, &trusted_peers, genesis, cdn, self.dev).await,
NodeType::Client => Node::new_client(self.node, rest_ip, account, &trusted_peers, genesis, cdn, self.dev, allowed_origins).await,
}
}

Expand Down
17 changes: 12 additions & 5 deletions node/rest/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use snarkvm::{
use anyhow::Result;
use axum::{
extract::{ConnectInfo, DefaultBodyLimit, Path, Query, State},
http::{header::CONTENT_TYPE, Method, Request, StatusCode},
http::{header::CONTENT_TYPE, HeaderValue, Method, Request, StatusCode},
middleware,
middleware::Next,
response::Response,
Expand All @@ -48,7 +48,7 @@ use parking_lot::Mutex;
use std::{net::SocketAddr, sync::Arc};
use tokio::task::JoinHandle;
use tower_http::{
cors::{Any, CorsLayer},
cors::{AllowOrigin, CorsLayer},
trace::TraceLayer,
};

Expand All @@ -72,11 +72,12 @@ impl<N: Network, C: 'static + ConsensusStorage<N>, R: Routing<N>> Rest<N, C, R>
consensus: Option<Consensus<N>>,
ledger: Ledger<N, C>,
routing: Arc<R>,
allowed_origins: Vec<String>,
) -> Result<Self> {
// Initialize the server.
let mut server = Self { consensus, ledger, routing, handles: Default::default() };
// Spawn the server.
server.spawn_server(rest_ip);
server.spawn_server(rest_ip, allowed_origins);
// Return the server.
Ok(server)
}
Expand All @@ -95,9 +96,15 @@ impl<N: Network, C: ConsensusStorage<N>, R: Routing<N>> Rest<N, C, R> {
}

impl<N: Network, C: ConsensusStorage<N>, R: Routing<N>> Rest<N, C, R> {
fn spawn_server(&mut self, rest_ip: SocketAddr) {
fn spawn_server(&mut self, rest_ip: SocketAddr, allowed_origins: Vec<String>) {
let allowed_origins = if allowed_origins.is_empty() {
AllowOrigin::any()
} else {
let origins = allowed_origins.iter().map(|s| HeaderValue::from_str(s).expect("Invalid CORS origin"));
AllowOrigin::list(origins)
};
SplittyDev marked this conversation as resolved.
Show resolved Hide resolved
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_origin(allowed_origins)
.allow_methods([Method::GET, Method::POST, Method::OPTIONS])
.allow_headers([CONTENT_TYPE]);

Expand Down
3 changes: 2 additions & 1 deletion node/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ impl<N: Network, C: ConsensusStorage<N>> Client<N, C> {
genesis: Block<N>,
cdn: Option<String>,
dev: Option<u16>,
allowed_origins: Vec<String>,
) -> Result<Self> {
// Initialize the signal handler.
let signal_node = Self::handle_signals();
Expand Down Expand Up @@ -129,7 +130,7 @@ impl<N: Network, C: ConsensusStorage<N>> Client<N, C> {

// Initialize the REST server.
if let Some(rest_ip) = rest_ip {
node.rest = Some(Rest::start(rest_ip, None, ledger.clone(), Arc::new(node.clone()))?);
node.rest = Some(Rest::start(rest_ip, None, ledger.clone(), Arc::new(node.clone()), allowed_origins)?);
}
// Initialize the routing.
node.initialize_routing().await;
Expand Down
21 changes: 18 additions & 3 deletions node/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,22 @@ impl<N: Network> Node<N> {
genesis: Block<N>,
cdn: Option<String>,
dev: Option<u16>,
allowed_origins: Vec<String>,
) -> Result<Self> {
Ok(Self::Validator(Arc::new(
Validator::new(node_ip, rest_ip, bft_ip, account, trusted_peers, trusted_validators, genesis, cdn, dev)
.await?,
Validator::new(
node_ip,
rest_ip,
bft_ip,
account,
trusted_peers,
trusted_validators,
genesis,
cdn,
dev,
allowed_origins,
)
.await?,
)))
}

Expand All @@ -75,8 +87,11 @@ impl<N: Network> Node<N> {
genesis: Block<N>,
cdn: Option<String>,
dev: Option<u16>,
allowed_origins: Vec<String>,
) -> Result<Self> {
Ok(Self::Client(Arc::new(Client::new(node_ip, rest_ip, account, trusted_peers, genesis, cdn, dev).await?)))
Ok(Self::Client(Arc::new(
Client::new(node_ip, rest_ip, account, trusted_peers, genesis, cdn, dev, allowed_origins).await?,
)))
}

/// Returns the node type.
Expand Down
5 changes: 4 additions & 1 deletion node/src/validator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ impl<N: Network, C: ConsensusStorage<N>> Validator<N, C> {
genesis: Block<N>,
cdn: Option<String>,
dev: Option<u16>,
allowed_origins: Vec<String>,
) -> Result<Self> {
// Initialize the signal handler.
let signal_node = Self::handle_signals();
Expand Down Expand Up @@ -136,7 +137,8 @@ impl<N: Network, C: ConsensusStorage<N>> Validator<N, C> {

// Initialize the REST server.
if let Some(rest_ip) = rest_ip {
node.rest = Some(Rest::start(rest_ip, Some(consensus), ledger.clone(), Arc::new(node.clone()))?);
node.rest =
Some(Rest::start(rest_ip, Some(consensus), ledger.clone(), Arc::new(node.clone()), allowed_origins)?);
}
// Initialize the routing.
node.initialize_routing().await;
Expand Down Expand Up @@ -486,6 +488,7 @@ mod tests {
genesis,
None,
dev,
vec![],
)
.await
.unwrap();
Expand Down
2 changes: 2 additions & 0 deletions node/tests/common/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub async fn client() -> Client<CurrentNetwork, ConsensusMemory<CurrentNetwork>>
sample_genesis_block(),
None, // No CDN.
None,
vec![],
)
.await
.expect("couldn't create client instance")
Expand Down Expand Up @@ -56,6 +57,7 @@ pub async fn validator() -> Validator<CurrentNetwork, ConsensusMemory<CurrentNet
sample_genesis_block(), // Should load the current network's genesis block.
None, // No CDN.
None,
vec![],
)
.await
.expect("couldn't create validator instance")
Expand Down