1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
use std::thread::{self, JoinHandle};

use crossbeam::channel::{self as mpmc, Receiver, Sender};
use once_cell::sync::Lazy;

use crate::{
    error::{Error, InvalidArgumentError},
    sink::{OverflowPolicy, Task},
    sync::*,
    Result,
};

/// A thread pool for processing operations asynchronously.
///
/// Currently only used in [`AsyncPoolSink`].
///
/// # Examples
///
/// ```
/// # use std::sync::Arc;
/// use spdlog::{sink::AsyncPoolSink, ThreadPool};
///
/// # fn main() -> Result<(), spdlog::Error> {
/// # let underlying_sink = spdlog::default_logger().sinks().first().unwrap().clone();
/// let thread_pool: Arc<ThreadPool> = Arc::new(ThreadPool::new()?);
/// let async_pool_sink: AsyncPoolSink = AsyncPoolSink::builder()
///     .sink(underlying_sink)
///     .thread_pool(thread_pool)
///     .build()?;
/// # Ok(()) }
/// ```
///
/// [`AsyncPoolSink`]: crate::sink::AsyncPoolSink
pub struct ThreadPool {
    threads: Vec<Option<JoinHandle<()>>>,
    sender: Option<Sender<Task>>,
}

/// The builder of [`ThreadPool`].
pub struct ThreadPoolBuilder {
    capacity: usize,
    threads: usize,
}

struct Worker {
    receiver: Receiver<Task>,
}

impl ThreadPool {
    /// Constructs a builder of `ThreadPool`.
    #[must_use]
    pub fn builder() -> ThreadPoolBuilder {
        ThreadPoolBuilder {
            capacity: 8192,
            threads: 1,
        }
    }

    /// Constructs a `ThreadPool` with default parameters.
    pub fn new() -> Result<Self> {
        Self::builder().build()
    }

    pub(super) fn assign_task(&self, task: Task, overflow_policy: OverflowPolicy) -> Result<()> {
        let sender = self.sender.as_ref().unwrap();

        match overflow_policy {
            OverflowPolicy::Block => sender.send(task).map_err(Error::from_crossbeam_send),
            OverflowPolicy::DropIncoming => sender
                .try_send(task)
                .map_err(Error::from_crossbeam_try_send),
        }
    }
}

impl Drop for ThreadPool {
    fn drop(&mut self) {
        // drop our sender, threads will break the loop after receiving and processing
        // the remaining tasks
        self.sender.take();

        for thread in &mut self.threads {
            thread
                .take()
                .unwrap()
                .join()
                .expect("failed to join a thread from pool");
        }
    }
}

impl ThreadPoolBuilder {
    /// Specifies the capacity of the operation channel.
    ///
    /// This parameter is **optional**, and defaults to 8192 (The value may
    /// change in the future).
    ///
    /// When a new operation is incoming, but the channel is full, it will be
    /// handled by sink according to the [`OverflowPolicy`] that has been set.
    ///
    /// # Panics
    ///
    /// Panics if the value is zero.
    pub fn capacity(&mut self, capacity: usize) -> &mut Self {
        self.capacity = capacity;
        self
    }

    // The current Sinks are not beneficial with more than one thread, so the method
    // is not public.
    //
    // If it is ready to be made public in the future, please don't forget to
    // replace the `panic!` in the `build` function with a recoverable error.
    #[allow(dead_code)]
    fn threads(&mut self, threads: usize) -> &mut Self {
        self.threads = threads;
        self
    }

    /// Builds a [`ThreadPool`].
    pub fn build(&self) -> Result<ThreadPool> {
        if self.capacity < 1 {
            return Err(Error::InvalidArgument(
                InvalidArgumentError::ThreadPoolCapacity("cannot be 0".to_string()),
            ));
        }

        if self.threads < 1 {
            // Users cannot currently configure this value, so `panic!` is not a problem
            // here.
            panic!("threads of ThreadPool cannot be 0");
        }

        let (sender, receiver) = mpmc::bounded(self.capacity);

        let mut threads = Vec::new();
        threads.resize_with(self.threads, || {
            let receiver = receiver.clone();
            Some(thread::spawn(move || Worker { receiver }.run()))
        });

        Ok(ThreadPool {
            threads,
            sender: Some(sender),
        })
    }
}

impl Worker {
    fn run(&self) {
        while let Ok(task) = self.receiver.recv() {
            task.exec();
        }
    }
}

#[must_use]
pub(crate) fn default_thread_pool() -> Arc<ThreadPool> {
    static POOL_WEAK: Lazy<Mutex<Weak<ThreadPool>>> = Lazy::new(|| Mutex::new(Weak::new()));

    let mut pool_weak = POOL_WEAK.lock_expect();

    match pool_weak.upgrade() {
        Some(pool) => pool,
        None => {
            let pool = Arc::new(ThreadPool::builder().build().unwrap());
            *pool_weak = Arc::downgrade(&pool);
            pool
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn panic_capacity_0() {
        assert!(matches!(
            ThreadPool::builder().capacity(0).build(),
            Err(Error::InvalidArgument(
                InvalidArgumentError::ThreadPoolCapacity(_)
            ))
        ));
    }

    #[test]
    #[should_panic]
    fn panic_thread_0() {
        let _ = ThreadPool::builder().threads(0).build();
    }
}