unknownrori_simple_thread_pool/
lib.rs

1pub mod error;
2
3mod message;
4mod worker;
5
6#[cfg(feature = "crossbeam")]
7pub use crossbeam_channel;
8
9#[cfg(feature = "crossbeam")]
10use crossbeam_channel::{unbounded, Sender};
11
12#[cfg(feature = "mpsc")]
13use std::sync::mpsc::{channel, Sender};
14
15#[cfg(feature = "mpsc")]
16use std::sync::{Arc, Mutex};
17
18use error::{FailedToSendJob, FailedToSpawnThread};
19use message::Message;
20use worker::Worker;
21
22type Job = Box<dyn FnOnce() + Send + 'static>;
23
24/// This is where the thread will be pooled
25///
26/// It depend on how you add this package on your project
27/// you can either using Rust standard library
28/// or you can use `crossbeam-channel`, the API is the same even on different feature flag.
29///
30/// ## Examples
31///
32/// ```rust,no_run
33/// use std::{
34///     io::Write,
35///     net::{TcpListener, TcpStream},
36///     thread,
37///     time::Duration,
38/// };
39///
40/// use unknownrori_simple_thread_pool::{error::FailedToSendJob, ThreadPool};
41///
42/// fn handle_connection(mut stream: TcpStream) {
43///     thread::sleep(Duration::from_secs(2));
44///
45///     let response = "HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nHi!";
46///
47///     stream.write_all(response.as_bytes()).unwrap();
48///
49///     thread::sleep(Duration::from_secs(2));
50/// }
51///
52/// fn main() -> Result<(), FailedToSendJob> {
53///     let pool = ThreadPool::new(2).unwrap();
54///
55///     let socket = TcpListener::bind("127.0.0.1:8000").unwrap();
56///     println!("server started at http://127.0.0.1:8000");
57///
58///     for stream in socket.incoming() {
59///         println!("Got stream!");
60///         match stream {
61///             Ok(stream) => pool.execute(|| handle_connection(stream))?,
62///             Err(_) => eprintln!("Something is wrong!"),
63///         }
64///     }
65///
66///     Ok(())
67/// }
68/// ```
69#[derive(Debug)]
70pub struct ThreadPool {
71    sender: Sender<Message>,
72    workers: Vec<Worker>,
73}
74
75impl ThreadPool {
76    /// Creates a new [`ThreadPool`], with passed worker args for how many worker thread to be created
77    ///
78    /// ## Examples
79    ///
80    /// ```rust,no_run
81    /// use std::{thread, time::Duration};
82    ///
83    /// use unknownrori_simple_thread_pool::{
84    ///     crossbeam_channel::unbounded,
85    ///     error::FailedToSendJob,
86    ///     ThreadPool,
87    /// };
88    ///
89    /// fn main() -> Result<(), FailedToSendJob> {
90    ///     let pool = ThreadPool::new(2).unwrap();
91    ///     let (send, recv) = unbounded();
92    ///
93    ///     pool.execute(move || {
94    ///         send.send(40).unwrap();
95    ///     })?;
96    ///
97    ///     assert_eq!(recv.recv().unwrap(), 40);
98    ///
99    ///     Ok(())
100    /// }
101    /// ```
102    ///
103    /// ## Error
104    ///
105    /// It will return an [`Err`] if cannot create thread worker
106    #[cfg(feature = "crossbeam")]
107    pub fn new(worker: usize) -> Result<ThreadPool, FailedToSpawnThread> {
108        let workers = Vec::with_capacity(worker);
109
110        let (sender, receiver) = unbounded();
111
112        let mut threadpool = ThreadPool { workers, sender };
113        for _ in 0..worker {
114            let thread_builder = std::thread::Builder::new();
115
116            let worker = Worker::new(receiver.clone(), thread_builder)
117                .or_else(|_| Err(FailedToSpawnThread))?;
118
119            threadpool.workers.push(worker);
120        }
121
122        Ok(threadpool)
123    }
124
125    /// Creates a new [`ThreadPool`], with passed worker args for how many worker thread to be created
126    ///
127    /// ## Examples
128    ///
129    /// ```rust,no_run
130    /// use std::sync::mpsc::channel;
131    /// use std::{thread, time::Duration};
132    ///
133    /// use unknownrori_simple_thread_pool::{error::FailedToSendJob, ThreadPool};
134    ///
135    /// fn main() -> Result<(), FailedToSendJob> {
136    ///     let pool = ThreadPool::new(2).unwrap();
137    ///     let (send, recv) = channel();
138    ///
139    ///     pool.execute(move || {
140    ///         send.send(40).unwrap();
141    ///     })?;
142    ///
143    ///     assert_eq!(recv.recv().unwrap(), 40);
144    ///
145    ///     Ok(())
146    /// }
147    /// ```
148    ///
149    /// ## Error
150    ///
151    /// It will return an [`Err`] if cannot create thread worker
152    #[cfg(feature = "mpsc")]
153    pub fn new(worker: usize) -> Result<ThreadPool, FailedToSpawnThread> {
154        let workers = Vec::with_capacity(worker);
155
156        let (sender, receiver) = channel();
157        let receiver = Arc::new(Mutex::new(receiver));
158
159        let mut threadpool = ThreadPool { sender, workers };
160        for _ in 0..worker {
161            let thread_builder = std::thread::Builder::new();
162
163            let worker = Worker::new(Arc::clone(&receiver), thread_builder)
164                .or_else(|_| Err(FailedToSpawnThread))?;
165
166            threadpool.workers.push(worker);
167        }
168
169        Ok(threadpool)
170    }
171
172    /// Execute a job to worker thread, it's require Closure with no param and no return
173    ///
174    /// ## Errors
175    ///
176    /// This function will return an [`Err`] if the communication channel between worker thread
177    /// and main thread is closed.
178    pub fn execute<F>(&self, job: F) -> Result<(), FailedToSendJob>
179    where
180        F: FnOnce() + Send + 'static,
181    {
182        self.sender
183            .send(Message::NewJob(Box::new(job)))
184            .or_else(|_| Err(FailedToSendJob))?;
185
186        Ok(())
187    }
188}
189
190impl Drop for ThreadPool {
191    /// Make sure the [`ThreadPool`] do proper clean up with it's thread workers
192    ///
193    /// ## Panic
194    ///
195    /// May Panic if communcation between worker thread and main thread is closed
196    /// or there are panic in worker thread.
197    fn drop(&mut self) {
198        for _ in &self.workers {
199            self.sender.send(Message::Terminate).unwrap();
200        }
201
202        for worker in &mut self.workers {
203            if let Some(thread) = worker.take_thread() {
204                thread.join().unwrap();
205            }
206        }
207    }
208}