tokio_task_pool/
lib.rs

1#![ doc = include_str!( concat!( env!( "CARGO_MANIFEST_DIR" ), "/", "README.md" ) ) ]
2#[cfg(feature = "log")]
3use log::error;
4use std::fmt;
5use std::future::Future;
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::Semaphore;
9use tokio::task::JoinHandle;
10#[cfg(feature = "tracing")]
11use tracing::{event, Level};
12
13pub type SpawnResult<T> = Result<JoinHandle<Result<<T as Future>::Output, Error>>, Error>;
14
15/// Task ID, can be created from &'static str or String
16#[derive(Debug, Clone, Eq, PartialEq)]
17pub enum TaskId {
18    Static(&'static str),
19    Owned(String),
20}
21
22impl From<&'static str> for TaskId {
23    #[inline]
24    fn from(s: &'static str) -> Self {
25        Self::Static(s)
26    }
27}
28
29impl From<String> for TaskId {
30    #[inline]
31    fn from(s: String) -> Self {
32        Self::Owned(s)
33    }
34}
35
36impl TaskId {
37    #[inline]
38    fn as_str(&self) -> &str {
39        match self {
40            TaskId::Static(v) => v,
41            TaskId::Owned(s) => s.as_str(),
42        }
43    }
44}
45
46/// Task
47///
48/// Contains Future, can contain custom ID and timeout
49pub struct Task<T>
50where
51    T: Future + Send + 'static,
52    T::Output: Send + 'static,
53{
54    id: Option<TaskId>,
55    timeout: Option<Duration>,
56    future: T,
57}
58
59impl<T> Task<T>
60where
61    T: Future + Send + 'static,
62    T::Output: Send + 'static,
63{
64    #[inline]
65    pub fn new(future: T) -> Self {
66        Self {
67            id: None,
68            timeout: None,
69            future,
70        }
71    }
72    #[inline]
73    pub fn with_id<I: Into<TaskId>>(mut self, id: I) -> Self {
74        self.id = Some(id.into());
75        self
76    }
77    #[inline]
78    pub fn with_timeout(mut self, timeout: Duration) -> Self {
79        self.timeout = Some(timeout);
80        self
81    }
82}
83
84#[derive(Debug, Clone, Eq, PartialEq)]
85pub enum Error {
86    SpawnTimeout,
87    RunTimeout(Option<TaskId>),
88    SpawnSemaphoneAcquireError,
89    NotAvailable,
90}
91
92impl fmt::Display for Error {
93    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
94        match self {
95            Error::SpawnTimeout => write!(f, "task spawn timeout"),
96            Error::RunTimeout(id) => {
97                if let Some(i) = id {
98                    write!(f, "task {} run timeout", i.as_str())
99                } else {
100                    write!(f, "task run timeout")
101                }
102            }
103            Error::SpawnSemaphoneAcquireError => write!(f, "task spawn semaphore error"),
104            Error::NotAvailable => write!(f, "no available task slots"),
105        }
106    }
107}
108
109impl std::error::Error for Error {}
110
111impl From<tokio::sync::AcquireError> for Error {
112    fn from(_: tokio::sync::AcquireError) -> Self {
113        Self::SpawnSemaphoneAcquireError
114    }
115}
116
117/// Task pool
118#[derive(Debug)]
119pub struct Pool {
120    id: Option<Arc<String>>,
121    spawn_timeout: Option<Duration>,
122    run_timeout: Option<Duration>,
123    limiter: Option<Arc<Semaphore>>,
124    capacity: Option<usize>,
125    #[cfg(any(feature = "log", feature = "tracing"))]
126    logging_enabled: bool,
127}
128
129impl Default for Pool {
130    fn default() -> Self {
131        Self::unbounded()
132    }
133}
134
135impl Pool {
136    /// Creates a bounded pool (recommended)
137    pub fn bounded(capacity: usize) -> Self {
138        Self {
139            id: None,
140            spawn_timeout: None,
141            run_timeout: None,
142            limiter: Some(Arc::new(Semaphore::new(capacity))),
143            capacity: Some(capacity),
144            #[cfg(any(feature = "log", feature = "tracing"))]
145            logging_enabled: true,
146        }
147    }
148    /// Creates an unbounded pool
149    pub fn unbounded() -> Self {
150        Self {
151            id: None,
152            spawn_timeout: None,
153            run_timeout: None,
154            limiter: None,
155            capacity: None,
156            #[cfg(any(feature = "log", feature = "tracing"))]
157            logging_enabled: true,
158        }
159    }
160    pub fn with_id<I: Into<String>>(mut self, id: I) -> Self {
161        self.id.replace(Arc::new(id.into()));
162        self
163    }
164    pub fn id(&self) -> Option<&str> {
165        self.id.as_deref().map(String::as_str)
166    }
167    /// Sets spawn timeout
168    ///
169    /// (ignored for unbounded)
170    #[inline]
171    pub fn with_spawn_timeout(mut self, timeout: Duration) -> Self {
172        self.spawn_timeout = Some(timeout);
173        self
174    }
175    /// Sets the default task run timeout
176    #[inline]
177    pub fn with_run_timeout(mut self, timeout: Duration) -> Self {
178        self.run_timeout = Some(timeout);
179        self
180    }
181    /// Sets both spawn and run timeouts
182    #[inline]
183    pub fn with_timeout(self, timeout: Duration) -> Self {
184        self.with_spawn_timeout(timeout).with_run_timeout(timeout)
185    }
186    #[cfg(any(feature = "log", feature = "tracing"))]
187    /// Disables internal error logging
188    #[inline]
189    pub fn with_no_logging_enabled(mut self) -> Self {
190        self.logging_enabled = false;
191        self
192    }
193    /// Returns pool capacity
194    #[inline]
195    pub fn capacity(&self) -> Option<usize> {
196        self.capacity
197    }
198    /// Returns pool available task permits
199    #[inline]
200    pub fn available_permits(&self) -> Option<usize> {
201        self.limiter.as_ref().map(|v| v.available_permits())
202    }
203    /// Returns pool busy task permits
204    #[inline]
205    pub fn busy_permits(&self) -> Option<usize> {
206        self.limiter
207            .as_ref()
208            .map(|v| self.capacity.unwrap_or_default() - v.available_permits())
209    }
210    /// Spawns a future
211    #[inline]
212    pub fn spawn<T>(&self, future: T) -> impl Future<Output = SpawnResult<T>> + '_
213    where
214        T: Future + Send + 'static,
215        T::Output: Send + 'static,
216    {
217        self.spawn_task(Task::new(future))
218    }
219    /// Spawns a future with a custom timeout
220    #[inline]
221    pub fn spawn_with_timeout<T>(
222        &self,
223        future: T,
224        timeout: Duration,
225    ) -> impl Future<Output = SpawnResult<T>> + '_
226    where
227        T: Future + Send + 'static,
228        T::Output: Send + 'static,
229    {
230        self.spawn_task(Task::new(future).with_timeout(timeout))
231    }
232    /// Spawns a task (a future which can have a custom ID and timeout)
233    pub async fn spawn_task<T>(&self, task: Task<T>) -> SpawnResult<T>
234    where
235        T: Future + Send + 'static,
236        T::Output: Send + 'static,
237    {
238        #[cfg(any(feature = "log", feature = "tracing"))]
239        let id = self.id.as_ref().cloned();
240        let perm = if let Some(ref limiter) = self.limiter {
241            if let Some(spawn_timeout) = self.spawn_timeout {
242                Some(
243                    tokio::time::timeout(spawn_timeout, limiter.clone().acquire_owned())
244                        .await
245                        .map_err(|_| Error::SpawnTimeout)??,
246                )
247            } else {
248                Some(limiter.clone().acquire_owned().await?)
249            }
250        } else {
251            None
252        };
253        if let Some(rtimeout) = task.timeout.or(self.run_timeout) {
254            #[cfg(any(feature = "log", feature = "tracing"))]
255            let logging_enabled = self.logging_enabled;
256            Ok(tokio::spawn(async move {
257                let _p = perm;
258                if let Ok(v) = tokio::time::timeout(rtimeout, task.future).await {
259                    Ok(v)
260                } else {
261                    let e = Error::RunTimeout(task.id);
262                    #[cfg(any(feature = "log", feature = "tracing"))]
263                    if logging_enabled {
264                        #[cfg(feature = "log")]
265                        error!("{}: {}", id.as_deref().map_or("", |v| v.as_str()), e);
266
267                        #[cfg(feature = "tracing")]
268                        event!(
269                            Level::ERROR,
270                            error = ?e,
271                            id = id.as_deref().map_or("", |v| v.as_str())
272                        );
273                    }
274                    Err(e)
275                }
276            }))
277        } else {
278            Ok(tokio::spawn(async move {
279                let _p = perm;
280                Ok(task.future.await)
281            }))
282        }
283    }
284    /// Tries to spawn a future if there is an available permit. Returns `Error::NotAvailable` if no
285    /// permit available
286    pub fn try_spawn<T>(&self, future: T) -> SpawnResult<T>
287    where
288        T: Future + Send + 'static,
289        T::Output: Send + 'static,
290    {
291        self.try_spawn_task(Task::new(future))
292    }
293    /// Tries to spawn a future with a custom timeout if there is an available permit. Returns
294    /// `Error::NotAvailable` if no permit available
295    pub fn try_spawn_with_timeout<T>(&self, future: T, timeout: Duration) -> SpawnResult<T>
296    where
297        T: Future + Send + 'static,
298        T::Output: Send + 'static,
299    {
300        self.try_spawn_task(Task::new(future).with_timeout(timeout))
301    }
302    /// Spawns a task (a future which can have a custom ID and timeout) if there is an available
303    /// permit. Returns `Error::NotAvailable` if no permit available
304    pub fn try_spawn_task<T>(&self, task: Task<T>) -> SpawnResult<T>
305    where
306        T: Future + Send + 'static,
307        T::Output: Send + 'static,
308    {
309        #[cfg(any(feature = "log", feature = "tracing"))]
310        let id = self.id.as_ref().cloned();
311        let perm = if let Some(ref limiter) = self.limiter {
312            Some(
313                limiter
314                    .clone()
315                    .try_acquire_owned()
316                    .map_err(|_| Error::NotAvailable)?,
317            )
318        } else {
319            None
320        };
321        if let Some(rtimeout) = task.timeout.or(self.run_timeout) {
322            #[cfg(any(feature = "log", feature = "tracing"))]
323            let logging_enabled = self.logging_enabled;
324            Ok(tokio::spawn(async move {
325                let _p = perm;
326                if let Ok(v) = tokio::time::timeout(rtimeout, task.future).await {
327                    Ok(v)
328                } else {
329                    let e = Error::RunTimeout(task.id);
330                    #[cfg(any(feature = "log", feature = "tracing"))]
331                    if logging_enabled {
332                        #[cfg(feature = "log")]
333                        error!("{}: {}", id.as_deref().map_or("", |v| v.as_str()), e);
334
335                        #[cfg(feature = "tracing")]
336                        event!(
337                            Level::ERROR,
338                            error = ?e,
339                            id = id.as_deref().map_or("", |v| v.as_str())
340                        );
341                    }
342                    Err(e)
343                }
344            }))
345        } else {
346            Ok(tokio::spawn(async move {
347                let _p = perm;
348                Ok(task.future.await)
349            }))
350        }
351    }
352}
353
354#[cfg(test)]
355mod test {
356    use super::Pool;
357    use std::sync::atomic::{AtomicUsize, Ordering};
358    use std::sync::Arc;
359    use std::time::Duration;
360    use tokio::sync::mpsc::channel;
361    use tokio::time::sleep;
362
363    #[tokio::test]
364    async fn test_spawn() {
365        let pool = Pool::bounded(5);
366        let counter = Arc::new(AtomicUsize::new(0));
367        for _ in 1..=5 {
368            let counter_c = counter.clone();
369            pool.spawn(async move {
370                sleep(Duration::from_secs(2)).await;
371                counter_c.fetch_add(1, Ordering::SeqCst);
372            })
373            .await
374            .unwrap();
375        }
376        sleep(Duration::from_secs(3)).await;
377        assert_eq!(counter.load(Ordering::SeqCst), 5);
378    }
379
380    #[tokio::test]
381    async fn test_spawn_timeout() {
382        let pool = Pool::bounded(5).with_spawn_timeout(Duration::from_secs(1));
383        for _ in 1..=5 {
384            let (tx, mut rx) = channel(1);
385            pool.spawn(async move {
386                tx.send(()).await.unwrap();
387                sleep(Duration::from_secs(2)).await;
388            })
389            .await
390            .unwrap();
391            rx.recv().await;
392        }
393        dbg!(pool.available_permits(), pool.busy_permits());
394        assert!(pool
395            .spawn(async move {
396                sleep(Duration::from_secs(2)).await;
397            })
398            .await
399            .is_err());
400    }
401
402    #[tokio::test]
403    async fn test_run_timeout() {
404        let pool = Pool::bounded(5).with_run_timeout(Duration::from_secs(2));
405        let counter = Arc::new(AtomicUsize::new(0));
406        for i in 1..=5 {
407            let counter_c = counter.clone();
408            pool.spawn(async move {
409                sleep(Duration::from_secs(if i == 5 { 3 } else { 1 })).await;
410                counter_c.fetch_add(1, Ordering::SeqCst);
411            })
412            .await
413            .unwrap();
414        }
415        sleep(Duration::from_secs(5)).await;
416        assert_eq!(counter.load(Ordering::SeqCst), 4);
417    }
418}