task_exec_queue/
lib.rs

1use std::collections::HashSet;
2use std::collections::VecDeque;
3use std::fmt::Debug;
4use std::sync::atomic::{AtomicIsize, Ordering};
5use std::sync::Arc;
6
7use futures::channel::mpsc;
8use once_cell::sync::OnceCell;
9use parking_lot::RwLock;
10
11pub use builder::{Builder, SpawnDefaultExt, SpawnExt};
12pub use exec::{TaskExecQueue, TaskType};
13pub use local::LocalTaskExecQueue;
14pub use local::LocalTaskType;
15pub use local_builder::{LocalBuilder, LocalSender, LocalSpawnExt};
16pub use local_spawner::{LocalGroupSpawner, LocalSpawner, TryLocalGroupSpawner, TryLocalSpawner};
17pub use spawner::{GroupSpawner, Spawner, TryGroupSpawner, TrySpawner};
18
19mod builder;
20mod close;
21mod exec;
22mod flush;
23mod spawner;
24
25mod local;
26mod local_builder;
27mod local_spawner;
28
29#[derive(Clone, Debug)]
30struct Counter(std::sync::Arc<AtomicIsize>);
31
32impl Counter {
33    #[inline]
34    fn new() -> Self {
35        Counter(std::sync::Arc::new(AtomicIsize::new(0)))
36    }
37
38    #[inline]
39    fn inc(&self) {
40        self.0.fetch_add(1, Ordering::SeqCst);
41    }
42
43    #[inline]
44    fn dec(&self) {
45        self.0.fetch_sub(1, Ordering::SeqCst);
46    }
47
48    #[inline]
49    fn value(&self) -> isize {
50        self.0.load(Ordering::SeqCst)
51    }
52}
53
54#[derive(Clone)]
55struct IndexSet(Arc<RwLock<HashSet<usize, ahash::RandomState>>>);
56
57impl IndexSet {
58    #[inline]
59    fn new() -> Self {
60        Self(Arc::new(RwLock::new(HashSet::default())))
61    }
62
63    #[inline]
64    #[allow(dead_code)]
65    fn len(&self) -> usize {
66        self.0.read().len()
67    }
68
69    #[inline]
70    fn is_empty(&self) -> bool {
71        self.0.read().is_empty()
72    }
73
74    #[inline]
75    fn insert(&self, v: usize) {
76        self.0.write().insert(v);
77    }
78
79    #[inline]
80    fn pop(&self) -> Option<usize> {
81        let mut set = self.0.write();
82        if let Some(idx) = set.iter().next().copied() {
83            set.remove(&idx);
84            Some(idx)
85        } else {
86            None
87        }
88    }
89}
90
91struct GroupTaskExecQueue<TT> {
92    tasks: VecDeque<TT>,
93    is_running: bool,
94}
95
96impl<TT> GroupTaskExecQueue<TT> {
97    #[inline]
98    fn new() -> Self {
99        Self {
100            tasks: VecDeque::default(),
101            is_running: false,
102        }
103    }
104
105    #[inline]
106    fn push(&mut self, task: TT) {
107        self.tasks.push_back(task);
108    }
109
110    #[inline]
111    fn pop(&mut self) -> Option<TT> {
112        if let Some(task) = self.tasks.pop_front() {
113            Some(task)
114        } else {
115            self.set_running(false);
116            None
117        }
118    }
119
120    #[inline]
121    fn set_running(&mut self, b: bool) {
122        self.is_running = b;
123    }
124
125    #[inline]
126    fn is_running(&self) -> bool {
127        self.is_running
128    }
129}
130
131#[derive(thiserror::Error, Debug)]
132pub enum Error<T> {
133    #[error("send error")]
134    SendError(ErrorType<T>),
135    #[error("try send error")]
136    TrySendError(ErrorType<T>),
137    #[error("send timeout error")]
138    SendTimeoutError(ErrorType<T>),
139    #[error("recv result error")]
140    RecvResultError,
141}
142
143#[derive(Debug, Eq, PartialEq)]
144pub enum ErrorType<T> {
145    Full(Option<T>),
146    Closed(Option<T>),
147    Timeout(Option<T>),
148}
149
150impl<T> Error<T> {
151    #[inline]
152    pub fn is_full(&self) -> bool {
153        matches!(
154            self,
155            Error::SendError(ErrorType::Full(_))
156                | Error::TrySendError(ErrorType::Full(_))
157                | Error::SendTimeoutError(ErrorType::Full(_))
158        )
159    }
160
161    #[inline]
162    pub fn is_closed(&self) -> bool {
163        matches!(
164            self,
165            Error::SendError(ErrorType::Closed(_))
166                | Error::TrySendError(ErrorType::Closed(_))
167                | Error::SendTimeoutError(ErrorType::Closed(_))
168        )
169    }
170
171    #[inline]
172    pub fn is_timeout(&self) -> bool {
173        matches!(
174            self,
175            Error::SendError(ErrorType::Timeout(_))
176                | Error::TrySendError(ErrorType::Timeout(_))
177                | Error::SendTimeoutError(ErrorType::Timeout(_))
178        )
179    }
180}
181
182impl<T> From<mpsc::TrySendError<T>> for Error<T> {
183    fn from(e: mpsc::TrySendError<T>) -> Self {
184        if e.is_full() {
185            Error::TrySendError(ErrorType::Full(Some(e.into_inner())))
186        } else {
187            Error::TrySendError(ErrorType::Closed(Some(e.into_inner())))
188        }
189    }
190}
191
192impl<T> From<mpsc::SendError> for Error<T> {
193    fn from(e: mpsc::SendError) -> Self {
194        if e.is_full() {
195            Error::SendError(ErrorType::Full(None))
196        } else {
197            Error::SendError(ErrorType::Closed(None))
198        }
199    }
200}
201
202// Just a helper function to ensure the futures we're returning all have the
203// right implementations.
204pub(crate) fn assert_future<T, F>(future: F) -> F
205where
206    F: futures::Future<Output = T>,
207{
208    future
209}
210
211static DEFAULT_EXEC_QUEUE: OnceCell<TaskExecQueue> = OnceCell::new();
212
213pub fn set_default(queue: TaskExecQueue) -> Result<(), TaskExecQueue> {
214    DEFAULT_EXEC_QUEUE.set(queue)
215}
216
217pub fn init_default() -> impl futures::Future<Output = ()> {
218    let (queue, runner) = Builder::default().workers(100).queue_max(100_000).build();
219    DEFAULT_EXEC_QUEUE.set(queue).ok().unwrap();
220    runner
221}
222
223pub fn default() -> &'static TaskExecQueue {
224    DEFAULT_EXEC_QUEUE
225        .get()
226        .expect("default task execution queue must be set first")
227}
228
229#[test]
230fn test_index_set() {
231    let set = IndexSet::new();
232    set.insert(1);
233    set.insert(10);
234    set.insert(100);
235    assert_eq!(set.len(), 3);
236    assert!(matches!(set.pop(), Some(1) | Some(10) | Some(100)));
237    assert_eq!(set.len(), 2);
238    set.pop();
239    set.pop();
240    assert_eq!(set.len(), 0);
241}