spdlog/
thread_pool.rs

1use std::{
2    num::NonZeroUsize,
3    thread::{self, JoinHandle},
4};
5
6use crossbeam::channel::{self as mpmc, Receiver, Sender};
7use once_cell::sync::Lazy;
8
9use crate::{
10    error::Error,
11    sink::{OverflowPolicy, Task},
12    sync::*,
13    Result,
14};
15
16/// A thread pool for processing operations asynchronously.
17///
18/// Currently only used in [`AsyncPoolSink`].
19///
20/// # Examples
21///
22/// ```
23/// # use std::sync::Arc;
24/// use spdlog::{sink::AsyncPoolSink, ThreadPool};
25///
26/// # fn main() -> Result<(), spdlog::Error> {
27/// # let underlying_sink = spdlog::default_logger().sinks().first().unwrap().clone();
28/// let thread_pool = Arc::new(ThreadPool::new()?);
29/// let async_pool_sink = AsyncPoolSink::builder()
30///     .sink(underlying_sink)
31///     .thread_pool(thread_pool)
32///     .build()?;
33/// # Ok(()) }
34/// ```
35///
36/// [`AsyncPoolSink`]: crate::sink::AsyncPoolSink
37pub struct ThreadPool(ArcSwapOption<ThreadPoolInner>);
38
39struct ThreadPoolInner {
40    threads: Vec<Option<JoinHandle<()>>>,
41    sender: Option<Sender<Task>>,
42}
43
44type Callback = Arc<dyn Fn() + Send + Sync + 'static>;
45
46#[allow(missing_docs)]
47pub struct ThreadPoolBuilder {
48    capacity: NonZeroUsize,
49    threads: NonZeroUsize,
50    on_thread_spawn: Option<Callback>,
51    on_thread_finish: Option<Callback>,
52}
53
54struct Worker {
55    receiver: Receiver<Task>,
56}
57
58impl ThreadPool {
59    /// Gets a builder of `ThreadPool` with default parameters:
60    ///
61    /// | Parameter          | Default Value                     |
62    /// |--------------------|-----------------------------------|
63    /// | [capacity]         | `8192` (may change in the future) |
64    /// | [on_thread_spawn]  | `None`                            |
65    /// | [on_thread_finish] | `None`                            |
66    ///
67    /// [capacity]: ThreadPoolBuilder::capacity
68    /// [on_thread_spawn]: ThreadPoolBuilder::on_thread_spawn
69    /// [on_thread_finish]: ThreadPoolBuilder::on_thread_finish
70    #[must_use]
71    pub fn builder() -> ThreadPoolBuilder {
72        ThreadPoolBuilder {
73            capacity: NonZeroUsize::new(8192).unwrap(),
74            threads: NonZeroUsize::new(1).unwrap(),
75            on_thread_spawn: None,
76            on_thread_finish: None,
77        }
78    }
79
80    /// Constructs a `ThreadPool` with default parameters (see documentation of
81    /// [`ThreadPool::builder`]).
82    pub fn new() -> Result<Self> {
83        Self::builder().build()
84    }
85
86    pub(super) fn assign_task(&self, task: Task, overflow_policy: OverflowPolicy) -> Result<()> {
87        let inner = self.0.load();
88        let sender = inner.as_ref().unwrap().sender.as_ref().unwrap();
89
90        match overflow_policy {
91            OverflowPolicy::Block => sender.send(task).map_err(Error::from_crossbeam_send),
92            OverflowPolicy::DropIncoming => sender
93                .try_send(task)
94                .map_err(Error::from_crossbeam_try_send),
95        }
96    }
97
98    pub(super) fn destroy(&self) {
99        if let Some(mut inner) = self.0.swap(None) {
100            // Or use `Arc::into_inner`, but it requires us to bump MSRV.
101            let inner = Arc::get_mut(&mut inner).unwrap();
102
103            // drop our sender, threads will break the loop after receiving and processing
104            // the remaining tasks
105            inner.sender.take();
106
107            for thread in &mut inner.threads {
108                if let Some(thread) = thread.take() {
109                    thread.join().expect("failed to join a thread from pool");
110                }
111            }
112        }
113    }
114}
115
116impl Drop for ThreadPool {
117    fn drop(&mut self) {
118        self.destroy();
119    }
120}
121
122impl ThreadPoolBuilder {
123    /// Specifies the capacity of the operation channel.
124    ///
125    /// This parameter is **optional**.
126    ///
127    /// When a new operation is incoming, but the channel is full, it will be
128    /// handled by sink according to the [`OverflowPolicy`] that has been set.
129    #[must_use]
130    pub fn capacity(&mut self, capacity: NonZeroUsize) -> &mut Self {
131        self.capacity = capacity;
132        self
133    }
134
135    // The current Sinks are not beneficial with more than one thread, so the method
136    // is not public.
137    #[must_use]
138    #[allow(dead_code)]
139    fn threads(&mut self, threads: NonZeroUsize) -> &mut Self {
140        self.threads = threads;
141        self
142    }
143
144    /// Provide a function that will be called on each thread of the thread pool
145    /// immediately after it is spawned. This can, for example, be used to set
146    /// core affinity for each thread.
147    #[must_use]
148    pub fn on_thread_spawn<F>(&mut self, f: F) -> &mut Self
149    where
150        F: Fn() + Send + Sync + 'static,
151    {
152        self.on_thread_spawn = Some(Arc::new(f));
153        self
154    }
155
156    /// Provide a function that will be called on each thread of the thread pool
157    /// just before the thread finishes.
158    #[must_use]
159    pub fn on_thread_finish<F>(&mut self, f: F) -> &mut Self
160    where
161        F: Fn() + Send + Sync + 'static,
162    {
163        self.on_thread_finish = Some(Arc::new(f));
164        self
165    }
166
167    /// Builds a [`ThreadPool`].
168    pub fn build(&self) -> Result<ThreadPool> {
169        let (sender, receiver) = mpmc::bounded(self.capacity.get());
170
171        let mut threads = Vec::new();
172        threads.resize_with(self.threads.get(), || {
173            let receiver = receiver.clone();
174            let on_thread_spawn = self.on_thread_spawn.clone();
175            let on_thread_finish = self.on_thread_finish.clone();
176
177            Some(thread::spawn(move || {
178                if let Some(f) = on_thread_spawn {
179                    f();
180                }
181
182                Worker { receiver }.run();
183
184                if let Some(f) = on_thread_finish {
185                    f();
186                }
187            }))
188        });
189
190        Ok(ThreadPool(ArcSwapOption::new(Some(Arc::new(
191            ThreadPoolInner {
192                threads,
193                sender: Some(sender),
194            },
195        )))))
196    }
197}
198
199impl Worker {
200    fn run(&self) {
201        while let Ok(task) = self.receiver.recv() {
202            task.exec();
203        }
204    }
205}
206
207#[must_use]
208pub(crate) fn default_thread_pool() -> Arc<ThreadPool> {
209    static POOL_WEAK: Lazy<Mutex<Weak<ThreadPool>>> = Lazy::new(|| Mutex::new(Weak::new()));
210
211    let mut pool_weak = POOL_WEAK.lock_expect();
212
213    match pool_weak.upgrade() {
214        Some(pool) => pool,
215        None => {
216            let pool = Arc::new(ThreadPool::builder().build().unwrap());
217            *pool_weak = Arc::downgrade(&pool);
218            pool
219        }
220    }
221}