spdlog/
thread_pool.rs

1use std::thread::{self, JoinHandle};
2
3use crossbeam::channel::{self as mpmc, Receiver, Sender};
4use once_cell::sync::Lazy;
5
6use crate::{
7    error::{Error, InvalidArgumentError},
8    sink::{OverflowPolicy, Task},
9    sync::*,
10    Result,
11};
12
13/// A thread pool for processing operations asynchronously.
14///
15/// Currently only used in [`AsyncPoolSink`].
16///
17/// # Examples
18///
19/// ```
20/// # use std::sync::Arc;
21/// use spdlog::{sink::AsyncPoolSink, ThreadPool};
22///
23/// # fn main() -> Result<(), spdlog::Error> {
24/// # let underlying_sink = spdlog::default_logger().sinks().first().unwrap().clone();
25/// let thread_pool = Arc::new(ThreadPool::new()?);
26/// let async_pool_sink = AsyncPoolSink::builder()
27///     .sink(underlying_sink)
28///     .thread_pool(thread_pool)
29///     .build()?;
30/// # Ok(()) }
31/// ```
32///
33/// [`AsyncPoolSink`]: crate::sink::AsyncPoolSink
34pub struct ThreadPool(ArcSwapOption<ThreadPoolInner>);
35
36struct ThreadPoolInner {
37    threads: Vec<Option<JoinHandle<()>>>,
38    sender: Option<Sender<Task>>,
39}
40
41type Callback = Arc<dyn Fn() + Send + Sync + 'static>;
42
43#[allow(missing_docs)]
44pub struct ThreadPoolBuilder {
45    capacity: usize,
46    threads: usize,
47    on_thread_spawn: Option<Callback>,
48    on_thread_finish: Option<Callback>,
49}
50
51struct Worker {
52    receiver: Receiver<Task>,
53}
54
55impl ThreadPool {
56    /// Gets a builder of `ThreadPool` with default parameters:
57    ///
58    /// | Parameter          | Default Value                     |
59    /// |--------------------|-----------------------------------|
60    /// | [capacity]         | `8192` (may change in the future) |
61    /// | [on_thread_spawn]  | `None`                            |
62    /// | [on_thread_finish] | `None`                            |
63    ///
64    /// [capacity]: ThreadPoolBuilder::capacity
65    /// [on_thread_spawn]: ThreadPoolBuilder::on_thread_spawn
66    /// [on_thread_finish]: ThreadPoolBuilder::on_thread_finish
67    #[must_use]
68    pub fn builder() -> ThreadPoolBuilder {
69        ThreadPoolBuilder {
70            capacity: 8192,
71            threads: 1,
72            on_thread_spawn: None,
73            on_thread_finish: None,
74        }
75    }
76
77    /// Constructs a `ThreadPool` with default parameters (see documentation of
78    /// [`ThreadPool::builder`]).
79    pub fn new() -> Result<Self> {
80        Self::builder().build()
81    }
82
83    pub(super) fn assign_task(&self, task: Task, overflow_policy: OverflowPolicy) -> Result<()> {
84        let inner = self.0.load();
85        let sender = inner.as_ref().unwrap().sender.as_ref().unwrap();
86
87        match overflow_policy {
88            OverflowPolicy::Block => sender.send(task).map_err(Error::from_crossbeam_send),
89            OverflowPolicy::DropIncoming => sender
90                .try_send(task)
91                .map_err(Error::from_crossbeam_try_send),
92        }
93    }
94
95    pub(super) fn destroy(&self) {
96        if let Some(mut inner) = self.0.swap(None) {
97            // Or use `Arc::into_inner`, but it requires us to bump MSRV.
98            let inner = Arc::get_mut(&mut inner).unwrap();
99
100            // drop our sender, threads will break the loop after receiving and processing
101            // the remaining tasks
102            inner.sender.take();
103
104            for thread in &mut inner.threads {
105                if let Some(thread) = thread.take() {
106                    thread.join().expect("failed to join a thread from pool");
107                }
108            }
109        }
110    }
111}
112
113impl Drop for ThreadPool {
114    fn drop(&mut self) {
115        self.destroy();
116    }
117}
118
119impl ThreadPoolBuilder {
120    /// Specifies the capacity of the operation channel.
121    ///
122    /// This parameter is **optional**.
123    ///
124    /// When a new operation is incoming, but the channel is full, it will be
125    /// handled by sink according to the [`OverflowPolicy`] that has been set.
126    ///
127    /// # Errors
128    ///
129    /// The `build()` will return [`Error::InvalidArgument`] if the value is
130    /// zero.
131    #[must_use]
132    pub fn capacity(&mut self, capacity: usize) -> &mut Self {
133        self.capacity = capacity;
134        self
135    }
136
137    // The current Sinks are not beneficial with more than one thread, so the method
138    // is not public.
139    //
140    // If it is ready to be made public in the future, please don't forget to
141    // replace the `panic!` in the `build` function with a recoverable error.
142    #[must_use]
143    #[allow(dead_code)]
144    fn threads(&mut self, threads: usize) -> &mut Self {
145        self.threads = threads;
146        self
147    }
148
149    /// Provide a function that will be called on each thread of the thread pool
150    /// immediately after it is spawned. This can, for example, be used to set
151    /// core affinity for each thread.
152    #[must_use]
153    pub fn on_thread_spawn<F>(&mut self, f: F) -> &mut Self
154    where
155        F: Fn() + Send + Sync + 'static,
156    {
157        self.on_thread_spawn = Some(Arc::new(f));
158        self
159    }
160
161    /// Provide a function that will be called on each thread of the thread pool
162    /// just before the thread finishes.
163    #[must_use]
164    pub fn on_thread_finish<F>(&mut self, f: F) -> &mut Self
165    where
166        F: Fn() + Send + Sync + 'static,
167    {
168        self.on_thread_finish = Some(Arc::new(f));
169        self
170    }
171
172    /// Builds a [`ThreadPool`].
173    pub fn build(&self) -> Result<ThreadPool> {
174        if self.capacity < 1 {
175            return Err(Error::InvalidArgument(
176                InvalidArgumentError::ThreadPoolCapacity("cannot be 0".to_string()),
177            ));
178        }
179
180        if self.threads < 1 {
181            // Users cannot currently configure this value, so `panic!` is not a problem
182            // here.
183            panic!("threads of ThreadPool cannot be 0");
184        }
185
186        let (sender, receiver) = mpmc::bounded(self.capacity);
187
188        let mut threads = Vec::new();
189        threads.resize_with(self.threads, || {
190            let receiver = receiver.clone();
191            let on_thread_spawn = self.on_thread_spawn.clone();
192            let on_thread_finish = self.on_thread_finish.clone();
193
194            Some(thread::spawn(move || {
195                if let Some(f) = on_thread_spawn {
196                    f();
197                }
198
199                Worker { receiver }.run();
200
201                if let Some(f) = on_thread_finish {
202                    f();
203                }
204            }))
205        });
206
207        Ok(ThreadPool(ArcSwapOption::new(Some(Arc::new(
208            ThreadPoolInner {
209                threads,
210                sender: Some(sender),
211            },
212        )))))
213    }
214}
215
216impl Worker {
217    fn run(&self) {
218        while let Ok(task) = self.receiver.recv() {
219            task.exec();
220        }
221    }
222}
223
224#[must_use]
225pub(crate) fn default_thread_pool() -> Arc<ThreadPool> {
226    static POOL_WEAK: Lazy<Mutex<Weak<ThreadPool>>> = Lazy::new(|| Mutex::new(Weak::new()));
227
228    let mut pool_weak = POOL_WEAK.lock_expect();
229
230    match pool_weak.upgrade() {
231        Some(pool) => pool,
232        None => {
233            let pool = Arc::new(ThreadPool::builder().build().unwrap());
234            *pool_weak = Arc::downgrade(&pool);
235            pool
236        }
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn error_capacity_0() {
246        assert!(matches!(
247            ThreadPool::builder().capacity(0).build(),
248            Err(Error::InvalidArgument(
249                InvalidArgumentError::ThreadPoolCapacity(_)
250            ))
251        ));
252    }
253
254    #[test]
255    #[should_panic]
256    fn panic_thread_0() {
257        let _ = ThreadPool::builder().threads(0).build();
258    }
259}