tanton_engine/threadpool/
mod.rs

1//! Contains the ThreadPool and the individual Threads.
2
3use std::alloc::{alloc_zeroed, dealloc, Layout};
4use std::cell::UnsafeCell;
5use std::ptr::NonNull;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::Once;
8use std::thread::{self, JoinHandle};
9use std::{mem, ptr};
10
11use tanton::board::*;
12use tanton::core::piece_move::BitMove;
13use tanton::tools::tanton_arc::Arc;
14use tanton::MoveList;
15
16use search::Searcher;
17use sync::LockLatch;
18use time::uci_timer::*;
19
20use consts::*;
21
22const KILOBYTE: usize = 1000;
23const THREAD_STACK_SIZE: usize = 18000 * KILOBYTE;
24const POOL_SIZE: usize = mem::size_of::<ThreadPool>();
25
26// An object that is the same size as a thread pool.
27type DummyThreadPool = [u8; POOL_SIZE];
28
29// The Global threadpool! Yes, this is *technically* an array the same
30// size as a ThreadPool object. This is a cheap hack to get a global value, as
31// Rust isn't particularily fond of mutable global statics.
32pub static mut THREADPOOL: DummyThreadPool = [0; POOL_SIZE];
33
34// ONCE for the Threadpool
35static THREADPOOL_INIT: Once = Once::new();
36
37// Initializes the threadpool, called once on startup.
38#[cold]
39pub fn init_threadpool() {
40    THREADPOOL_INIT.call_once(|| {
41        unsafe {
42            // We have a spawned thread create all structures, as a stack overflow can
43            // occur otherwise
44            let builder = thread::Builder::new()
45                .name("Starter".to_string())
46                .stack_size(THREAD_STACK_SIZE);
47
48            let handle = builder.spawn(|| {
49                let pool: *mut ThreadPool = mem::transmute(&mut THREADPOOL);
50                ptr::write(pool, ThreadPool::new());
51            });
52            handle.unwrap().join().unwrap();
53        }
54    });
55}
56
57/// Returns access to the global thread pool.
58#[inline(always)]
59pub fn threadpool() -> &'static mut ThreadPool {
60    unsafe { mem::transmute::<&mut DummyThreadPool, &'static mut ThreadPool>(&mut THREADPOOL) }
61}
62
63#[derive(Copy, Clone)]
64enum ThreadSelection {
65    Main,
66    NonMain,
67    All,
68}
69
70impl ThreadSelection {
71    #[inline(always)]
72    pub fn is_selection(self, id: usize) -> bool {
73        match self {
74            ThreadSelection::Main => id == 0,
75            ThreadSelection::NonMain => id != 0,
76            ThreadSelection::All => true,
77        }
78    }
79}
80
81// Dummy struct to allow us to pass a pointer into a spawned thread.
82struct SearcherPtr {
83    ptr: UnsafeCell<*mut Searcher>,
84}
85
86unsafe impl Sync for SearcherPtr {}
87unsafe impl Send for SearcherPtr {}
88
89/// The thread-pool for the chess engine.
90pub struct ThreadPool {
91    /// Access to each thread's Structure
92    pub threads: Vec<UnsafeCell<*mut Searcher>>,
93    /// Handles of each thread
94    handles: Vec<JoinHandle<()>>,
95    /// Condition for the main thread to start.
96    pub main_cond: Arc<LockLatch>,
97    /// Condition for all non-main threads
98    pub thread_cond: Arc<LockLatch>,
99    /// Stop condition, if true the threads should halt.
100    pub stop: AtomicBool,
101}
102
103// Okay, this all looks like madness, but there is some reason to it all.
104// Basically, `ThreadPool` manages spawning and despawning threads, as well
105// as passing state to / from those threads, telling them to stop, go, drop,
106// and lastly determining the "best move" from all the threads.
107///
108// While we spawn all the other threads, We mostly communicate with the
109// MainThread to do anything useful. The mainthread handles anything fun.
110// The goal of the ThreadPool is to be NON BLOCKING, unless we want to await a
111// result.
112impl ThreadPool {
113    /// Creates a new `ThreadPool`
114    pub fn new() -> Self {
115        let mut pool: ThreadPool = ThreadPool {
116            threads: Vec::new(),
117            handles: Vec::new(),
118            main_cond: Arc::new(LockLatch::new()),
119            thread_cond: Arc::new(LockLatch::new()),
120            stop: AtomicBool::new(true),
121        };
122        // Lock both the cond variables
123        pool.main_cond.lock();
124        pool.thread_cond.lock();
125
126        // spawn the main thread
127        pool.attach_thread();
128        pool
129    }
130
131    /// Spawns a new thread and appends it to our vector of JoinHandles.
132    fn attach_thread(&mut self) {
133        unsafe {
134            let thread_ptr: SearcherPtr = self.create_thread();
135            let builder = thread::Builder::new()
136                .name(self.size().to_string())
137                .stack_size(THREAD_STACK_SIZE);
138
139            let handle = builder
140                .spawn(move || {
141                    let thread = &mut **thread_ptr.ptr.get();
142                    thread.cond.lock();
143                    thread.idle_loop();
144                })
145                .unwrap();
146            self.handles.push(handle);
147        };
148    }
149
150    /// Allocates a thread structure and pushes it to the threadstack.
151    ///
152    /// This does not spawn a thread, just creates the structure in the heap for the thread.
153    ///
154    /// Only to be called by `attach_thread()`.
155    fn create_thread(&mut self) -> SearcherPtr {
156        let len: usize = self.threads.len();
157        let layout = Layout::new::<Searcher>();
158        let cond = if len == 0 {
159            self.main_cond.clone()
160        } else {
161            self.thread_cond.clone()
162        };
163        unsafe {
164            let result = alloc_zeroed(layout);
165            let new_ptr: *mut Searcher = result.cast() as *mut Searcher;
166            ptr::write(new_ptr, Searcher::new(len, cond));
167            self.threads.push(UnsafeCell::new(new_ptr));
168            SearcherPtr {
169                ptr: UnsafeCell::new(new_ptr),
170            }
171        }
172    }
173
174    /// Returns the number of threads
175    #[inline(always)]
176    pub fn size(&self) -> usize {
177        self.threads.len()
178    }
179
180    /// Returns a pointer to the main thread.
181    fn main(&mut self) -> &mut Searcher {
182        unsafe {
183            let main_thread: *mut Searcher = *self.threads.get_unchecked(0).get();
184            &mut *main_thread
185        }
186    }
187
188    /// Sets the use of standard out. This can be changed mid search as well.
189    #[inline(always)]
190    pub fn stdout(&mut self, use_stdout: bool) {
191        USE_STDOUT.store(use_stdout, Ordering::Relaxed);
192    }
193
194    /// Sets the thread count of the pool. If num is less than 1, nothing will happen.
195    ///
196    /// # Safety
197    ///
198    /// Completely unsafe to use when the pool is searching.
199    pub fn set_thread_count(&mut self, mut num: usize) {
200        if num >= 1 {
201            num = num.min(MAX_THREADS);
202            self.wait_for_finish();
203            self.kill_all();
204
205            while self.size() < num {
206                self.attach_thread();
207            }
208        }
209    }
210
211    /// Kills and de-allocates all the threads that are running. This function will also
212    /// block on waiting for the search to finish.
213    pub fn kill_all(&mut self) {
214        self.stop.store(true, Ordering::Relaxed);
215        self.wait_for_finish();
216        let mut join_handles = Vec::with_capacity(self.size());
217        unsafe {
218            // tell each thread to drop
219            self.threads
220                .iter()
221                .map(|s| &**s.get())
222                .for_each(|s: &Searcher| s.kill.store(true, Ordering::SeqCst));
223
224            // If awaiting a signal, wake up each thread so each can drop
225            self.threads
226                .iter()
227                .map(|s| &**s.get())
228                .for_each(|s: &Searcher| {
229                    s.cond.set();
230                });
231
232            // Start connecting each join handle. We don't unwrap here, as if one of the
233            // threads fail, the other threads remain un-joined.
234            while let Some(handle) = self.handles.pop() {
235                join_handles.push(handle.join());
236            }
237
238            // De-allocate each thread.
239            while let Some(unc) = self.threads.pop() {
240                let th: *mut Searcher = *unc.get();
241                let ptr: NonNull<u8> = mem::transmute(NonNull::new_unchecked(th));
242                let layout = Layout::new::<Searcher>();
243                dealloc(ptr.as_ptr(), layout);
244            }
245        }
246
247        // Unwrap the results from each `thread::join`,
248        while let Some(handle_result) = join_handles.pop() {
249            handle_result.unwrap_or_else(|e| println!("Thread failed: {:?}", e));
250        }
251    }
252
253    /// Sets the threads to stop (or not!).
254    #[inline(always)]
255    pub fn set_stop(&mut self, stop: bool) {
256        self.stop.store(stop, Ordering::Relaxed);
257    }
258
259    /// Waits for all the threads to finish
260    pub fn wait_for_finish(&self) {
261        self.await_search_cond(ThreadSelection::All, false);
262    }
263
264    /// Waits for all the threads to start.
265    pub fn wait_for_start(&self) {
266        self.await_search_cond(ThreadSelection::All, true);
267    }
268
269    /// Waits for all non-main threads to finish.
270    pub fn wait_for_non_main(&self) {
271        self.await_search_cond(ThreadSelection::NonMain, false);
272    }
273
274    /// Waits for all the non-main threads to start running
275    pub fn wait_for_main_start(&self) {
276        self.await_search_cond(ThreadSelection::Main, true);
277    }
278
279    fn await_search_cond(&self, thread_sel: ThreadSelection, await_search: bool) {
280        self.threads
281            .iter()
282            .map(|s| unsafe { &**s.get() })
283            .filter(|t| thread_sel.is_selection(t.id))
284            .for_each(|t: &Searcher| {
285                t.searching.wait(await_search);
286            });
287    }
288
289    pub fn clear_all(&mut self) {
290        self.threads
291            .iter_mut()
292            .map(|thread_ptr| unsafe { &mut **(*thread_ptr).get() })
293            .for_each(|t| t.clear());
294    }
295
296    /// Starts a UCI search. The result will be printed to stdout if the stdout setting
297    /// is true.
298    pub fn uci_search(&mut self, board: &Board, limits: &Limits) {
299        // Start the timer!
300        if let Some(uci_timer) = limits.use_time_management() {
301            timer().init(limits.start, &uci_timer, board.turn(), board.moves_played());
302        } else {
303            timer().start_timer(limits.start);
304        }
305
306        let root_moves: MoveList = board.generate_moves();
307
308        assert!(!root_moves.is_empty());
309        self.wait_for_finish();
310        self.stop.store(false, Ordering::Relaxed);
311
312        for thread_ptr in self.threads.iter_mut() {
313            let thread: &mut Searcher = unsafe { &mut **(*thread_ptr).get() };
314            thread.nodes.store(0, Ordering::Relaxed);
315            thread.depth_completed = 0;
316            thread.board = board.shallow_clone();
317            thread.limit = limits.clone();
318            thread.root_moves().replace(&root_moves);
319        }
320
321        self.main_cond.set();
322        self.wait_for_main_start();
323        self.main_cond.lock();
324    }
325
326    /// Performs a standard search, and blocks waiting for a returned `BitMove`.
327    pub fn search(&mut self, board: &Board, limits: &Limits) -> BitMove {
328        self.uci_search(board, limits);
329        self.wait_for_finish();
330        self.best_move()
331    }
332
333    /// Returns the best move of a search
334    pub fn best_move(&mut self) -> BitMove {
335        self.main().root_moves().get(0).unwrap().bit_move
336    }
337
338    /// Returns total number of nodes searched so far.
339    pub fn nodes(&self) -> u64 {
340        self.threads
341            .iter()
342            .map(|s| unsafe { &**s.get() })
343            .map(|s: &Searcher| s.nodes.load(Ordering::Relaxed))
344            .sum()
345    }
346}
347
348impl Drop for ThreadPool {
349    fn drop(&mut self) {
350        self.kill_all();
351    }
352}