Skip to main content

zero_pool/
task_future.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::thread::{self, Thread};
4use std::time::{Duration, Instant};
5
6/// A future that tracks completion of submitted tasks
7///
8/// `TaskFuture` provides both blocking and non-blocking ways to wait for
9/// task completion. Tasks can be checked for completion, waited on
10/// indefinitely, or waited on with a timeout.
11///
12/// `TaskFuture` is cheaply cloneable. However, it captures the thread handle
13/// of the thread that created it.
14///
15/// If sharing the future with other threads, `is_complete()` is safe to call
16/// from anywhere.
17///
18/// **Important:** `wait()` and `wait_timeout()` **must** be called from the
19/// same thread that created the `TaskFuture`. Calling these methods from a
20/// different thread will panic in debug builds and may cause the calling thread
21/// to hang indefinitely in release builds.
22///
23#[derive(Clone)]
24pub struct TaskFuture {
25    count: Arc<AtomicUsize>,
26    owner_thread: Thread,
27}
28
29impl TaskFuture {
30    pub(crate) fn new(task_count: usize) -> Self {
31        TaskFuture {
32            count: Arc::new(AtomicUsize::new(task_count)),
33            owner_thread: thread::current(),
34        }
35    }
36
37    /// Check if all tasks are complete without blocking
38    ///
39    /// Returns `true` if all tasks have finished execution.
40    /// This is a non-blocking operation using atomic loads.
41    pub fn is_complete(&self) -> bool {
42        self.count.load(Ordering::Acquire) == 0
43    }
44
45    /// Wait for all tasks to complete
46    ///
47    /// First checks completion with an atomic load; if incomplete, parks the thread that sent the work.
48    pub fn wait(&self) {
49        debug_assert_eq!(
50            self.owner_thread.id(),
51            thread::current().id(),
52            "TaskFuture::wait() must be called from the thread that created it."
53        );
54
55        while !self.is_complete() {
56            thread::park();
57        }
58    }
59
60    /// Wait for all tasks to complete with a timeout
61    ///
62    /// First checks completion with an atomic load; if incomplete, parks the thread that sent the work.
63    /// Returns `true` if all tasks completed within the timeout,
64    /// `false` if the timeout was reached first.
65    pub fn wait_timeout(&self, timeout: Duration) -> bool {
66        debug_assert_eq!(
67            self.owner_thread.id(),
68            thread::current().id(),
69            "TaskFuture::wait_timeout() must be called from the thread that created it."
70        );
71
72        let start = Instant::now();
73        loop {
74            if self.is_complete() {
75                return true;
76            }
77            let elapsed = start.elapsed();
78            if elapsed >= timeout {
79                return false;
80            }
81            thread::park_timeout(timeout - elapsed);
82        }
83    }
84
85    // completes multiple tasks, decrements counter and notifies if all done
86    pub(crate) fn complete_many(&self, count: usize) -> bool {
87        if self.count.fetch_sub(count, Ordering::Release) == count {
88            self.owner_thread.unpark();
89            true
90        } else {
91            false
92        }
93    }
94}