1#![ doc = include_str!( concat!( env!( "CARGO_MANIFEST_DIR" ), "/", "README.md" ) ) ]
2#[cfg(feature = "log")]
3use log::error;
4use std::fmt;
5use std::future::Future;
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::Semaphore;
9use tokio::task::JoinHandle;
10#[cfg(feature = "tracing")]
11use tracing::{event, Level};
12
13pub type SpawnResult<T> = Result<JoinHandle<Result<<T as Future>::Output, Error>>, Error>;
14
15#[derive(Debug, Clone, Eq, PartialEq)]
17pub enum TaskId {
18 Static(&'static str),
19 Owned(String),
20}
21
22impl From<&'static str> for TaskId {
23 #[inline]
24 fn from(s: &'static str) -> Self {
25 Self::Static(s)
26 }
27}
28
29impl From<String> for TaskId {
30 #[inline]
31 fn from(s: String) -> Self {
32 Self::Owned(s)
33 }
34}
35
36impl TaskId {
37 #[inline]
38 fn as_str(&self) -> &str {
39 match self {
40 TaskId::Static(v) => v,
41 TaskId::Owned(s) => s.as_str(),
42 }
43 }
44}
45
46pub struct Task<T>
50where
51 T: Future + Send + 'static,
52 T::Output: Send + 'static,
53{
54 id: Option<TaskId>,
55 timeout: Option<Duration>,
56 future: T,
57}
58
59impl<T> Task<T>
60where
61 T: Future + Send + 'static,
62 T::Output: Send + 'static,
63{
64 #[inline]
65 pub fn new(future: T) -> Self {
66 Self {
67 id: None,
68 timeout: None,
69 future,
70 }
71 }
72 #[inline]
73 pub fn with_id<I: Into<TaskId>>(mut self, id: I) -> Self {
74 self.id = Some(id.into());
75 self
76 }
77 #[inline]
78 pub fn with_timeout(mut self, timeout: Duration) -> Self {
79 self.timeout = Some(timeout);
80 self
81 }
82}
83
84#[derive(Debug, Clone, Eq, PartialEq)]
85pub enum Error {
86 SpawnTimeout,
87 RunTimeout(Option<TaskId>),
88 SpawnSemaphoneAcquireError,
89 NotAvailable,
90}
91
92impl fmt::Display for Error {
93 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
94 match self {
95 Error::SpawnTimeout => write!(f, "task spawn timeout"),
96 Error::RunTimeout(id) => {
97 if let Some(i) = id {
98 write!(f, "task {} run timeout", i.as_str())
99 } else {
100 write!(f, "task run timeout")
101 }
102 }
103 Error::SpawnSemaphoneAcquireError => write!(f, "task spawn semaphore error"),
104 Error::NotAvailable => write!(f, "no available task slots"),
105 }
106 }
107}
108
109impl std::error::Error for Error {}
110
111impl From<tokio::sync::AcquireError> for Error {
112 fn from(_: tokio::sync::AcquireError) -> Self {
113 Self::SpawnSemaphoneAcquireError
114 }
115}
116
117#[derive(Debug)]
119pub struct Pool {
120 id: Option<Arc<String>>,
121 spawn_timeout: Option<Duration>,
122 run_timeout: Option<Duration>,
123 limiter: Option<Arc<Semaphore>>,
124 capacity: Option<usize>,
125 #[cfg(any(feature = "log", feature = "tracing"))]
126 logging_enabled: bool,
127}
128
129impl Default for Pool {
130 fn default() -> Self {
131 Self::unbounded()
132 }
133}
134
135impl Pool {
136 pub fn bounded(capacity: usize) -> Self {
138 Self {
139 id: None,
140 spawn_timeout: None,
141 run_timeout: None,
142 limiter: Some(Arc::new(Semaphore::new(capacity))),
143 capacity: Some(capacity),
144 #[cfg(any(feature = "log", feature = "tracing"))]
145 logging_enabled: true,
146 }
147 }
148 pub fn unbounded() -> Self {
150 Self {
151 id: None,
152 spawn_timeout: None,
153 run_timeout: None,
154 limiter: None,
155 capacity: None,
156 #[cfg(any(feature = "log", feature = "tracing"))]
157 logging_enabled: true,
158 }
159 }
160 pub fn with_id<I: Into<String>>(mut self, id: I) -> Self {
161 self.id.replace(Arc::new(id.into()));
162 self
163 }
164 pub fn id(&self) -> Option<&str> {
165 self.id.as_deref().map(String::as_str)
166 }
167 #[inline]
171 pub fn with_spawn_timeout(mut self, timeout: Duration) -> Self {
172 self.spawn_timeout = Some(timeout);
173 self
174 }
175 #[inline]
177 pub fn with_run_timeout(mut self, timeout: Duration) -> Self {
178 self.run_timeout = Some(timeout);
179 self
180 }
181 #[inline]
183 pub fn with_timeout(self, timeout: Duration) -> Self {
184 self.with_spawn_timeout(timeout).with_run_timeout(timeout)
185 }
186 #[cfg(any(feature = "log", feature = "tracing"))]
187 #[inline]
189 pub fn with_no_logging_enabled(mut self) -> Self {
190 self.logging_enabled = false;
191 self
192 }
193 #[inline]
195 pub fn capacity(&self) -> Option<usize> {
196 self.capacity
197 }
198 #[inline]
200 pub fn available_permits(&self) -> Option<usize> {
201 self.limiter.as_ref().map(|v| v.available_permits())
202 }
203 #[inline]
205 pub fn busy_permits(&self) -> Option<usize> {
206 self.limiter
207 .as_ref()
208 .map(|v| self.capacity.unwrap_or_default() - v.available_permits())
209 }
210 #[inline]
212 pub fn spawn<T>(&self, future: T) -> impl Future<Output = SpawnResult<T>> + '_
213 where
214 T: Future + Send + 'static,
215 T::Output: Send + 'static,
216 {
217 self.spawn_task(Task::new(future))
218 }
219 #[inline]
221 pub fn spawn_with_timeout<T>(
222 &self,
223 future: T,
224 timeout: Duration,
225 ) -> impl Future<Output = SpawnResult<T>> + '_
226 where
227 T: Future + Send + 'static,
228 T::Output: Send + 'static,
229 {
230 self.spawn_task(Task::new(future).with_timeout(timeout))
231 }
232 pub async fn spawn_task<T>(&self, task: Task<T>) -> SpawnResult<T>
234 where
235 T: Future + Send + 'static,
236 T::Output: Send + 'static,
237 {
238 #[cfg(any(feature = "log", feature = "tracing"))]
239 let id = self.id.as_ref().cloned();
240 let perm = if let Some(ref limiter) = self.limiter {
241 if let Some(spawn_timeout) = self.spawn_timeout {
242 Some(
243 tokio::time::timeout(spawn_timeout, limiter.clone().acquire_owned())
244 .await
245 .map_err(|_| Error::SpawnTimeout)??,
246 )
247 } else {
248 Some(limiter.clone().acquire_owned().await?)
249 }
250 } else {
251 None
252 };
253 if let Some(rtimeout) = task.timeout.or(self.run_timeout) {
254 #[cfg(any(feature = "log", feature = "tracing"))]
255 let logging_enabled = self.logging_enabled;
256 Ok(tokio::spawn(async move {
257 let _p = perm;
258 if let Ok(v) = tokio::time::timeout(rtimeout, task.future).await {
259 Ok(v)
260 } else {
261 let e = Error::RunTimeout(task.id);
262 #[cfg(any(feature = "log", feature = "tracing"))]
263 if logging_enabled {
264 #[cfg(feature = "log")]
265 error!("{}: {}", id.as_deref().map_or("", |v| v.as_str()), e);
266
267 #[cfg(feature = "tracing")]
268 event!(
269 Level::ERROR,
270 error = ?e,
271 id = id.as_deref().map_or("", |v| v.as_str())
272 );
273 }
274 Err(e)
275 }
276 }))
277 } else {
278 Ok(tokio::spawn(async move {
279 let _p = perm;
280 Ok(task.future.await)
281 }))
282 }
283 }
284 pub fn try_spawn<T>(&self, future: T) -> SpawnResult<T>
287 where
288 T: Future + Send + 'static,
289 T::Output: Send + 'static,
290 {
291 self.try_spawn_task(Task::new(future))
292 }
293 pub fn try_spawn_with_timeout<T>(&self, future: T, timeout: Duration) -> SpawnResult<T>
296 where
297 T: Future + Send + 'static,
298 T::Output: Send + 'static,
299 {
300 self.try_spawn_task(Task::new(future).with_timeout(timeout))
301 }
302 pub fn try_spawn_task<T>(&self, task: Task<T>) -> SpawnResult<T>
305 where
306 T: Future + Send + 'static,
307 T::Output: Send + 'static,
308 {
309 #[cfg(any(feature = "log", feature = "tracing"))]
310 let id = self.id.as_ref().cloned();
311 let perm = if let Some(ref limiter) = self.limiter {
312 Some(
313 limiter
314 .clone()
315 .try_acquire_owned()
316 .map_err(|_| Error::NotAvailable)?,
317 )
318 } else {
319 None
320 };
321 if let Some(rtimeout) = task.timeout.or(self.run_timeout) {
322 #[cfg(any(feature = "log", feature = "tracing"))]
323 let logging_enabled = self.logging_enabled;
324 Ok(tokio::spawn(async move {
325 let _p = perm;
326 if let Ok(v) = tokio::time::timeout(rtimeout, task.future).await {
327 Ok(v)
328 } else {
329 let e = Error::RunTimeout(task.id);
330 #[cfg(any(feature = "log", feature = "tracing"))]
331 if logging_enabled {
332 #[cfg(feature = "log")]
333 error!("{}: {}", id.as_deref().map_or("", |v| v.as_str()), e);
334
335 #[cfg(feature = "tracing")]
336 event!(
337 Level::ERROR,
338 error = ?e,
339 id = id.as_deref().map_or("", |v| v.as_str())
340 );
341 }
342 Err(e)
343 }
344 }))
345 } else {
346 Ok(tokio::spawn(async move {
347 let _p = perm;
348 Ok(task.future.await)
349 }))
350 }
351 }
352}
353
354#[cfg(test)]
355mod test {
356 use super::Pool;
357 use std::sync::atomic::{AtomicUsize, Ordering};
358 use std::sync::Arc;
359 use std::time::Duration;
360 use tokio::sync::mpsc::channel;
361 use tokio::time::sleep;
362
363 #[tokio::test]
364 async fn test_spawn() {
365 let pool = Pool::bounded(5);
366 let counter = Arc::new(AtomicUsize::new(0));
367 for _ in 1..=5 {
368 let counter_c = counter.clone();
369 pool.spawn(async move {
370 sleep(Duration::from_secs(2)).await;
371 counter_c.fetch_add(1, Ordering::SeqCst);
372 })
373 .await
374 .unwrap();
375 }
376 sleep(Duration::from_secs(3)).await;
377 assert_eq!(counter.load(Ordering::SeqCst), 5);
378 }
379
380 #[tokio::test]
381 async fn test_spawn_timeout() {
382 let pool = Pool::bounded(5).with_spawn_timeout(Duration::from_secs(1));
383 for _ in 1..=5 {
384 let (tx, mut rx) = channel(1);
385 pool.spawn(async move {
386 tx.send(()).await.unwrap();
387 sleep(Duration::from_secs(2)).await;
388 })
389 .await
390 .unwrap();
391 rx.recv().await;
392 }
393 dbg!(pool.available_permits(), pool.busy_permits());
394 assert!(pool
395 .spawn(async move {
396 sleep(Duration::from_secs(2)).await;
397 })
398 .await
399 .is_err());
400 }
401
402 #[tokio::test]
403 async fn test_run_timeout() {
404 let pool = Pool::bounded(5).with_run_timeout(Duration::from_secs(2));
405 let counter = Arc::new(AtomicUsize::new(0));
406 for i in 1..=5 {
407 let counter_c = counter.clone();
408 pool.spawn(async move {
409 sleep(Duration::from_secs(if i == 5 { 3 } else { 1 })).await;
410 counter_c.fetch_add(1, Ordering::SeqCst);
411 })
412 .await
413 .unwrap();
414 }
415 sleep(Duration::from_secs(5)).await;
416 assert_eq!(counter.load(Ordering::SeqCst), 4);
417 }
418}