1use crate::auto_tuning::{AutoTuneConfig, AutoTuner, FftVariant};
16use crate::backend::BackendContext;
17use crate::error::{FFTError, FFTResult};
18use crate::plan_serialization::{PlanMetrics, PlanSerializationManager};
19
20use scirs2_core::ndarray::{ArrayBase, Data, Dimension};
21use scirs2_core::numeric::Complex64;
22use std::collections::HashMap;
23use std::sync::{Arc, Mutex};
24use std::time::{Duration, Instant};
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
28pub enum PlanningStrategy {
29 AlwaysNew,
31 #[default]
33 CacheFirst,
34 SerializedFirst,
36 AutoTuned,
38}
39
40#[derive(Debug, Clone)]
42pub struct PlanningConfig {
43 pub strategy: PlanningStrategy,
45 pub measure_performance: bool,
47 pub serialized_db_path: Option<String>,
49 pub auto_tune_config: Option<AutoTuneConfig>,
51 pub max_cached_plans: usize,
53 pub max_plan_age: Duration,
55 pub parallel_planning: bool,
57}
58
59impl Default for PlanningConfig {
60 fn default() -> Self {
61 Self {
62 strategy: PlanningStrategy::default(),
63 measure_performance: true,
64 serialized_db_path: None,
65 auto_tune_config: None,
66 max_cached_plans: 128,
67 max_plan_age: Duration::from_secs(3600), parallel_planning: true,
69 }
70 }
71}
72
73#[derive(Clone)]
75pub struct FftPlan {
76 shape: Vec<usize>,
78 #[allow(dead_code)]
81 forward: bool,
82 internal_plan: Arc<dyn rustfft::Fft<f64>>,
84 metrics: Option<PlanMetrics>,
86 #[allow(dead_code)]
89 backend: PlannerBackend,
90 auto_tune_info: Option<FftVariant>,
92 last_used: Instant,
94 usage_count: usize,
96}
97
98impl FftPlan {
99 pub fn new(
101 shape: &[usize],
102 forward: bool,
103 planner: &mut rustfft::FftPlanner<f64>,
104 backend: PlannerBackend,
105 ) -> Self {
106 let size = shape.iter().product();
108
109 let internal_plan = if forward {
110 planner.plan_fft_forward(size)
111 } else {
112 planner.plan_fft_inverse(size)
113 };
114
115 Self {
116 shape: shape.to_vec(),
117 forward,
118 internal_plan,
119 metrics: None,
120 backend,
121 auto_tune_info: None,
122 last_used: Instant::now(),
123 usage_count: 0,
124 }
125 }
126
127 pub fn get_internal(&self) -> Arc<dyn rustfft::Fft<f64>> {
129 self.internal_plan.clone()
130 }
131
132 pub fn record_usage(&mut self) {
134 self.usage_count += 1;
135 self.last_used = Instant::now();
136 }
137
138 pub fn shape(&self) -> &[usize] {
140 &self.shape
141 }
142
143 pub fn is_compatible_with(&self, shape: &[usize]) -> bool {
145 if self.shape.len() != shape.len() {
146 return false;
147 }
148
149 self.shape.iter().zip(shape.iter()).all(|(&a, &b)| a == b)
150 }
151
152 pub fn metrics(&self) -> Option<&PlanMetrics> {
154 self.metrics.as_ref()
155 }
156
157 pub fn set_metrics(&mut self, metrics: PlanMetrics) {
159 self.metrics = Some(metrics);
160 }
161}
162
163pub struct AdvancedFftPlanner {
165 config: PlanningConfig,
167 cache: Arc<Mutex<HashMap<PlanKey, FftPlan>>>,
169 serialization_manager: Option<PlanSerializationManager>,
171 auto_tuner: Option<AutoTuner>,
173 internal_planner: rustfft::FftPlanner<f64>,
175}
176
177#[derive(Clone, Debug, Hash, PartialEq, Eq, Default)]
179pub enum PlannerBackend {
180 #[default]
182 RustFFT,
183 FFTW,
185 CUDA,
187 Custom(String),
189}
190
191#[derive(Clone, Debug, Hash, PartialEq, Eq)]
193struct PlanKey {
194 shape: Vec<usize>,
196 forward: bool,
198 backend: PlannerBackend,
200}
201
202impl AdvancedFftPlanner {
204 pub fn new() -> Self {
206 Self::with_config(PlanningConfig::default())
207 }
208
209 pub fn with_config(config: PlanningConfig) -> Self {
211 let serialization_manager = config
212 .serialized_db_path
213 .as_ref()
214 .map(PlanSerializationManager::new);
215
216 let auto_tuner = if config.strategy == PlanningStrategy::AutoTuned {
217 let tuner = AutoTuner::new();
218 if let Some(_autoconfig) = &config.auto_tune_config {
219 }
222 Some(tuner)
223 } else {
224 None
225 };
226
227 Self {
228 config,
229 cache: Arc::new(Mutex::new(HashMap::new())),
230 serialization_manager,
231 auto_tuner,
232 internal_planner: rustfft::FftPlanner::new(),
233 }
234 }
235
236 pub fn clear_cache(&self) {
238 if let Ok(mut cache) = self.cache.lock() {
239 cache.clear();
240 }
241 }
242
243 pub fn plan_fft(
245 &mut self,
246 shape: &[usize],
247 forward: bool,
248 backend: PlannerBackend,
249 ) -> FFTResult<Arc<FftPlan>> {
250 let key = PlanKey {
251 shape: shape.to_vec(),
252 forward,
253 backend: backend.clone(),
254 };
255
256 if self.config.strategy == PlanningStrategy::CacheFirst
258 || self.config.strategy == PlanningStrategy::SerializedFirst
259 {
260 if let Ok(mut cache) = self.cache.lock() {
261 if let Some(plan) = cache.get_mut(&key) {
262 plan.record_usage();
263 return Ok(Arc::new(plan.clone()));
264 }
265 }
266 }
267
268 if self.config.strategy == PlanningStrategy::SerializedFirst {
270 if let Some(manager) = &self.serialization_manager {
271 let size = shape.iter().product();
274
275 if manager.plan_exists(size, forward) {
276 if let Some((_plan_info, metrics)) =
277 manager.get_best_plan_metrics(size, forward)
278 {
279 let mut plan = self.create_new_plan(shape, forward, backend.clone())?;
281 plan.set_metrics(metrics.clone());
282
283 if let Ok(mut cache) = self.cache.lock() {
285 cache.insert(key, plan.clone());
286 }
287
288 return Ok(Arc::new(plan));
289 }
290 }
291 }
292 }
293
294 if self.config.strategy == PlanningStrategy::AutoTuned {
296 if let Some(tuner) = &self.auto_tuner {
297 let size = shape.iter().product();
299 let variant = tuner.get_best_variant(size, forward);
300
301 let mut plan = self.create_new_plan(shape, forward, backend)?;
303 plan.auto_tune_info = Some(variant);
304
305 if let Ok(mut cache) = self.cache.lock() {
307 cache.insert(key, plan.clone());
308 }
309
310 return Ok(Arc::new(plan));
311 }
312 }
313
314 let plan = self.create_new_plan(shape, forward, backend)?;
316
317 if self.config.strategy != PlanningStrategy::AlwaysNew {
319 if let Ok(mut cache) = self.cache.lock() {
320 if cache.len() >= self.config.max_cached_plans {
322 self.evict_old_entries(&mut cache);
323 }
324
325 cache.insert(key, plan.clone());
326 }
327 }
328
329 Ok(Arc::new(plan))
330 }
331
332 fn create_new_plan(
334 &mut self,
335 shape: &[usize],
336 forward: bool,
337 backend: PlannerBackend,
338 ) -> FFTResult<FftPlan> {
339 let start = Instant::now();
341
342 let plan = FftPlan::new(shape, forward, &mut self.internal_planner, backend);
343
344 let elapsed = start.elapsed();
345
346 if self.config.measure_performance {
348 if let Some(manager) = &self.serialization_manager {
349 let size = shape.iter().product();
352
353 let plan_info = manager.create_plan_info(size, forward);
354 let _ = manager.record_plan_usage(&plan_info, elapsed.as_nanos() as u64);
355 }
356 }
357
358 Ok(plan)
359 }
360
361 fn evict_old_entries(&self, cache: &mut HashMap<PlanKey, FftPlan>) {
363 let max_age = self.config.max_plan_age;
365 cache.retain(|_, v| v.last_used.elapsed() <= max_age);
366
367 let max_entries = self.config.max_cached_plans;
369 while cache.len() >= max_entries {
370 if let Some((key_to_remove_, _)) = cache
371 .iter()
372 .min_by_key(|(_, v)| (v.last_used, v.usage_count))
373 .map(|(k_, _)| (k_.clone(), ()))
374 {
375 cache.remove(&key_to_remove_);
376 } else {
377 break;
378 }
379 }
380 }
381
382 pub fn plan_fft_1d<S, D>(
384 &mut self,
385 arr: &ArrayBase<S, D>,
386 forward: bool,
387 ) -> FFTResult<Arc<FftPlan>>
388 where
389 S: Data<Elem = Complex64>,
390 D: Dimension,
391 {
392 let shape = arr.shape().to_vec();
393 self.plan_fft(&shape, forward, PlannerBackend::default())
394 }
395
396 pub fn plan_fft_2d<S, D>(
398 &mut self,
399 arr: &ArrayBase<S, D>,
400 forward: bool,
401 ) -> FFTResult<Arc<FftPlan>>
402 where
403 S: Data<Elem = Complex64>,
404 D: Dimension,
405 {
406 if arr.ndim() != 2 {
407 return Err(FFTError::ValueError(
408 "Input array must be 2-dimensional".to_string(),
409 ));
410 }
411
412 let shape = arr.shape().to_vec();
413 self.plan_fft(&shape, forward, PlannerBackend::default())
414 }
415
416 pub fn plan_fft_nd<S, D>(
418 &mut self,
419 arr: &ArrayBase<S, D>,
420 forward: bool,
421 ) -> FFTResult<Arc<FftPlan>>
422 where
423 S: Data<Elem = Complex64>,
424 D: Dimension,
425 {
426 let shape = arr.shape().to_vec();
427 self.plan_fft(&shape, forward, PlannerBackend::default())
428 }
429
430 pub fn precompute_commonsizes(&mut self, sizes: &[&[usize]]) -> FFTResult<()> {
432 for &shape in sizes {
433 let _ = self.plan_fft(shape, true, PlannerBackend::default())?;
435 let _ = self.plan_fft(shape, false, PlannerBackend::default())?;
436 }
437
438 Ok(())
439 }
440
441 pub fn save_plans(&self) -> FFTResult<()> {
443 if let Some(manager) = &self.serialization_manager {
444 manager.save_database()?;
445 }
446
447 Ok(())
448 }
449}
450
451impl Default for AdvancedFftPlanner {
452 fn default() -> Self {
453 Self::new()
454 }
455}
456
457static GLOBAL_FFT_PLANNER: std::sync::OnceLock<Mutex<AdvancedFftPlanner>> =
459 std::sync::OnceLock::new();
460
461#[allow(dead_code)]
463pub fn get_global_planner() -> &'static Mutex<AdvancedFftPlanner> {
464 GLOBAL_FFT_PLANNER.get_or_init(|| Mutex::new(AdvancedFftPlanner::new()))
465}
466
467#[allow(dead_code)]
469pub fn init_global_planner(config: PlanningConfig) -> Result<(), &'static str> {
470 GLOBAL_FFT_PLANNER
471 .set(Mutex::new(AdvancedFftPlanner::with_config(config)))
472 .map_err(|_| "Global FFT planner already initialized")
473}
474
475pub struct FftPlanExecutor {
477 plan: Arc<FftPlan>,
479 #[allow(dead_code)]
482 context: Option<BackendContext>,
483}
484
485impl FftPlanExecutor {
486 pub fn new(plan: Arc<FftPlan>) -> Self {
488 Self {
489 plan,
490 context: None,
491 }
492 }
493
494 pub fn with_context(plan: Arc<FftPlan>, context: BackendContext) -> Self {
496 Self {
497 plan,
498 context: Some(context),
499 }
500 }
501
502 pub fn execute(&self, input: &[Complex64], output: &mut [Complex64]) -> FFTResult<()> {
504 let internal_plan = self.plan.get_internal();
509
510 let expected_size: usize = self.plan.shape().iter().product();
512 if input.len() != expected_size || output.len() != expected_size {
513 return Err(FFTError::ValueError(format!(
514 "Buffer size mismatch: expected {}, got input={}, output={}",
515 expected_size,
516 input.len(),
517 output.len()
518 )));
519 }
520
521 output.copy_from_slice(input);
524
525 let mut scratch = vec![Complex64::default(); internal_plan.get_inplace_scratch_len()];
527 internal_plan.process_with_scratch(output, &mut scratch);
528
529 Ok(())
530 }
531
532 pub fn execute_inplace(&self, buffer: &mut [Complex64]) -> FFTResult<()> {
534 let internal_plan = self.plan.get_internal();
536
537 let expected_size: usize = self.plan.shape().iter().product();
539 if buffer.len() != expected_size {
540 return Err(FFTError::ValueError(format!(
541 "Buffer size mismatch: expected {}, got {}",
542 expected_size,
543 buffer.len()
544 )));
545 }
546
547 let mut scratch = vec![Complex64::default(); internal_plan.get_inplace_scratch_len()];
549 internal_plan.process_with_scratch(buffer, &mut scratch);
550
551 Ok(())
552 }
553
554 pub fn plan(&self) -> &FftPlan {
556 &self.plan
557 }
558}
559
560pub struct PlanBuilder {
562 config: PlanningConfig,
564 shape: Option<Vec<usize>>,
566 forward: bool,
568 backend: PlannerBackend,
570}
571
572impl PlanBuilder {
573 pub fn new() -> Self {
575 Self {
576 config: PlanningConfig::default(),
577 shape: None,
578 forward: true,
579 backend: PlannerBackend::default(),
580 }
581 }
582
583 pub fn shape(mut self, shape: &[usize]) -> Self {
585 self.shape = Some(shape.to_vec());
586 self
587 }
588
589 pub fn forward(mut self, forward: bool) -> Self {
591 self.forward = forward;
592 self
593 }
594
595 pub fn backend(mut self, backend: PlannerBackend) -> Self {
597 self.backend = backend;
598 self
599 }
600
601 pub fn strategy(mut self, strategy: PlanningStrategy) -> Self {
603 self.config.strategy = strategy;
604 self
605 }
606
607 pub fn measure_performance(mut self, enable: bool) -> Self {
609 self.config.measure_performance = enable;
610 self
611 }
612
613 pub fn serialized_db_path(mut self, path: &str) -> Self {
615 self.config.serialized_db_path = Some(path.to_string());
616 self
617 }
618
619 pub fn auto_tune_config(mut self, config: AutoTuneConfig) -> Self {
621 self.config.auto_tune_config = Some(config);
622 self
623 }
624
625 pub fn max_cached_plans(mut self, max: usize) -> Self {
627 self.config.max_cached_plans = max;
628 self
629 }
630
631 pub fn max_plan_age(mut self, age: Duration) -> Self {
633 self.config.max_plan_age = age;
634 self
635 }
636
637 pub fn parallel_planning(mut self, enable: bool) -> Self {
639 self.config.parallel_planning = enable;
640 self
641 }
642
643 pub fn build(self) -> FFTResult<Arc<FftPlan>> {
645 let shape = self
646 .shape
647 .ok_or_else(|| FFTError::ValueError("Cannot build plan without shape".to_string()))?;
648
649 let mut planner = AdvancedFftPlanner::with_config(self.config);
650 planner.plan_fft(&shape, self.forward, self.backend)
651 }
652}
653
654impl Default for PlanBuilder {
655 fn default() -> Self {
656 Self::new()
657 }
658}
659
660#[allow(dead_code)]
674pub fn plan_ahead_of_time(sizes: &[usize], dbpath: Option<&str>) -> FFTResult<()> {
675 let mut config = PlanningConfig::default();
676 if let Some(_path) = dbpath {
677 config.serialized_db_path = Some(_path.to_string());
678 config.strategy = PlanningStrategy::SerializedFirst;
679 }
680
681 let mut planner = AdvancedFftPlanner::with_config(config);
682
683 let shapes: Vec<Vec<usize>> = sizes.iter().map(|&s| vec![s]).collect();
685
686 for shape in shapes {
687 let _ = planner.plan_fft(&shape, true, PlannerBackend::default())?;
689 let _ = planner.plan_fft(&shape, false, PlannerBackend::default())?;
690 }
691
692 planner.save_plans()?;
694
695 Ok(())
696}
697
698#[cfg(test)]
699mod tests {
700 use super::*;
701 use scirs2_core::numeric::Complex64;
702 use tempfile::tempdir;
703
704 #[test]
705 fn test_plan_basic() {
706 let mut planner = AdvancedFftPlanner::new();
707 let shape = vec![8, 8];
708
709 let plan = planner
711 .plan_fft(&shape, true, PlannerBackend::default())
712 .unwrap();
713
714 assert_eq!(plan.shape(), &shape);
716 assert!(plan.is_compatible_with(&shape));
717
718 assert!(!plan.is_compatible_with(&[16, 16]));
720 }
721
722 #[test]
723 fn test_plan_executor() {
724 let mut planner = AdvancedFftPlanner::new();
725 let shape = vec![8];
726
727 let plan = planner
729 .plan_fft(&shape, true, PlannerBackend::default())
730 .unwrap();
731
732 let executor = FftPlanExecutor::new(plan);
734
735 let input = vec![
737 Complex64::new(1.0, 0.0),
738 Complex64::new(0.0, 0.0),
739 Complex64::new(0.0, 0.0),
740 Complex64::new(0.0, 0.0),
741 Complex64::new(0.0, 0.0),
742 Complex64::new(0.0, 0.0),
743 Complex64::new(0.0, 0.0),
744 Complex64::new(0.0, 0.0),
745 ];
746 let mut output = vec![Complex64::default(); 8];
747
748 executor.execute(&input, &mut output).unwrap();
750
751 for val in &output {
754 let magnitude = (val.re.powi(2) + val.im.powi(2)).sqrt();
756 assert!((magnitude - 1.0).abs() < 1e-10);
757 }
758 }
759
760 #[test]
761 fn test_plan_builder() {
762 let builder = PlanBuilder::new()
763 .shape(&[16])
764 .forward(true)
765 .strategy(PlanningStrategy::AlwaysNew)
766 .measure_performance(true);
767
768 let plan = builder.build().unwrap();
769
770 assert_eq!(plan.shape(), &[16]);
771 }
772
773 #[test]
774 fn test_serialization() {
775 let temp_dir = tempdir().unwrap();
777 let db_path = temp_dir.path().join("test_plan_db.json");
778
779 let mut config = PlanningConfig::default();
781 config.serialized_db_path = Some(db_path.to_str().unwrap().to_string());
782 config.strategy = PlanningStrategy::SerializedFirst;
783
784 let mut planner = AdvancedFftPlanner::with_config(config);
785
786 let shape = vec![32];
788 let _ = planner
789 .plan_fft(&shape, true, PlannerBackend::default())
790 .unwrap();
791
792 planner.save_plans().unwrap();
794
795 assert!(db_path.exists());
797 }
798
799 #[test]
800 fn test_global_planner() {
801 let planner = get_global_planner();
803
804 let mut planner_guard = planner.lock().unwrap();
806 let shape = vec![64];
807 let plan = planner_guard
808 .plan_fft(&shape, true, PlannerBackend::default())
809 .unwrap();
810
811 assert_eq!(plan.shape(), &shape);
812 }
813
814 #[test]
815 fn test_ahead_of_time_planning() {
816 let temp_dir = tempdir().unwrap();
818 let db_path = temp_dir.path().join("ahead_of_time.json");
819
820 let sizes = [8, 16, 32, 64];
822 plan_ahead_of_time(&sizes, Some(db_path.to_str().unwrap())).unwrap();
823
824 assert!(db_path.exists());
826 }
827}