tokio_thread_pool/lib.rs
1use std::sync::Arc;
2
3use tokio::{
4 runtime::{self, Runtime},
5 sync::Semaphore,
6 task::JoinHandle,
7};
8
9/// A small wrapper around the tokio runtime supporting multithreading with max concurrency limits
10///
11/// ```rust
12/// use tokio_thread_pool::ThreadPool;
13///
14/// // Create a pool with default settings
15/// let my_pool = ThreadPool::new(
16/// None, // optional max task concurrency (usize),
17/// None, // optional max number of threads defaulting to the number of CPU cores on the system (usize),
18/// None, // an optional tokio runtime that you provide with your own custom settings (tokio::Runtime)
19/// );
20///
21/// // Create a pool with a limit on task concurrency
22/// let my_pool = ThreadPool::new(Some(10), None, None); // maximimum of ten concurrent tasks running at once
23///
24/// // Create a pool with a limit on spawned threads
25/// let my_pool = ThreadPool::new(None, Some(4), None); // maximimum of four threads for task allocation
26///
27/// // Create a pool with your own runtime provided
28/// let my_pool = ThreadPool::new(
29/// None,
30/// None,
31/// Some(tokio::runtime::Builder::new_multi_thread().build().unwrap())
32/// );
33///
34/// // Spawn async tasks
35/// let handle = my_pool.spawn(async move || {}); // return any value
36/// // Spawn sync tasks
37/// let handle = my_pool.spawn_blocking(move || {}); // return any value
38///
39/// // Get result
40/// let result = handle.await;
41/// ```
42pub struct ThreadPool {
43 pub pool: Runtime,
44 semaphore: Arc<Semaphore>,
45}
46
47impl ThreadPool {
48 /// Constructs a new ThreadPool instance
49 ///
50 /// #### Arguments
51 ///
52 /// `max_concurrency` (`Option<usize>`): optional max task concurrency
53 ///
54 /// `max_threads` (`Option<usize>`): optional max number of threads defaulting to the number of CPU cores on the system
55 ///
56 /// `pool_override` (`Option<tokio::Runtime>`): an optional tokio runtime that you provide with your own custom settings
57 ///
58 /// ```rust
59 /// // Create a pool with default settings
60 /// let my_pool = ThreadPool::new(None, None, None);
61 ///
62 /// // Create a pool with a limit on task concurrency
63 /// let my_pool = ThreadPool::new(Some(10), None, None); // maximimum of ten concurrent tasks running at once
64 ///
65 /// // Create a pool with a limit on spawned threads
66 /// let my_pool = ThreadPool::new(None, Some(4), None); // maximimum of four threads for task allocation
67 ///
68 /// // Create a pool with your own runtime provided
69 /// let my_pool = ThreadPool::new(
70 /// None,
71 /// None,
72 /// Some(tokio::runtime::Builder::new_multi_thread().build().unwrap())
73 /// );
74 /// ```
75 pub fn new(
76 max_concurrency: Option<usize>,
77 max_threads: Option<usize>,
78 pool_override: Option<Runtime>,
79 ) -> ThreadPool {
80 let pool = pool_override.unwrap_or(ThreadPool::create_pool(max_threads));
81 let semaphore = Arc::new(Semaphore::new(
82 max_concurrency.unwrap_or(Semaphore::MAX_PERMITS),
83 ));
84 ThreadPool { pool, semaphore }
85 }
86
87 /// Spawns an async task and returns its `Handler<T>`
88 ///
89 /// #### Arguments
90 ///
91 /// `task` (`(Fn() -> T)`): The task to execute inside of the thread pool
92 ///
93 /// ```rust
94 /// // Create a pool with default settings
95 /// let my_pool = ThreadPool::new(None, None, None);
96 ///
97 /// let my_handle = my_pool.spawn(async move || {});
98 ///
99 /// let result = my_handle.await;
100 /// ```
101 pub fn spawn<T: Send + 'static, F: (Fn() -> T) + 'static + Send>(
102 &mut self,
103 task: F,
104 ) -> JoinHandle<T> {
105 let concurrecy = self.semaphore.clone();
106 self.pool.spawn(async move {
107 let _ticket = concurrecy.acquire().await.unwrap();
108 task()
109 })
110 }
111
112 /// Spawns a synchronous task and returns its `Handler<T>`
113 ///
114 /// #### Arguments
115 ///
116 /// `task` (`(Fn() -> T)`): The task to execute inside of the thread pool
117 ///
118 /// ```rust
119 /// // Create a pool with default settings
120 /// let my_pool = ThreadPool::new(None, None, None);
121 ///
122 /// let my_handle = my_pool.spawn_blocking(async move || {});
123 ///
124 /// let result = my_handle.await;
125 /// ```
126 pub fn spawn_blocking<T: Send + 'static, F: (Fn() -> T) + 'static + Send>(
127 &mut self,
128 task: F,
129 ) -> JoinHandle<T> {
130 self.pool.spawn_blocking(task)
131 }
132
133 fn create_pool(threads: Option<usize>) -> Runtime {
134 let mut pool = runtime::Builder::new_multi_thread();
135 pool.enable_all();
136 match threads {
137 Some(size) => pool.worker_threads(size),
138 None => &pool,
139 };
140 pool.build().unwrap()
141 }
142}