scirs2_fft/
planning_parallel.rs

1//! Parallel FFT Planning
2//!
3//! This module extends the planning system with multithreaded planning capabilities,
4//! allowing for parallel plan creation and execution.
5
6use crate::error::{FFTError, FFTResult};
7use crate::planning::{
8    AdvancedFftPlanner, FftPlan, FftPlanExecutor, PlannerBackend, PlanningConfig,
9};
10use crate::worker_pool::WorkerPool;
11
12use scirs2_core::numeric::Complex64;
13use std::sync::{Arc, Mutex};
14use std::time::{Duration, Instant};
15
16/// Configuration options for parallel planning
17#[derive(Debug, Clone)]
18pub struct ParallelPlanningConfig {
19    /// Base planning configuration
20    pub base_config: PlanningConfig,
21
22    /// Maximum number of threads to use
23    pub max_threads: Option<usize>,
24
25    /// Minimum size threshold for parallel planning
26    pub parallel_threshold: usize,
27
28    /// Whether to use work stealing
29    pub use_work_stealing: bool,
30
31    /// Whether to enable parallel execution
32    pub parallel_execution: bool,
33}
34
35impl Default for ParallelPlanningConfig {
36    fn default() -> Self {
37        Self {
38            base_config: PlanningConfig::default(),
39            max_threads: None,        // Use all available threads
40            parallel_threshold: 1024, // Only use parallelism for FFTs >= 1024 elements
41            use_work_stealing: true,
42            parallel_execution: true,
43        }
44    }
45}
46
47/// Result of a parallel plan creation
48// Custom Debug implementation because FftPlan doesn't implement Debug
49pub struct ParallelPlanResult {
50    /// The created plan
51    pub plan: Arc<FftPlan>,
52
53    /// Time taken to create the plan
54    pub creation_time: Duration,
55
56    /// Shape of the FFT for this plan
57    pub shape: Vec<usize>,
58
59    /// Thread ID that created this plan
60    pub thread_id: usize,
61}
62
63impl std::fmt::Debug for ParallelPlanResult {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        f.debug_struct("ParallelPlanResult")
66            .field("shape", &self.shape)
67            .field("creation_time", &self.creation_time)
68            .field("thread_id", &self.thread_id)
69            .field("plan", &format!("<FftPlan: shape={:?}>", self.shape))
70            .finish()
71    }
72}
73
74/// Parallel planner that can create multiple plans simultaneously
75pub struct ParallelPlanner {
76    /// Base planner
77    base_planner: Arc<Mutex<AdvancedFftPlanner>>,
78
79    /// Configuration
80    config: ParallelPlanningConfig,
81
82    /// Worker pool for parallel execution
83    worker_pool: Arc<WorkerPool>,
84}
85
86impl ParallelPlanner {
87    /// Create a new parallel planner
88    pub fn new(config: Option<ParallelPlanningConfig>) -> Self {
89        let config = config.unwrap_or_default();
90        let base_planner = Arc::new(Mutex::new(AdvancedFftPlanner::with_config(
91            config.base_config.clone(),
92        )));
93
94        let worker_pool = match config.max_threads {
95            Some(threads) => {
96                let worker_config = crate::worker_pool::WorkerConfig {
97                    num_workers: threads,
98                    ..Default::default()
99                };
100                Arc::new(
101                    WorkerPool::with_config(worker_config).unwrap_or_else(|_| WorkerPool::new()),
102                )
103            }
104            None => Arc::new(WorkerPool::new()),
105        };
106
107        Self {
108            base_planner,
109            config,
110            worker_pool,
111        }
112    }
113
114    /// Create a single plan
115    pub fn plan_fft(
116        &self,
117        shape: &[usize],
118        forward: bool,
119        backend: PlannerBackend,
120    ) -> FFTResult<Arc<FftPlan>> {
121        // For small FFTs, use the base planner directly
122        let size = shape.iter().product::<usize>();
123        if size < self.config.parallel_threshold || !self.config.parallel_execution {
124            let mut planner = self.base_planner.lock().unwrap();
125            return planner.plan_fft(shape, forward, backend);
126        }
127
128        // For larger FFTs, use the worker pool
129        let planner_clone = self.base_planner.clone();
130        let shape_clone = shape.to_vec();
131        let backend_clone = backend.clone();
132
133        let result = self.worker_pool.execute(move || {
134            let mut planner = planner_clone.lock().unwrap();
135            planner
136                .plan_fft(&shape_clone, forward, backend_clone)
137                .map_err(|e| format!("FFT planning error: {e}"))
138        });
139
140        match result {
141            Ok(plan) => Ok(plan),
142            Err(err) => Err(FFTError::PlanError(err)),
143        }
144    }
145
146    /// Create multiple plans in parallel
147    pub fn plan_multiple(
148        &self,
149        specs: &[(Vec<usize>, bool, PlannerBackend)],
150    ) -> FFTResult<Vec<ParallelPlanResult>> {
151        // Filter out small FFTs that would be processed serially
152        let (small_specs, large_specs): (Vec<_>, Vec<_>) =
153            specs.iter().enumerate().partition(|(_, (shape__, _, _))| {
154                shape__.iter().product::<usize>() < self.config.parallel_threshold
155            });
156
157        // Process small FFTs serially
158        let mut results = Vec::with_capacity(specs.len());
159        for (idx, (shape, forward, backend)) in small_specs {
160            let start = Instant::now();
161            let plan = {
162                let mut planner = self.base_planner.lock().unwrap();
163                planner.plan_fft(shape, *forward, backend.clone())?
164            };
165            results.push((
166                idx,
167                ParallelPlanResult {
168                    plan,
169                    creation_time: start.elapsed(),
170                    shape: shape.clone(),
171                    thread_id: 0, // Main thread
172                },
173            ));
174        }
175
176        // Process large FFTs in parallel
177        if !large_specs.is_empty() {
178            let planner_clone = self.base_planner.clone();
179
180            // Submit each plan creation as a separate task
181            let plan_futures = large_specs
182                .iter()
183                .map(|(idx, (shape, forward, backend))| {
184                    let planner = planner_clone.clone();
185                    let shape_clone = shape.clone();
186                    let backend_clone = backend.clone();
187                    let forward_val = *forward;
188                    let idx_val = *idx;
189
190                    self.worker_pool.execute(move || {
191                        let thread_id = 0; // Thread ID tracking handled by core parallel abstractions
192                        let start = Instant::now();
193                        let plan = {
194                            let mut planner_guard = planner.lock().unwrap();
195                            planner_guard
196                                .plan_fft(&shape_clone, forward_val, backend_clone)
197                                .map_err(|e| format!("FFT planning error: {e}"))?
198                        };
199
200                        Ok((
201                            idx_val,
202                            ParallelPlanResult {
203                                plan,
204                                creation_time: start.elapsed(),
205                                shape: shape_clone,
206                                thread_id,
207                            },
208                        ))
209                    })
210                })
211                .collect::<Vec<_>>();
212
213            // Plans are computed when executed - directly collect results
214            for result in plan_futures {
215                match result {
216                    Ok((idx, result)) => results.push((idx, result)),
217                    Err(err) => return Err(FFTError::PlanError(err)),
218                }
219            }
220        }
221
222        // Sort results by original index
223        results.sort_by_key(|(idx_, _)| *idx_);
224        Ok(results.into_iter().map(|(_, result)| result).collect())
225    }
226
227    /// Clear the plan cache
228    pub fn clear_cache(&self) {
229        let planner = self.base_planner.lock().unwrap();
230        planner.clear_cache();
231    }
232
233    /// Save plans to disk
234    pub fn save_plans(&self) -> FFTResult<()> {
235        let planner = self.base_planner.lock().unwrap();
236        planner.save_plans()
237    }
238}
239
240/// Executor for parallel FFT operations
241pub struct ParallelExecutor {
242    /// The plan to execute
243    plan: Arc<FftPlan>,
244
245    /// Configuration
246    config: ParallelPlanningConfig,
247
248    /// Worker pool
249    worker_pool: Arc<WorkerPool>,
250}
251
252impl ParallelExecutor {
253    /// Create a new parallel executor
254    pub fn new(plan: Arc<FftPlan>, config: Option<ParallelPlanningConfig>) -> Self {
255        let config = config.unwrap_or_default();
256
257        let worker_pool = match config.max_threads {
258            Some(threads) => {
259                let worker_config = crate::worker_pool::WorkerConfig {
260                    num_workers: threads,
261                    ..Default::default()
262                };
263                Arc::new(
264                    WorkerPool::with_config(worker_config).unwrap_or_else(|_| WorkerPool::new()),
265                )
266            }
267            None => Arc::new(WorkerPool::new()),
268        };
269
270        Self {
271            plan,
272            config,
273            worker_pool,
274        }
275    }
276
277    /// Execute the plan in parallel
278    pub fn execute(&self, input: &[Complex64], output: &mut [Complex64]) -> FFTResult<()> {
279        // Use the standard executor
280        let size = self.plan.shape().iter().product::<usize>();
281        if size < self.config.parallel_threshold || !self.config.parallel_execution {
282            let executor = FftPlanExecutor::new(self.plan.clone());
283            return executor.execute(input, output);
284        }
285
286        // For larger FFTs, use parallel execution
287        // This is a simplified implementation - a real one would split the data
288        // and distribute subtasks across threads
289
290        // For now, we'll just offload the execution to a worker thread
291        let plan_clone = self.plan.clone();
292        let input_vec = input.to_vec(); // Copy input for thread safety
293
294        let result = self.worker_pool.execute(move || {
295            let mut output_vec = vec![Complex64::default(); input_vec.len()];
296            let executor = FftPlanExecutor::new(plan_clone);
297
298            executor
299                .execute(&input_vec, &mut output_vec)
300                .map_err(|e| format!("FFT execution error: {e}"))?;
301
302            Ok(output_vec)
303        });
304
305        // Process the result and copy to output
306        match result {
307            Ok(result_vec) => {
308                output.copy_from_slice(&result_vec);
309                Ok(())
310            }
311            Err(err) => Err(FFTError::ComputationError(err)),
312        }
313    }
314
315    /// Execute multiple FFTs in parallel
316    pub fn execute_batch(
317        &self,
318        inputs: &[&[Complex64]],
319        outputs: &mut [&mut [Complex64]],
320    ) -> FFTResult<Vec<Duration>> {
321        if inputs.len() != outputs.len() {
322            return Err(FFTError::ValueError(
323                "Input and output counts must match".to_string(),
324            ));
325        }
326
327        // Verify all inputs/outputs have the correct size
328        let expected_size = self.plan.shape().iter().product::<usize>();
329        for (i, input) in inputs.iter().enumerate() {
330            if input.len() != expected_size {
331                return Err(FFTError::ValueError(format!(
332                    "Input {} has wrong size: expected {}, got {}",
333                    i,
334                    expected_size,
335                    input.len()
336                )));
337            }
338
339            if outputs[i].len() != expected_size {
340                return Err(FFTError::ValueError(format!(
341                    "Output {} has wrong size: expected {}, got {}",
342                    i,
343                    expected_size,
344                    outputs[i].len()
345                )));
346            }
347        }
348
349        // For small batch size, process serially
350        if inputs.len() < 2 || !self.config.parallel_execution {
351            let mut times = Vec::with_capacity(inputs.len());
352            let executor = FftPlanExecutor::new(self.plan.clone());
353
354            for i in 0..inputs.len() {
355                let start = Instant::now();
356                executor.execute(inputs[i], outputs[i])?;
357                times.push(start.elapsed());
358            }
359
360            return Ok(times);
361        }
362
363        // Process batch in parallel
364        let plan_clone = self.plan.clone();
365
366        // Prepare futures for each FFT
367        let futures = inputs
368            .iter()
369            .zip(outputs.iter_mut())
370            .enumerate()
371            .map(|(idx, (input, output))| {
372                let plan = plan_clone.clone();
373                let input_vec = input.to_vec(); // Copy for thread safety
374                let output_len = output.len();
375
376                self.worker_pool.execute(move || {
377                    let mut local_output = vec![Complex64::default(); output_len];
378                    let executor = FftPlanExecutor::new(plan);
379
380                    let start = Instant::now();
381                    executor
382                        .execute(&input_vec, &mut local_output)
383                        .map_err(|e| format!("FFT execution error for batch {idx}: {e}"))?;
384                    let elapsed = start.elapsed();
385
386                    Ok((idx, local_output, elapsed))
387                })
388            })
389            .collect::<Vec<_>>();
390
391        // Collect results
392        let mut times = vec![Duration::from_secs(0); inputs.len()];
393
394        for result in futures {
395            match result {
396                Ok((idx, result_vec, elapsed)) => {
397                    outputs[idx].copy_from_slice(&result_vec);
398                    times[idx] = elapsed;
399                }
400                Err(err) => return Err(FFTError::ComputationError(err)),
401            }
402        }
403
404        Ok(times)
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411
412    #[test]
413    fn test_parallel_planner() {
414        let planner = ParallelPlanner::new(None);
415
416        // Create a plan
417        let plan = planner
418            .plan_fft(&[64], true, PlannerBackend::default())
419            .unwrap();
420
421        // Check the plan properties
422        assert_eq!(plan.shape(), &[64]);
423    }
424
425    #[test]
426    fn test_parallel_executor() {
427        let planner = ParallelPlanner::new(None);
428        let plan = planner
429            .plan_fft(&[64], true, PlannerBackend::default())
430            .unwrap();
431
432        let executor = ParallelExecutor::new(plan, None);
433
434        // Create test data
435        let input = vec![Complex64::new(1.0, 0.0); 64];
436        let mut output = vec![Complex64::default(); 64];
437
438        // Execute the plan
439        executor.execute(&input, &mut output).unwrap();
440
441        // Basic validation - output should not be all zeros
442        assert!(output.iter().any(|&val| val != Complex64::default()));
443    }
444
445    #[test]
446    fn test_batch_execution() {
447        let planner = ParallelPlanner::new(None);
448        let plan = planner
449            .plan_fft(&[32], true, PlannerBackend::default())
450            .unwrap();
451
452        let executor = ParallelExecutor::new(plan, None);
453
454        // Create multiple test inputs
455        let input1 = vec![Complex64::new(1.0, 0.0); 32];
456        let input2 = vec![Complex64::new(0.0, 1.0); 32];
457
458        let mut output1 = vec![Complex64::default(); 32];
459        let mut output2 = vec![Complex64::default(); 32];
460
461        let inputs = [&input1[..], &input2[..]];
462        let mut outputs = [&mut output1[..], &mut output2[..]];
463
464        // Execute batch
465        let times = executor.execute_batch(&inputs, &mut outputs).unwrap();
466
467        // Check that we got timing information
468        assert_eq!(times.len(), 2);
469
470        // Validate outputs
471        assert!(output1.iter().any(|&val| val != Complex64::default()));
472        assert!(output2.iter().any(|&val| val != Complex64::default()));
473
474        // Outputs should be different since inputs were different
475        assert_ne!(output1[0], output2[0]);
476    }
477}