Skip to content

Commit

Permalink
Improve Gate::close to support repeated calls
Browse files Browse the repository at this point in the history
If Gate::close has a timeout, closing the gate is left in a dangling
state with no way to safely resume waiting. This change makes it safe in
that calling close on a closing Gate still waits for any outstanding
tasks. This way you can repeatedly try calling gate close with a timeout
set while safely waiting for all tasks to complete.

This is technically an observable change in that a previous call to close
when still in "closing" would have returned an error but now waits for the
currently running tasks instead. I can't imagine anyone actually relies
too much on this.
  • Loading branch information
vlovich committed May 11, 2024
1 parent 83d3023 commit d6b2fe2
Showing 1 changed file with 105 additions and 12 deletions.
117 changes: 105 additions & 12 deletions glommio/src/sync/gate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,26 @@ impl Gate {
)
}

/// Close the gate, and wait for all spawned tasks to complete
/// Close the gate, and wait for all spawned tasks to complete. If the gate is currently closing, this will wait
/// for it to close before returning a success. This is particularly useful if you might have a timeout on the close
/// - the would otherwise be no safe way to retry & wait for remaining tasks to finish.
pub async fn close(&self) -> Result<(), GlommioError<()>> {
self.inner.close().await
}

/// Whether the gate is open or not
/// Whether the gate is open or not.
pub fn is_open(&self) -> bool {
self.inner.is_open()
}

/// This returns true only if [Self::close] has been called and all spawned tasks are complete. If it returns false,
/// you may call [Self::close] without it returning an error and it'll wait for all spawned tasks to complete.
///
/// NOTE: multiple concurrent calls to [Self::close] may be a performance issue since each invocation to close will
/// allocate some nominal amount of memory for the channel underneath.
pub fn is_closed(&self) -> bool {
self.inner.is_closed()
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -125,24 +136,44 @@ impl GateInner {
}

pub async fn close(&self) -> Result<(), GlommioError<()>> {
if self.is_open() {
if self.count.get() == 0 {
*self.state.borrow_mut() = State::Closed;
} else {
match self.state.replace(State::Closed) {
State::Open => {
if self.count.get() != 0 {
let (sender, receiver) = local_channel::new_bounded(1);
self.state.replace(State::Closing(sender));
receiver.recv().await;
}
Ok(())
}
State::Closing(previous_closer) => {
assert!(
self.count.get() != 0,
"If count is 0 then the state should have been marked as closed"
);
assert!(
!previous_closer.is_full(),
"Already notified that the gate is closed!"
);

let (sender, receiver) = local_channel::new_bounded(1);
*self.state.borrow_mut() = State::Closing(sender);
self.state.replace(State::Closing(sender));

receiver.recv().await;
let _ = previous_closer.try_send(true);
Ok(())
}
Ok(())
} else {
Err(GlommioError::Closed(ResourceType::Gate))
State::Closed => Err(GlommioError::Closed(ResourceType::Gate)),
}
}

pub fn is_open(&self) -> bool {
matches!(*self.state.borrow(), State::Open)
}

pub fn is_closed(&self) -> bool {
matches!(*self.state.borrow(), State::Closed)
}

pub fn notify_closed(&self) {
if let State::Closing(sender) = self.state.replace(State::Closed) {
sender.try_send(true).unwrap();
Expand All @@ -154,10 +185,11 @@ impl GateInner {

#[cfg(test)]
mod tests {
use crate::{enclose, LocalExecutor};

use super::*;
use crate::sync::Semaphore;
use crate::{enclose, timer::timeout, LocalExecutor};
use futures::join;
use std::time::Duration;

#[test]
fn test_immediate_close() {
Expand Down Expand Up @@ -229,4 +261,65 @@ mod tests {
assert!(!running.get());
})
}

#[test]
fn test_concurrent_close() {
LocalExecutor::default().run(async {
let gate = &Gate::new();
let gate_closures = &Semaphore::new(0);
let closed = &RefCell::new(false);

let pass = gate.enter().unwrap();

join!(
async {
gate_closures.signal(1);
gate.close().await.unwrap();
assert!(*closed.borrow());
},
async {
gate_closures.signal(1);
gate.close().await.unwrap();
assert!(*closed.borrow());
},
async {
gate_closures.acquire(2).await.unwrap();
drop(pass);
closed.replace(true);
},
);
})
}

#[test]
fn test_close_after_timed_out_close() {
LocalExecutor::default().run(async {
let gate = Gate::new();
let gate = &gate;
let gate_closed_once = Rc::new(Semaphore::new(0));
let task_gate = gate_closed_once.clone();

let _task = gate
.spawn(async move {
task_gate.acquire(1).await.unwrap();
})
.unwrap();

timeout(Duration::from_millis(1), async move {
gate.close().await.unwrap();
Ok(())
})
.await
.expect_err("Should have timed out");

assert!(
!gate.is_closed(),
"Should still be waiting for a task that hasn't finished"
);

gate_closed_once.signal(1);

gate.close().await.unwrap();
})
}
}

0 comments on commit d6b2fe2

Please sign in to comment.