scirs2_fft/
worker_pool.rs1use scirs2_core::parallel_ops::*;
7use std::env;
8use std::sync::{Arc, Mutex, OnceLock};
9use std::thread;
10
11#[derive(Debug, Clone)]
13pub struct WorkerConfig {
14 pub num_workers: usize,
16 pub enabled: bool,
18 pub stack_size: Option<usize>,
20 pub thread_name_prefix: String,
22}
23
24impl Default for WorkerConfig {
25 fn default() -> Self {
26 let num_cpus = thread::available_parallelism()
28 .map(|n| n.get())
29 .unwrap_or(1);
30
31 let num_workers = env::var("SCIRS2_FFT_WORKERS")
33 .ok()
34 .and_then(|s| s.parse().ok())
35 .unwrap_or(num_cpus);
36
37 Self {
38 num_workers,
39 enabled: true,
40 stack_size: None,
41 thread_name_prefix: "scirs2-fft-worker".to_string(),
42 }
43 }
44}
45
46pub struct WorkerPool {
49 config: Arc<Mutex<WorkerConfig>>,
50}
51
52impl WorkerPool {
53 pub fn new() -> Self {
55 let config = WorkerConfig::default();
56
57 Self {
58 config: Arc::new(Mutex::new(config)),
59 }
60 }
61
62 pub fn with_config(
64 config: WorkerConfig,
65 ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
66 Ok(Self {
67 config: Arc::new(Mutex::new(config)),
68 })
69 }
70
71 pub fn get_workers(&self) -> usize {
75 self.config.lock().unwrap().num_workers
76 }
77
78 pub fn set_workers(
82 &mut self,
83 num_workers: usize,
84 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
85 let mut config = self.config.lock().unwrap();
86 config.num_workers = num_workers;
87 Ok(())
88 }
89
90 pub fn is_enabled(&self) -> bool {
92 self.config.lock().unwrap().enabled
93 }
94
95 pub fn set_enabled(&self, enabled: bool) {
97 self.config.lock().unwrap().enabled = enabled;
98 }
99
100 pub fn execute<F, R>(&self, f: F) -> R
103 where
104 F: FnOnce() -> R + Send,
105 R: Send,
106 {
107 f()
110 }
111
112 pub fn execute_with_workers<F, R>(&self, _numworkers: usize, f: F) -> R
115 where
116 F: FnOnce() -> R + Send,
117 R: Send,
118 {
119 f()
122 }
123
124 pub fn get_info(&self) -> WorkerPoolInfo {
126 let config = self.config.lock().unwrap();
127 WorkerPoolInfo {
128 num_workers: config.num_workers,
129 enabled: config.enabled,
130 current_threads: num_threads(), thread_name_prefix: config.thread_name_prefix.clone(),
132 }
133 }
134}
135
136impl Default for WorkerPool {
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142#[derive(Debug, Clone)]
144pub struct WorkerPoolInfo {
145 pub num_workers: usize,
146 pub enabled: bool,
147 pub current_threads: usize,
148 pub thread_name_prefix: String,
149}
150
151impl std::fmt::Display for WorkerPoolInfo {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 write!(
154 f,
155 "Worker Pool: {} workers (current: {}), enabled: {}, prefix: {}",
156 self.num_workers, self.current_threads, self.enabled, self.thread_name_prefix
157 )
158 }
159}
160
161static GLOBAL_WORKER_POOL: OnceLock<WorkerPool> = OnceLock::new();
163
164#[allow(dead_code)]
166pub fn get_global_pool() -> &'static WorkerPool {
167 GLOBAL_WORKER_POOL.get_or_init(WorkerPool::new)
168}
169
170#[allow(dead_code)]
172pub fn init_global_pool(config: WorkerConfig) -> Result<(), &'static str> {
173 GLOBAL_WORKER_POOL
174 .set(WorkerPool::with_config(config).map_err(|_| "Failed to create worker pool")?)
175 .map_err(|_| "Global worker pool already initialized")
176}
177
178pub struct WorkerContext {
180 #[allow(dead_code)]
181 previous_workers: usize,
182 #[allow(dead_code)]
183 pool: &'static WorkerPool,
184}
185
186impl WorkerContext {
187 pub fn new(_numworkers: usize) -> Self {
189 let pool = get_global_pool();
190 let previous_workers = pool.get_workers();
191
192 Self {
196 previous_workers,
197 pool,
198 }
199 }
200}
201
202impl Drop for WorkerContext {
203 fn drop(&mut self) {
204 }
207}
208
209#[allow(dead_code)]
211pub fn set_workers(_n: usize) -> Result<(), &'static str> {
212 let _pool = get_global_pool();
213 Ok(())
216}
217
218#[allow(dead_code)]
220pub fn get_workers() -> usize {
221 get_global_pool().get_workers()
222}
223
224#[allow(dead_code)]
226pub fn with_workers<F, R>(num_workers: usize, f: F) -> R
227where
228 F: FnOnce() -> R + Send,
229 R: Send,
230{
231 get_global_pool().execute_with_workers(num_workers, f)
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn test_default_worker_pool() {
240 let pool = WorkerPool::new();
241 assert!(pool.get_workers() > 0);
242 assert!(pool.is_enabled());
243 }
244
245 #[test]
246 fn test_worker_config() {
247 let config = WorkerConfig {
248 num_workers: 4,
249 enabled: true,
250 stack_size: Some(2 * 1024 * 1024),
251 thread_name_prefix: "test-worker".to_string(),
252 };
253
254 let pool = WorkerPool::with_config(config).unwrap();
255 assert_eq!(pool.get_workers(), 4);
256 }
257
258 #[test]
259 fn test_enable_disable() {
260 let pool = WorkerPool::new();
261 assert!(pool.is_enabled());
262
263 pool.set_enabled(false);
264 assert!(!pool.is_enabled());
265
266 pool.set_enabled(true);
267 assert!(pool.is_enabled());
268 }
269
270 #[test]
271 fn test_execute() {
272 let pool = WorkerPool::new();
273
274 let result = pool.execute(|| 42);
276 assert_eq!(result, 42);
277
278 pool.set_enabled(false);
280 let result = pool.execute(|| 84);
281 assert_eq!(result, 84);
282 }
283
284 #[test]
285 fn test_execute_with_workers() {
286 let pool = WorkerPool::new();
287
288 let result = pool.execute_with_workers(2, || num_threads());
289
290 if pool.is_enabled() {
294 assert!(result > 0);
295 }
296 }
297
298 #[test]
299 fn test_worker_info() {
300 let pool = WorkerPool::new();
301 let info = pool.get_info();
302
303 assert_eq!(info.num_workers, pool.get_workers());
304 assert_eq!(info.enabled, pool.is_enabled());
305 }
306}