sas/
lib.rs

1use std::sync::atomic::{AtomicBool, Ordering};
2
3use anyhow::{anyhow, Result};
4
5static IS_INITED: AtomicBool = AtomicBool::new(false);
6
7/// Automatically collects system topology and optimizes key operations on the fly.
8///
9/// ## Panics
10///
11/// Panies when failing to init SAS.
12/// Note that reinitialization is not supported.
13///
14#[inline]
15pub fn init() {
16    try_init().unwrap()
17}
18
19/// Automatically collects system topology and optimizes key operations on the fly.
20///
21#[inline]
22pub fn try_init() -> Result<()> {
23    Sas::default()
24        .init()
25        .map_err(|error| anyhow!("failed to init SAS: {error}"))
26}
27
28/// SAS optimization arguments.
29#[derive(Clone, Debug, Default)]
30#[cfg_attr(feature = "clap", derive(::clap::Parser))]
31#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
32pub struct Sas {
33    /// Runtime system type.
34    #[cfg_attr(
35        feature = "clap",
36        arg(
37            default_value = "SystemType::default()",
38            env = "SAS_SYSTEM_TYPE",
39            long = "sas-system-type",
40            value_name = "TYPE"
41        )
42    )]
43    pub system_type: SystemType,
44}
45
46impl Sas {
47    /// Optimizes key operations with given arguments.
48    ///
49    pub fn init(self) -> Result<()> {
50        if !IS_INITED.swap(true, Ordering::SeqCst) {
51            self.init_unchecked()
52        } else {
53            Ok(())
54        }
55    }
56
57    fn init_unchecked(self) -> Result<()> {
58        #[cfg(feature = "rayon")]
59        {
60            use rayon::ThreadPoolBuilder;
61
62            let (has_multiple_numa_nodes, threads) = prepare_threads()?;
63
64            let mut builder = ThreadPoolBuilder::new().num_threads(threads.len());
65            if matches!(self.system_type, SystemType::Python) {
66                builder = builder.use_current_thread();
67            }
68            builder.build_global()?;
69
70            if has_multiple_numa_nodes {
71                bind_threads(threads)?;
72            }
73        }
74
75        Ok(())
76    }
77}
78
79/// Runtime system type.
80#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
81#[cfg_attr(feature = "clap", derive(::clap::Parser))]
82#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
83#[cfg_attr(feature = "strum", derive(::strum::Display, ::strum::EnumString))]
84pub enum SystemType {
85    /// Use all threads without the main thread
86    #[default]
87    Generic,
88    /// Use all threads even with the main thread
89    Python,
90}
91
92#[cfg(all(feature = "numa", feature = "rayon"))]
93#[inline]
94fn get_topology() -> Result<::hwlocality::Topology> {
95    ::hwlocality::Topology::new().map_err(Into::into)
96}
97
98#[cfg(all(not(feature = "numa"), feature = "rayon"))]
99fn prepare_threads() -> Result<(bool, Vec<usize>)> {
100    use std::thread;
101
102    let num_threads = thread::available_parallelism()
103        .map(usize::from)
104        .unwrap_or(1);
105    Ok((false, (0..num_threads).collect()))
106}
107
108#[cfg(all(feature = "numa", feature = "rayon"))]
109fn prepare_threads() -> Result<(bool, Vec<usize>)> {
110    use rand::{
111        distributions::{Distribution, Uniform},
112        thread_rng,
113    };
114
115    // get NUMA/CPUs info
116    let topology = get_topology()?;
117    let all_numa_nodes = topology.nodeset();
118    let all_cpus = topology.cpuset();
119
120    // count the resources
121    let num_numa_nodes = all_numa_nodes
122        .last_set()
123        .map(|set| set.into())
124        .unwrap_or(0usize)
125        + 1;
126    let num_cpus = all_cpus.last_set().map(|set| set.into()).unwrap_or(0usize) + 1;
127    let num_threads_per_cpu = num_cpus / num_numa_nodes;
128
129    // pick a random NUMA node
130    let numa_node = Uniform::new(0usize, num_numa_nodes).sample(&mut thread_rng());
131
132    // get all the CPUs in the NUMA node
133    let cpu_begin = numa_node * num_threads_per_cpu;
134    let cpu_end = cpu_begin + num_threads_per_cpu;
135    let cpus = (cpu_begin..cpu_end).collect();
136    Ok((num_numa_nodes > 1, cpus))
137}
138
139#[cfg(all(not(feature = "numa"), feature = "rayon"))]
140#[inline]
141fn bind_threads(_: Vec<usize>) -> Result<()> {
142    Ok(())
143}
144
145#[cfg(all(feature = "numa", feature = "rayon"))]
146fn bind_threads(threads: Vec<usize>) -> Result<()> {
147    use hwlocality::cpu::{binding::CpuBindingFlags, cpuset::CpuSet};
148
149    ::rayon::scope(|s| {
150        s.spawn_broadcast({
151            move |_, ctx| {
152                // bind the given thread into the NUMA node
153                let topology = get_topology().expect("failed to load topology");
154                let cpus = {
155                    let mut res = CpuSet::new();
156                    res.set(threads[ctx.index()]);
157                    res
158                };
159                topology
160                    .bind_cpu(&cpus, CpuBindingFlags::THREAD)
161                    .expect("failed to bind the rayon thread into CPU");
162            }
163        });
164    });
165    Ok(())
166}