1#![ doc = include_str!( concat!( env!( "CARGO_MANIFEST_DIR" ), "/", "README.md" ) ) ]
2#[cfg(feature = "log")]
3use log::error;
4use std::fmt;
5use std::future::Future;
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::Semaphore;
9use tokio::task::JoinHandle;
10
11pub type SpawnResult<T> = Result<JoinHandle<Result<<T as Future>::Output, Error>>, Error>;
12
13#[derive(Debug, Clone, Eq, PartialEq)]
15pub enum TaskId {
16 Static(&'static str),
17 Owned(String),
18}
19
20impl From<&'static str> for TaskId {
21 #[inline]
22 fn from(s: &'static str) -> Self {
23 Self::Static(s)
24 }
25}
26
27impl From<String> for TaskId {
28 #[inline]
29 fn from(s: String) -> Self {
30 Self::Owned(s)
31 }
32}
33
34impl TaskId {
35 #[inline]
36 fn as_str(&self) -> &str {
37 match self {
38 TaskId::Static(v) => v,
39 TaskId::Owned(s) => s.as_str(),
40 }
41 }
42}
43
44pub struct Task<T>
48where
49 T: Future + Send + 'static,
50 T::Output: Send + 'static,
51{
52 id: Option<TaskId>,
53 timeout: Option<Duration>,
54 future: T,
55}
56
57impl<T> Task<T>
58where
59 T: Future + Send + 'static,
60 T::Output: Send + 'static,
61{
62 #[inline]
63 pub fn new(future: T) -> Self {
64 Self {
65 id: None,
66 timeout: None,
67 future,
68 }
69 }
70 #[inline]
71 pub fn with_id<I: Into<TaskId>>(mut self, id: I) -> Self {
72 self.id = Some(id.into());
73 self
74 }
75 #[inline]
76 pub fn with_timeout(mut self, timeout: Duration) -> Self {
77 self.timeout = Some(timeout);
78 self
79 }
80}
81
82#[derive(Debug, Clone, Eq, PartialEq)]
83pub enum Error {
84 SpawnTimeout,
85 RunTimeout(Option<TaskId>),
86 SpawnSemaphoneAcquireError,
87 NotAvailable,
88}
89
90impl fmt::Display for Error {
91 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
92 match self {
93 Error::SpawnTimeout => write!(f, "task spawn timeout"),
94 Error::RunTimeout(id) => {
95 if let Some(i) = id {
96 write!(f, "task {} run timeout", i.as_str())
97 } else {
98 write!(f, "task run timeout")
99 }
100 }
101 Error::SpawnSemaphoneAcquireError => write!(f, "task spawn semaphore error"),
102 Error::NotAvailable => write!(f, "no available task slots"),
103 }
104 }
105}
106
107impl std::error::Error for Error {}
108
109impl From<tokio::sync::AcquireError> for Error {
110 fn from(_: tokio::sync::AcquireError) -> Self {
111 Self::SpawnSemaphoneAcquireError
112 }
113}
114
115#[derive(Debug)]
117pub struct Pool {
118 id: Option<Arc<String>>,
119 spawn_timeout: Option<Duration>,
120 run_timeout: Option<Duration>,
121 limiter: Option<Arc<Semaphore>>,
122 capacity: Option<usize>,
123 #[cfg(feature = "log")]
124 logging_enabled: bool,
125}
126
127impl Default for Pool {
128 fn default() -> Self {
129 Self::unbounded()
130 }
131}
132
133impl Pool {
134 pub fn bounded(capacity: usize) -> Self {
136 Self {
137 id: None,
138 spawn_timeout: None,
139 run_timeout: None,
140 limiter: Some(Arc::new(Semaphore::new(capacity))),
141 capacity: Some(capacity),
142 #[cfg(feature = "log")]
143 logging_enabled: true,
144 }
145 }
146 pub fn unbounded() -> Self {
148 Self {
149 id: None,
150 spawn_timeout: None,
151 run_timeout: None,
152 limiter: None,
153 capacity: None,
154 #[cfg(feature = "log")]
155 logging_enabled: true,
156 }
157 }
158 pub fn with_id<I: Into<String>>(mut self, id: I) -> Self {
159 self.id.replace(Arc::new(id.into()));
160 self
161 }
162 pub fn id(&self) -> Option<&str> {
163 self.id.as_deref().map(String::as_str)
164 }
165 #[inline]
169 pub fn with_spawn_timeout(mut self, timeout: Duration) -> Self {
170 self.spawn_timeout = Some(timeout);
171 self
172 }
173 #[inline]
175 pub fn with_run_timeout(mut self, timeout: Duration) -> Self {
176 self.run_timeout = Some(timeout);
177 self
178 }
179 #[inline]
181 pub fn with_timeout(self, timeout: Duration) -> Self {
182 self.with_spawn_timeout(timeout).with_run_timeout(timeout)
183 }
184 #[cfg(feature = "log")]
185 #[inline]
187 pub fn with_no_logging_enabled(mut self) -> Self {
188 self.logging_enabled = false;
189 self
190 }
191 #[inline]
193 pub fn capacity(&self) -> Option<usize> {
194 self.capacity
195 }
196 #[inline]
198 pub fn available_permits(&self) -> Option<usize> {
199 self.limiter.as_ref().map(|v| v.available_permits())
200 }
201 #[inline]
203 pub fn busy_permits(&self) -> Option<usize> {
204 self.limiter
205 .as_ref()
206 .map(|v| self.capacity.unwrap_or_default() - v.available_permits())
207 }
208 #[inline]
210 pub fn spawn<T>(&self, future: T) -> impl Future<Output = SpawnResult<T>> + '_
211 where
212 T: Future + Send + 'static,
213 T::Output: Send + 'static,
214 {
215 self.spawn_task(Task::new(future))
216 }
217 #[inline]
219 pub fn spawn_with_timeout<T>(
220 &self,
221 future: T,
222 timeout: Duration,
223 ) -> impl Future<Output = SpawnResult<T>> + '_
224 where
225 T: Future + Send + 'static,
226 T::Output: Send + 'static,
227 {
228 self.spawn_task(Task::new(future).with_timeout(timeout))
229 }
230 pub async fn spawn_task<T>(&self, task: Task<T>) -> SpawnResult<T>
232 where
233 T: Future + Send + 'static,
234 T::Output: Send + 'static,
235 {
236 #[cfg(feature = "log")]
237 let id = self.id.as_ref().cloned();
238 let perm = if let Some(ref limiter) = self.limiter {
239 if let Some(spawn_timeout) = self.spawn_timeout {
240 Some(
241 tokio::time::timeout(spawn_timeout, limiter.clone().acquire_owned())
242 .await
243 .map_err(|_| Error::SpawnTimeout)??,
244 )
245 } else {
246 Some(limiter.clone().acquire_owned().await?)
247 }
248 } else {
249 None
250 };
251 if let Some(rtimeout) = task.timeout.or(self.run_timeout) {
252 #[cfg(feature = "log")]
253 let logging_enabled = self.logging_enabled;
254 Ok(tokio::spawn(async move {
255 let _p = perm;
256 if let Ok(v) = tokio::time::timeout(rtimeout, task.future).await {
257 Ok(v)
258 } else {
259 let e = Error::RunTimeout(task.id);
260 #[cfg(feature = "log")]
261 if logging_enabled {
262 error!("{}: {}", id.as_deref().map_or("", |v| v.as_str()), e);
263 }
264 Err(e)
265 }
266 }))
267 } else {
268 Ok(tokio::spawn(async move {
269 let _p = perm;
270 Ok(task.future.await)
271 }))
272 }
273 }
274 pub fn try_spawn<T>(&self, future: T) -> SpawnResult<T>
277 where
278 T: Future + Send + 'static,
279 T::Output: Send + 'static,
280 {
281 self.try_spawn_task(Task::new(future))
282 }
283 pub fn try_spawn_with_timeout<T>(&self, future: T, timeout: Duration) -> SpawnResult<T>
286 where
287 T: Future + Send + 'static,
288 T::Output: Send + 'static,
289 {
290 self.try_spawn_task(Task::new(future).with_timeout(timeout))
291 }
292 pub fn try_spawn_task<T>(&self, task: Task<T>) -> SpawnResult<T>
295 where
296 T: Future + Send + 'static,
297 T::Output: Send + 'static,
298 {
299 #[cfg(feature = "log")]
300 let id = self.id.as_ref().cloned();
301 let perm = if let Some(ref limiter) = self.limiter {
302 Some(
303 limiter
304 .clone()
305 .try_acquire_owned()
306 .map_err(|_| Error::NotAvailable)?,
307 )
308 } else {
309 None
310 };
311 if let Some(rtimeout) = task.timeout.or(self.run_timeout) {
312 #[cfg(feature = "log")]
313 let logging_enabled = self.logging_enabled;
314 Ok(tokio::spawn(async move {
315 let _p = perm;
316 if let Ok(v) = tokio::time::timeout(rtimeout, task.future).await {
317 Ok(v)
318 } else {
319 let e = Error::RunTimeout(task.id);
320 #[cfg(feature = "log")]
321 if logging_enabled {
322 error!("{}: {}", id.as_deref().map_or("", |v| v.as_str()), e);
323 }
324 Err(e)
325 }
326 }))
327 } else {
328 Ok(tokio::spawn(async move {
329 let _p = perm;
330 Ok(task.future.await)
331 }))
332 }
333 }
334}
335
336#[cfg(test)]
337mod test {
338 use super::Pool;
339 use std::sync::atomic::{AtomicUsize, Ordering};
340 use std::sync::Arc;
341 use std::time::Duration;
342 use tokio::sync::mpsc::channel;
343 use tokio::time::sleep;
344
345 #[tokio::test]
346 async fn test_spawn() {
347 let pool = Pool::bounded(5);
348 let counter = Arc::new(AtomicUsize::new(0));
349 for _ in 1..=5 {
350 let counter_c = counter.clone();
351 pool.spawn(async move {
352 sleep(Duration::from_secs(2)).await;
353 counter_c.fetch_add(1, Ordering::SeqCst);
354 })
355 .await
356 .unwrap();
357 }
358 sleep(Duration::from_secs(3)).await;
359 assert_eq!(counter.load(Ordering::SeqCst), 5);
360 }
361
362 #[tokio::test]
363 async fn test_spawn_timeout() {
364 let pool = Pool::bounded(5).with_spawn_timeout(Duration::from_secs(1));
365 for _ in 1..=5 {
366 let (tx, mut rx) = channel(1);
367 pool.spawn(async move {
368 tx.send(()).await.unwrap();
369 sleep(Duration::from_secs(2)).await;
370 })
371 .await
372 .unwrap();
373 rx.recv().await;
374 }
375 dbg!(pool.available_permits(), pool.busy_permits());
376 assert!(pool
377 .spawn(async move {
378 sleep(Duration::from_secs(2)).await;
379 })
380 .await
381 .is_err());
382 }
383
384 #[tokio::test]
385 async fn test_run_timeout() {
386 let pool = Pool::bounded(5).with_run_timeout(Duration::from_secs(2));
387 let counter = Arc::new(AtomicUsize::new(0));
388 for i in 1..=5 {
389 let counter_c = counter.clone();
390 pool.spawn(async move {
391 sleep(Duration::from_secs(if i == 5 { 3 } else { 1 })).await;
392 counter_c.fetch_add(1, Ordering::SeqCst);
393 })
394 .await
395 .unwrap();
396 }
397 sleep(Duration::from_secs(5)).await;
398 assert_eq!(counter.load(Ordering::SeqCst), 4);
399 }
400}