task_executor/
lib.rs

1use std::collections::HashSet;
2use std::collections::VecDeque;
3use std::fmt::Debug;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::sync::atomic::{AtomicIsize, Ordering};
8use std::task::{Context, Poll};
9
10use futures::channel::mpsc;
11use futures::task::AtomicWaker;
12use once_cell::sync::OnceCell;
13use parking_lot::RwLock;
14
15pub use builder::{
16    Builder, ChannelBuilder, GroupBuilder, GroupChannelBuilder, SpawnDefaultExt, SpawnExt,
17};
18pub use exec::{Executor, TaskType};
19pub use local::LocalExecutor;
20pub use local::LocalTaskType;
21pub use local_builder::{
22    ChannelLocalBuilder, GroupChannelLocalBuilder, LocalBuilder, LocalSpawnExt, SyncSender,
23};
24pub use local_spawner::{LocalGroupSpawner, LocalSpawner};
25pub use spawner::{GroupSpawner, Spawner};
26
27mod builder;
28mod close;
29mod exec;
30mod flush;
31mod spawner;
32
33mod local;
34mod local_builder;
35mod local_spawner;
36
37#[derive(Clone, Debug)]
38struct Counter(std::sync::Arc<AtomicIsize>);
39
40impl Counter {
41    #[inline]
42    fn new() -> Self {
43        Counter(std::sync::Arc::new(AtomicIsize::new(0)))
44    }
45
46    #[inline]
47    fn inc(&self) {
48        self.0.fetch_add(1, Ordering::SeqCst);
49    }
50
51    #[inline]
52    fn dec(&self) {
53        self.0.fetch_sub(1, Ordering::SeqCst);
54    }
55
56    #[inline]
57    fn value(&self) -> isize {
58        self.0.load(Ordering::SeqCst)
59    }
60}
61
62#[derive(Clone)]
63struct IndexSet(Arc<RwLock<HashSet<usize, ahash::RandomState>>>);
64
65impl IndexSet {
66    #[inline]
67    fn new() -> Self {
68        Self(Arc::new(RwLock::new(HashSet::default())))
69    }
70
71    #[inline]
72    #[allow(dead_code)]
73    fn len(&self) -> usize {
74        self.0.read().len()
75    }
76
77    #[inline]
78    fn is_empty(&self) -> bool {
79        self.0.read().is_empty()
80    }
81
82    #[inline]
83    fn insert(&self, v: usize) {
84        self.0.write().insert(v);
85    }
86
87    #[inline]
88    fn pop(&self) -> Option<usize> {
89        let mut set = self.0.write();
90        if let Some(idx) = set.iter().next().copied() {
91            set.remove(&idx);
92            Some(idx)
93        } else {
94            None
95        }
96    }
97}
98
99struct PendingOnce {
100    w: Arc<AtomicWaker>,
101    is_ready: bool,
102}
103
104impl PendingOnce {
105    #[inline]
106    fn new(w: Arc<AtomicWaker>) -> Self {
107        Self { w, is_ready: false }
108    }
109}
110
111impl Future for PendingOnce {
112    type Output = ();
113    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
114        if self.is_ready {
115            Poll::Ready(())
116        } else {
117            self.w.register(cx.waker());
118            self.is_ready = true;
119            Poll::Pending
120        }
121    }
122}
123
124struct GroupTaskQueue<TT> {
125    tasks: VecDeque<TT>,
126    is_running: bool,
127}
128
129impl<TT> GroupTaskQueue<TT> {
130    #[inline]
131    fn new() -> Self {
132        Self {
133            tasks: VecDeque::default(),
134            is_running: false,
135        }
136    }
137
138    #[inline]
139    fn push(&mut self, task: TT) {
140        self.tasks.push_back(task);
141    }
142
143    #[inline]
144    fn pop(&mut self) -> Option<TT> {
145        if let Some(task) = self.tasks.pop_front() {
146            Some(task)
147        } else {
148            self.set_running(false);
149            None
150        }
151    }
152
153    #[inline]
154    fn set_running(&mut self, b: bool) {
155        self.is_running = b;
156    }
157
158    #[inline]
159    fn is_running(&self) -> bool {
160        self.is_running
161    }
162}
163
164#[derive(thiserror::Error, Debug)]
165pub enum Error<T> {
166    #[error("send error")]
167    SendError(ErrorType<T>),
168    #[error("try send error")]
169    TrySendError(ErrorType<T>),
170    #[error("send timeout error")]
171    SendTimeoutError(ErrorType<T>),
172    #[error("recv result error")]
173    RecvResultError,
174}
175
176#[derive(Debug, Eq, PartialEq)]
177pub enum ErrorType<T> {
178    Full(Option<T>),
179    Closed(Option<T>),
180    Timeout(Option<T>),
181}
182
183impl<T> Error<T> {
184    #[inline]
185    pub fn is_full(&self) -> bool {
186        matches!(
187            self,
188            Error::SendError(ErrorType::Full(_))
189                | Error::TrySendError(ErrorType::Full(_))
190                | Error::SendTimeoutError(ErrorType::Full(_))
191        )
192    }
193
194    #[inline]
195    pub fn is_closed(&self) -> bool {
196        matches!(
197            self,
198            Error::SendError(ErrorType::Closed(_))
199                | Error::TrySendError(ErrorType::Closed(_))
200                | Error::SendTimeoutError(ErrorType::Closed(_))
201        )
202    }
203
204    #[inline]
205    pub fn is_timeout(&self) -> bool {
206        matches!(
207            self,
208            Error::SendError(ErrorType::Timeout(_))
209                | Error::TrySendError(ErrorType::Timeout(_))
210                | Error::SendTimeoutError(ErrorType::Timeout(_))
211        )
212    }
213}
214
215impl<T> From<mpsc::TrySendError<T>> for Error<T> {
216    fn from(e: mpsc::TrySendError<T>) -> Self {
217        if e.is_full() {
218            Error::TrySendError(ErrorType::Full(Some(e.into_inner())))
219        } else {
220            Error::TrySendError(ErrorType::Closed(Some(e.into_inner())))
221        }
222    }
223}
224
225impl<T> From<mpsc::SendError> for Error<T> {
226    fn from(e: mpsc::SendError) -> Self {
227        if e.is_full() {
228            Error::SendError(ErrorType::Full(None))
229        } else {
230            Error::SendError(ErrorType::Closed(None))
231        }
232    }
233}
234
235// Just a helper function to ensure the futures we're returning all have the
236// right implementations.
237pub(crate) fn assert_future<T, F>(future: F) -> F
238    where
239        F: futures::Future<Output=T>,
240{
241    future
242}
243
244static DEFAULT_EXECUTOR: OnceCell<Executor> = OnceCell::new();
245
246pub fn set_default(exec: Executor) -> Result<(), Executor> {
247    DEFAULT_EXECUTOR.set(exec)
248}
249
250pub fn init_default() -> impl futures::Future<Output=()> {
251    let (exec, runner) = Builder::default().workers(100).queue_max(100_000).build();
252    DEFAULT_EXECUTOR.set(exec).ok().unwrap();
253    runner
254}
255
256pub fn default() -> &'static Executor {
257    DEFAULT_EXECUTOR
258        .get()
259        .expect("default executor must be set first")
260}
261
262#[test]
263fn test_index_set() {
264    let set = IndexSet::new();
265    set.insert(1);
266    set.insert(10);
267    set.insert(100);
268    assert_eq!(set.len(), 3);
269    assert!(matches!(set.pop(), Some(1) | Some(10) | Some(100)));
270    assert_eq!(set.len(), 2);
271    set.pop();
272    set.pop();
273    assert_eq!(set.len(), 0);
274}