1use crate::error::{FFTError, FFTResult};
7use crate::planning::{
8 AdvancedFftPlanner, FftPlan, FftPlanExecutor, PlannerBackend, PlanningConfig,
9};
10use crate::worker_pool::WorkerPool;
11
12use scirs2_core::numeric::Complex64;
13use std::sync::{Arc, Mutex};
14use std::time::{Duration, Instant};
15
16#[derive(Debug, Clone)]
18pub struct ParallelPlanningConfig {
19 pub base_config: PlanningConfig,
21
22 pub max_threads: Option<usize>,
24
25 pub parallel_threshold: usize,
27
28 pub use_work_stealing: bool,
30
31 pub parallel_execution: bool,
33}
34
35impl Default for ParallelPlanningConfig {
36 fn default() -> Self {
37 Self {
38 base_config: PlanningConfig::default(),
39 max_threads: None, parallel_threshold: 1024, use_work_stealing: true,
42 parallel_execution: true,
43 }
44 }
45}
46
47pub struct ParallelPlanResult {
50 pub plan: Arc<FftPlan>,
52
53 pub creation_time: Duration,
55
56 pub shape: Vec<usize>,
58
59 pub thread_id: usize,
61}
62
63impl std::fmt::Debug for ParallelPlanResult {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 f.debug_struct("ParallelPlanResult")
66 .field("shape", &self.shape)
67 .field("creation_time", &self.creation_time)
68 .field("thread_id", &self.thread_id)
69 .field("plan", &format!("<FftPlan: shape={:?}>", self.shape))
70 .finish()
71 }
72}
73
74pub struct ParallelPlanner {
76 base_planner: Arc<Mutex<AdvancedFftPlanner>>,
78
79 config: ParallelPlanningConfig,
81
82 worker_pool: Arc<WorkerPool>,
84}
85
86impl ParallelPlanner {
87 pub fn new(config: Option<ParallelPlanningConfig>) -> Self {
89 let config = config.unwrap_or_default();
90 let base_planner = Arc::new(Mutex::new(AdvancedFftPlanner::with_config(
91 config.base_config.clone(),
92 )));
93
94 let worker_pool = match config.max_threads {
95 Some(threads) => {
96 let worker_config = crate::worker_pool::WorkerConfig {
97 num_workers: threads,
98 ..Default::default()
99 };
100 Arc::new(
101 WorkerPool::with_config(worker_config).unwrap_or_else(|_| WorkerPool::new()),
102 )
103 }
104 None => Arc::new(WorkerPool::new()),
105 };
106
107 Self {
108 base_planner,
109 config,
110 worker_pool,
111 }
112 }
113
114 pub fn plan_fft(
116 &self,
117 shape: &[usize],
118 forward: bool,
119 backend: PlannerBackend,
120 ) -> FFTResult<Arc<FftPlan>> {
121 let size = shape.iter().product::<usize>();
123 if size < self.config.parallel_threshold || !self.config.parallel_execution {
124 let mut planner = self.base_planner.lock().unwrap();
125 return planner.plan_fft(shape, forward, backend);
126 }
127
128 let planner_clone = self.base_planner.clone();
130 let shape_clone = shape.to_vec();
131 let backend_clone = backend.clone();
132
133 let result = self.worker_pool.execute(move || {
134 let mut planner = planner_clone.lock().unwrap();
135 planner
136 .plan_fft(&shape_clone, forward, backend_clone)
137 .map_err(|e| format!("FFT planning error: {e}"))
138 });
139
140 match result {
141 Ok(plan) => Ok(plan),
142 Err(err) => Err(FFTError::PlanError(err)),
143 }
144 }
145
146 pub fn plan_multiple(
148 &self,
149 specs: &[(Vec<usize>, bool, PlannerBackend)],
150 ) -> FFTResult<Vec<ParallelPlanResult>> {
151 let (small_specs, large_specs): (Vec<_>, Vec<_>) =
153 specs.iter().enumerate().partition(|(_, (shape__, _, _))| {
154 shape__.iter().product::<usize>() < self.config.parallel_threshold
155 });
156
157 let mut results = Vec::with_capacity(specs.len());
159 for (idx, (shape, forward, backend)) in small_specs {
160 let start = Instant::now();
161 let plan = {
162 let mut planner = self.base_planner.lock().unwrap();
163 planner.plan_fft(shape, *forward, backend.clone())?
164 };
165 results.push((
166 idx,
167 ParallelPlanResult {
168 plan,
169 creation_time: start.elapsed(),
170 shape: shape.clone(),
171 thread_id: 0, },
173 ));
174 }
175
176 if !large_specs.is_empty() {
178 let planner_clone = self.base_planner.clone();
179
180 let plan_futures = large_specs
182 .iter()
183 .map(|(idx, (shape, forward, backend))| {
184 let planner = planner_clone.clone();
185 let shape_clone = shape.clone();
186 let backend_clone = backend.clone();
187 let forward_val = *forward;
188 let idx_val = *idx;
189
190 self.worker_pool.execute(move || {
191 let thread_id = 0; let start = Instant::now();
193 let plan = {
194 let mut planner_guard = planner.lock().unwrap();
195 planner_guard
196 .plan_fft(&shape_clone, forward_val, backend_clone)
197 .map_err(|e| format!("FFT planning error: {e}"))?
198 };
199
200 Ok((
201 idx_val,
202 ParallelPlanResult {
203 plan,
204 creation_time: start.elapsed(),
205 shape: shape_clone,
206 thread_id,
207 },
208 ))
209 })
210 })
211 .collect::<Vec<_>>();
212
213 for result in plan_futures {
215 match result {
216 Ok((idx, result)) => results.push((idx, result)),
217 Err(err) => return Err(FFTError::PlanError(err)),
218 }
219 }
220 }
221
222 results.sort_by_key(|(idx_, _)| *idx_);
224 Ok(results.into_iter().map(|(_, result)| result).collect())
225 }
226
227 pub fn clear_cache(&self) {
229 let planner = self.base_planner.lock().unwrap();
230 planner.clear_cache();
231 }
232
233 pub fn save_plans(&self) -> FFTResult<()> {
235 let planner = self.base_planner.lock().unwrap();
236 planner.save_plans()
237 }
238}
239
240pub struct ParallelExecutor {
242 plan: Arc<FftPlan>,
244
245 config: ParallelPlanningConfig,
247
248 worker_pool: Arc<WorkerPool>,
250}
251
252impl ParallelExecutor {
253 pub fn new(plan: Arc<FftPlan>, config: Option<ParallelPlanningConfig>) -> Self {
255 let config = config.unwrap_or_default();
256
257 let worker_pool = match config.max_threads {
258 Some(threads) => {
259 let worker_config = crate::worker_pool::WorkerConfig {
260 num_workers: threads,
261 ..Default::default()
262 };
263 Arc::new(
264 WorkerPool::with_config(worker_config).unwrap_or_else(|_| WorkerPool::new()),
265 )
266 }
267 None => Arc::new(WorkerPool::new()),
268 };
269
270 Self {
271 plan,
272 config,
273 worker_pool,
274 }
275 }
276
277 pub fn execute(&self, input: &[Complex64], output: &mut [Complex64]) -> FFTResult<()> {
279 let size = self.plan.shape().iter().product::<usize>();
281 if size < self.config.parallel_threshold || !self.config.parallel_execution {
282 let executor = FftPlanExecutor::new(self.plan.clone());
283 return executor.execute(input, output);
284 }
285
286 let plan_clone = self.plan.clone();
292 let input_vec = input.to_vec(); let result = self.worker_pool.execute(move || {
295 let mut output_vec = vec![Complex64::default(); input_vec.len()];
296 let executor = FftPlanExecutor::new(plan_clone);
297
298 executor
299 .execute(&input_vec, &mut output_vec)
300 .map_err(|e| format!("FFT execution error: {e}"))?;
301
302 Ok(output_vec)
303 });
304
305 match result {
307 Ok(result_vec) => {
308 output.copy_from_slice(&result_vec);
309 Ok(())
310 }
311 Err(err) => Err(FFTError::ComputationError(err)),
312 }
313 }
314
315 pub fn execute_batch(
317 &self,
318 inputs: &[&[Complex64]],
319 outputs: &mut [&mut [Complex64]],
320 ) -> FFTResult<Vec<Duration>> {
321 if inputs.len() != outputs.len() {
322 return Err(FFTError::ValueError(
323 "Input and output counts must match".to_string(),
324 ));
325 }
326
327 let expected_size = self.plan.shape().iter().product::<usize>();
329 for (i, input) in inputs.iter().enumerate() {
330 if input.len() != expected_size {
331 return Err(FFTError::ValueError(format!(
332 "Input {} has wrong size: expected {}, got {}",
333 i,
334 expected_size,
335 input.len()
336 )));
337 }
338
339 if outputs[i].len() != expected_size {
340 return Err(FFTError::ValueError(format!(
341 "Output {} has wrong size: expected {}, got {}",
342 i,
343 expected_size,
344 outputs[i].len()
345 )));
346 }
347 }
348
349 if inputs.len() < 2 || !self.config.parallel_execution {
351 let mut times = Vec::with_capacity(inputs.len());
352 let executor = FftPlanExecutor::new(self.plan.clone());
353
354 for i in 0..inputs.len() {
355 let start = Instant::now();
356 executor.execute(inputs[i], outputs[i])?;
357 times.push(start.elapsed());
358 }
359
360 return Ok(times);
361 }
362
363 let plan_clone = self.plan.clone();
365
366 let futures = inputs
368 .iter()
369 .zip(outputs.iter_mut())
370 .enumerate()
371 .map(|(idx, (input, output))| {
372 let plan = plan_clone.clone();
373 let input_vec = input.to_vec(); let output_len = output.len();
375
376 self.worker_pool.execute(move || {
377 let mut local_output = vec![Complex64::default(); output_len];
378 let executor = FftPlanExecutor::new(plan);
379
380 let start = Instant::now();
381 executor
382 .execute(&input_vec, &mut local_output)
383 .map_err(|e| format!("FFT execution error for batch {idx}: {e}"))?;
384 let elapsed = start.elapsed();
385
386 Ok((idx, local_output, elapsed))
387 })
388 })
389 .collect::<Vec<_>>();
390
391 let mut times = vec![Duration::from_secs(0); inputs.len()];
393
394 for result in futures {
395 match result {
396 Ok((idx, result_vec, elapsed)) => {
397 outputs[idx].copy_from_slice(&result_vec);
398 times[idx] = elapsed;
399 }
400 Err(err) => return Err(FFTError::ComputationError(err)),
401 }
402 }
403
404 Ok(times)
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 #[test]
413 fn test_parallel_planner() {
414 let planner = ParallelPlanner::new(None);
415
416 let plan = planner
418 .plan_fft(&[64], true, PlannerBackend::default())
419 .unwrap();
420
421 assert_eq!(plan.shape(), &[64]);
423 }
424
425 #[test]
426 fn test_parallel_executor() {
427 let planner = ParallelPlanner::new(None);
428 let plan = planner
429 .plan_fft(&[64], true, PlannerBackend::default())
430 .unwrap();
431
432 let executor = ParallelExecutor::new(plan, None);
433
434 let input = vec![Complex64::new(1.0, 0.0); 64];
436 let mut output = vec![Complex64::default(); 64];
437
438 executor.execute(&input, &mut output).unwrap();
440
441 assert!(output.iter().any(|&val| val != Complex64::default()));
443 }
444
445 #[test]
446 fn test_batch_execution() {
447 let planner = ParallelPlanner::new(None);
448 let plan = planner
449 .plan_fft(&[32], true, PlannerBackend::default())
450 .unwrap();
451
452 let executor = ParallelExecutor::new(plan, None);
453
454 let input1 = vec![Complex64::new(1.0, 0.0); 32];
456 let input2 = vec![Complex64::new(0.0, 1.0); 32];
457
458 let mut output1 = vec![Complex64::default(); 32];
459 let mut output2 = vec![Complex64::default(); 32];
460
461 let inputs = [&input1[..], &input2[..]];
462 let mut outputs = [&mut output1[..], &mut output2[..]];
463
464 let times = executor.execute_batch(&inputs, &mut outputs).unwrap();
466
467 assert_eq!(times.len(), 2);
469
470 assert!(output1.iter().any(|&val| val != Complex64::default()));
472 assert!(output2.iter().any(|&val| val != Complex64::default()));
473
474 assert_ne!(output1[0], output2[0]);
476 }
477}