rayon_core/
registry.rs

1use crate::job::{JobFifo, JobRef, StackJob};
2use crate::latch::{AsCoreLatch, CoreLatch, Latch, LatchRef, LockLatch, OnceLatch, SpinLatch};
3use crate::sleep::Sleep;
4use crate::tlv::Tlv;
5use crate::unwind;
6use crate::{
7    AcquireThreadHandler, DeadlockHandler, ErrorKind, ExitHandler, PanicHandler,
8    ReleaseThreadHandler, StartHandler, ThreadPoolBuildError, ThreadPoolBuilder, Yield,
9};
10use crossbeam_deque::{Injector, Steal, Stealer, Worker};
11use std::cell::Cell;
12use std::collections::hash_map::DefaultHasher;
13use std::fmt;
14use std::hash::Hasher;
15use std::io;
16use std::mem;
17use std::ptr;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::{Arc, Mutex, Once};
20use std::thread;
21
22/// Thread builder used for customization via
23/// [`ThreadPoolBuilder::spawn_handler`](struct.ThreadPoolBuilder.html#method.spawn_handler).
24pub struct ThreadBuilder {
25    name: Option<String>,
26    stack_size: Option<usize>,
27    worker: Worker<JobRef>,
28    stealer: Stealer<JobRef>,
29    registry: Arc<Registry>,
30    index: usize,
31}
32
33impl ThreadBuilder {
34    /// Gets the index of this thread in the pool, within `0..num_threads`.
35    pub fn index(&self) -> usize {
36        self.index
37    }
38
39    /// Gets the string that was specified by `ThreadPoolBuilder::name()`.
40    pub fn name(&self) -> Option<&str> {
41        self.name.as_deref()
42    }
43
44    /// Gets the value that was specified by `ThreadPoolBuilder::stack_size()`.
45    pub fn stack_size(&self) -> Option<usize> {
46        self.stack_size
47    }
48
49    /// Executes the main loop for this thread. This will not return until the
50    /// thread pool is dropped.
51    pub fn run(self) {
52        unsafe { main_loop(self) }
53    }
54}
55
56impl fmt::Debug for ThreadBuilder {
57    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58        f.debug_struct("ThreadBuilder")
59            .field("pool", &self.registry.id())
60            .field("index", &self.index)
61            .field("name", &self.name)
62            .field("stack_size", &self.stack_size)
63            .finish()
64    }
65}
66
67/// Generalized trait for spawning a thread in the `Registry`.
68///
69/// This trait is pub-in-private -- E0445 forces us to make it public,
70/// but we don't actually want to expose these details in the API.
71pub trait ThreadSpawn {
72    private_decl! {}
73
74    /// Spawn a thread with the `ThreadBuilder` parameters, and then
75    /// call `ThreadBuilder::run()`.
76    fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()>;
77}
78
79/// Spawns a thread in the "normal" way with `std::thread::Builder`.
80///
81/// This type is pub-in-private -- E0445 forces us to make it public,
82/// but we don't actually want to expose these details in the API.
83#[derive(Debug, Default)]
84pub struct DefaultSpawn;
85
86impl ThreadSpawn for DefaultSpawn {
87    private_impl! {}
88
89    fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> {
90        let mut b = thread::Builder::new();
91        if let Some(name) = thread.name() {
92            b = b.name(name.to_owned());
93        }
94        if let Some(stack_size) = thread.stack_size() {
95            b = b.stack_size(stack_size);
96        }
97        b.spawn(|| thread.run())?;
98        Ok(())
99    }
100}
101
102/// Spawns a thread with a user's custom callback.
103///
104/// This type is pub-in-private -- E0445 forces us to make it public,
105/// but we don't actually want to expose these details in the API.
106#[derive(Debug)]
107pub struct CustomSpawn<F>(F);
108
109impl<F> CustomSpawn<F>
110where
111    F: FnMut(ThreadBuilder) -> io::Result<()>,
112{
113    pub(super) fn new(spawn: F) -> Self {
114        CustomSpawn(spawn)
115    }
116}
117
118impl<F> ThreadSpawn for CustomSpawn<F>
119where
120    F: FnMut(ThreadBuilder) -> io::Result<()>,
121{
122    private_impl! {}
123
124    #[inline]
125    fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> {
126        (self.0)(thread)
127    }
128}
129
130pub struct Registry {
131    thread_infos: Vec<ThreadInfo>,
132    sleep: Sleep,
133    injected_jobs: Injector<JobRef>,
134    broadcasts: Mutex<Vec<Worker<JobRef>>>,
135    panic_handler: Option<Box<PanicHandler>>,
136    pub(crate) deadlock_handler: Option<Box<DeadlockHandler>>,
137    start_handler: Option<Box<StartHandler>>,
138    exit_handler: Option<Box<ExitHandler>>,
139    pub(crate) acquire_thread_handler: Option<Box<AcquireThreadHandler>>,
140    pub(crate) release_thread_handler: Option<Box<ReleaseThreadHandler>>,
141
142    // When this latch reaches 0, it means that all work on this
143    // registry must be complete. This is ensured in the following ways:
144    //
145    // - if this is the global registry, there is a ref-count that never
146    //   gets released.
147    // - if this is a user-created thread-pool, then so long as the thread-pool
148    //   exists, it holds a reference.
149    // - when we inject a "blocking job" into the registry with `ThreadPool::install()`,
150    //   no adjustment is needed; the `ThreadPool` holds the reference, and since we won't
151    //   return until the blocking job is complete, that ref will continue to be held.
152    // - when `join()` or `scope()` is invoked, similarly, no adjustments are needed.
153    //   These are always owned by some other job (e.g., one injected by `ThreadPool::install()`)
154    //   and that job will keep the pool alive.
155    terminate_count: AtomicUsize,
156}
157
158/// ////////////////////////////////////////////////////////////////////////
159/// Initialization
160
161static mut THE_REGISTRY: Option<Arc<Registry>> = None;
162static THE_REGISTRY_SET: Once = Once::new();
163
164/// Starts the worker threads (if that has not already happened). If
165/// initialization has not already occurred, use the default
166/// configuration.
167pub(super) fn global_registry() -> &'static Arc<Registry> {
168    set_global_registry(default_global_registry)
169        .or_else(|err| {
170            // SAFETY: we only create a shared reference to `THE_REGISTRY` after the `call_once`
171            // that initializes it, and there will be no more mutable accesses at all.
172            debug_assert!(THE_REGISTRY_SET.is_completed());
173            let the_registry = unsafe { &*ptr::addr_of!(THE_REGISTRY) };
174            the_registry.as_ref().ok_or(err)
175        })
176        .expect("The global thread pool has not been initialized.")
177}
178
179/// Starts the worker threads (if that has not already happened) with
180/// the given builder.
181pub(super) fn init_global_registry<S>(
182    builder: ThreadPoolBuilder<S>,
183) -> Result<&'static Arc<Registry>, ThreadPoolBuildError>
184where
185    S: ThreadSpawn,
186{
187    set_global_registry(|| Registry::new(builder))
188}
189
190/// Starts the worker threads (if that has not already happened)
191/// by creating a registry with the given callback.
192fn set_global_registry<F>(registry: F) -> Result<&'static Arc<Registry>, ThreadPoolBuildError>
193where
194    F: FnOnce() -> Result<Arc<Registry>, ThreadPoolBuildError>,
195{
196    let mut result = Err(ThreadPoolBuildError::new(
197        ErrorKind::GlobalPoolAlreadyInitialized,
198    ));
199
200    THE_REGISTRY_SET.call_once(|| {
201        result = registry().map(|registry: Arc<Registry>| {
202            // SAFETY: this is the only mutable access to `THE_REGISTRY`, thanks to `Once`, and
203            // `global_registry()` only takes a shared reference **after** this `call_once`.
204            unsafe {
205                ptr::addr_of_mut!(THE_REGISTRY).write(Some(registry));
206                (*ptr::addr_of!(THE_REGISTRY)).as_ref().unwrap_unchecked()
207            }
208        })
209    });
210
211    result
212}
213
214fn default_global_registry() -> Result<Arc<Registry>, ThreadPoolBuildError> {
215    let result = Registry::new(ThreadPoolBuilder::new());
216
217    // If we're running in an environment that doesn't support threads at all, we can fall back to
218    // using the current thread alone. This is crude, and probably won't work for non-blocking
219    // calls like `spawn` or `broadcast_spawn`, but a lot of stuff does work fine.
220    //
221    // Notably, this allows current WebAssembly targets to work even though their threading support
222    // is stubbed out, and we won't have to change anything if they do add real threading.
223    let unsupported = matches!(&result, Err(e) if e.is_unsupported());
224    if unsupported && WorkerThread::current().is_null() {
225        let builder = ThreadPoolBuilder::new()
226            .num_threads(1)
227            .spawn_handler(|thread| {
228                // Rather than starting a new thread, we're just taking over the current thread
229                // *without* running the main loop, so we can still return from here.
230                // The WorkerThread is leaked, but we never shutdown the global pool anyway.
231                let worker_thread = Box::leak(Box::new(WorkerThread::from(thread)));
232                let registry = &*worker_thread.registry;
233                let index = worker_thread.index;
234
235                unsafe {
236                    WorkerThread::set_current(worker_thread);
237
238                    // let registry know we are ready to do work
239                    Latch::set(&registry.thread_infos[index].primed);
240                }
241
242                Ok(())
243            });
244
245        let fallback_result = Registry::new(builder);
246        if fallback_result.is_ok() {
247            return fallback_result;
248        }
249    }
250
251    result
252}
253
254struct Terminator<'a>(&'a Arc<Registry>);
255
256impl<'a> Drop for Terminator<'a> {
257    fn drop(&mut self) {
258        self.0.terminate()
259    }
260}
261
262impl Registry {
263    pub(super) fn new<S>(
264        mut builder: ThreadPoolBuilder<S>,
265    ) -> Result<Arc<Self>, ThreadPoolBuildError>
266    where
267        S: ThreadSpawn,
268    {
269        // Soft-limit the number of threads that we can actually support.
270        let n_threads = Ord::min(builder.get_num_threads(), crate::max_num_threads());
271
272        let breadth_first = builder.get_breadth_first();
273
274        let (workers, stealers): (Vec<_>, Vec<_>) = (0..n_threads)
275            .map(|_| {
276                let worker = if breadth_first {
277                    Worker::new_fifo()
278                } else {
279                    Worker::new_lifo()
280                };
281
282                let stealer = worker.stealer();
283                (worker, stealer)
284            })
285            .unzip();
286
287        let (broadcasts, broadcast_stealers): (Vec<_>, Vec<_>) = (0..n_threads)
288            .map(|_| {
289                let worker = Worker::new_fifo();
290                let stealer = worker.stealer();
291                (worker, stealer)
292            })
293            .unzip();
294
295        let registry = Arc::new(Registry {
296            thread_infos: stealers.into_iter().map(ThreadInfo::new).collect(),
297            sleep: Sleep::new(n_threads),
298            injected_jobs: Injector::new(),
299            broadcasts: Mutex::new(broadcasts),
300            terminate_count: AtomicUsize::new(1),
301            panic_handler: builder.take_panic_handler(),
302            deadlock_handler: builder.take_deadlock_handler(),
303            start_handler: builder.take_start_handler(),
304            exit_handler: builder.take_exit_handler(),
305            acquire_thread_handler: builder.take_acquire_thread_handler(),
306            release_thread_handler: builder.take_release_thread_handler(),
307        });
308
309        // If we return early or panic, make sure to terminate existing threads.
310        let t1000 = Terminator(&registry);
311
312        for (index, (worker, stealer)) in workers.into_iter().zip(broadcast_stealers).enumerate() {
313            let thread = ThreadBuilder {
314                name: builder.get_thread_name(index),
315                stack_size: builder.get_stack_size(),
316                registry: Arc::clone(&registry),
317                worker,
318                stealer,
319                index,
320            };
321            if let Err(e) = builder.get_spawn_handler().spawn(thread) {
322                return Err(ThreadPoolBuildError::new(ErrorKind::IOError(e)));
323            }
324        }
325
326        // Returning normally now, without termination.
327        mem::forget(t1000);
328
329        Ok(registry)
330    }
331
332    pub fn current() -> Arc<Registry> {
333        unsafe {
334            let worker_thread = WorkerThread::current();
335            let registry = if worker_thread.is_null() {
336                global_registry()
337            } else {
338                &(*worker_thread).registry
339            };
340            Arc::clone(registry)
341        }
342    }
343
344    /// Returns the number of threads in the current registry.  This
345    /// is better than `Registry::current().num_threads()` because it
346    /// avoids incrementing the `Arc`.
347    pub(super) fn current_num_threads() -> usize {
348        unsafe {
349            let worker_thread = WorkerThread::current();
350            if worker_thread.is_null() {
351                global_registry().num_threads()
352            } else {
353                (*worker_thread).registry.num_threads()
354            }
355        }
356    }
357
358    /// Returns the current `WorkerThread` if it's part of this `Registry`.
359    pub(super) fn current_thread(&self) -> Option<&WorkerThread> {
360        unsafe {
361            let worker = WorkerThread::current().as_ref()?;
362            if worker.registry().id() == self.id() {
363                Some(worker)
364            } else {
365                None
366            }
367        }
368    }
369
370    /// Returns an opaque identifier for this registry.
371    pub(super) fn id(&self) -> RegistryId {
372        // We can rely on `self` not to change since we only ever create
373        // registries that are boxed up in an `Arc` (see `new()` above).
374        RegistryId {
375            addr: self as *const Self as usize,
376        }
377    }
378
379    pub(super) fn num_threads(&self) -> usize {
380        self.thread_infos.len()
381    }
382
383    pub(super) fn catch_unwind(&self, f: impl FnOnce()) {
384        if let Err(err) = unwind::halt_unwinding(f) {
385            // If there is no handler, or if that handler itself panics, then we abort.
386            let abort_guard = unwind::AbortIfPanic;
387            if let Some(ref handler) = self.panic_handler {
388                handler(err);
389                mem::forget(abort_guard);
390            }
391        }
392    }
393
394    /// Waits for the worker threads to get up and running.  This is
395    /// meant to be used for benchmarking purposes, primarily, so that
396    /// you can get more consistent numbers by having everything
397    /// "ready to go".
398    pub(super) fn wait_until_primed(&self) {
399        for info in &self.thread_infos {
400            info.primed.wait();
401        }
402    }
403
404    /// Waits for the worker threads to stop. This is used for testing
405    /// -- so we can check that termination actually works.
406    pub(super) fn wait_until_stopped(&self) {
407        self.release_thread();
408        for info in &self.thread_infos {
409            info.stopped.wait();
410        }
411        self.acquire_thread();
412    }
413
414    pub(crate) fn acquire_thread(&self) {
415        if let Some(ref acquire_thread_handler) = self.acquire_thread_handler {
416            acquire_thread_handler();
417        }
418    }
419
420    pub(crate) fn release_thread(&self) {
421        if let Some(ref release_thread_handler) = self.release_thread_handler {
422            release_thread_handler();
423        }
424    }
425
426    /// ////////////////////////////////////////////////////////////////////////
427    /// MAIN LOOP
428    ///
429    /// So long as all of the worker threads are hanging out in their
430    /// top-level loop, there is no work to be done.
431
432    /// Push a job into the given `registry`. If we are running on a
433    /// worker thread for the registry, this will push onto the
434    /// deque. Else, it will inject from the outside (which is slower).
435    pub(super) fn inject_or_push(&self, job_ref: JobRef) {
436        let worker_thread = WorkerThread::current();
437        unsafe {
438            if !worker_thread.is_null() && (*worker_thread).registry().id() == self.id() {
439                (*worker_thread).push(job_ref);
440            } else {
441                self.inject(job_ref);
442            }
443        }
444    }
445
446    /// Push a job into the "external jobs" queue; it will be taken by
447    /// whatever worker has nothing to do. Use this if you know that
448    /// you are not on a worker of this registry.
449    pub(super) fn inject(&self, injected_job: JobRef) {
450        // It should not be possible for `state.terminate` to be true
451        // here. It is only set to true when the user creates (and
452        // drops) a `ThreadPool`; and, in that case, they cannot be
453        // calling `inject()` later, since they dropped their
454        // `ThreadPool`.
455        debug_assert_ne!(
456            self.terminate_count.load(Ordering::Acquire),
457            0,
458            "inject() sees state.terminate as true"
459        );
460
461        let queue_was_empty = self.injected_jobs.is_empty();
462
463        self.injected_jobs.push(injected_job);
464        self.sleep.new_injected_jobs(1, queue_was_empty);
465    }
466
467    pub(crate) fn has_injected_job(&self) -> bool {
468        !self.injected_jobs.is_empty()
469    }
470
471    fn pop_injected_job(&self) -> Option<JobRef> {
472        loop {
473            match self.injected_jobs.steal() {
474                Steal::Success(job) => return Some(job),
475                Steal::Empty => return None,
476                Steal::Retry => {}
477            }
478        }
479    }
480
481    /// Push a job into each thread's own "external jobs" queue; it will be
482    /// executed only on that thread, when it has nothing else to do locally,
483    /// before it tries to steal other work.
484    ///
485    /// **Panics** if not given exactly as many jobs as there are threads.
486    pub(super) fn inject_broadcast(&self, injected_jobs: impl ExactSizeIterator<Item = JobRef>) {
487        assert_eq!(self.num_threads(), injected_jobs.len());
488        {
489            let broadcasts = self.broadcasts.lock().unwrap();
490
491            // It should not be possible for `state.terminate` to be true
492            // here. It is only set to true when the user creates (and
493            // drops) a `ThreadPool`; and, in that case, they cannot be
494            // calling `inject_broadcast()` later, since they dropped their
495            // `ThreadPool`.
496            debug_assert_ne!(
497                self.terminate_count.load(Ordering::Acquire),
498                0,
499                "inject_broadcast() sees state.terminate as true"
500            );
501
502            assert_eq!(broadcasts.len(), injected_jobs.len());
503            for (worker, job_ref) in broadcasts.iter().zip(injected_jobs) {
504                worker.push(job_ref);
505            }
506        }
507        for i in 0..self.num_threads() {
508            self.sleep.notify_worker_latch_is_set(i);
509        }
510    }
511
512    /// If already in a worker-thread of this registry, just execute `op`.
513    /// Otherwise, inject `op` in this thread-pool. Either way, block until `op`
514    /// completes and return its return value. If `op` panics, that panic will
515    /// be propagated as well.  The second argument indicates `true` if injection
516    /// was performed, `false` if executed directly.
517    pub(super) fn in_worker<OP, R>(&self, op: OP) -> R
518    where
519        OP: FnOnce(&WorkerThread, bool) -> R + Send,
520        R: Send,
521    {
522        unsafe {
523            let worker_thread = WorkerThread::current();
524            if worker_thread.is_null() {
525                self.in_worker_cold(op)
526            } else if (*worker_thread).registry().id() != self.id() {
527                self.in_worker_cross(&*worker_thread, op)
528            } else {
529                // Perfectly valid to give them a `&T`: this is the
530                // current thread, so we know the data structure won't be
531                // invalidated until we return.
532                op(&*worker_thread, false)
533            }
534        }
535    }
536
537    #[cold]
538    unsafe fn in_worker_cold<OP, R>(&self, op: OP) -> R
539    where
540        OP: FnOnce(&WorkerThread, bool) -> R + Send,
541        R: Send,
542    {
543        thread_local!(static LOCK_LATCH: LockLatch = LockLatch::new());
544
545        LOCK_LATCH.with(|l| {
546            // This thread isn't a member of *any* thread pool, so just block.
547            debug_assert!(WorkerThread::current().is_null());
548            let job = StackJob::new(
549                Tlv::null(),
550                |injected| {
551                    let worker_thread = WorkerThread::current();
552                    assert!(injected && !worker_thread.is_null());
553                    op(&*worker_thread, true)
554                },
555                LatchRef::new(l),
556            );
557            self.inject(job.as_job_ref());
558            self.release_thread();
559            job.latch.wait_and_reset(); // Make sure we can use the same latch again next time.
560            self.acquire_thread();
561
562            job.into_result()
563        })
564    }
565
566    #[cold]
567    unsafe fn in_worker_cross<OP, R>(&self, current_thread: &WorkerThread, op: OP) -> R
568    where
569        OP: FnOnce(&WorkerThread, bool) -> R + Send,
570        R: Send,
571    {
572        // This thread is a member of a different pool, so let it process
573        // other work while waiting for this `op` to complete.
574        debug_assert!(current_thread.registry().id() != self.id());
575        let latch = SpinLatch::cross(current_thread);
576        let job = StackJob::new(
577            Tlv::null(),
578            |injected| {
579                let worker_thread = WorkerThread::current();
580                assert!(injected && !worker_thread.is_null());
581                op(&*worker_thread, true)
582            },
583            latch,
584        );
585        self.inject(job.as_job_ref());
586        current_thread.wait_until(&job.latch);
587        job.into_result()
588    }
589
590    /// Increments the terminate counter. This increment should be
591    /// balanced by a call to `terminate`, which will decrement. This
592    /// is used when spawning asynchronous work, which needs to
593    /// prevent the registry from terminating so long as it is active.
594    ///
595    /// Note that blocking functions such as `join` and `scope` do not
596    /// need to concern themselves with this fn; their context is
597    /// responsible for ensuring the current thread-pool will not
598    /// terminate until they return.
599    ///
600    /// The global thread-pool always has an outstanding reference
601    /// (the initial one). Custom thread-pools have one outstanding
602    /// reference that is dropped when the `ThreadPool` is dropped:
603    /// since installing the thread-pool blocks until any joins/scopes
604    /// complete, this ensures that joins/scopes are covered.
605    ///
606    /// The exception is `::spawn()`, which can create a job outside
607    /// of any blocking scope. In that case, the job itself holds a
608    /// terminate count and is responsible for invoking `terminate()`
609    /// when finished.
610    pub(super) fn increment_terminate_count(&self) {
611        let previous = self.terminate_count.fetch_add(1, Ordering::AcqRel);
612        debug_assert!(previous != 0, "registry ref count incremented from zero");
613        assert!(previous != usize::MAX, "overflow in registry ref count");
614    }
615
616    /// Signals that the thread-pool which owns this registry has been
617    /// dropped. The worker threads will gradually terminate, once any
618    /// extant work is completed.
619    pub(super) fn terminate(&self) {
620        if self.terminate_count.fetch_sub(1, Ordering::AcqRel) == 1 {
621            for (i, thread_info) in self.thread_infos.iter().enumerate() {
622                unsafe { OnceLatch::set_and_tickle_one(&thread_info.terminate, self, i) };
623            }
624        }
625    }
626
627    /// Notify the worker that the latch they are sleeping on has been "set".
628    pub(super) fn notify_worker_latch_is_set(&self, target_worker_index: usize) {
629        self.sleep.notify_worker_latch_is_set(target_worker_index);
630    }
631}
632
633/// Mark a Rayon worker thread as blocked. This triggers the deadlock handler
634/// if no other worker thread is active
635#[inline]
636pub fn mark_blocked() {
637    let worker_thread = WorkerThread::current();
638    assert!(!worker_thread.is_null());
639    unsafe {
640        let registry = &(*worker_thread).registry;
641        registry.sleep.mark_blocked(&registry.deadlock_handler)
642    }
643}
644
645/// Mark a previously blocked Rayon worker thread as unblocked
646#[inline]
647pub fn mark_unblocked(registry: &Registry) {
648    registry.sleep.mark_unblocked()
649}
650
651#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
652pub(super) struct RegistryId {
653    addr: usize,
654}
655
656struct ThreadInfo {
657    /// Latch set once thread has started and we are entering into the
658    /// main loop. Used to wait for worker threads to become primed,
659    /// primarily of interest for benchmarking.
660    primed: LockLatch,
661
662    /// Latch is set once worker thread has completed. Used to wait
663    /// until workers have stopped; only used for tests.
664    stopped: LockLatch,
665
666    /// The latch used to signal that terminated has been requested.
667    /// This latch is *set* by the `terminate` method on the
668    /// `Registry`, once the registry's main "terminate" counter
669    /// reaches zero.
670    terminate: OnceLatch,
671
672    /// the "stealer" half of the worker's deque
673    stealer: Stealer<JobRef>,
674}
675
676impl ThreadInfo {
677    fn new(stealer: Stealer<JobRef>) -> ThreadInfo {
678        ThreadInfo {
679            primed: LockLatch::new(),
680            stopped: LockLatch::new(),
681            terminate: OnceLatch::new(),
682            stealer,
683        }
684    }
685}
686
687/// ////////////////////////////////////////////////////////////////////////
688/// WorkerThread identifiers
689
690pub(super) struct WorkerThread {
691    /// the "worker" half of our local deque
692    worker: Worker<JobRef>,
693
694    /// the "stealer" half of the worker's broadcast deque
695    stealer: Stealer<JobRef>,
696
697    /// local queue used for `spawn_fifo` indirection
698    fifo: JobFifo,
699
700    pub(crate) index: usize,
701
702    /// A weak random number generator.
703    rng: XorShift64Star,
704
705    pub(crate) registry: Arc<Registry>,
706}
707
708// This is a bit sketchy, but basically: the WorkerThread is
709// allocated on the stack of the worker on entry and stored into this
710// thread local variable. So it will remain valid at least until the
711// worker is fully unwound. Using an unsafe pointer avoids the need
712// for a RefCell<T> etc.
713thread_local! {
714    static WORKER_THREAD_STATE: Cell<*const WorkerThread> = const { Cell::new(ptr::null()) };
715}
716
717impl From<ThreadBuilder> for WorkerThread {
718    fn from(thread: ThreadBuilder) -> Self {
719        Self {
720            worker: thread.worker,
721            stealer: thread.stealer,
722            fifo: JobFifo::new(),
723            index: thread.index,
724            rng: XorShift64Star::new(),
725            registry: thread.registry,
726        }
727    }
728}
729
730impl Drop for WorkerThread {
731    fn drop(&mut self) {
732        // Undo `set_current`
733        WORKER_THREAD_STATE.with(|t| {
734            assert!(t.get().eq(&(self as *const _)));
735            t.set(ptr::null());
736        });
737    }
738}
739
740impl WorkerThread {
741    /// Gets the `WorkerThread` index for the current thread; returns
742    /// NULL if this is not a worker thread. This pointer is valid
743    /// anywhere on the current thread.
744    #[inline]
745    pub(super) fn current() -> *const WorkerThread {
746        WORKER_THREAD_STATE.with(Cell::get)
747    }
748
749    /// Sets `self` as the worker thread index for the current thread.
750    /// This is done during worker thread startup.
751    unsafe fn set_current(thread: *const WorkerThread) {
752        WORKER_THREAD_STATE.with(|t| {
753            assert!(t.get().is_null());
754            t.set(thread);
755        });
756    }
757
758    /// Returns the registry that owns this worker thread.
759    #[inline]
760    pub(super) fn registry(&self) -> &Arc<Registry> {
761        &self.registry
762    }
763
764    /// Our index amongst the worker threads (ranges from `0..self.num_threads()`).
765    #[inline]
766    pub(super) fn index(&self) -> usize {
767        self.index
768    }
769
770    #[inline]
771    pub(super) unsafe fn push(&self, job: JobRef) {
772        let queue_was_empty = self.worker.is_empty();
773        self.worker.push(job);
774        self.registry.sleep.new_internal_jobs(1, queue_was_empty);
775    }
776
777    #[inline]
778    pub(super) unsafe fn push_fifo(&self, job: JobRef) {
779        self.push(self.fifo.push(job));
780    }
781
782    #[inline]
783    pub(super) fn local_deque_is_empty(&self) -> bool {
784        self.worker.is_empty()
785    }
786
787    /// Attempts to obtain a "local" job -- typically this means
788    /// popping from the top of the stack, though if we are configured
789    /// for breadth-first execution, it would mean dequeuing from the
790    /// bottom.
791    #[inline]
792    pub(super) fn take_local_job(&self) -> Option<JobRef> {
793        let popped_job = self.worker.pop();
794
795        if popped_job.is_some() {
796            return popped_job;
797        }
798
799        loop {
800            match self.stealer.steal() {
801                Steal::Success(job) => return Some(job),
802                Steal::Empty => return None,
803                Steal::Retry => {}
804            }
805        }
806    }
807
808    pub(super) fn has_injected_job(&self) -> bool {
809        !self.stealer.is_empty() || self.registry.has_injected_job()
810    }
811
812    /// Wait until the latch is set. Try to keep busy by popping and
813    /// stealing tasks as necessary.
814    #[inline]
815    pub(super) unsafe fn wait_until<L: AsCoreLatch + ?Sized>(&self, latch: &L) {
816        let latch = latch.as_core_latch();
817        if !latch.probe() {
818            self.wait_until_cold(latch);
819        }
820    }
821
822    #[cold]
823    unsafe fn wait_until_cold(&self, latch: &CoreLatch) {
824        // the code below should swallow all panics and hence never
825        // unwind; but if something does wrong, we want to abort,
826        // because otherwise other code in rayon may assume that the
827        // latch has been signaled, and that can lead to random memory
828        // accesses, which would be *very bad*
829        let abort_guard = unwind::AbortIfPanic;
830
831        'outer: while !latch.probe() {
832            // Check for local work *before* we start marking ourself idle,
833            // especially to avoid modifying shared sleep state.
834            if let Some(job) = self.take_local_job() {
835                self.execute(job);
836                continue;
837            }
838
839            let mut idle_state = self.registry.sleep.start_looking(self.index);
840            while !latch.probe() {
841                if let Some(job) = self.find_work() {
842                    self.registry.sleep.work_found();
843                    self.execute(job);
844                    // The job might have injected local work, so go back to the outer loop.
845                    continue 'outer;
846                } else {
847                    self.registry
848                        .sleep
849                        .no_work_found(&mut idle_state, latch, &self)
850                }
851            }
852
853            // If we were sleepy, we are not anymore. We "found work" --
854            // whatever the surrounding thread was doing before it had to wait.
855            self.registry.sleep.work_found();
856            break;
857        }
858
859        mem::forget(abort_guard); // successful execution, do not abort
860    }
861
862    unsafe fn wait_until_out_of_work(&self) {
863        debug_assert_eq!(self as *const _, WorkerThread::current());
864        let registry = &*self.registry;
865        let index = self.index;
866
867        registry.acquire_thread();
868        self.wait_until(&registry.thread_infos[index].terminate);
869
870        // Should not be any work left in our queue.
871        debug_assert!(self.take_local_job().is_none());
872
873        // Let registry know we are done
874        Latch::set(&registry.thread_infos[index].stopped);
875    }
876
877    fn find_work(&self) -> Option<JobRef> {
878        // Try to find some work to do. We give preference first
879        // to things in our local deque, then in other workers
880        // deques, and finally to injected jobs from the
881        // outside. The idea is to finish what we started before
882        // we take on something new.
883        self.take_local_job()
884            .or_else(|| self.steal())
885            .or_else(|| self.registry.pop_injected_job())
886    }
887
888    pub(super) fn yield_now(&self) -> Yield {
889        match self.find_work() {
890            Some(job) => unsafe {
891                self.execute(job);
892                Yield::Executed
893            },
894            None => Yield::Idle,
895        }
896    }
897
898    pub(super) fn yield_local(&self) -> Yield {
899        match self.take_local_job() {
900            Some(job) => unsafe {
901                self.execute(job);
902                Yield::Executed
903            },
904            None => Yield::Idle,
905        }
906    }
907
908    #[inline]
909    pub(super) unsafe fn execute(&self, job: JobRef) {
910        job.execute();
911    }
912
913    /// Try to steal a single job and return it.
914    ///
915    /// This should only be done as a last resort, when there is no
916    /// local work to do.
917    fn steal(&self) -> Option<JobRef> {
918        // we only steal when we don't have any work to do locally
919        debug_assert!(self.local_deque_is_empty());
920
921        // otherwise, try to steal
922        let thread_infos = &self.registry.thread_infos.as_slice();
923        let num_threads = thread_infos.len();
924        if num_threads <= 1 {
925            return None;
926        }
927
928        loop {
929            let mut retry = false;
930            let start = self.rng.next_usize(num_threads);
931            let job = (start..num_threads)
932                .chain(0..start)
933                .filter(move |&i| i != self.index)
934                .find_map(|victim_index| {
935                    let victim = &thread_infos[victim_index];
936                    match victim.stealer.steal() {
937                        Steal::Success(job) => Some(job),
938                        Steal::Empty => None,
939                        Steal::Retry => {
940                            retry = true;
941                            None
942                        }
943                    }
944                });
945            if job.is_some() || !retry {
946                return job;
947            }
948        }
949    }
950}
951
952/// ////////////////////////////////////////////////////////////////////////
953
954unsafe fn main_loop(thread: ThreadBuilder) {
955    let worker_thread = &WorkerThread::from(thread);
956    WorkerThread::set_current(worker_thread);
957    let registry = &*worker_thread.registry;
958    let index = worker_thread.index;
959
960    // let registry know we are ready to do work
961    Latch::set(&registry.thread_infos[index].primed);
962
963    // Worker threads should not panic. If they do, just abort, as the
964    // internal state of the threadpool is corrupted. Note that if
965    // **user code** panics, we should catch that and redirect.
966    let abort_guard = unwind::AbortIfPanic;
967
968    // Inform a user callback that we started a thread.
969    if let Some(ref handler) = registry.start_handler {
970        registry.catch_unwind(|| handler(index));
971    }
972
973    worker_thread.wait_until_out_of_work();
974
975    // Normal termination, do not abort.
976    mem::forget(abort_guard);
977
978    // Inform a user callback that we exited a thread.
979    if let Some(ref handler) = registry.exit_handler {
980        registry.catch_unwind(|| handler(index));
981        // We're already exiting the thread, there's nothing else to do.
982    }
983
984    registry.release_thread();
985}
986
987/// If already in a worker-thread, just execute `op`.  Otherwise,
988/// execute `op` in the default thread-pool. Either way, block until
989/// `op` completes and return its return value. If `op` panics, that
990/// panic will be propagated as well.  The second argument indicates
991/// `true` if injection was performed, `false` if executed directly.
992pub(super) fn in_worker<OP, R>(op: OP) -> R
993where
994    OP: FnOnce(&WorkerThread, bool) -> R + Send,
995    R: Send,
996{
997    unsafe {
998        let owner_thread = WorkerThread::current();
999        if !owner_thread.is_null() {
1000            // Perfectly valid to give them a `&T`: this is the
1001            // current thread, so we know the data structure won't be
1002            // invalidated until we return.
1003            op(&*owner_thread, false)
1004        } else {
1005            global_registry().in_worker(op)
1006        }
1007    }
1008}
1009
1010/// [xorshift*] is a fast pseudorandom number generator which will
1011/// even tolerate weak seeding, as long as it's not zero.
1012///
1013/// [xorshift*]: https://en.wikipedia.org/wiki/Xorshift#xorshift*
1014struct XorShift64Star {
1015    state: Cell<u64>,
1016}
1017
1018impl XorShift64Star {
1019    fn new() -> Self {
1020        // Any non-zero seed will do -- this uses the hash of a global counter.
1021        let mut seed = 0;
1022        while seed == 0 {
1023            let mut hasher = DefaultHasher::new();
1024            static COUNTER: AtomicUsize = AtomicUsize::new(0);
1025            hasher.write_usize(COUNTER.fetch_add(1, Ordering::Relaxed));
1026            seed = hasher.finish();
1027        }
1028
1029        XorShift64Star {
1030            state: Cell::new(seed),
1031        }
1032    }
1033
1034    fn next(&self) -> u64 {
1035        let mut x = self.state.get();
1036        debug_assert_ne!(x, 0);
1037        x ^= x >> 12;
1038        x ^= x << 25;
1039        x ^= x >> 27;
1040        self.state.set(x);
1041        x.wrapping_mul(0x2545_f491_4f6c_dd1d)
1042    }
1043
1044    /// Return a value from `0..n`.
1045    fn next_usize(&self, n: usize) -> usize {
1046        (self.next() % n as u64) as usize
1047    }
1048}