1#![ doc = include_str!( concat!( env!( "CARGO_MANIFEST_DIR" ), "/", "README.md" ) ) ]
2#[cfg(feature = "log")]
3use log::error;
4use std::fmt;
5use std::future::Future;
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::Semaphore;
9use tokio::task::JoinHandle;
10
11pub type SpawnResult<T> = Result<JoinHandle<Result<<T as Future>::Output, Error>>, Error>;
12
13#[derive(Debug, Clone, Eq, PartialEq)]
15pub enum TaskId {
16 Static(&'static str),
17 Owned(String),
18}
19
20impl From<&'static str> for TaskId {
21 #[inline]
22 fn from(s: &'static str) -> Self {
23 Self::Static(s)
24 }
25}
26
27impl From<String> for TaskId {
28 #[inline]
29 fn from(s: String) -> Self {
30 Self::Owned(s)
31 }
32}
33
34impl TaskId {
35 #[inline]
36 fn as_str(&self) -> &str {
37 match self {
38 TaskId::Static(v) => v,
39 TaskId::Owned(s) => s.as_str(),
40 }
41 }
42}
43
44pub struct Task<T>
48where
49 T: Future + Send + 'static,
50 T::Output: Send + 'static,
51{
52 id: Option<TaskId>,
53 timeout: Option<Duration>,
54 future: T,
55}
56
57impl<T> Task<T>
58where
59 T: Future + Send + 'static,
60 T::Output: Send + 'static,
61{
62 #[inline]
63 pub fn new(future: T) -> Self {
64 Self {
65 id: None,
66 timeout: None,
67 future,
68 }
69 }
70 #[inline]
71 pub fn with_id<I: Into<TaskId>>(mut self, id: I) -> Self {
72 self.id = Some(id.into());
73 self
74 }
75 #[inline]
76 pub fn with_timeout(mut self, timeout: Duration) -> Self {
77 self.timeout = Some(timeout);
78 self
79 }
80}
81
82#[derive(Debug, Clone, Eq, PartialEq)]
83pub enum Error {
84 SpawnTimeout,
85 RunTimeout(Option<TaskId>),
86 SpawnSemaphoneAcquireError,
87}
88
89impl fmt::Display for Error {
90 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
91 match self {
92 Error::SpawnTimeout => write!(f, "task spawn timeout"),
93 Error::RunTimeout(id) => {
94 if let Some(i) = id {
95 write!(f, "task {} run timeout", i.as_str())
96 } else {
97 write!(f, "task run timeout")
98 }
99 }
100 Error::SpawnSemaphoneAcquireError => write!(f, "task spawn semaphore error"),
101 }
102 }
103}
104
105impl std::error::Error for Error {}
106
107impl From<tokio::sync::AcquireError> for Error {
108 fn from(_: tokio::sync::AcquireError) -> Self {
109 Self::SpawnSemaphoneAcquireError
110 }
111}
112
113#[derive(Debug)]
115pub struct Pool {
116 id: Option<Arc<String>>,
117 spawn_timeout: Option<Duration>,
118 run_timeout: Option<Duration>,
119 limiter: Option<Arc<Semaphore>>,
120 capacity: Option<usize>,
121 #[cfg(feature = "log")]
122 logging_enabled: bool,
123}
124
125impl Default for Pool {
126 fn default() -> Self {
127 Self::unbounded()
128 }
129}
130
131impl Pool {
132 pub fn bounded(capacity: usize) -> Self {
134 Self {
135 id: None,
136 spawn_timeout: None,
137 run_timeout: None,
138 limiter: Some(Arc::new(Semaphore::new(capacity))),
139 capacity: Some(capacity),
140 #[cfg(feature = "log")]
141 logging_enabled: true,
142 }
143 }
144 pub fn unbounded() -> Self {
146 Self {
147 id: None,
148 spawn_timeout: None,
149 run_timeout: None,
150 limiter: None,
151 capacity: None,
152 #[cfg(feature = "log")]
153 logging_enabled: true,
154 }
155 }
156 pub fn with_id<I: Into<String>>(mut self, id: I) -> Self {
157 self.id.replace(Arc::new(id.into()));
158 self
159 }
160 pub fn id(&self) -> Option<&str> {
161 self.id.as_deref().map(String::as_str)
162 }
163 #[inline]
167 pub fn with_spawn_timeout(mut self, timeout: Duration) -> Self {
168 self.spawn_timeout = Some(timeout);
169 self
170 }
171 #[inline]
173 pub fn with_run_timeout(mut self, timeout: Duration) -> Self {
174 self.run_timeout = Some(timeout);
175 self
176 }
177 #[inline]
179 pub fn with_timeout(self, timeout: Duration) -> Self {
180 self.with_spawn_timeout(timeout).with_run_timeout(timeout)
181 }
182 #[cfg(feature = "log")]
183 #[inline]
185 pub fn with_no_logging_enabled(mut self) -> Self {
186 self.logging_enabled = false;
187 self
188 }
189 #[inline]
191 pub fn capacity(&self) -> Option<usize> {
192 self.capacity
193 }
194 #[inline]
196 pub fn available_permits(&self) -> Option<usize> {
197 self.limiter.as_ref().map(|v| v.available_permits())
198 }
199 #[inline]
201 pub fn busy_permits(&self) -> Option<usize> {
202 self.limiter
203 .as_ref()
204 .map(|v| self.capacity.unwrap_or_default() - v.available_permits())
205 }
206 #[inline]
208 pub fn spawn<T>(&self, future: T) -> impl Future<Output = SpawnResult<T>> + '_
209 where
210 T: Future + Send + 'static,
211 T::Output: Send + 'static,
212 {
213 self.spawn_task(Task::new(future))
214 }
215 #[inline]
217 pub fn spawn_with_timeout<T>(
218 &self,
219 future: T,
220 timeout: Duration,
221 ) -> impl Future<Output = SpawnResult<T>> + '_
222 where
223 T: Future + Send + 'static,
224 T::Output: Send + 'static,
225 {
226 self.spawn_task(Task::new(future).with_timeout(timeout))
227 }
228 pub async fn spawn_task<T>(&self, task: Task<T>) -> SpawnResult<T>
230 where
231 T: Future + Send + 'static,
232 T::Output: Send + 'static,
233 {
234 #[cfg(feature = "log")]
235 let id = self.id.as_ref().cloned();
236 let perm = if let Some(ref limiter) = self.limiter {
237 if let Some(spawn_timeout) = self.spawn_timeout {
238 Some(
239 tokio::time::timeout(spawn_timeout, limiter.clone().acquire_owned())
240 .await
241 .map_err(|_| Error::SpawnTimeout)??,
242 )
243 } else {
244 Some(limiter.clone().acquire_owned().await?)
245 }
246 } else {
247 None
248 };
249 if let Some(rtimeout) = task.timeout.or(self.run_timeout) {
250 #[cfg(feature = "log")]
251 let logging_enabled = self.logging_enabled;
252 Ok(tokio::spawn(async move {
253 let _p = perm;
254 if let Ok(v) = tokio::time::timeout(rtimeout, task.future).await {
255 Ok(v)
256 } else {
257 let e = Error::RunTimeout(task.id);
258 #[cfg(feature = "log")]
259 if logging_enabled {
260 error!("{}: {}", id.as_deref().map_or("", |v| v.as_str()), e);
261 }
262 Err(e)
263 }
264 }))
265 } else {
266 Ok(tokio::spawn(async move {
267 let _p = perm;
268 Ok(task.future.await)
269 }))
270 }
271 }
272}
273
274#[cfg(test)]
275mod test {
276 use super::Pool;
277 use std::sync::atomic::{AtomicUsize, Ordering};
278 use std::sync::Arc;
279 use std::time::Duration;
280 use tokio::sync::mpsc::channel;
281 use tokio::time::sleep;
282
283 #[tokio::test]
284 async fn test_spawn() {
285 let pool = Pool::bounded(5);
286 let counter = Arc::new(AtomicUsize::new(0));
287 for _ in 1..=5 {
288 let counter_c = counter.clone();
289 pool.spawn(async move {
290 sleep(Duration::from_secs(2)).await;
291 counter_c.fetch_add(1, Ordering::SeqCst);
292 })
293 .await
294 .unwrap();
295 }
296 sleep(Duration::from_secs(3)).await;
297 assert_eq!(counter.load(Ordering::SeqCst), 5);
298 }
299
300 #[tokio::test]
301 async fn test_spawn_timeout() {
302 let pool = Pool::bounded(5).with_spawn_timeout(Duration::from_secs(1));
303 for _ in 1..=5 {
304 let (tx, mut rx) = channel(1);
305 pool.spawn(async move {
306 tx.send(()).await.unwrap();
307 sleep(Duration::from_secs(2)).await;
308 })
309 .await
310 .unwrap();
311 rx.recv().await;
312 }
313 dbg!(pool.available_permits(), pool.busy_permits());
314 assert!(pool
315 .spawn(async move {
316 sleep(Duration::from_secs(2)).await;
317 })
318 .await
319 .is_err());
320 }
321
322 #[tokio::test]
323 async fn test_run_timeout() {
324 let pool = Pool::bounded(5).with_run_timeout(Duration::from_secs(2));
325 let counter = Arc::new(AtomicUsize::new(0));
326 for i in 1..=5 {
327 let counter_c = counter.clone();
328 pool.spawn(async move {
329 sleep(Duration::from_secs(if i == 5 { 3 } else { 1 })).await;
330 counter_c.fetch_add(1, Ordering::SeqCst);
331 })
332 .await
333 .unwrap();
334 }
335 sleep(Duration::from_secs(5)).await;
336 assert_eq!(counter.load(Ordering::SeqCst), 4);
337 }
338}