diff --git a/src/rust/catmem/mod.rs b/src/rust/catmem/mod.rs index b36e5ad61..2c1ef814a 100644 --- a/src/rust/catmem/mod.rs +++ b/src/rust/catmem/mod.rs @@ -220,16 +220,16 @@ impl SharedCatmemLibOS { pub fn wait_next_n bool>( &mut self, mut acceptor: Acceptor, - timeout: Duration - ) -> Result<(), Fail> - { - self.runtime.clone().wait_next_n( - |qt, qd, result| acceptor(self.create_result(result, qd, qt)), timeout) + timeout: Duration, + ) -> Result<(), Fail> { + self.runtime + .clone() + .wait_next_n(|qt, qd, result| acceptor(self.create_result(result, qd, qt)), timeout) } /// Waits for any operation in an I/O queue. - pub fn poll(&mut self) { - self.runtime.poll() + pub fn poll_task(&mut self, qt: QToken) { + self.runtime.poll_task(qt) } pub fn create_result(&self, result: OperationResult, qd: QDesc, qt: QToken) -> demi_qresult_t { diff --git a/src/rust/catnap/win/overlapped.rs b/src/rust/catnap/win/overlapped.rs index d7f38486e..284a47612 100644 --- a/src/rust/catnap/win/overlapped.rs +++ b/src/rust/catnap/win/overlapped.rs @@ -689,7 +689,7 @@ mod tests { let result: OperationResult = loop { iocp.get_mut().process_events()?; - runtime.poll(); + runtime.poll_task(server_task); if let Some((_, result)) = runtime.get_completed_task(&server_task) { break result; } diff --git a/src/rust/demikernel/libos/memory.rs b/src/rust/demikernel/libos/memory.rs index bb7221bfe..c5f17d7b5 100644 --- a/src/rust/demikernel/libos/memory.rs +++ b/src/rust/demikernel/libos/memory.rs @@ -119,9 +119,8 @@ impl MemoryLibOS { pub fn wait_next_n bool>( &mut self, acceptor: Acceptor, - timeout: Duration - ) -> Result<(), Fail> - { + timeout: Duration, + ) -> Result<(), Fail> { trace!("wait_next_n(): acceptor, timeout={:?}", timeout); match self { #[cfg(feature = "catmem-libos")] @@ -130,7 +129,6 @@ impl MemoryLibOS { } } - /// Allocates a scatter-gather array. #[allow(unreachable_patterns, unused_variables)] pub fn sgaalloc(&self, size: usize) -> Result { @@ -153,10 +151,10 @@ impl MemoryLibOS { /// Waits for any operation in an I/O queue. #[allow(unreachable_patterns, unused_variables)] - pub fn poll(&mut self) { + pub fn poll_task(&mut self, qt: QToken) { match self { #[cfg(feature = "catmem-libos")] - MemoryLibOS::Catmem(libos) => libos.poll(), + MemoryLibOS::Catmem(libos) => libos.poll_task(qt), _ => unreachable!("unknown memory libos"), } } diff --git a/src/rust/demikernel/libos/mod.rs b/src/rust/demikernel/libos/mod.rs index dca73d062..eb6997d20 100644 --- a/src/rust/demikernel/libos/mod.rs +++ b/src/rust/demikernel/libos/mod.rs @@ -182,53 +182,41 @@ impl LibOS { /// Creates a new memory queue and connect to consumer end. #[allow(unused_variables)] pub fn create_pipe(&mut self, name: &str) -> Result { - let result: Result = { - timer!("demikernel::create_pipe"); - match self { - #[cfg(any( - feature = "catnap-libos", - feature = "catnip-libos", - feature = "catpowder-libos", - feature = "catloop-libos" - ))] - LibOS::NetworkLibOS(_) => Err(Fail::new( - libc::ENOTSUP, - "create_pipe() is not supported on network liboses", - )), - #[cfg(feature = "catmem-libos")] - LibOS::MemoryLibOS(libos) => libos.create_pipe(name), - } - }; - - self.poll(); - - result + timer!("demikernel::create_pipe"); + match self { + #[cfg(any( + feature = "catnap-libos", + feature = "catnip-libos", + feature = "catpowder-libos", + feature = "catloop-libos" + ))] + LibOS::NetworkLibOS(_) => Err(Fail::new( + libc::ENOTSUP, + "create_pipe() is not supported on network liboses", + )), + #[cfg(feature = "catmem-libos")] + LibOS::MemoryLibOS(libos) => libos.create_pipe(name), + } } /// Opens an existing memory queue and connects to producer end. #[allow(unused_variables)] pub fn open_pipe(&mut self, name: &str) -> Result { - let result: Result = { - timer!("demikernel::open_pipe"); - match self { - #[cfg(any( - feature = "catnap-libos", - feature = "catnip-libos", - feature = "catpowder-libos", - feature = "catloop-libos" - ))] - LibOS::NetworkLibOS(_) => Err(Fail::new( - libc::ENOTSUP, - "open_pipe() is not supported on network liboses", - )), - #[cfg(feature = "catmem-libos")] - LibOS::MemoryLibOS(libos) => libos.open_pipe(name), - } - }; - - self.poll(); - - result + timer!("demikernel::open_pipe"); + match self { + #[cfg(any( + feature = "catnap-libos", + feature = "catnip-libos", + feature = "catpowder-libos", + feature = "catloop-libos" + ))] + LibOS::NetworkLibOS(_) => Err(Fail::new( + libc::ENOTSUP, + "open_pipe() is not supported on network liboses", + )), + #[cfg(feature = "catmem-libos")] + LibOS::MemoryLibOS(libos) => libos.open_pipe(name), + } } /// Creates a socket. @@ -239,150 +227,114 @@ impl LibOS { socket_type: libc::c_int, protocol: libc::c_int, ) -> Result { - let result: Result = { - timer!("demikernel::socket"); - match self { - #[cfg(any( - feature = "catnap-libos", - feature = "catnip-libos", - feature = "catpowder-libos", - feature = "catloop-libos" - ))] - LibOS::NetworkLibOS(libos) => libos.socket(domain, socket_type, protocol), - #[cfg(feature = "catmem-libos")] - LibOS::MemoryLibOS(_) => Err(Fail::new(libc::ENOTSUP, "socket() is not supported on memory liboses")), - } - }; - - self.poll(); - - result + timer!("demikernel::socket"); + match self { + #[cfg(any( + feature = "catnap-libos", + feature = "catnip-libos", + feature = "catpowder-libos", + feature = "catloop-libos" + ))] + LibOS::NetworkLibOS(libos) => libos.socket(domain, socket_type, protocol), + #[cfg(feature = "catmem-libos")] + LibOS::MemoryLibOS(_) => Err(Fail::new(libc::ENOTSUP, "socket() is not supported on memory liboses")), + } } /// Sets an SO_* option on the socket referenced by [sockqd]. pub fn set_socket_option(&mut self, sockqd: QDesc, option: SocketOption) -> Result<(), Fail> { - let result: Result<(), Fail> = { - match self { - #[cfg(any( - feature = "catnap-libos", - feature = "catnip-libos", - feature = "catpowder-libos", - feature = "catloop-libos" - ))] - LibOS::NetworkLibOS(libos) => libos.set_socket_option(sockqd, option), - #[cfg(feature = "catmem-libos")] - LibOS::MemoryLibOS(_) => { - let cause: String = format!("Socket options are not supported on memory liboses"); - error!("get_socket_option(): {}", cause); - Err(Fail::new(libc::ENOTSUP, &cause)) - }, - } - }; - - self.poll(); - - result + match self { + #[cfg(any( + feature = "catnap-libos", + feature = "catnip-libos", + feature = "catpowder-libos", + feature = "catloop-libos" + ))] + LibOS::NetworkLibOS(libos) => libos.set_socket_option(sockqd, option), + #[cfg(feature = "catmem-libos")] + LibOS::MemoryLibOS(_) => { + let cause: String = format!("Socket options are not supported on memory liboses"); + error!("get_socket_option(): {}", cause); + Err(Fail::new(libc::ENOTSUP, &cause)) + }, + } } /// Gets a SO_* option on the socket referenced by [sockqd]. pub fn get_socket_option(&mut self, sockqd: QDesc, option: SocketOption) -> Result { - let result: Result = { - match self { - #[cfg(any( - feature = "catnap-libos", - feature = "catnip-libos", - feature = "catpowder-libos", - feature = "catloop-libos" - ))] - LibOS::NetworkLibOS(libos) => libos.get_socket_option(sockqd, option), - #[cfg(feature = "catmem-libos")] - LibOS::MemoryLibOS(_) => { - let cause: String = format!("Socket options are not supported on memory liboses"); - error!("get_socket_option(): {}", cause); - Err(Fail::new(libc::ENOTSUP, &cause)) - }, - } - }; - - self.poll(); - - result + match self { + #[cfg(any( + feature = "catnap-libos", + feature = "catnip-libos", + feature = "catpowder-libos", + feature = "catloop-libos" + ))] + LibOS::NetworkLibOS(libos) => libos.get_socket_option(sockqd, option), + #[cfg(feature = "catmem-libos")] + LibOS::MemoryLibOS(_) => { + let cause: String = format!("Socket options are not supported on memory liboses"); + error!("get_socket_option(): {}", cause); + Err(Fail::new(libc::ENOTSUP, &cause)) + }, + } } pub fn getpeername(&mut self, sockqd: QDesc) -> Result { - let result: Result = { - match self { - #[cfg(any( - feature = "catnap-libos", - feature = "catnip-libos", - feature = "catpowder-libos", - feature = "catloop-libos" - ))] - LibOS::NetworkLibOS(libos) => libos.getpeername(sockqd), - #[cfg(feature = "catmem-libos")] - LibOS::MemoryLibOS(_) => { - let cause: String = format!("Peername is not supported on memory liboses"); - error!("getpeername(): {}", cause); - Err(Fail::new(libc::ENOTSUP, &cause)) - }, - } - }; - - self.poll(); - - result + match self { + #[cfg(any( + feature = "catnap-libos", + feature = "catnip-libos", + feature = "catpowder-libos", + feature = "catloop-libos" + ))] + LibOS::NetworkLibOS(libos) => libos.getpeername(sockqd), + #[cfg(feature = "catmem-libos")] + LibOS::MemoryLibOS(_) => { + let cause: String = format!("Peername is not supported on memory liboses"); + error!("getpeername(): {}", cause); + Err(Fail::new(libc::ENOTSUP, &cause)) + }, + } } /// Binds a socket to a local address. #[allow(unused_variables)] pub fn bind(&mut self, sockqd: QDesc, local: SocketAddr) -> Result<(), Fail> { - let result: Result<(), Fail> = { - timer!("demikernel::bind"); - match self { - #[cfg(any( - feature = "catnap-libos", - feature = "catnip-libos", - feature = "catpowder-libos", - feature = "catloop-libos" - ))] - LibOS::NetworkLibOS(libos) => libos.bind(sockqd, local), - #[cfg(feature = "catmem-libos")] - LibOS::MemoryLibOS(_) => Err(Fail::new(libc::ENOTSUP, "bind() is not supported on memory liboses")), - } - }; - - self.poll(); - - result + timer!("demikernel::bind"); + match self { + #[cfg(any( + feature = "catnap-libos", + feature = "catnip-libos", + feature = "catpowder-libos", + feature = "catloop-libos" + ))] + LibOS::NetworkLibOS(libos) => libos.bind(sockqd, local), + #[cfg(feature = "catmem-libos")] + LibOS::MemoryLibOS(_) => Err(Fail::new(libc::ENOTSUP, "bind() is not supported on memory liboses")), + } } /// Marks a socket as a passive one. #[allow(unused_variables)] pub fn listen(&mut self, sockqd: QDesc, backlog: usize) -> Result<(), Fail> { - let result: Result<(), Fail> = { - timer!("demikernel::listen"); - match self { - #[cfg(any( - feature = "catnap-libos", - feature = "catnip-libos", - feature = "catpowder-libos", - feature = "catloop-libos" - ))] - LibOS::NetworkLibOS(libos) => libos.listen(sockqd, backlog), - #[cfg(feature = "catmem-libos")] - LibOS::MemoryLibOS(_) => Err(Fail::new(libc::ENOTSUP, "listen() is not supported on memory liboses")), - } - }; - - self.poll(); - - result + timer!("demikernel::listen"); + match self { + #[cfg(any( + feature = "catnap-libos", + feature = "catnip-libos", + feature = "catpowder-libos", + feature = "catloop-libos" + ))] + LibOS::NetworkLibOS(libos) => libos.listen(sockqd, backlog), + #[cfg(feature = "catmem-libos")] + LibOS::MemoryLibOS(_) => Err(Fail::new(libc::ENOTSUP, "listen() is not supported on memory liboses")), + } } /// Accepts an incoming connection on a TCP socket. #[allow(unused_variables)] pub fn accept(&mut self, sockqd: QDesc) -> Result { - let result: Result = { + let qt: QToken = { timer!("demikernel::accept"); match self { #[cfg(any( @@ -395,17 +347,17 @@ impl LibOS { #[cfg(feature = "catmem-libos")] LibOS::MemoryLibOS(_) => Err(Fail::new(libc::ENOTSUP, "accept() is not supported on memory liboses")), } - }; + }?; - self.poll(); + self.poll_task(qt); - result + Ok(qt) } /// Initiates a connection with a remote TCP socket. #[allow(unused_variables)] pub fn connect(&mut self, sockqd: QDesc, remote: SocketAddr) -> Result { - let result: Result = { + let qt: QToken = { timer!("demikernel::connect"); match self { #[cfg(any( @@ -418,50 +370,44 @@ impl LibOS { #[cfg(feature = "catmem-libos")] LibOS::MemoryLibOS(_) => Err(Fail::new(libc::ENOTSUP, "connect() is not supported on memory liboses")), } - }; + }?; - self.poll(); + self.poll_task(qt); - result + Ok(qt) } /// Closes an I/O queue. /// async_close() + wait() achieves the same effect as synchronous close. pub fn close(&mut self, qd: QDesc) -> Result<(), Fail> { - let result: Result<(), Fail> = { - timer!("demikernel::close"); - match self { - #[cfg(any( - feature = "catnap-libos", - feature = "catnip-libos", - feature = "catpowder-libos", - feature = "catloop-libos" - ))] - LibOS::NetworkLibOS(libos) => match libos.async_close(qd) { - Ok(qt) => match self.wait(qt, None) { - Ok(_) => Ok(()), - Err(e) => Err(e), - }, + timer!("demikernel::close"); + match self { + #[cfg(any( + feature = "catnap-libos", + feature = "catnip-libos", + feature = "catpowder-libos", + feature = "catloop-libos" + ))] + LibOS::NetworkLibOS(libos) => match libos.async_close(qd) { + Ok(qt) => match self.wait(qt, None) { + Ok(_) => Ok(()), Err(e) => Err(e), }, - #[cfg(feature = "catmem-libos")] - LibOS::MemoryLibOS(libos) => match libos.async_close(qd) { - Ok(qt) => match self.wait(qt, None) { - Ok(_) => Ok(()), - Err(e) => Err(e), - }, + Err(e) => Err(e), + }, + #[cfg(feature = "catmem-libos")] + LibOS::MemoryLibOS(libos) => match libos.async_close(qd) { + Ok(qt) => match self.wait(qt, None) { + Ok(_) => Ok(()), Err(e) => Err(e), }, - } - }; - - self.poll(); - - result + Err(e) => Err(e), + }, + } } pub fn async_close(&mut self, qd: QDesc) -> Result { - let result: Result = { + let qt: QToken = { timer!("demikernel::async_close"); match self { #[cfg(any( @@ -474,16 +420,16 @@ impl LibOS { #[cfg(feature = "catmem-libos")] LibOS::MemoryLibOS(libos) => libos.async_close(qd), } - }; + }?; - self.poll(); + self.poll_task(qt); - result + Ok(qt) } /// Pushes a scatter-gather array to an I/O queue. pub fn push(&mut self, qd: QDesc, sga: &demi_sgarray_t) -> Result { - let result: Result = { + let qt: QToken = { timer!("demikernel::push"); match self { #[cfg(any( @@ -496,41 +442,40 @@ impl LibOS { #[cfg(feature = "catmem-libos")] LibOS::MemoryLibOS(libos) => libos.push(qd, sga), } - }; + }?; - self.poll(); + self.poll_task(qt); - result + Ok(qt) } /// Pushes a scatter-gather array to a UDP socket. #[allow(unused_variables)] pub fn pushto(&mut self, qd: QDesc, sga: &demi_sgarray_t, to: SocketAddr) -> Result { - let result: Result = { - timer!("demikernel::pushto"); - match self { - #[cfg(any( - feature = "catnap-libos", - feature = "catnip-libos", - feature = "catpowder-libos", - feature = "catloop-libos" - ))] - LibOS::NetworkLibOS(libos) => libos.pushto(qd, sga, to), - #[cfg(feature = "catmem-libos")] - LibOS::MemoryLibOS(_) => Err(Fail::new(libc::ENOTSUP, "pushto() is not supported on memory liboses")), - } + timer!("demikernel::pushto"); + let qt: QToken = match self { + #[cfg(any( + feature = "catnap-libos", + feature = "catnip-libos", + feature = "catpowder-libos", + feature = "catloop-libos" + ))] + LibOS::NetworkLibOS(libos) => libos.pushto(qd, sga, to)?, + #[cfg(feature = "catmem-libos")] + LibOS::MemoryLibOS(_) => { + return Err(Fail::new(libc::ENOTSUP, "pushto() is not supported on memory liboses")) + }, }; - self.poll(); + self.poll_task(qt); - result + Ok(qt) } /// Pops data from a an I/O queue. pub fn pop(&mut self, qd: QDesc, size: Option) -> Result { - let result: Result = { - timer!("demikernel::pop"); - + timer!("demikernel::pop"); + let qt: QToken = { // Check if this is a fixed-size pop. if let Some(size) = size { // Check if size is valid. @@ -552,11 +497,11 @@ impl LibOS { #[cfg(feature = "catmem-libos")] LibOS::MemoryLibOS(libos) => libos.pop(qd, size), } - }; + }?; - self.poll(); + self.poll_task(qt); - result + Ok(qt) } /// Waits for a pending I/O operation to complete or a timeout to expire. @@ -655,7 +600,7 @@ impl LibOS { result } - pub fn poll(&mut self) { + pub fn poll_task(&mut self, qt: QToken) { timer!("demikernel::poll"); match self { #[cfg(any( @@ -664,9 +609,9 @@ impl LibOS { feature = "catpowder-libos", feature = "catloop-libos" ))] - LibOS::NetworkLibOS(libos) => libos.poll(), + LibOS::NetworkLibOS(libos) => libos.poll_task(qt), #[cfg(feature = "catmem-libos")] - LibOS::MemoryLibOS(libos) => libos.poll(), + LibOS::MemoryLibOS(libos) => libos.poll_task(qt), } } } diff --git a/src/rust/demikernel/libos/network/libos.rs b/src/rust/demikernel/libos/network/libos.rs index 43d9f8509..c15fa2a42 100644 --- a/src/rust/demikernel/libos/network/libos.rs +++ b/src/rust/demikernel/libos/network/libos.rs @@ -638,8 +638,8 @@ impl SharedNetworkLibOS { } /// Runs all runnable coroutines. - pub fn poll(&mut self) { - self.runtime.poll() + pub fn poll_task(&mut self, qt: QToken) { + self.runtime.poll_task(qt) } /// Releases a scatter-gather array. diff --git a/src/rust/demikernel/libos/network/mod.rs b/src/rust/demikernel/libos/network/mod.rs index 13816b7ae..f34bf5eb0 100644 --- a/src/rust/demikernel/libos/network/mod.rs +++ b/src/rust/demikernel/libos/network/mod.rs @@ -305,16 +305,16 @@ impl NetworkLibOSWrapper { } /// Waits for any operation in an I/O queue. - pub fn poll(&mut self) { + pub fn poll_task(&mut self, qt: QToken) { match self { #[cfg(feature = "catpowder-libos")] - NetworkLibOSWrapper::Catpowder(libos) => libos.poll(), + NetworkLibOSWrapper::Catpowder(libos) => libos.poll_task(qt), #[cfg(all(feature = "catnap-libos"))] - NetworkLibOSWrapper::Catnap(libos) => libos.poll(), + NetworkLibOSWrapper::Catnap(libos) => libos.poll_task(qt), #[cfg(feature = "catnip-libos")] - NetworkLibOSWrapper::Catnip(libos) => libos.poll(), + NetworkLibOSWrapper::Catnip(libos) => libos.poll_task(qt), #[cfg(feature = "catloop-libos")] - NetworkLibOSWrapper::Catloop(libos) => libos.poll(), + NetworkLibOSWrapper::Catloop(libos) => libos.poll_task(qt), } } diff --git a/src/rust/inetstack/protocols/arp/tests.rs b/src/rust/inetstack/protocols/arp/tests.rs index c6ac07e14..03c1650ec 100644 --- a/src/rust/inetstack/protocols/arp/tests.rs +++ b/src/rust/inetstack/protocols/arp/tests.rs @@ -80,7 +80,7 @@ fn arp_immediate_reply() -> Result<()> { // Move clock forward and poll the engine. now += Duration::from_micros(1); engine.advance_clock(now); - engine.poll(); + engine.poll_background(); // Check if the ARP cache outputs a reply message. let buffers: VecDeque = engine.pop_all_frames(); @@ -121,7 +121,7 @@ fn arp_no_reply() -> Result<()> { // Move clock forward and poll the engine. now += Duration::from_micros(1); engine.advance_clock(now); - engine.poll(); + engine.poll_background(); // Ensure that no reply message is output. let buffers: VecDeque = engine.pop_all_frames(); @@ -150,7 +150,7 @@ fn arp_cache_update() -> Result<()> { // Move clock forward and poll the engine. now += Duration::from_micros(1); engine.advance_clock(now); - engine.poll(); + engine.poll_background(); // Check if the ARP cache has been updated. let cache: HashMap = engine.get_transport().export_arp_cache(); @@ -185,9 +185,11 @@ fn arp_cache_timeout() -> Result<()> { let mut engine: SharedEngine = new_engine(now, test_helpers::ALICE_CONFIG_PATH)?; let coroutine = Box::pin(engine.clone().arp_query(other_remote_ipv4).fuse()); - let qt: QToken = engine.get_runtime().clone().insert_coroutine("arp query", coroutine)?; - engine.poll(); - engine.poll(); + let qt: QToken = engine + .get_runtime() + .clone() + .insert_test_coroutine("arp query", coroutine)?; + engine.poll_task(qt); for _ in 0..(ARP_RETRY_COUNT + 1) { // Check if the ARP cache outputs a reply message. @@ -197,8 +199,7 @@ fn arp_cache_timeout() -> Result<()> { // Move clock forward and poll the engine. now += ARP_REQUEST_TIMEOUT; engine.advance_clock(now); - engine.poll(); - engine.poll(); + engine.poll_task(qt); } // Check if the ARP cache outputs a reply message. diff --git a/src/rust/inetstack/protocols/tcp/tests/simulator.rs b/src/rust/inetstack/protocols/tcp/tests/simulator.rs index 261f914f2..d99af3ce8 100644 --- a/src/rust/inetstack/protocols/tcp/tests/simulator.rs +++ b/src/rust/inetstack/protocols/tcp/tests/simulator.rs @@ -542,7 +542,7 @@ impl Simulation { self.inflight = Some(push_qt); // We need an extra poll because we now perform all work for the push inside the asynchronous coroutine. // TODO: Remove this once we separate the poll and advance clock functions. - self.engine.poll(); + self.engine.poll_task(push_qt); Ok(()) }, @@ -755,7 +755,7 @@ impl Simulation { let buf: DemiBuffer = Self::serialize_segment(segment); self.engine.receive(buf)?; - self.engine.poll(); + self.engine.poll_background(); Ok(()) } @@ -848,21 +848,20 @@ impl Simulation { /// Runs an outgoing packet. fn run_outgoing_packet(&mut self, tcp_packet: &TcpPacket) -> Result<()> { - let mut n = 0; - let frames = loop { - let frames = self.engine.pop_all_frames(); + let mut frames: VecDeque = VecDeque::::new(); + for _ in 0..5 { + frames = self.engine.pop_all_frames(); if frames.is_empty() { - if n > 5 { - anyhow::bail!("did not emit a frame after 5 loops"); - } else { - self.engine.poll(); - n += 1; - } + self.engine.poll_io(); + self.engine.poll_background(); } else { // FIXME: We currently do not support multi-frame segments. crate::ensure_eq!(frames.len(), 1); - break frames; + break; } + } + if frames.is_empty() { + anyhow::bail!("did not emit a frame after 5 loops"); }; let bytes = &frames[0]; let (eth2_header, eth2_payload) = Ethernet2Header::parse(bytes.clone())?; diff --git a/src/rust/inetstack/protocols/udp/tests.rs b/src/rust/inetstack/protocols/udp/tests.rs index 70f7c741a..caa06bfc0 100644 --- a/src/rust/inetstack/protocols/udp/tests.rs +++ b/src/rust/inetstack/protocols/udp/tests.rs @@ -204,7 +204,7 @@ fn udp_ping_pong() -> Result<()> { // Receive data from Bob. carrie.receive(bob.pop_frame()).unwrap(); let carrie_qt: QToken = carrie.udp_pop(carrie_fd)?; - carrie.poll(); + carrie.poll_task(carrie_qt); let (remote_addr, received_buf_a): (Option, DemiBuffer) = match carrie.wait(carrie_qt, DEFAULT_TIMEOUT)? { @@ -223,7 +223,7 @@ fn udp_ping_pong() -> Result<()> { (_, OperationResult::Push) => {}, _ => anyhow::bail!("Push failed"), }; - carrie.poll(); + carrie.poll_background(); now += Duration::from_micros(1); // Receive data from Carrie. @@ -331,7 +331,7 @@ fn udp_loop2_push_pop() -> Result<()> { (_, OperationResult::Push) => {}, _ => anyhow::bail!("Push failed"), }; - bob.poll(); + bob.poll_background(); now += Duration::from_micros(1); @@ -395,7 +395,7 @@ fn udp_loop2_ping_pong() -> Result<()> { (_, OperationResult::Push) => {}, _ => anyhow::bail!("Push failed"), }; - bob.poll(); + bob.poll_background(); now += Duration::from_micros(1); @@ -546,7 +546,7 @@ fn udp_pop_not_bound() -> Result<()> { (_, OperationResult::Push) => {}, _ => anyhow::bail!("Push failed"), }; - bob.poll(); + bob.poll_background(); now += Duration::from_micros(1); diff --git a/src/rust/inetstack/test_helpers/engine.rs b/src/rust/inetstack/test_helpers/engine.rs index 9a5e766ea..68a3aedde 100644 --- a/src/rust/inetstack/test_helpers/engine.rs +++ b/src/rust/inetstack/test_helpers/engine.rs @@ -49,7 +49,7 @@ use ::std::{ }; /// A default amount of time to wait on an operation to complete. This was chosen arbitrarily. -pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(120); +pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(20); #[derive(Clone)] pub struct SharedEngine(SharedNetworkLibOS>); @@ -87,8 +87,8 @@ impl SharedEngine { // We no longer do processing in this function, so we will not know if the packet is dropped or not. self.get_transport().receive(bytes)?; // So poll the scheduler to do the processing. - self.get_runtime().poll(); - self.get_runtime().poll(); + self.get_runtime().poll_background(); + self.get_runtime().poll_background(); Ok(()) } @@ -163,8 +163,16 @@ impl SharedEngine { self.get_transport().export_arp_cache() } - pub fn poll(&self) { - self.get_runtime().poll() + pub fn poll_io(&self) { + self.get_runtime().poll_io() + } + + pub fn poll_background(&self) { + self.get_runtime().poll_background() + } + + pub fn poll_task(&self, qt: QToken) { + self.get_runtime().poll_task(qt) } pub fn wait(&self, qt: QToken, timeout: Duration) -> Result<(QDesc, OperationResult), Fail> { diff --git a/src/rust/runtime/mod.rs b/src/rust/runtime/mod.rs index ffe7fd039..5a9dbebec 100644 --- a/src/rust/runtime/mod.rs +++ b/src/rust/runtime/mod.rs @@ -111,6 +111,10 @@ pub struct DemiRuntime { ts_iters: usize, /// Tasks that have been completed and removed from the completed_tasks: HashMap, + /// Background tasks. + background_task_group: TaskId, + /// Foreground I/O tasks. + io_task_group: TaskId, } #[derive(Clone)] @@ -139,15 +143,9 @@ impl DemiRuntime { impl SharedDemiRuntime { #[cfg(test)] pub fn new(now: Instant) -> Self { + let me: Self = Self::default(); timer::global_set_time(now); - Self(SharedObject::::new(DemiRuntime { - qtable: IoQueueTable::default(), - scheduler: SharedScheduler::default(), - ephemeral_ports: EphemeralPorts::default(), - network_table: NetworkQueueTable::default(), - ts_iters: 0, - completed_tasks: HashMap::::new(), - })) + me } /// Inserts the `coroutine` named `task_name` into the scheduler. @@ -156,7 +154,7 @@ impl SharedDemiRuntime { task_name: &'static str, coroutine: Pin>, ) -> Result { - self.insert_coroutine(task_name, coroutine) + self.insert_coroutine(task_name, coroutine, self.io_task_group) } /// Inserts the background `coroutine` named `task_name` into the scheduler @@ -165,23 +163,46 @@ impl SharedDemiRuntime { task_name: &'static str, coroutine: Pin>, ) -> Result { - self.insert_coroutine(task_name, coroutine) + self.insert_coroutine(task_name, coroutine, self.background_task_group) } /// Inserts a coroutine of type T and task - pub fn insert_coroutine( + fn insert_coroutine( &mut self, task_name: &'static str, coroutine: Pin>, + group_id: TaskId, ) -> Result where F::Output: Unpin + Clone + Any, { - trace!("Inserting coroutine: {:?}", task_name); #[cfg(feature = "profiler")] let coroutine = coroutine_timer!(task_name, coroutine); let task: TaskWithResult = TaskWithResult::::new(task_name, coroutine); - match self.scheduler.insert_task(task) { + match self.scheduler.insert_task(group_id, task) { + Some(task_id) => Ok(task_id.into()), + None => { + let cause: String = format!("cannot schedule coroutine (task_name={:?})", &task_name); + error!("insert_background_coroutine(): {}", cause); + Err(Fail::new(libc::EAGAIN, &cause)) + }, + } + } + + #[cfg(test)] + pub fn insert_test_coroutine( + &mut self, + task_name: &'static str, + coroutine: Pin>, + ) -> Result + where + F::Output: Unpin + Clone + Any, + { + #[cfg(feature = "profiler")] + let coroutine = coroutine_timer!(task_name, coroutine); + let group_id: TaskId = self.io_task_group; + let task: TaskWithResult = TaskWithResult::::new(task_name, coroutine); + match self.scheduler.insert_task(group_id, task) { Some(task_id) => Ok(task_id.into()), None => { let cause: String = format!("cannot schedule coroutine (task_name={:?})", &task_name); @@ -215,24 +236,29 @@ impl SharedDemiRuntime { // 2. None of the tasks have already completed, so start a timer and move the clock. self.advance_clock_to_now(); + let group_id: TaskId = self.io_task_group; loop { - if let Some(boxed_task) = self.scheduler.get_next_completed_task(TIMER_RESOLUTION) { - // Perform bookkeeping for the completed and removed task. - trace!("Removing coroutine: {:?}", boxed_task.get_name()); - let completed_qt: QToken = boxed_task.get_id().into(); - // If an operation task (and not a background task), then check the task to see if it is one of ours. - if let Ok(mut operation_task) = OperationTask::try_from(boxed_task.as_any()) { - let (qd, result): (QDesc, OperationResult) = - expect_some!(operation_task.get_result(), "coroutine not finished"); - - // Check whether it matches any of the queue tokens that we are waiting on. - if completed_qt == qt { - return Ok((qd, result)); + match self.scheduler.get_next_completed_task(group_id, TIMER_RESOLUTION) { + (_, Some(boxed_task)) => { + // Perform bookkeeping for the completed and removed task. + trace!("Removing coroutine: {:?}", boxed_task.get_name()); + let completed_qt: QToken = boxed_task.get_id().into(); + // If an operation task (and not a background task), then check the task to see if it is one of ours. + if let Ok(mut operation_task) = OperationTask::try_from(boxed_task.as_any()) { + let (qd, result): (QDesc, OperationResult) = + expect_some!(operation_task.get_result(), "coroutine not finished"); + + // Check whether it matches any of the queue tokens that we are waiting on. + if completed_qt == qt { + return Ok((qd, result)); + } + + // If not a queue token that we are waiting on, then insert into our list of completed tasks. + self.completed_tasks.insert(qt, (qd, result)); } - - // If not a queue token that we are waiting on, then insert into our list of completed tasks. - self.completed_tasks.insert(qt, (qd, result)); - } + }, + (i, None) if i < TIMER_RESOLUTION => self.poll_background(), + _ => (), } // Check the timeout. if let Some(abstime) = abstime { @@ -317,11 +343,12 @@ impl SharedDemiRuntime { // 3. Invoke the scheduler and run some tasks. loop { - // Run for one quanta and if one of our queue tokens completed, then return. + // Run for one quanta and if the acceptor condition is not met, then return. if let Some((qt, qd, result)) = self.run_next(remaining_time) { if acceptor(qt, qd, result) == false { return Ok(()); } + } else { } // Otherwise, move time forward. self.advance_clock_to_now(); @@ -362,27 +389,50 @@ impl SharedDemiRuntime { timeout if timeout.as_secs() > 0 => TIMER_RESOLUTION, _ => TIMER_FINER_RESOLUTION, }; - if let Some(boxed_task) = self.scheduler.get_next_completed_task(iterations) { - // Perform bookkeeping for the completed and removed task. - trace!("Removing coroutine: {:?}", boxed_task.get_name()); + let group_id: TaskId = self.io_task_group; + match self.scheduler.get_next_completed_task(group_id, iterations) { + (_, Some(boxed_task)) => { + // Perform bookkeeping for the completed and removed task. + trace!("Removing coroutine: {:?}", boxed_task.get_name()); + let qt: QToken = boxed_task.get_id().into(); + + // If an operation task, then take a look at the result. + if let Ok(mut operation_task) = OperationTask::try_from(boxed_task.as_any()) { + let (qd, result): (QDesc, OperationResult) = + expect_some!(operation_task.get_result(), "coroutine not finished"); + + Some((qt, qd, result)) + } else { + None + } + }, + (i, None) if i < iterations => { + self.poll_background(); + None + }, + _ => None, + } + } + + /// Performs a single poll on the underlying scheduler of a single task. + pub fn poll_task(&mut self, qt: QToken) { + // For all ready tasks that were removed from the scheduler, add to our completed task list. + if let Some(boxed_task) = self.scheduler.poll_task(qt.into()) { + trace!("Completed while polling coroutine: {:?}", boxed_task.get_name()); let qt: QToken = boxed_task.get_id().into(); - // If an operation task, then take a look at the result. if let Ok(mut operation_task) = OperationTask::try_from(boxed_task.as_any()) { let (qd, result): (QDesc, OperationResult) = expect_some!(operation_task.get_result(), "coroutine not finished"); - - return Some((qt, qd, result)); + self.completed_tasks.insert(qt, (qd, result)); } } - - None } - /// Performs a single pool on the underlying scheduler. - pub fn poll(&mut self) { + /// Performs a single poll on the underlying scheduler of a group of tasks. + pub fn poll_group(&mut self, group_id: TaskId) { // For all ready tasks that were removed from the scheduler, add to our completed task list. - for boxed_task in self.scheduler.poll_all() { + for boxed_task in self.scheduler.poll_group(&group_id) { trace!("Completed while polling coroutine: {:?}", boxed_task.get_name()); let qt: QToken = boxed_task.get_id().into(); @@ -394,6 +444,16 @@ impl SharedDemiRuntime { } } + /// Polls all of the background coroutines. + pub fn poll_background(&mut self) { + self.poll_group(self.background_task_group) + } + + /// Polls all of the foreground I/O coroutines. + pub fn poll_io(&mut self) { + self.poll_group(self.io_task_group) + } + /// Allocates a queue of type `T` and returns the associated queue descriptor. pub fn alloc_queue(&mut self, queue: T) -> QDesc { let qd: QDesc = self.qtable.alloc::(queue); @@ -591,13 +651,19 @@ pub async fn poll_yield() { impl Default for SharedDemiRuntime { fn default() -> Self { timer::global_set_time(Instant::now()); + let mut scheduler: SharedScheduler = SharedScheduler::default(); + let background_task_group: TaskId = scheduler.create_group(); + let io_task_group: TaskId = scheduler.create_group(); + Self(SharedObject::::new(DemiRuntime { qtable: IoQueueTable::default(), - scheduler: SharedScheduler::default(), + scheduler, ephemeral_ports: EphemeralPorts::default(), network_table: NetworkQueueTable::default(), ts_iters: 0, completed_tasks: HashMap::::new(), + background_task_group, + io_task_group, })) } } diff --git a/src/rust/runtime/scheduler/group.rs b/src/rust/runtime/scheduler/group.rs index b2ca84707..3b7d6c79d 100644 --- a/src/rust/runtime/scheduler/group.rs +++ b/src/rust/runtime/scheduler/group.rs @@ -34,6 +34,7 @@ use crate::{ use ::bit_iter::BitIter; use ::futures::Future; use ::std::{ + collections::VecDeque, pin::Pin, ptr::NonNull, task::{ @@ -56,6 +57,8 @@ pub struct TaskGroup { tasks: PinSlab>, /// Holds the waker bits for controlling task scheduling. waker_page_refs: Vec, + // The current set of ready tasks in the group. + ready_tasks: VecDeque, } //====================================================================================================================== @@ -91,17 +94,23 @@ impl TaskGroup { } /// Insert a new task into our scheduler returning a handle corresponding to it. - pub fn insert(&mut self, task: Box) -> Option { + pub fn insert(&mut self, task_id: TaskId, task: Box) -> bool { let task_name: &'static str = task.get_name(); // The pin slab index can be reverse-computed in a page index and an offset within the page. - let pin_slab_index: usize = self.tasks.insert(task)?; - let task_id: TaskId = self.ids.insert_with_new_id(pin_slab_index.into()); + let pin_slab_index: usize = match self.tasks.insert(task) { + Some(index) => index, + None => return false, + }; + self.ids.insert(task_id, pin_slab_index.into()); self.add_new_pages_up_to_pin_slab_index(pin_slab_index.into()); // Initialize the appropriate page offset. let (waker_page_ref, waker_page_offset): (&WakerPageRef, usize) = { - let (waker_page_index, waker_page_offset) = self.get_waker_page_index_and_offset(pin_slab_index)?; + let (waker_page_index, waker_page_offset) = match self.get_waker_page_index_and_offset(pin_slab_index) { + Some(result) => result, + None => return false, + }; (&self.waker_page_refs[waker_page_index], waker_page_offset) }; waker_page_ref.initialize(waker_page_offset); @@ -112,9 +121,9 @@ impl TaskGroup { task_id, pin_slab_index ); - // Set this task's id. + // Set this task's id. Expect is safe here because we just allocated the expect_some!(self.tasks.get_pin_mut(pin_slab_index), "just allocated!").set_id(task_id); - Some(task_id) + true } /// Computes the page and page offset of a given task based on its total offset. @@ -149,29 +158,23 @@ impl TaskGroup { (waker_page_index << WAKER_BIT_LENGTH_SHIFT) + waker_page_offset } - pub fn get_offsets_for_ready_tasks(&mut self) -> Vec { - let mut result: Vec = vec![]; + pub fn update_offsets_for_ready_tasks(&mut self) { for i in 0..self.get_num_waker_pages() { // Grab notified bits. let notified: u64 = self.waker_page_refs[i].take_notified(); // Turn into bit iter. - let mut offset: Vec = BitIter::from(notified) + let mut offset: VecDeque = BitIter::from(notified) .map(|x| Self::get_pin_slab_index(i, x).into()) .collect(); - result.append(&mut offset); + self.ready_tasks.append(&mut offset); } - result } /// Translates an internal task id to an external one. Expects the task to exist. - pub fn unchecked_internal_to_external_id(&self, internal_id: InternalId) -> TaskId { + fn unchecked_internal_to_external_id(&self, internal_id: InternalId) -> TaskId { expect_some!(self.tasks.get(internal_id.into()), "Invalid offset: {:?}", internal_id).get_id() } - pub fn unchecked_external_to_internal_id(&self, task_id: &TaskId) -> InternalId { - expect_some!(self.ids.get(task_id), "Invalid id: {:?}", task_id) - } - fn get_pinned_task_ptr(&mut self, pin_slab_index: usize) -> Pin<&mut Box> { // Get the pinned ref. expect_some!( @@ -188,6 +191,57 @@ impl TaskGroup { Some(unsafe { Waker::from_raw(WakerRef::new(raw_waker).into()) }) } + /// Poll a single task. This function polls the task regardless of if it is ready. + pub fn poll_task(&mut self, task_id: TaskId) -> Option> { + // Safe to expect here because we must have found the task_id to find the group id. + let internal_id: InternalId = self.ids.get(&task_id).expect("Task should exist"); + // Grab the waker page and assume that it is set + let (waker_page_ref, waker_page_offset): (&WakerPageRef, usize) = { + let (waker_page_index, waker_page_offset) = self.get_waker_page_index_and_offset(internal_id.into())?; + (&self.waker_page_refs[waker_page_index], waker_page_offset) + }; + waker_page_ref.clear(waker_page_offset); + self.poll_notified_task_and_remove_if_ready(internal_id) + } + + /// Does a single sweep of ready bits and runs all ready tasks. + pub fn poll_all(&mut self) -> Vec> { + let mut completed_tasks: Vec> = vec![]; + // Grab all ready tasks. + self.update_offsets_for_ready_tasks(); + + while let Some(task_id) = self.ready_tasks.pop_front() { + if let Some(task) = self.poll_notified_task_and_remove_if_ready(task_id) { + completed_tasks.push(task); + } + } + completed_tasks + } + + /// Runs coroutines in this group until a completed one is found. Returns the number of coroutines that were polled. + /// and a completed task if one completed. + pub fn get_next_completed_task(&mut self, max_iterations: usize) -> (usize, Option>) { + for i in 0..max_iterations { + let task_id: InternalId = match self.ready_tasks.pop_front() { + Some(id) => id, + None => { + self.update_offsets_for_ready_tasks(); + if let Some(id) = self.ready_tasks.pop_front() { + id + } else { + return (i, None); + } + }, + }; + trace!("polling task: {:?}", task_id); + + if let Some(task) = self.poll_notified_task_and_remove_if_ready(task_id) { + return (i + 1, Some(task)); + } + } + (max_iterations, None) + } + pub fn poll_notified_task_and_remove_if_ready(&mut self, internal_task_id: InternalId) -> Option> { // Perform the actual work of running the task. let poll_result: Poll<()> = { diff --git a/src/rust/runtime/scheduler/scheduler.rs b/src/rust/runtime/scheduler/scheduler.rs index 588e19167..52f375706 100644 --- a/src/rust/runtime/scheduler/scheduler.rs +++ b/src/rust/runtime/scheduler/scheduler.rs @@ -13,7 +13,6 @@ use crate::{ collections::id_map::IdMap, - expect_some, runtime::{ scheduler::{ group::TaskGroup, @@ -24,12 +23,9 @@ use crate::{ }, }; use ::slab::Slab; -use ::std::{ - ops::{ - Deref, - DerefMut, - }, - task::Waker, +use ::std::ops::{ + Deref, + DerefMut, }; //====================================================================================================================== @@ -48,18 +44,6 @@ pub struct Scheduler { // A group of tasks used for resource management. Currently all of our tasks are in a single group but we will // eventually break them up by Demikernel queue for fairness and performance isolation. groups: Slab, - // Track the currently running task id. This is entirely for external use. If there are no coroutines running (i. - // e.g, we did not enter the scheduler through a wait), this MUST be set to none because we cannot yield or wake /// - // unless inside a task/async coroutine. - current_running_task: Box>, - - // These global variables are for our scheduling policy. For now, we simply use round robin. - // The index of the current or last task that we ran. - current_task_id: InternalId, - // The group index of the current or last task that we ran. - current_group_id: InternalId, - // The current set of ready tasks in the group. - current_ready_tasks: Vec, } #[derive(Clone)] @@ -73,18 +57,8 @@ impl Scheduler { /// Creates a new task group. Returns an identifier for the group. pub fn create_group(&mut self) -> TaskId { let internal_id: InternalId = self.groups.insert(TaskGroup::default()).into(); - self.ids.insert_with_new_id(internal_id) - } - - /// Switch to a different task group. Returns true if the group has been switched. - pub fn switch_group(&mut self, group_id: TaskId) -> bool { - if let Some(internal_id) = self.ids.get(&group_id) { - if self.groups.contains(internal_id.into()) { - self.current_group_id = internal_id; - return true; - } - } - false + let external_id: TaskId = self.ids.insert_with_new_id(internal_id); + external_id } /// Get a reference to the task group using the id. @@ -116,30 +90,19 @@ impl Scheduler { /// Insert a task into a task group. The parent id can either be the id of the group or another task in the same /// group. - pub fn insert_task(&mut self, task: T) -> Option { - // Use the currently running task id to find the task group for this task. - let group: &mut TaskGroup = self.groups.get_mut(self.current_group_id.into())?; - // Insert the task into the task group. - let new_task_id: TaskId = group.insert(Box::new(task))?; - // Add a mapping so we can use this new task id to find the task in the future. - if let Some(existing) = self.ids.insert(new_task_id, self.current_group_id) { - panic!("should not exist an id: {:?}", existing); - } - Some(new_task_id) - } - - /// Insert a task into a task group. The parent id can either be the id of the group or another task in the same - /// group. - pub fn insert_task_with_group_id(&mut self, group_id: TaskId, task: T) -> Option { + pub fn insert_task(&mut self, group_id: TaskId, task: T) -> Option { // Get the internal id of the parent task or group. - let group_id: InternalId = self.ids.get(&group_id)?; + let internal_group_id: InternalId = self.ids.get(&group_id)?; + // Allocate a new task id for the task. + let new_task_id: TaskId = self.ids.insert_with_new_id(internal_group_id); // Use that to find the task group for this task. - let group: &mut TaskGroup = self.groups.get_mut(group_id.into())?; + let group: &mut TaskGroup = self.groups.get_mut(internal_group_id.into())?; // Insert the task into the task group. - let new_task_id: TaskId = group.insert(Box::new(task))?; - // Add a mapping so we can use this new task id to find the task in the future. - self.ids.insert(new_task_id, group_id); - Some(new_task_id) + if group.insert(new_task_id, Box::new(task)) { + Some(new_task_id) + } else { + None + } } pub fn remove_task(&mut self, task_id: TaskId) -> Option> { @@ -152,100 +115,37 @@ impl Scheduler { Some(task) } - fn poll_notified_task_and_remove_if_ready(&mut self) -> Option> { - let group: &mut TaskGroup = expect_some!( - self.groups.get_mut(self.current_group_id.into()), - "task group should exist: " - ); - assert!(self.current_running_task.is_none()); - *self.current_running_task = Some(group.unchecked_internal_to_external_id(self.current_task_id)); - assert!(self.current_running_task.is_some()); - let result: Option> = group.poll_notified_task_and_remove_if_ready(self.current_task_id); - assert!(self.current_running_task.is_some()); - *self.current_running_task = None; - assert!(self.current_running_task.is_none()); - result - } - - /// Poll all tasks which are ready to run for [max_iterations]. This does the same thing as get_next_completed task - /// but does not stop until it has reached [max_iterations] and collects all of the - pub fn poll_all(&mut self) -> Vec> { - let mut completed_tasks: Vec> = vec![]; - let start_group = self.current_group_id; - loop { - self.current_task_id = { - match self.current_ready_tasks.pop() { - Some(index) => index, - None => { - self.next_runnable_group(); - if self.current_group_id == start_group { - break; - } - match self.current_ready_tasks.pop() { - Some(index) => index, - None => return completed_tasks, - } - }, - } - }; + /// A high-level call for polling a single task. We only use this to poll a task synchronously immediately after + /// scheduling it. + pub fn poll_task(&mut self, task_id: TaskId) -> Option> { + // Use that to find the task group for this task. + let group: &mut TaskGroup = self.get_mut_group(&task_id)?; + group.poll_task(task_id) + } - // Now that we have a runnable task, actually poll it. - if let Some(task) = self.poll_notified_task_and_remove_if_ready() { - completed_tasks.push(task); - } - } - completed_tasks + /// Poll all tasks which are ready to run in a group at least once. + pub fn poll_group(&mut self, group_id: &TaskId) -> Vec> { + // Use that to find the task group for this task. + let group: &mut TaskGroup = match self.get_mut_group(&group_id) { + Some(group) => group, + None => return vec![], + }; + group.poll_all() } /// Poll all tasks until one completes. Remove that task and return it or fail after polling [max_iteration] number /// of tasks. - pub fn get_next_completed_task(&mut self, max_iterations: usize) -> Option> { - for _ in 0..max_iterations { - self.current_task_id = { - match self.current_ready_tasks.pop() { - Some(index) => index, - None => { - self.next_runnable_group(); - match self.current_ready_tasks.pop() { - Some(index) => index, - None => return None, - } - }, - } - }; - - // Now that we have a runnable task, actually poll it. - if let Some(task) = self.poll_notified_task_and_remove_if_ready() { - return Some(task); - } - } - None - } - - /// Poll over all of the groups looking for a group with runnable tasks. Sets the current_group_id to the next - /// runnable task group and current_ready_tasks to a list of tasks that are runnable in that group. - fn next_runnable_group(&mut self) { - let starting_group_index: InternalId = self.current_group_id; - self.current_group_id = self.get_next_group_index(); - - loop { - self.current_ready_tasks = self.groups[self.current_group_id.into()].get_offsets_for_ready_tasks(); - if !self.current_ready_tasks.is_empty() { - return; - } - // If we reach this point, then we have looped all the way around without finding any runnable tasks. - if self.current_group_id == starting_group_index { - return; - } - // Update the current_group_id - self.current_group_id = self.get_next_group_index(); - } - } - - /// Choose the index of the next group to run. - fn get_next_group_index(&self) -> InternalId { - // For now, we just choose the next group in the list. - InternalId::from((usize::from(self.current_group_id) + 1) % self.groups.len()) + pub fn get_next_completed_task( + &mut self, + group_id: TaskId, + max_iterations: usize, + ) -> (usize, Option>) { + // Use that to find the task group for this task. + let group: &mut TaskGroup = match self.get_mut_group(&group_id) { + Some(group) => group, + None => return (0, None), + }; + group.get_next_completed_task(max_iterations) } #[allow(unused)] @@ -258,18 +158,6 @@ impl Scheduler { } } - /// Returns the current running task id if we are in the scheduler, otherwise None. - pub fn get_task_id(&self) -> Option { - *self.current_running_task.clone() - } - - #[allow(unused)] - pub fn get_waker(&self, task_id: TaskId) -> Option { - let group: &TaskGroup = self.get_group(&task_id)?; - let internal_id: InternalId = group.unchecked_external_to_internal_id(&task_id); - group.get_waker(internal_id) - } - #[cfg(test)] pub fn num_tasks(&self) -> usize { let mut num_tasks: usize = 0; @@ -286,20 +174,9 @@ impl Scheduler { impl Default for Scheduler { fn default() -> Self { - let group: TaskGroup = TaskGroup::default(); - let mut ids: IdMap = IdMap::::default(); - let mut groups: Slab = Slab::::default(); - let internal_id: InternalId = groups.insert(group).into(); - // Use 0 as a special task id for the root. - let current_task: TaskId = TaskId::from(0); - ids.insert(current_task, internal_id); Self { - ids, - groups, - current_running_task: Box::new(None), - current_group_id: internal_id, - current_task_id: InternalId(0), - current_ready_tasks: vec![], + ids: IdMap::::default(), + groups: Slab::::default(), } } } @@ -416,16 +293,17 @@ mod tests { #[test] fn insert_creates_unique_tasks_ids() -> Result<()> { let mut scheduler: Scheduler = Scheduler::default(); + let group_id: TaskId = scheduler.create_group(); // Insert a task and make sure the task id is not a simple counter. let task: DummyTask = DummyTask::new("testing", Box::pin(DummyCoroutine::new(0).fuse())); - let Some(task_id) = scheduler.insert_task(task) else { + let Some(task_id) = scheduler.insert_task(group_id, task) else { anyhow::bail!("insert() failed") }; // Insert another task and make sure the task id is not sequentially after the previous one. let task2: DummyTask = DummyTask::new("testing", Box::pin(DummyCoroutine::new(0).fuse())); - let Some(task_id2) = scheduler.insert_task(task2) else { + let Some(task_id2) = scheduler.insert_task(group_id, task2) else { anyhow::bail!("insert() failed") }; @@ -437,19 +315,22 @@ mod tests { #[test] fn poll_once_with_one_small_task_completes_it() -> Result<()> { let mut scheduler: Scheduler = Scheduler::default(); + let group_id: TaskId = scheduler.create_group(); // Insert a single future in the scheduler. This future shall complete with a single poll operation. let task: DummyTask = DummyTask::new("testing", Box::pin(DummyCoroutine::new(0).fuse())); - let Some(task_id) = scheduler.insert_task(task) else { + let Some(task_id) = scheduler.insert_task(group_id, task) else { anyhow::bail!("insert() failed") }; // All futures are inserted in the scheduler with notification flag set. // By polling once, our future should complete. - if let Some(task) = scheduler.get_next_completed_task(1) { - crate::ensure_eq!(task.get_id(), task_id); - } else { - anyhow::bail!("task should have completed"); + match scheduler.get_next_completed_task(group_id, 1) { + (i, Some(task)) => { + crate::ensure_eq!(task.get_id(), task_id); + crate::ensure_eq!(i, 1); + }, + (_, None) => anyhow::bail!("task should have completed"), } Ok(()) } @@ -457,46 +338,51 @@ mod tests { #[test] fn poll_next_with_one_small_task_completes_it() -> Result<()> { let mut scheduler: Scheduler = Scheduler::default(); + let group_id: TaskId = scheduler.create_group(); // Insert a single future in the scheduler. This future shall complete with a single poll operation. let task: DummyTask = DummyTask::new("testing", Box::pin(DummyCoroutine::new(0).fuse())); - let Some(task_id) = scheduler.insert_task(task) else { + let Some(task_id) = scheduler.insert_task(group_id, task) else { anyhow::bail!("insert() failed") }; // All futures are inserted in the scheduler with notification flag set. // By polling once, our future should complete. - if let Some(task) = scheduler.get_next_completed_task(MAX_ITERATIONS) { - crate::ensure_eq!(task_id, task.get_id()); - Ok(()) - } else { - anyhow::bail!("task should have completed") + match scheduler.get_next_completed_task(group_id, MAX_ITERATIONS) { + (i, Some(task)) => { + crate::ensure_eq!(task.get_id(), task_id); + crate::ensure_eq!(i, 1); + }, + (_, None) => anyhow::bail!("task should have completed"), } + Ok(()) } #[test] fn poll_twice_with_one_long_task_completes_it() -> Result<()> { let mut scheduler: Scheduler = Scheduler::default(); + let group_id: TaskId = scheduler.create_group(); // Insert a single future in the scheduler. This future shall complete // with two poll operations. let task: DummyTask = DummyTask::new("testing", Box::pin(DummyCoroutine::new(1).fuse())); - let Some(task_id) = scheduler.insert_task(task) else { + let Some(task_id) = scheduler.insert_task(group_id, task) else { anyhow::bail!("insert() failed") }; // All futures are inserted in the scheduler with notification flag set. // By polling once, this future should make a transition. - // All futures are inserted in the scheduler with notification flag set. - // By polling once, our future should complete. - let result = scheduler.get_next_completed_task(1); + let (iterations, result) = scheduler.get_next_completed_task(group_id, 1); crate::ensure_eq!(result.is_some(), false); + crate::ensure_eq!(iterations, 1); // This shall make the future ready. - if let Some(task) = scheduler.get_next_completed_task(1) { - crate::ensure_eq!(task.get_id(), task_id); - } else { - anyhow::bail!("task should have completed"); + match scheduler.get_next_completed_task(group_id, 1) { + (i, Some(task)) => { + crate::ensure_eq!(task.get_id(), task_id); + crate::ensure_eq!(i, 1); + }, + (_, None) => anyhow::bail!("task should have completed"), } Ok(()) } @@ -504,43 +390,49 @@ mod tests { #[test] fn poll_next_with_one_long_task_completes_it() -> Result<()> { let mut scheduler: Scheduler = Scheduler::default(); + let group_id: TaskId = scheduler.create_group(); // Insert a single future in the scheduler. This future shall complete with a single poll operation. let task: DummyTask = DummyTask::new("testing", Box::pin(DummyCoroutine::new(0).fuse())); - let Some(task_id) = scheduler.insert_task(task) else { + let Some(task_id) = scheduler.insert_task(group_id, task) else { anyhow::bail!("insert() failed") }; // All futures are inserted in the scheduler with notification flag set. // By polling until the task completes, our future should complete. - if let Some(task) = scheduler.get_next_completed_task(MAX_ITERATIONS) { - crate::ensure_eq!(task_id, task.get_id()); - Ok(()) - } else { - anyhow::bail!("task should have completed") + match scheduler.get_next_completed_task(group_id, MAX_ITERATIONS) { + (i, Some(task)) => { + crate::ensure_eq!(task.get_id(), task_id); + crate::ensure_eq!(i, 1); + }, + (_, None) => anyhow::bail!("task should have completed"), } + Ok(()) } /// Tests if consecutive tasks are not assigned the same task id. #[test] fn insert_consecutive_creates_unique_task_ids() -> Result<()> { let mut scheduler: Scheduler = Scheduler::default(); + let group_id: TaskId = scheduler.create_group(); // Create and run a task. let task: DummyTask = DummyTask::new("testing", Box::pin(DummyCoroutine::new(0).fuse())); - let Some(task_id) = scheduler.insert_task(task) else { + let Some(task_id) = scheduler.insert_task(group_id, task) else { anyhow::bail!("insert() failed") }; - if let Some(task) = scheduler.get_next_completed_task(1) { - crate::ensure_eq!(task.get_id(), task_id); - } else { - anyhow::bail!("task should have completed"); + match scheduler.get_next_completed_task(group_id, 1) { + (i, Some(task)) => { + crate::ensure_eq!(task.get_id(), task_id); + crate::ensure_eq!(i, 1); + }, + (_, None) => anyhow::bail!("task should have completed"), } // Create another task. let task2: DummyTask = DummyTask::new("testing", Box::pin(DummyCoroutine::new(0).fuse())); - let Some(task_id2) = scheduler.insert_task(task2) else { + let Some(task_id2) = scheduler.insert_task(group_id, task2) else { anyhow::bail!("insert() failed") }; @@ -553,6 +445,7 @@ mod tests { #[test] fn remove_removes_task_id() -> Result<()> { let mut scheduler: Scheduler = Scheduler::default(); + let group_id: TaskId = scheduler.create_group(); // Arbitrarily large number. const NUM_TASKS: usize = 8192; @@ -562,7 +455,7 @@ mod tests { for val in 0..NUM_TASKS { let task: DummyTask = DummyTask::new("testing", Box::pin(DummyCoroutine::new(val).fuse())); - let Some(task_id) = scheduler.insert_task(task) else { + let Some(task_id) = scheduler.insert_task(group_id, task) else { panic!("insert() failed"); }; task_ids.push(task_id); @@ -589,10 +482,14 @@ mod tests { #[bench] fn benchmark_insert(b: &mut Bencher) { let mut scheduler: Scheduler = Scheduler::default(); + let group_id: TaskId = scheduler.create_group(); b.iter(|| { let task: DummyTask = DummyTask::new("testing", Box::pin(black_box(DummyCoroutine::default().fuse()))); - let task_id: TaskId = expect_some!(scheduler.insert_task(task), "couldn't insert future in scheduler"); + let task_id: TaskId = expect_some!( + scheduler.insert_task(group_id, task), + "couldn't insert future in scheduler" + ); black_box(task_id); }); } @@ -602,36 +499,39 @@ mod tests { let mut scheduler: Scheduler = Scheduler::default(); const NUM_TASKS: usize = 1024; let mut task_ids: Vec = Vec::::with_capacity(NUM_TASKS); + let group_id: TaskId = scheduler.create_group(); for val in 0..NUM_TASKS { let task: DummyTask = DummyTask::new("testing", Box::pin(DummyCoroutine::new(val).fuse())); - let Some(task_id) = scheduler.insert_task(task) else { + let Some(task_id) = scheduler.insert_task(group_id, task) else { panic!("insert() failed"); }; task_ids.push(task_id); } b.iter(|| { - black_box(scheduler.poll_all()); + black_box(scheduler.poll_group(&group_id)); }); } #[bench] fn benchmark_next(b: &mut Bencher) { let mut scheduler: Scheduler = Scheduler::default(); + let group_id: TaskId = scheduler.create_group(); + const NUM_TASKS: usize = 1024; let mut task_ids: Vec = Vec::::with_capacity(NUM_TASKS); for val in 0..NUM_TASKS { let task: DummyTask = DummyTask::new("testing", Box::pin(DummyCoroutine::new(val).fuse())); - let Some(task_id) = scheduler.insert_task(task) else { + let Some(task_id) = scheduler.insert_task(group_id, task) else { panic!("insert() failed"); }; task_ids.push(task_id); } b.iter(|| { - black_box(scheduler.get_next_completed_task(MAX_ITERATIONS)); + black_box(scheduler.get_next_completed_task(group_id, MAX_ITERATIONS)); }); } }