tract_linalg/
multithread.rs

1use std::cell::RefCell;
2use std::sync::{Arc, Mutex};
3
4use rayon::{ThreadPool, ThreadPoolBuilder};
5
6#[derive(Debug, Clone, Default)]
7pub enum Executor {
8    #[default]
9    SingleThread,
10    MultiThread(Arc<ThreadPool>),
11}
12
13impl Executor {
14    pub fn multithread(n: usize) -> Executor {
15        Executor::multithread_with_name(n, "tract-default")
16    }
17
18    pub fn multithread_with_name(n: usize, name: &str) -> Executor {
19        let name = name.to_string();
20        let pool = ThreadPoolBuilder::new()
21            .thread_name(move |n| format!("{name}-{n}"))
22            .num_threads(n)
23            .build()
24            .unwrap();
25        Executor::MultiThread(Arc::new(pool))
26    }
27}
28
29static DEFAULT_EXECUTOR: Mutex<Executor> = Mutex::new(Executor::SingleThread);
30
31thread_local! {
32    static TLS_EXECUTOR_OVERRIDE: RefCell<Option<Executor>> = Default::default();
33}
34
35pub fn current_tract_executor() -> Executor {
36    if let Some(over_ride) = TLS_EXECUTOR_OVERRIDE.with_borrow(|tls| tls.clone()) {
37        over_ride
38    } else {
39        DEFAULT_EXECUTOR.lock().unwrap().clone()
40    }
41}
42
43pub fn set_default_executor(executor: Executor) {
44    *DEFAULT_EXECUTOR.lock().unwrap() = executor;
45}
46
47pub fn multithread_tract_scope<R, F: FnOnce() -> R>(pool: Executor, f: F) -> R {
48    let previous = TLS_EXECUTOR_OVERRIDE.replace(Some(pool));
49    let result = f();
50    TLS_EXECUTOR_OVERRIDE.set(previous);
51    result
52}