tokio_task_pool/
lib.rs

1#![ doc = include_str!( concat!( env!( "CARGO_MANIFEST_DIR" ), "/", "README.md" ) ) ]
2#[cfg(feature = "log")]
3use log::error;
4use std::fmt;
5use std::future::Future;
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::Semaphore;
9use tokio::task::JoinHandle;
10
11pub type SpawnResult<T> = Result<JoinHandle<Result<<T as Future>::Output, Error>>, Error>;
12
13/// Task ID, can be created from &'static str or String
14#[derive(Debug, Clone, Eq, PartialEq)]
15pub enum TaskId {
16    Static(&'static str),
17    Owned(String),
18}
19
20impl From<&'static str> for TaskId {
21    #[inline]
22    fn from(s: &'static str) -> Self {
23        Self::Static(s)
24    }
25}
26
27impl From<String> for TaskId {
28    #[inline]
29    fn from(s: String) -> Self {
30        Self::Owned(s)
31    }
32}
33
34impl TaskId {
35    #[inline]
36    fn as_str(&self) -> &str {
37        match self {
38            TaskId::Static(v) => v,
39            TaskId::Owned(s) => s.as_str(),
40        }
41    }
42}
43
44/// Task
45///
46/// Contains Future, can contain custom ID and timeout
47pub struct Task<T>
48where
49    T: Future + Send + 'static,
50    T::Output: Send + 'static,
51{
52    id: Option<TaskId>,
53    timeout: Option<Duration>,
54    future: T,
55}
56
57impl<T> Task<T>
58where
59    T: Future + Send + 'static,
60    T::Output: Send + 'static,
61{
62    #[inline]
63    pub fn new(future: T) -> Self {
64        Self {
65            id: None,
66            timeout: None,
67            future,
68        }
69    }
70    #[inline]
71    pub fn with_id<I: Into<TaskId>>(mut self, id: I) -> Self {
72        self.id = Some(id.into());
73        self
74    }
75    #[inline]
76    pub fn with_timeout(mut self, timeout: Duration) -> Self {
77        self.timeout = Some(timeout);
78        self
79    }
80}
81
82#[derive(Debug, Clone, Eq, PartialEq)]
83pub enum Error {
84    SpawnTimeout,
85    RunTimeout(Option<TaskId>),
86    SpawnSemaphoneAcquireError,
87    NotAvailable,
88}
89
90impl fmt::Display for Error {
91    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
92        match self {
93            Error::SpawnTimeout => write!(f, "task spawn timeout"),
94            Error::RunTimeout(id) => {
95                if let Some(i) = id {
96                    write!(f, "task {} run timeout", i.as_str())
97                } else {
98                    write!(f, "task run timeout")
99                }
100            }
101            Error::SpawnSemaphoneAcquireError => write!(f, "task spawn semaphore error"),
102            Error::NotAvailable => write!(f, "no available task slots"),
103        }
104    }
105}
106
107impl std::error::Error for Error {}
108
109impl From<tokio::sync::AcquireError> for Error {
110    fn from(_: tokio::sync::AcquireError) -> Self {
111        Self::SpawnSemaphoneAcquireError
112    }
113}
114
115/// Task pool
116#[derive(Debug)]
117pub struct Pool {
118    id: Option<Arc<String>>,
119    spawn_timeout: Option<Duration>,
120    run_timeout: Option<Duration>,
121    limiter: Option<Arc<Semaphore>>,
122    capacity: Option<usize>,
123    #[cfg(feature = "log")]
124    logging_enabled: bool,
125}
126
127impl Default for Pool {
128    fn default() -> Self {
129        Self::unbounded()
130    }
131}
132
133impl Pool {
134    /// Creates a bounded pool (recommended)
135    pub fn bounded(capacity: usize) -> Self {
136        Self {
137            id: None,
138            spawn_timeout: None,
139            run_timeout: None,
140            limiter: Some(Arc::new(Semaphore::new(capacity))),
141            capacity: Some(capacity),
142            #[cfg(feature = "log")]
143            logging_enabled: true,
144        }
145    }
146    /// Creates an unbounded pool
147    pub fn unbounded() -> Self {
148        Self {
149            id: None,
150            spawn_timeout: None,
151            run_timeout: None,
152            limiter: None,
153            capacity: None,
154            #[cfg(feature = "log")]
155            logging_enabled: true,
156        }
157    }
158    pub fn with_id<I: Into<String>>(mut self, id: I) -> Self {
159        self.id.replace(Arc::new(id.into()));
160        self
161    }
162    pub fn id(&self) -> Option<&str> {
163        self.id.as_deref().map(String::as_str)
164    }
165    /// Sets spawn timeout
166    ///
167    /// (ignored for unbounded)
168    #[inline]
169    pub fn with_spawn_timeout(mut self, timeout: Duration) -> Self {
170        self.spawn_timeout = Some(timeout);
171        self
172    }
173    /// Sets the default task run timeout
174    #[inline]
175    pub fn with_run_timeout(mut self, timeout: Duration) -> Self {
176        self.run_timeout = Some(timeout);
177        self
178    }
179    /// Sets both spawn and run timeouts
180    #[inline]
181    pub fn with_timeout(self, timeout: Duration) -> Self {
182        self.with_spawn_timeout(timeout).with_run_timeout(timeout)
183    }
184    #[cfg(feature = "log")]
185    /// Disables internal error logging
186    #[inline]
187    pub fn with_no_logging_enabled(mut self) -> Self {
188        self.logging_enabled = false;
189        self
190    }
191    /// Returns pool capacity
192    #[inline]
193    pub fn capacity(&self) -> Option<usize> {
194        self.capacity
195    }
196    /// Returns pool available task permits
197    #[inline]
198    pub fn available_permits(&self) -> Option<usize> {
199        self.limiter.as_ref().map(|v| v.available_permits())
200    }
201    /// Returns pool busy task permits
202    #[inline]
203    pub fn busy_permits(&self) -> Option<usize> {
204        self.limiter
205            .as_ref()
206            .map(|v| self.capacity.unwrap_or_default() - v.available_permits())
207    }
208    /// Spawns a future
209    #[inline]
210    pub fn spawn<T>(&self, future: T) -> impl Future<Output = SpawnResult<T>> + '_
211    where
212        T: Future + Send + 'static,
213        T::Output: Send + 'static,
214    {
215        self.spawn_task(Task::new(future))
216    }
217    /// Spawns a future with a custom timeout
218    #[inline]
219    pub fn spawn_with_timeout<T>(
220        &self,
221        future: T,
222        timeout: Duration,
223    ) -> impl Future<Output = SpawnResult<T>> + '_
224    where
225        T: Future + Send + 'static,
226        T::Output: Send + 'static,
227    {
228        self.spawn_task(Task::new(future).with_timeout(timeout))
229    }
230    /// Spawns a task (a future which can have a custom ID and timeout)
231    pub async fn spawn_task<T>(&self, task: Task<T>) -> SpawnResult<T>
232    where
233        T: Future + Send + 'static,
234        T::Output: Send + 'static,
235    {
236        #[cfg(feature = "log")]
237        let id = self.id.as_ref().cloned();
238        let perm = if let Some(ref limiter) = self.limiter {
239            if let Some(spawn_timeout) = self.spawn_timeout {
240                Some(
241                    tokio::time::timeout(spawn_timeout, limiter.clone().acquire_owned())
242                        .await
243                        .map_err(|_| Error::SpawnTimeout)??,
244                )
245            } else {
246                Some(limiter.clone().acquire_owned().await?)
247            }
248        } else {
249            None
250        };
251        if let Some(rtimeout) = task.timeout.or(self.run_timeout) {
252            #[cfg(feature = "log")]
253            let logging_enabled = self.logging_enabled;
254            Ok(tokio::spawn(async move {
255                let _p = perm;
256                if let Ok(v) = tokio::time::timeout(rtimeout, task.future).await {
257                    Ok(v)
258                } else {
259                    let e = Error::RunTimeout(task.id);
260                    #[cfg(feature = "log")]
261                    if logging_enabled {
262                        error!("{}: {}", id.as_deref().map_or("", |v| v.as_str()), e);
263                    }
264                    Err(e)
265                }
266            }))
267        } else {
268            Ok(tokio::spawn(async move {
269                let _p = perm;
270                Ok(task.future.await)
271            }))
272        }
273    }
274    /// Tries to spawn a future if there is an available permit. Returns `Error::NotAvailable` if no
275    /// permit available
276    pub fn try_spawn<T>(&self, future: T) -> SpawnResult<T>
277    where
278        T: Future + Send + 'static,
279        T::Output: Send + 'static,
280    {
281        self.try_spawn_task(Task::new(future))
282    }
283    /// Tries to spawn a future with a custom timeout if there is an available permit. Returns
284    /// `Error::NotAvailable` if no permit available
285    pub fn try_spawn_with_timeout<T>(&self, future: T, timeout: Duration) -> SpawnResult<T>
286    where
287        T: Future + Send + 'static,
288        T::Output: Send + 'static,
289    {
290        self.try_spawn_task(Task::new(future).with_timeout(timeout))
291    }
292    /// Spawns a task (a future which can have a custom ID and timeout) if there is an available
293    /// permit. Returns `Error::NotAvailable` if no permit available
294    pub fn try_spawn_task<T>(&self, task: Task<T>) -> SpawnResult<T>
295    where
296        T: Future + Send + 'static,
297        T::Output: Send + 'static,
298    {
299        #[cfg(feature = "log")]
300        let id = self.id.as_ref().cloned();
301        let perm = if let Some(ref limiter) = self.limiter {
302            Some(
303                limiter
304                    .clone()
305                    .try_acquire_owned()
306                    .map_err(|_| Error::NotAvailable)?,
307            )
308        } else {
309            None
310        };
311        if let Some(rtimeout) = task.timeout.or(self.run_timeout) {
312            #[cfg(feature = "log")]
313            let logging_enabled = self.logging_enabled;
314            Ok(tokio::spawn(async move {
315                let _p = perm;
316                if let Ok(v) = tokio::time::timeout(rtimeout, task.future).await {
317                    Ok(v)
318                } else {
319                    let e = Error::RunTimeout(task.id);
320                    #[cfg(feature = "log")]
321                    if logging_enabled {
322                        error!("{}: {}", id.as_deref().map_or("", |v| v.as_str()), e);
323                    }
324                    Err(e)
325                }
326            }))
327        } else {
328            Ok(tokio::spawn(async move {
329                let _p = perm;
330                Ok(task.future.await)
331            }))
332        }
333    }
334}
335
336#[cfg(test)]
337mod test {
338    use super::Pool;
339    use std::sync::atomic::{AtomicUsize, Ordering};
340    use std::sync::Arc;
341    use std::time::Duration;
342    use tokio::sync::mpsc::channel;
343    use tokio::time::sleep;
344
345    #[tokio::test]
346    async fn test_spawn() {
347        let pool = Pool::bounded(5);
348        let counter = Arc::new(AtomicUsize::new(0));
349        for _ in 1..=5 {
350            let counter_c = counter.clone();
351            pool.spawn(async move {
352                sleep(Duration::from_secs(2)).await;
353                counter_c.fetch_add(1, Ordering::SeqCst);
354            })
355            .await
356            .unwrap();
357        }
358        sleep(Duration::from_secs(3)).await;
359        assert_eq!(counter.load(Ordering::SeqCst), 5);
360    }
361
362    #[tokio::test]
363    async fn test_spawn_timeout() {
364        let pool = Pool::bounded(5).with_spawn_timeout(Duration::from_secs(1));
365        for _ in 1..=5 {
366            let (tx, mut rx) = channel(1);
367            pool.spawn(async move {
368                tx.send(()).await.unwrap();
369                sleep(Duration::from_secs(2)).await;
370            })
371            .await
372            .unwrap();
373            rx.recv().await;
374        }
375        dbg!(pool.available_permits(), pool.busy_permits());
376        assert!(pool
377            .spawn(async move {
378                sleep(Duration::from_secs(2)).await;
379            })
380            .await
381            .is_err());
382    }
383
384    #[tokio::test]
385    async fn test_run_timeout() {
386        let pool = Pool::bounded(5).with_run_timeout(Duration::from_secs(2));
387        let counter = Arc::new(AtomicUsize::new(0));
388        for i in 1..=5 {
389            let counter_c = counter.clone();
390            pool.spawn(async move {
391                sleep(Duration::from_secs(if i == 5 { 3 } else { 1 })).await;
392                counter_c.fetch_add(1, Ordering::SeqCst);
393            })
394            .await
395            .unwrap();
396        }
397        sleep(Duration::from_secs(5)).await;
398        assert_eq!(counter.load(Ordering::SeqCst), 4);
399    }
400}