Skip to main content

spdlog/
thread_pool.rs

1use std::{
2    num::NonZeroUsize,
3    thread::{self, JoinHandle},
4};
5
6use crossbeam::channel::{self as mpmc, Receiver, Sender};
7use once_cell::sync::Lazy;
8
9use crate::{
10    error::Error,
11    sink::{OverflowPolicy, Task},
12    sync::*,
13    Result,
14};
15
16/// A thread pool for processing operations asynchronously.
17///
18/// Currently only used in [`AsyncPoolSink`].
19///
20/// # Examples
21///
22/// ```
23/// # use std::sync::Arc;
24/// use spdlog::{sink::AsyncPoolSink, ThreadPool};
25///
26/// # fn main() -> Result<(), spdlog::Error> {
27/// # let underlying_sink = spdlog::default_logger().sinks().first().unwrap().clone();
28/// let thread_pool = Arc::new(ThreadPool::new()?);
29/// let async_pool_sink = AsyncPoolSink::builder()
30///     .sink(underlying_sink)
31///     .thread_pool(thread_pool)
32///     .build()?;
33/// # Ok(()) }
34/// ```
35///
36/// [`AsyncPoolSink`]: crate::sink::AsyncPoolSink
37pub struct ThreadPool(ArcSwapOption<ThreadPoolInner>);
38
39struct ThreadPoolInner {
40    threads: Vec<Option<JoinHandle<()>>>,
41    sender: Option<Sender<Task>>,
42}
43
44type Callback = Arc<dyn Fn() + Send + Sync + 'static>;
45
46#[allow(missing_docs)]
47pub struct ThreadPoolBuilder {
48    capacity: NonZeroUsize,
49    threads: NonZeroUsize,
50    on_thread_spawn: Option<Callback>,
51    on_thread_finish: Option<Callback>,
52}
53
54struct Worker {
55    receiver: Receiver<Task>,
56}
57
58impl ThreadPool {
59    /// Gets a builder of `ThreadPool` with default parameters:
60    ///
61    /// | Parameter          | Default Value                     |
62    /// |--------------------|-----------------------------------|
63    /// | [capacity]         | `8192` (may change in the future) |
64    /// | [on_thread_spawn]  | `None`                            |
65    /// | [on_thread_finish] | `None`                            |
66    ///
67    /// [capacity]: ThreadPoolBuilder::capacity
68    /// [on_thread_spawn]: ThreadPoolBuilder::on_thread_spawn
69    /// [on_thread_finish]: ThreadPoolBuilder::on_thread_finish
70    #[must_use]
71    pub fn builder() -> ThreadPoolBuilder {
72        ThreadPoolBuilder {
73            capacity: NonZeroUsize::new(8192).unwrap(),
74            threads: NonZeroUsize::new(1).unwrap(),
75            on_thread_spawn: None,
76            on_thread_finish: None,
77        }
78    }
79
80    /// Constructs a `ThreadPool` with default parameters (see documentation of
81    /// [`ThreadPool::builder`]).
82    pub fn new() -> Result<Self> {
83        Self::builder().build()
84    }
85
86    pub(super) fn assign_task(&self, task: Task, overflow_policy: OverflowPolicy) -> Result<()> {
87        let inner = self.0.load();
88        if let Some(inner) = inner.as_ref() {
89            let sender = inner.sender.as_ref().unwrap();
90
91            match overflow_policy {
92                OverflowPolicy::Block => sender.send(task).map_err(Error::from_crossbeam_send),
93                OverflowPolicy::DropIncoming => sender
94                    .try_send(task)
95                    .map_err(Error::from_crossbeam_try_send),
96            }
97        } else {
98            // https://github.com/SpriteOvO/spdlog-rs/issues/120
99            //
100            // The thread pool has been destroyed
101            //
102            // TODO: Return an error and perform the task directly on the current thread.
103            Ok(())
104        }
105    }
106
107    pub(super) fn destroy(&self) {
108        if let Some(inner) = self.0.swap(None) {
109            // https://github.com/SpriteOvO/spdlog-rs/issues/120
110            //
111            // If a task is being assigned, there will be more than one strong reference,
112            // causing `into_inner` to return `None`.
113            //
114            // TODO: Skip it if it's None. This avoids panic, but might introduce a memory
115            // leak? However, it's not a big deal since this isn't a frequent operation.
116            // Anyway, we should eventually fix it.
117            if let Some(mut inner) = Arc::into_inner(inner) {
118                // drop our sender, threads will break the loop after receiving and processing
119                // the remaining tasks
120                inner.sender.take();
121
122                for thread in &mut inner.threads {
123                    if let Some(thread) = thread.take() {
124                        thread.join().expect("failed to join a thread from pool");
125                    }
126                }
127            }
128        }
129    }
130}
131
132impl Drop for ThreadPool {
133    fn drop(&mut self) {
134        self.destroy();
135    }
136}
137
138impl ThreadPoolBuilder {
139    /// Specifies the capacity of the operation channel.
140    ///
141    /// This parameter is **optional**, and defaults to `8192` (may change in
142    /// the future).
143    ///
144    /// When a new operation is incoming, but the channel is full, it will be
145    /// handled by sink according to the [`OverflowPolicy`] that has been set.
146    #[must_use]
147    pub fn capacity(&mut self, capacity: NonZeroUsize) -> &mut Self {
148        self.capacity = capacity;
149        self
150    }
151
152    // The current Sinks are not beneficial with more than one thread, so the method
153    // is not public.
154    #[must_use]
155    #[allow(dead_code)]
156    fn threads(&mut self, threads: NonZeroUsize) -> &mut Self {
157        self.threads = threads;
158        self
159    }
160
161    /// Provide a function that will be called on each thread of the thread pool
162    /// immediately after it is spawned. This can, for example, be used to set
163    /// core affinity for each thread.
164    #[must_use]
165    pub fn on_thread_spawn<F>(&mut self, f: F) -> &mut Self
166    where
167        F: Fn() + Send + Sync + 'static,
168    {
169        self.on_thread_spawn = Some(Arc::new(f));
170        self
171    }
172
173    /// Provide a function that will be called on each thread of the thread pool
174    /// just before the thread finishes.
175    #[must_use]
176    pub fn on_thread_finish<F>(&mut self, f: F) -> &mut Self
177    where
178        F: Fn() + Send + Sync + 'static,
179    {
180        self.on_thread_finish = Some(Arc::new(f));
181        self
182    }
183
184    /// Builds a [`ThreadPool`].
185    pub fn build(&self) -> Result<ThreadPool> {
186        let (sender, receiver) = mpmc::bounded(self.capacity.get());
187
188        let mut threads = Vec::new();
189        threads.resize_with(self.threads.get(), || {
190            let receiver = receiver.clone();
191            let on_thread_spawn = self.on_thread_spawn.clone();
192            let on_thread_finish = self.on_thread_finish.clone();
193
194            Some(thread::spawn(move || {
195                if let Some(f) = on_thread_spawn {
196                    f();
197                }
198
199                Worker { receiver }.run();
200
201                if let Some(f) = on_thread_finish {
202                    f();
203                }
204            }))
205        });
206
207        Ok(ThreadPool(ArcSwapOption::new(Some(Arc::new(
208            ThreadPoolInner {
209                threads,
210                sender: Some(sender),
211            },
212        )))))
213    }
214
215    /// Builds a `Arc<ThreadPool>`.
216    ///
217    /// This is a shorthand method for `.build().map(Arc::new)`.
218    pub fn build_arc(&self) -> Result<Arc<ThreadPool>> {
219        self.build().map(Arc::new)
220    }
221}
222
223impl Worker {
224    fn run(&self) {
225        while let Ok(task) = self.receiver.recv() {
226            task.exec();
227        }
228    }
229}
230
231#[must_use]
232pub(crate) fn default_thread_pool() -> Arc<ThreadPool> {
233    static POOL_WEAK: Lazy<Mutex<Weak<ThreadPool>>> = Lazy::new(|| Mutex::new(Weak::new()));
234
235    let mut pool_weak = POOL_WEAK.lock_expect();
236
237    match pool_weak.upgrade() {
238        Some(pool) => pool,
239        None => {
240            let pool = ThreadPool::builder().build_arc().unwrap();
241            *pool_weak = Arc::downgrade(&pool);
242            pool
243        }
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use std::{thread::sleep, time::Duration};
250
251    use super::*;
252
253    // https://github.com/SpriteOvO/spdlog-rs/issues/120
254    #[test]
255    fn inner_arc_multiple_strong_refs() {
256        let thread_pool = ThreadPool::builder()
257            .capacity(1.try_into().unwrap())
258            .build_arc()
259            .unwrap();
260
261        let task = || Task::__ForTestUse {
262            sleep: Some(Duration::from_secs(1)),
263        };
264
265        thread_pool
266            .assign_task(task(), OverflowPolicy::Block)
267            .unwrap();
268
269        let (first_blocked_assign, second_blocked_assign, destroy, third_assign) =
270            std::thread::scope(|s| {
271                let first_blocked_assign = s.spawn({
272                    let thread_pool = thread_pool.clone();
273                    move || {
274                        thread_pool
275                            .assign_task(task(), OverflowPolicy::Block)
276                            .unwrap();
277                    }
278                });
279                let second_blocked_assign = s.spawn({
280                    let thread_pool = thread_pool.clone();
281                    move || {
282                        thread_pool
283                            .assign_task(task(), OverflowPolicy::Block)
284                            .unwrap();
285                    }
286                });
287                sleep(Duration::from_millis(200));
288                let destroy = s.spawn({
289                    let thread_pool = thread_pool.clone();
290                    move || {
291                        thread_pool.destroy();
292                    }
293                });
294                let third_assign = s.spawn({
295                    let thread_pool = thread_pool.clone();
296                    move || {
297                        thread_pool
298                            .assign_task(task(), OverflowPolicy::Block)
299                            .unwrap();
300                    }
301                });
302                (
303                    first_blocked_assign.join(),
304                    second_blocked_assign.join(),
305                    destroy.join(),
306                    third_assign.join(),
307                )
308            });
309        first_blocked_assign.unwrap();
310        second_blocked_assign.unwrap();
311        destroy.unwrap();
312        third_assign.unwrap();
313    }
314}