zero_pool/
future.rs

1use std::sync::atomic::Ordering;
2use std::sync::{Arc, Condvar, Mutex};
3use std::time::Duration;
4
5use crate::padded_type::PaddedAtomicUsize;
6
7// public work future with arc wrapped fields
8#[derive(Clone)]
9#[repr(align(64))]
10pub struct WorkFuture {
11    remaining: Arc<PaddedAtomicUsize>,
12    state: Arc<(Mutex<()>, Condvar)>,
13}
14
15impl WorkFuture {
16    // create a new work future for the given number of tasks
17    pub fn new(task_count: usize) -> Self {
18        WorkFuture {
19            remaining: Arc::new(PaddedAtomicUsize::new(task_count)),
20            state: Arc::new((Mutex::new(()), Condvar::new())),
21        }
22    }
23
24    // check if all tasks are complete
25    pub fn is_complete(&self) -> bool {
26        self.remaining.load(Ordering::Acquire) == 0
27    }
28
29    // wait for all tasks to complete
30    pub fn wait(self) {
31        if self.is_complete() {
32            return;
33        }
34
35        let (lock, cvar) = &*self.state;
36        let mut guard = lock.lock().unwrap();
37
38        while !self.is_complete() {
39            guard = cvar.wait(guard).unwrap();
40        }
41    }
42
43    // wait for all tasks with timeout
44    pub fn wait_timeout(self, timeout: Duration) -> bool {
45        if self.is_complete() {
46            return true;
47        }
48
49        let (lock, cvar) = &*self.state;
50        let mut guard = lock.lock().unwrap();
51
52        while !self.is_complete() {
53            let (new_guard, timeout_result) = cvar.wait_timeout(guard, timeout).unwrap();
54            guard = new_guard;
55            if timeout_result.timed_out() {
56                return self.is_complete();
57            }
58        }
59        true
60    }
61
62    // get remaining task count
63    pub fn remaining_count(&self) -> usize {
64        self.remaining.load(Ordering::Relaxed)
65    }
66
67    // complets multiple tasks, decrements counter and notifies if all done
68    #[inline]
69    pub fn complete_many(&self, count: usize) {
70        let remaining_count = self.remaining.fetch_sub(count, Ordering::Release);
71
72        // if this completed the last tasks, notify waiters
73        if remaining_count == count {
74            let (lock, cvar) = &*self.state;
75            let _guard = lock.lock().unwrap();
76            cvar.notify_all();
77        }
78    }
79}