Skip to main content

tract_linalg/
multithread.rs

1use std::cell::RefCell;
2#[cfg(feature = "multithread-mm")]
3use std::sync::atomic::{AtomicUsize, Ordering};
4#[allow(unused_imports)]
5use std::sync::{Arc, Mutex};
6
7#[cfg(feature = "multithread-mm")]
8use rayon::{ThreadPool, ThreadPoolBuilder};
9
10#[derive(Debug, Clone, Default)]
11pub enum Executor {
12    #[default]
13    SingleThread,
14    #[cfg(feature = "multithread-mm")]
15    MultiThread(Arc<ThreadPool>),
16    /// Use rayon's GLOBAL thread pool — the one set up by
17    /// `wasm_bindgen_rayon::init_thread_pool` on `wasm32-unknown-unknown`,
18    /// or rayon's auto-initialised default on native.
19    ///
20    /// Exists because `Arc<rayon::ThreadPool>` cannot be constructed on
21    /// `wasm32-unknown-unknown`: rayon's default `spawn_handler` calls
22    /// `std::thread::spawn`, which is unsupported there. The only working
23    /// route is rayon's global pool, accessed via `into_par_iter` directly.
24    #[cfg(feature = "multithread-mm")]
25    RayonGlobal,
26}
27
28impl Executor {
29    #[cfg(feature = "multithread-mm")]
30    pub fn multithread(n: usize) -> Executor {
31        Executor::multithread_with_name(n, "tract-default")
32    }
33
34    #[cfg(feature = "multithread-mm")]
35    pub fn multithread_with_name(n: usize, name: &str) -> Executor {
36        let name = name.to_string();
37        let pool = ThreadPoolBuilder::new()
38            .thread_name(move |n| format!("{name}-{n}"))
39            .num_threads(n)
40            .build()
41            .unwrap();
42        Executor::MultiThread(Arc::new(pool))
43    }
44}
45
46static DEFAULT_EXECUTOR: Mutex<Executor> = Mutex::new(Executor::SingleThread);
47
48thread_local! {
49    static TLS_EXECUTOR_OVERRIDE: RefCell<Option<Executor>> = Default::default();
50}
51
52pub fn current_tract_executor() -> Executor {
53    if let Some(over_ride) = TLS_EXECUTOR_OVERRIDE.with_borrow(|tls| tls.clone()) {
54        over_ride
55    } else {
56        DEFAULT_EXECUTOR.lock().unwrap().clone()
57    }
58}
59
60pub fn set_default_executor(executor: Executor) {
61    *DEFAULT_EXECUTOR.lock().unwrap() = executor;
62}
63
64pub fn multithread_tract_scope<R, F: FnOnce() -> R>(pool: Executor, f: F) -> R {
65    let previous = TLS_EXECUTOR_OVERRIDE.replace(Some(pool));
66    let result = f();
67    TLS_EXECUTOR_OVERRIDE.set(previous);
68    result
69}
70
71/// Threshold (in panels) below which the rayon MMM dispatcher skips
72/// parallelism and runs inline single-threaded. Below this size,
73/// per-call dispatch overhead (~5 µs native, ~50 µs wasm-bindgen-rayon
74/// worker) exceeds the parallel speedup.
75///
76/// Default `64`. Tune higher for many-small-MMM workloads (mobile vision,
77/// streaming RNN) or lower for transformer-class workloads where every MMM
78/// is large. `0` disables the gate entirely (always thread).
79#[cfg(feature = "multithread-mm")]
80static THREADING_PANEL_THRESHOLD: AtomicUsize = AtomicUsize::new(64);
81
82/// Read the current MMM panel-count threshold for the rayon path.
83#[cfg(feature = "multithread-mm")]
84pub fn current_threading_panel_threshold() -> usize {
85    THREADING_PANEL_THRESHOLD.load(Ordering::Relaxed)
86}
87
88/// Set the MMM panel-count threshold for the rayon path. Default is `64`.
89/// Pass `0` to thread regardless of size.
90#[cfg(feature = "multithread-mm")]
91pub fn set_threading_panel_threshold(panels: usize) {
92    THREADING_PANEL_THRESHOLD.store(panels, Ordering::Relaxed);
93}