1use std::sync::atomic::{AtomicBool, Ordering};
2
3use anyhow::{anyhow, Result};
4
5static IS_INITED: AtomicBool = AtomicBool::new(false);
6
7#[inline]
15pub fn init() {
16 try_init().unwrap()
17}
18
19#[inline]
22pub fn try_init() -> Result<()> {
23 Sas::default()
24 .init()
25 .map_err(|error| anyhow!("failed to init SAS: {error}"))
26}
27
28#[derive(Clone, Debug, Default)]
30#[cfg_attr(feature = "clap", derive(::clap::Parser))]
31#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
32pub struct Sas {
33 #[cfg_attr(
35 feature = "clap",
36 arg(
37 default_value = "SystemType::default()",
38 env = "SAS_SYSTEM_TYPE",
39 long = "sas-system-type",
40 value_name = "TYPE"
41 )
42 )]
43 pub system_type: SystemType,
44}
45
46impl Sas {
47 pub fn init(self) -> Result<()> {
50 if !IS_INITED.swap(true, Ordering::SeqCst) {
51 self.init_unchecked()
52 } else {
53 Ok(())
54 }
55 }
56
57 fn init_unchecked(self) -> Result<()> {
58 #[cfg(feature = "rayon")]
59 {
60 use rayon::ThreadPoolBuilder;
61
62 let (has_multiple_numa_nodes, threads) = prepare_threads()?;
63
64 let mut builder = ThreadPoolBuilder::new().num_threads(threads.len());
65 if matches!(self.system_type, SystemType::Python) {
66 builder = builder.use_current_thread();
67 }
68 builder.build_global()?;
69
70 if has_multiple_numa_nodes {
71 bind_threads(threads)?;
72 }
73 }
74
75 Ok(())
76 }
77}
78
79#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
81#[cfg_attr(feature = "clap", derive(::clap::Parser))]
82#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
83#[cfg_attr(feature = "strum", derive(::strum::Display, ::strum::EnumString))]
84pub enum SystemType {
85 #[default]
87 Generic,
88 Python,
90}
91
92#[cfg(all(feature = "numa", feature = "rayon"))]
93#[inline]
94fn get_topology() -> Result<::hwlocality::Topology> {
95 ::hwlocality::Topology::new().map_err(Into::into)
96}
97
98#[cfg(all(not(feature = "numa"), feature = "rayon"))]
99fn prepare_threads() -> Result<(bool, Vec<usize>)> {
100 use std::thread;
101
102 let num_threads = thread::available_parallelism()
103 .map(usize::from)
104 .unwrap_or(1);
105 Ok((false, (0..num_threads).collect()))
106}
107
108#[cfg(all(feature = "numa", feature = "rayon"))]
109fn prepare_threads() -> Result<(bool, Vec<usize>)> {
110 use rand::{
111 distributions::{Distribution, Uniform},
112 thread_rng,
113 };
114
115 let topology = get_topology()?;
117 let all_numa_nodes = topology.nodeset();
118 let all_cpus = topology.cpuset();
119
120 let num_numa_nodes = all_numa_nodes
122 .last_set()
123 .map(|set| set.into())
124 .unwrap_or(0usize)
125 + 1;
126 let num_cpus = all_cpus.last_set().map(|set| set.into()).unwrap_or(0usize) + 1;
127 let num_threads_per_cpu = num_cpus / num_numa_nodes;
128
129 let numa_node = Uniform::new(0usize, num_numa_nodes).sample(&mut thread_rng());
131
132 let cpu_begin = numa_node * num_threads_per_cpu;
134 let cpu_end = cpu_begin + num_threads_per_cpu;
135 let cpus = (cpu_begin..cpu_end).collect();
136 Ok((num_numa_nodes > 1, cpus))
137}
138
139#[cfg(all(not(feature = "numa"), feature = "rayon"))]
140#[inline]
141fn bind_threads(_: Vec<usize>) -> Result<()> {
142 Ok(())
143}
144
145#[cfg(all(feature = "numa", feature = "rayon"))]
146fn bind_threads(threads: Vec<usize>) -> Result<()> {
147 use hwlocality::cpu::{binding::CpuBindingFlags, cpuset::CpuSet};
148
149 ::rayon::scope(|s| {
150 s.spawn_broadcast({
151 move |_, ctx| {
152 let topology = get_topology().expect("failed to load topology");
154 let cpus = {
155 let mut res = CpuSet::new();
156 res.set(threads[ctx.index()]);
157 res
158 };
159 topology
160 .bind_cpu(&cpus, CpuBindingFlags::THREAD)
161 .expect("failed to bind the rayon thread into CPU");
162 }
163 });
164 });
165 Ok(())
166}