tanton_engine/threadpool/
mod.rs1use std::alloc::{alloc_zeroed, dealloc, Layout};
4use std::cell::UnsafeCell;
5use std::ptr::NonNull;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::Once;
8use std::thread::{self, JoinHandle};
9use std::{mem, ptr};
10
11use tanton::board::*;
12use tanton::core::piece_move::BitMove;
13use tanton::tools::tanton_arc::Arc;
14use tanton::MoveList;
15
16use search::Searcher;
17use sync::LockLatch;
18use time::uci_timer::*;
19
20use consts::*;
21
22const KILOBYTE: usize = 1000;
23const THREAD_STACK_SIZE: usize = 18000 * KILOBYTE;
24const POOL_SIZE: usize = mem::size_of::<ThreadPool>();
25
26type DummyThreadPool = [u8; POOL_SIZE];
28
29pub static mut THREADPOOL: DummyThreadPool = [0; POOL_SIZE];
33
34static THREADPOOL_INIT: Once = Once::new();
36
37#[cold]
39pub fn init_threadpool() {
40 THREADPOOL_INIT.call_once(|| {
41 unsafe {
42 let builder = thread::Builder::new()
45 .name("Starter".to_string())
46 .stack_size(THREAD_STACK_SIZE);
47
48 let handle = builder.spawn(|| {
49 let pool: *mut ThreadPool = mem::transmute(&mut THREADPOOL);
50 ptr::write(pool, ThreadPool::new());
51 });
52 handle.unwrap().join().unwrap();
53 }
54 });
55}
56
57#[inline(always)]
59pub fn threadpool() -> &'static mut ThreadPool {
60 unsafe { mem::transmute::<&mut DummyThreadPool, &'static mut ThreadPool>(&mut THREADPOOL) }
61}
62
63#[derive(Copy, Clone)]
64enum ThreadSelection {
65 Main,
66 NonMain,
67 All,
68}
69
70impl ThreadSelection {
71 #[inline(always)]
72 pub fn is_selection(self, id: usize) -> bool {
73 match self {
74 ThreadSelection::Main => id == 0,
75 ThreadSelection::NonMain => id != 0,
76 ThreadSelection::All => true,
77 }
78 }
79}
80
81struct SearcherPtr {
83 ptr: UnsafeCell<*mut Searcher>,
84}
85
86unsafe impl Sync for SearcherPtr {}
87unsafe impl Send for SearcherPtr {}
88
89pub struct ThreadPool {
91 pub threads: Vec<UnsafeCell<*mut Searcher>>,
93 handles: Vec<JoinHandle<()>>,
95 pub main_cond: Arc<LockLatch>,
97 pub thread_cond: Arc<LockLatch>,
99 pub stop: AtomicBool,
101}
102
103impl ThreadPool {
113 pub fn new() -> Self {
115 let mut pool: ThreadPool = ThreadPool {
116 threads: Vec::new(),
117 handles: Vec::new(),
118 main_cond: Arc::new(LockLatch::new()),
119 thread_cond: Arc::new(LockLatch::new()),
120 stop: AtomicBool::new(true),
121 };
122 pool.main_cond.lock();
124 pool.thread_cond.lock();
125
126 pool.attach_thread();
128 pool
129 }
130
131 fn attach_thread(&mut self) {
133 unsafe {
134 let thread_ptr: SearcherPtr = self.create_thread();
135 let builder = thread::Builder::new()
136 .name(self.size().to_string())
137 .stack_size(THREAD_STACK_SIZE);
138
139 let handle = builder
140 .spawn(move || {
141 let thread = &mut **thread_ptr.ptr.get();
142 thread.cond.lock();
143 thread.idle_loop();
144 })
145 .unwrap();
146 self.handles.push(handle);
147 };
148 }
149
150 fn create_thread(&mut self) -> SearcherPtr {
156 let len: usize = self.threads.len();
157 let layout = Layout::new::<Searcher>();
158 let cond = if len == 0 {
159 self.main_cond.clone()
160 } else {
161 self.thread_cond.clone()
162 };
163 unsafe {
164 let result = alloc_zeroed(layout);
165 let new_ptr: *mut Searcher = result.cast() as *mut Searcher;
166 ptr::write(new_ptr, Searcher::new(len, cond));
167 self.threads.push(UnsafeCell::new(new_ptr));
168 SearcherPtr {
169 ptr: UnsafeCell::new(new_ptr),
170 }
171 }
172 }
173
174 #[inline(always)]
176 pub fn size(&self) -> usize {
177 self.threads.len()
178 }
179
180 fn main(&mut self) -> &mut Searcher {
182 unsafe {
183 let main_thread: *mut Searcher = *self.threads.get_unchecked(0).get();
184 &mut *main_thread
185 }
186 }
187
188 #[inline(always)]
190 pub fn stdout(&mut self, use_stdout: bool) {
191 USE_STDOUT.store(use_stdout, Ordering::Relaxed);
192 }
193
194 pub fn set_thread_count(&mut self, mut num: usize) {
200 if num >= 1 {
201 num = num.min(MAX_THREADS);
202 self.wait_for_finish();
203 self.kill_all();
204
205 while self.size() < num {
206 self.attach_thread();
207 }
208 }
209 }
210
211 pub fn kill_all(&mut self) {
214 self.stop.store(true, Ordering::Relaxed);
215 self.wait_for_finish();
216 let mut join_handles = Vec::with_capacity(self.size());
217 unsafe {
218 self.threads
220 .iter()
221 .map(|s| &**s.get())
222 .for_each(|s: &Searcher| s.kill.store(true, Ordering::SeqCst));
223
224 self.threads
226 .iter()
227 .map(|s| &**s.get())
228 .for_each(|s: &Searcher| {
229 s.cond.set();
230 });
231
232 while let Some(handle) = self.handles.pop() {
235 join_handles.push(handle.join());
236 }
237
238 while let Some(unc) = self.threads.pop() {
240 let th: *mut Searcher = *unc.get();
241 let ptr: NonNull<u8> = mem::transmute(NonNull::new_unchecked(th));
242 let layout = Layout::new::<Searcher>();
243 dealloc(ptr.as_ptr(), layout);
244 }
245 }
246
247 while let Some(handle_result) = join_handles.pop() {
249 handle_result.unwrap_or_else(|e| println!("Thread failed: {:?}", e));
250 }
251 }
252
253 #[inline(always)]
255 pub fn set_stop(&mut self, stop: bool) {
256 self.stop.store(stop, Ordering::Relaxed);
257 }
258
259 pub fn wait_for_finish(&self) {
261 self.await_search_cond(ThreadSelection::All, false);
262 }
263
264 pub fn wait_for_start(&self) {
266 self.await_search_cond(ThreadSelection::All, true);
267 }
268
269 pub fn wait_for_non_main(&self) {
271 self.await_search_cond(ThreadSelection::NonMain, false);
272 }
273
274 pub fn wait_for_main_start(&self) {
276 self.await_search_cond(ThreadSelection::Main, true);
277 }
278
279 fn await_search_cond(&self, thread_sel: ThreadSelection, await_search: bool) {
280 self.threads
281 .iter()
282 .map(|s| unsafe { &**s.get() })
283 .filter(|t| thread_sel.is_selection(t.id))
284 .for_each(|t: &Searcher| {
285 t.searching.wait(await_search);
286 });
287 }
288
289 pub fn clear_all(&mut self) {
290 self.threads
291 .iter_mut()
292 .map(|thread_ptr| unsafe { &mut **(*thread_ptr).get() })
293 .for_each(|t| t.clear());
294 }
295
296 pub fn uci_search(&mut self, board: &Board, limits: &Limits) {
299 if let Some(uci_timer) = limits.use_time_management() {
301 timer().init(limits.start, &uci_timer, board.turn(), board.moves_played());
302 } else {
303 timer().start_timer(limits.start);
304 }
305
306 let root_moves: MoveList = board.generate_moves();
307
308 assert!(!root_moves.is_empty());
309 self.wait_for_finish();
310 self.stop.store(false, Ordering::Relaxed);
311
312 for thread_ptr in self.threads.iter_mut() {
313 let thread: &mut Searcher = unsafe { &mut **(*thread_ptr).get() };
314 thread.nodes.store(0, Ordering::Relaxed);
315 thread.depth_completed = 0;
316 thread.board = board.shallow_clone();
317 thread.limit = limits.clone();
318 thread.root_moves().replace(&root_moves);
319 }
320
321 self.main_cond.set();
322 self.wait_for_main_start();
323 self.main_cond.lock();
324 }
325
326 pub fn search(&mut self, board: &Board, limits: &Limits) -> BitMove {
328 self.uci_search(board, limits);
329 self.wait_for_finish();
330 self.best_move()
331 }
332
333 pub fn best_move(&mut self) -> BitMove {
335 self.main().root_moves().get(0).unwrap().bit_move
336 }
337
338 pub fn nodes(&self) -> u64 {
340 self.threads
341 .iter()
342 .map(|s| unsafe { &**s.get() })
343 .map(|s: &Searcher| s.nodes.load(Ordering::Relaxed))
344 .sum()
345 }
346}
347
348impl Drop for ThreadPool {
349 fn drop(&mut self) {
350 self.kill_all();
351 }
352}