rust_rcs_core/util/
thread_pool.rs

1// Copyright 2023 宋昊文
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// duplicating code from <The Rust Programming Language>
16
17use std::sync::mpsc;
18use std::sync::Arc;
19use std::sync::Mutex;
20use std::thread;
21
22pub struct ThreadPool {
23    workers: Vec<Worker>,
24    sender: mpsc::Sender<Message>,
25}
26
27type Job = Box<dyn FnOnce() + Send + 'static>;
28
29enum Message {
30    NewJob(Job),
31    Terminate,
32}
33
34impl ThreadPool {
35    /// Create a new ThreadPool.
36    ///
37    /// The size is the number of threads in the pool.
38    ///
39    /// # Panics
40    ///
41    /// The `new` function will panic if the size is zero.
42    pub fn new(size: usize) -> ThreadPool {
43        assert!(size > 0);
44
45        let (sender, receiver) = mpsc::channel();
46
47        let receiver = Arc::new(Mutex::new(receiver));
48
49        let mut workers = Vec::with_capacity(size);
50
51        for id in 0..size {
52            workers.push(Worker::new(id, Arc::clone(&receiver)));
53        }
54
55        ThreadPool { workers, sender }
56    }
57
58    pub fn execute<F>(&self, f: F)
59    where
60        F: FnOnce() + Send + 'static,
61    {
62        let job = Box::new(f);
63
64        self.sender.send(Message::NewJob(job)).unwrap();
65    }
66}
67
68impl Drop for ThreadPool {
69    fn drop(&mut self) {
70        println!("Sending terminate message to all workers.");
71
72        for _ in &self.workers {
73            self.sender.send(Message::Terminate).unwrap();
74        }
75
76        println!("Shutting down all workers.");
77
78        for worker in &mut self.workers {
79            println!("Shutting down worker {}", worker.id);
80
81            if let Some(thread) = worker.thread.take() {
82                thread.join().unwrap();
83            }
84        }
85    }
86}
87
88struct Worker {
89    id: usize,
90    thread: Option<thread::JoinHandle<()>>,
91}
92
93impl Worker {
94    fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Message>>>) -> Worker {
95        let thread = thread::spawn(move || loop {
96            let message = receiver.lock().unwrap().recv().unwrap();
97
98            match message {
99                Message::NewJob(job) => {
100                    println!("Worker {} got a job; executing.", id);
101
102                    job();
103                }
104                Message::Terminate => {
105                    println!("Worker {} was told to terminate.", id);
106
107                    break;
108                }
109            }
110        });
111
112        Worker {
113            id,
114            thread: Some(thread),
115        }
116    }
117}