diff --git a/node/Cargo.toml b/node/Cargo.toml index 05e2487b5a..85d01542ea 100644 --- a/node/Cargo.toml +++ b/node/Cargo.toml @@ -44,6 +44,9 @@ version = "2.0" [dependencies.num_cpus] version = "1" +[dependencies.once_cell] +version = "1" + [dependencies.parking_lot] version = "0.12" diff --git a/node/src/beacon/mod.rs b/node/src/beacon/mod.rs index fa99dfccf0..89a2a78caa 100644 --- a/node/src/beacon/mod.rs +++ b/node/src/beacon/mod.rs @@ -97,6 +97,9 @@ impl> Beacon { ) -> Result { let timer = timer!("Beacon::new"); + // Initialize the signal handler. + let signal_node = Self::handle_signals(); + // Initialize the ledger. let ledger = Ledger::load(genesis, dev)?; lap!(timer, "Initialize the ledger"); @@ -155,8 +158,8 @@ impl> Beacon { node.initialize_routing().await; // Initialize the block production. node.initialize_block_production().await; - // Initialize the signal handler. - node.handle_signals(); + // Pass the node to the signal handler. + let _ = signal_node.set(node.clone()); lap!(timer, "Initialize the handlers"); finish!(timer); diff --git a/node/src/client/mod.rs b/node/src/client/mod.rs index 33aa25e0b8..21b7e6a736 100644 --- a/node/src/client/mod.rs +++ b/node/src/client/mod.rs @@ -60,6 +60,9 @@ impl> Client { genesis: Block, dev: Option, ) -> Result { + // Initialize the signal handler. + let signal_node = Self::handle_signals(); + // Initialize the node router. let router = Router::new( node_ip, @@ -83,8 +86,8 @@ impl> Client { }; // Initialize the routing. node.initialize_routing().await; - // Initialize the signal handler. - node.handle_signals(); + // Pass the node to the signal handler. + let _ = signal_node.set(node.clone()); // Return the node. Ok(node) } diff --git a/node/src/prover/mod.rs b/node/src/prover/mod.rs index ed0f14ac8e..120ba1fe7d 100644 --- a/node/src/prover/mod.rs +++ b/node/src/prover/mod.rs @@ -77,6 +77,9 @@ impl> Prover { genesis: Block, dev: Option, ) -> Result { + // Initialize the signal handler. + let signal_node = Self::handle_signals(); + // Initialize the node router. let router = Router::new( node_ip, @@ -108,8 +111,8 @@ impl> Prover { node.initialize_routing().await; // Initialize the coinbase puzzle. node.initialize_coinbase_puzzle().await; - // Initialize the signal handler. - node.handle_signals(); + // Pass the node to the signal handler. + let _ = signal_node.set(node.clone()); // Return the node. Ok(node) } diff --git a/node/src/traits.rs b/node/src/traits.rs index 173bafffe4..7675f45295 100644 --- a/node/src/traits.rs +++ b/node/src/traits.rs @@ -16,6 +16,9 @@ use snarkos_node_messages::NodeType; use snarkos_node_router::Routing; use snarkvm::prelude::{Address, Network, PrivateKey, ViewKey}; +use once_cell::sync::OnceCell; +use std::sync::Arc; + #[async_trait] pub trait NodeInterface: Routing { /// Returns the node type. @@ -45,17 +48,25 @@ pub trait NodeInterface: Routing { /// Handles OS signals for the node to intercept and perform a clean shutdown. /// Note: Only Ctrl-C is supported; it should work on both Unix-family systems and Windows. - fn handle_signals(&self) { - let node = self.clone(); + fn handle_signals() -> Arc> { + // In order for the signal handler to be started as early as possible, a reference to the node needs + // to be passed to it at a later time. + let node: Arc> = Default::default(); + + let node_clone = node.clone(); tokio::task::spawn(async move { match tokio::signal::ctrl_c().await { Ok(()) => { - node.shut_down().await; + if let Some(node) = node_clone.get() { + node.shut_down().await; + } std::process::exit(0); } Err(error) => error!("tokio::signal::ctrl_c encountered an error: {}", error), } }); + + node } /// Shuts down the node. diff --git a/node/src/validator/mod.rs b/node/src/validator/mod.rs index 6eda7ad5b9..682fbd0fd9 100644 --- a/node/src/validator/mod.rs +++ b/node/src/validator/mod.rs @@ -72,6 +72,9 @@ impl> Validator { cdn: Option, dev: Option, ) -> Result { + // Initialize the signal handler. + let signal_node = Self::handle_signals(); + // Initialize the ledger. let ledger = Ledger::load(genesis, dev)?; // Initialize the CDN. @@ -114,8 +117,8 @@ impl> Validator { node.initialize_sync()?; // Initialize the routing. node.initialize_routing().await; - // Initialize the signal handler. - node.handle_signals(); + // Pass the node to the signal handler. + let _ = signal_node.set(node.clone()); // Return the node. Ok(node) }