1use crate::auto_tuning::{AutoTuneConfig, AutoTuner, FftVariant};
16use crate::backend::BackendContext;
17use crate::error::{FFTError, FFTResult};
18#[cfg(feature = "oxifft")]
19use crate::oxifft_plan_cache;
20use crate::plan_serialization::{PlanMetrics, PlanSerializationManager};
21#[cfg(feature = "oxifft")]
22use oxifft::{Complex as OxiComplex, Direction};
23
24use scirs2_core::ndarray::{ArrayBase, Data, Dimension};
25use scirs2_core::numeric::Complex64;
26use std::collections::HashMap;
27use std::sync::{Arc, Mutex};
28use std::time::{Duration, Instant};
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
32pub enum PlanningStrategy {
33 AlwaysNew,
35 #[default]
37 CacheFirst,
38 SerializedFirst,
40 AutoTuned,
42}
43
44#[derive(Debug, Clone)]
46pub struct PlanningConfig {
47 pub strategy: PlanningStrategy,
49 pub measure_performance: bool,
51 pub serialized_db_path: Option<String>,
53 pub auto_tune_config: Option<AutoTuneConfig>,
55 pub max_cached_plans: usize,
57 pub max_plan_age: Duration,
59 pub parallel_planning: bool,
61}
62
63impl Default for PlanningConfig {
64 fn default() -> Self {
65 Self {
66 strategy: PlanningStrategy::default(),
67 measure_performance: true,
68 serialized_db_path: None,
69 auto_tune_config: None,
70 max_cached_plans: 128,
71 max_plan_age: Duration::from_secs(3600), parallel_planning: true,
73 }
74 }
75}
76
77#[derive(Clone)]
79pub struct FftPlan {
80 shape: Vec<usize>,
82 forward: bool,
84 metrics: Option<PlanMetrics>,
86 #[allow(dead_code)]
89 backend: PlannerBackend,
90 pub(crate) auto_tune_info: Option<FftVariant>,
92 pub(crate) last_used: Instant,
94 pub(crate) usage_count: usize,
96}
97
98impl FftPlan {
99 pub fn new(shape: &[usize], forward: bool, backend: PlannerBackend) -> Self {
102 Self {
103 shape: shape.to_vec(),
104 forward,
105 metrics: None,
106 backend,
107 auto_tune_info: None,
108 last_used: Instant::now(),
109 usage_count: 0,
110 }
111 }
112
113 pub fn is_forward(&self) -> bool {
115 self.forward
116 }
117
118 pub fn record_usage(&mut self) {
120 self.usage_count += 1;
121 self.last_used = Instant::now();
122 }
123
124 pub fn shape(&self) -> &[usize] {
126 &self.shape
127 }
128
129 pub fn is_compatible_with(&self, shape: &[usize]) -> bool {
131 if self.shape.len() != shape.len() {
132 return false;
133 }
134
135 self.shape.iter().zip(shape.iter()).all(|(&a, &b)| a == b)
136 }
137
138 pub fn metrics(&self) -> Option<&PlanMetrics> {
140 self.metrics.as_ref()
141 }
142
143 pub fn set_metrics(&mut self, metrics: PlanMetrics) {
145 self.metrics = Some(metrics);
146 }
147}
148
149pub struct AdvancedFftPlanner {
152 config: PlanningConfig,
154 cache: Arc<Mutex<HashMap<PlanKey, FftPlan>>>,
156 serialization_manager: Option<PlanSerializationManager>,
158 auto_tuner: Option<AutoTuner>,
160}
161
162#[derive(Clone, Debug, Hash, PartialEq, Eq, Default)]
164pub enum PlannerBackend {
165 #[default]
167 OxiFFT,
168 CUDA,
170 Custom(String),
172}
173
174#[derive(Clone, Debug, Hash, PartialEq, Eq)]
176struct PlanKey {
177 shape: Vec<usize>,
179 forward: bool,
181 backend: PlannerBackend,
183}
184
185impl AdvancedFftPlanner {
187 pub fn new() -> Self {
189 Self::with_config(PlanningConfig::default())
190 }
191
192 pub fn with_config(config: PlanningConfig) -> Self {
194 let serialization_manager = config
195 .serialized_db_path
196 .as_ref()
197 .map(PlanSerializationManager::new);
198
199 let auto_tuner = if config.strategy == PlanningStrategy::AutoTuned {
200 let tuner = AutoTuner::new();
201 if let Some(_autoconfig) = &config.auto_tune_config {
202 }
205 Some(tuner)
206 } else {
207 None
208 };
209
210 Self {
211 config,
212 cache: Arc::new(Mutex::new(HashMap::new())),
213 serialization_manager,
214 auto_tuner,
215 }
216 }
217
218 pub fn clear_cache(&self) {
220 if let Ok(mut cache) = self.cache.lock() {
221 cache.clear();
222 }
223 }
224
225 pub fn plan_fft(
227 &mut self,
228 shape: &[usize],
229 forward: bool,
230 backend: PlannerBackend,
231 ) -> FFTResult<Arc<FftPlan>> {
232 let key = PlanKey {
233 shape: shape.to_vec(),
234 forward,
235 backend: backend.clone(),
236 };
237
238 if self.config.strategy == PlanningStrategy::CacheFirst
240 || self.config.strategy == PlanningStrategy::SerializedFirst
241 {
242 if let Ok(mut cache) = self.cache.lock() {
243 if let Some(plan) = cache.get_mut(&key) {
244 plan.record_usage();
245 return Ok(Arc::new(plan.clone()));
246 }
247 }
248 }
249
250 if self.config.strategy == PlanningStrategy::SerializedFirst {
252 if let Some(manager) = &self.serialization_manager {
253 let size = shape.iter().product();
256
257 if manager.plan_exists(size, forward) {
258 if let Some((_plan_info, metrics)) =
259 manager.get_best_plan_metrics(size, forward)
260 {
261 let mut plan = self.create_new_plan(shape, forward, backend.clone())?;
263 plan.set_metrics(metrics.clone());
264
265 if let Ok(mut cache) = self.cache.lock() {
267 cache.insert(key, plan.clone());
268 }
269
270 return Ok(Arc::new(plan));
271 }
272 }
273 }
274 }
275
276 if self.config.strategy == PlanningStrategy::AutoTuned {
278 if let Some(tuner) = &self.auto_tuner {
279 let size = shape.iter().product();
281 let variant = tuner.get_best_variant(size, forward);
282
283 let mut plan = self.create_new_plan(shape, forward, backend)?;
285 plan.auto_tune_info = Some(variant);
286
287 if let Ok(mut cache) = self.cache.lock() {
289 cache.insert(key, plan.clone());
290 }
291
292 return Ok(Arc::new(plan));
293 }
294 }
295
296 let plan = self.create_new_plan(shape, forward, backend)?;
298
299 if self.config.strategy != PlanningStrategy::AlwaysNew {
301 if let Ok(mut cache) = self.cache.lock() {
302 if cache.len() >= self.config.max_cached_plans {
304 self.evict_old_entries(&mut cache);
305 }
306
307 cache.insert(key, plan.clone());
308 }
309 }
310
311 Ok(Arc::new(plan))
312 }
313
314 fn create_new_plan(
316 &mut self,
317 shape: &[usize],
318 forward: bool,
319 backend: PlannerBackend,
320 ) -> FFTResult<FftPlan> {
321 let start = Instant::now();
323
324 let plan = FftPlan::new(shape, forward, backend);
325
326 let elapsed = start.elapsed();
327
328 if self.config.measure_performance {
330 if let Some(manager) = &self.serialization_manager {
331 let size = shape.iter().product();
334
335 let plan_info = manager.create_plan_info(size, forward);
336 let _ = manager.record_plan_usage(&plan_info, elapsed.as_nanos() as u64);
337 }
338 }
339
340 Ok(plan)
341 }
342
343 fn evict_old_entries(&self, cache: &mut HashMap<PlanKey, FftPlan>) {
345 let max_age = self.config.max_plan_age;
347 cache.retain(|_, v| v.last_used.elapsed() <= max_age);
348
349 let max_entries = self.config.max_cached_plans;
351 while cache.len() >= max_entries {
352 if let Some((key_to_remove_, _)) = cache
353 .iter()
354 .min_by_key(|(_, v)| (v.last_used, v.usage_count))
355 .map(|(k_, _)| (k_.clone(), ()))
356 {
357 cache.remove(&key_to_remove_);
358 } else {
359 break;
360 }
361 }
362 }
363
364 pub fn plan_fft_1d<S, D>(
366 &mut self,
367 arr: &ArrayBase<S, D>,
368 forward: bool,
369 ) -> FFTResult<Arc<FftPlan>>
370 where
371 S: Data<Elem = Complex64>,
372 D: Dimension,
373 {
374 let shape = arr.shape().to_vec();
375 self.plan_fft(&shape, forward, PlannerBackend::default())
376 }
377
378 pub fn plan_fft_2d<S, D>(
380 &mut self,
381 arr: &ArrayBase<S, D>,
382 forward: bool,
383 ) -> FFTResult<Arc<FftPlan>>
384 where
385 S: Data<Elem = Complex64>,
386 D: Dimension,
387 {
388 if arr.ndim() != 2 {
389 return Err(FFTError::ValueError(
390 "Input array must be 2-dimensional".to_string(),
391 ));
392 }
393
394 let shape = arr.shape().to_vec();
395 self.plan_fft(&shape, forward, PlannerBackend::default())
396 }
397
398 pub fn plan_fft_nd<S, D>(
400 &mut self,
401 arr: &ArrayBase<S, D>,
402 forward: bool,
403 ) -> FFTResult<Arc<FftPlan>>
404 where
405 S: Data<Elem = Complex64>,
406 D: Dimension,
407 {
408 let shape = arr.shape().to_vec();
409 self.plan_fft(&shape, forward, PlannerBackend::default())
410 }
411
412 pub fn precompute_commonsizes(&mut self, sizes: &[&[usize]]) -> FFTResult<()> {
414 for &shape in sizes {
415 let _ = self.plan_fft(shape, true, PlannerBackend::default())?;
417 let _ = self.plan_fft(shape, false, PlannerBackend::default())?;
418 }
419
420 Ok(())
421 }
422
423 pub fn save_plans(&self) -> FFTResult<()> {
425 if let Some(manager) = &self.serialization_manager {
426 manager.save_database()?;
427 }
428
429 Ok(())
430 }
431}
432
433impl Default for AdvancedFftPlanner {
434 fn default() -> Self {
435 Self::new()
436 }
437}
438
439static GLOBAL_FFT_PLANNER: std::sync::OnceLock<Mutex<AdvancedFftPlanner>> =
441 std::sync::OnceLock::new();
442
443#[allow(dead_code)]
445pub fn get_global_planner() -> &'static Mutex<AdvancedFftPlanner> {
446 GLOBAL_FFT_PLANNER.get_or_init(|| Mutex::new(AdvancedFftPlanner::new()))
447}
448
449#[allow(dead_code)]
451pub fn init_global_planner(config: PlanningConfig) -> Result<(), &'static str> {
452 GLOBAL_FFT_PLANNER
453 .set(Mutex::new(AdvancedFftPlanner::with_config(config)))
454 .map_err(|_| "Global FFT planner already initialized")
455}
456
457pub struct FftPlanExecutor {
459 plan: Arc<FftPlan>,
461 #[allow(dead_code)]
464 context: Option<BackendContext>,
465}
466
467impl FftPlanExecutor {
468 pub fn new(plan: Arc<FftPlan>) -> Self {
470 Self {
471 plan,
472 context: None,
473 }
474 }
475
476 pub fn with_context(plan: Arc<FftPlan>, context: BackendContext) -> Self {
478 Self {
479 plan,
480 context: Some(context),
481 }
482 }
483
484 pub fn execute(&self, input: &[Complex64], output: &mut [Complex64]) -> FFTResult<()> {
487 let expected_size: usize = self.plan.shape().iter().product();
488 if input.len() != expected_size || output.len() != expected_size {
489 return Err(FFTError::ValueError(format!(
490 "Buffer size mismatch: expected {}, got input={}, output={}",
491 expected_size,
492 input.len(),
493 output.len()
494 )));
495 }
496
497 #[cfg(feature = "oxifft")]
498 {
499 let direction = if self.plan.is_forward() {
500 Direction::Forward
501 } else {
502 Direction::Backward
503 };
504 let input_oxi: Vec<OxiComplex<f64>> =
505 input.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
506 let mut output_oxi: Vec<OxiComplex<f64>> =
507 vec![OxiComplex::new(0.0, 0.0); expected_size];
508 oxifft_plan_cache::execute_c2c(&input_oxi, &mut output_oxi, direction)?;
509 for (i, c) in output_oxi.iter().enumerate() {
510 output[i] = Complex64::new(c.re, c.im);
511 }
512 Ok(())
513 }
514
515 #[cfg(not(feature = "oxifft"))]
516 {
517 Err(FFTError::ComputationError(
518 "No FFT backend available. Enable 'oxifft' feature.".to_string(),
519 ))
520 }
521 }
522
523 pub fn execute_inplace(&self, buffer: &mut [Complex64]) -> FFTResult<()> {
526 let expected_size: usize = self.plan.shape().iter().product();
527 if buffer.len() != expected_size {
528 return Err(FFTError::ValueError(format!(
529 "Buffer size mismatch: expected {}, got {}",
530 expected_size,
531 buffer.len()
532 )));
533 }
534
535 #[cfg(feature = "oxifft")]
536 {
537 let direction = if self.plan.is_forward() {
538 Direction::Forward
539 } else {
540 Direction::Backward
541 };
542 let input_oxi: Vec<OxiComplex<f64>> =
543 buffer.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
544 let mut output_oxi: Vec<OxiComplex<f64>> =
545 vec![OxiComplex::new(0.0, 0.0); expected_size];
546 oxifft_plan_cache::execute_c2c(&input_oxi, &mut output_oxi, direction)?;
547 for (i, c) in output_oxi.iter().enumerate() {
548 buffer[i] = Complex64::new(c.re, c.im);
549 }
550 Ok(())
551 }
552
553 #[cfg(not(feature = "oxifft"))]
554 {
555 Err(FFTError::ComputationError(
556 "No FFT backend available. Enable 'oxifft' feature.".to_string(),
557 ))
558 }
559 }
560
561 pub fn plan(&self) -> &FftPlan {
563 &self.plan
564 }
565}
566
567pub struct PlanBuilder {
569 config: PlanningConfig,
571 shape: Option<Vec<usize>>,
573 forward: bool,
575 backend: PlannerBackend,
577}
578
579impl PlanBuilder {
580 pub fn new() -> Self {
582 Self {
583 config: PlanningConfig::default(),
584 shape: None,
585 forward: true,
586 backend: PlannerBackend::default(),
587 }
588 }
589
590 pub fn shape(mut self, shape: &[usize]) -> Self {
592 self.shape = Some(shape.to_vec());
593 self
594 }
595
596 pub fn forward(mut self, forward: bool) -> Self {
598 self.forward = forward;
599 self
600 }
601
602 pub fn backend(mut self, backend: PlannerBackend) -> Self {
604 self.backend = backend;
605 self
606 }
607
608 pub fn strategy(mut self, strategy: PlanningStrategy) -> Self {
610 self.config.strategy = strategy;
611 self
612 }
613
614 pub fn measure_performance(mut self, enable: bool) -> Self {
616 self.config.measure_performance = enable;
617 self
618 }
619
620 pub fn serialized_db_path(mut self, path: &str) -> Self {
622 self.config.serialized_db_path = Some(path.to_string());
623 self
624 }
625
626 pub fn auto_tune_config(mut self, config: AutoTuneConfig) -> Self {
628 self.config.auto_tune_config = Some(config);
629 self
630 }
631
632 pub fn max_cached_plans(mut self, max: usize) -> Self {
634 self.config.max_cached_plans = max;
635 self
636 }
637
638 pub fn max_plan_age(mut self, age: Duration) -> Self {
640 self.config.max_plan_age = age;
641 self
642 }
643
644 pub fn parallel_planning(mut self, enable: bool) -> Self {
646 self.config.parallel_planning = enable;
647 self
648 }
649
650 pub fn build(self) -> FFTResult<Arc<FftPlan>> {
652 let shape = self
653 .shape
654 .ok_or_else(|| FFTError::ValueError("Cannot build plan without shape".to_string()))?;
655
656 let mut planner = AdvancedFftPlanner::with_config(self.config);
657 planner.plan_fft(&shape, self.forward, self.backend)
658 }
659}
660
661impl Default for PlanBuilder {
662 fn default() -> Self {
663 Self::new()
664 }
665}
666
667#[allow(dead_code)]
681pub fn plan_ahead_of_time(sizes: &[usize], dbpath: Option<&str>) -> FFTResult<()> {
682 let mut config = PlanningConfig::default();
683 if let Some(_path) = dbpath {
684 config.serialized_db_path = Some(_path.to_string());
685 config.strategy = PlanningStrategy::SerializedFirst;
686 }
687
688 let mut planner = AdvancedFftPlanner::with_config(config);
689
690 let shapes: Vec<Vec<usize>> = sizes.iter().map(|&s| vec![s]).collect();
692
693 for shape in shapes {
694 let _ = planner.plan_fft(&shape, true, PlannerBackend::default())?;
696 let _ = planner.plan_fft(&shape, false, PlannerBackend::default())?;
697 }
698
699 planner.save_plans()?;
701
702 Ok(())
703}
704
705#[cfg(test)]
706mod tests {
707 use super::*;
708 use scirs2_core::numeric::Complex64;
709 use tempfile::tempdir;
710
711 #[test]
712 fn test_plan_basic() {
713 let mut planner = AdvancedFftPlanner::new();
714 let shape = vec![8, 8];
715
716 let plan = planner
718 .plan_fft(&shape, true, PlannerBackend::default())
719 .expect("Operation failed");
720
721 assert_eq!(plan.shape(), &shape);
723 assert!(plan.is_compatible_with(&shape));
724
725 assert!(!plan.is_compatible_with(&[16, 16]));
727 }
728
729 #[test]
730 fn test_plan_executor() {
731 let mut planner = AdvancedFftPlanner::new();
732 let shape = vec![8];
733
734 let plan = planner
736 .plan_fft(&shape, true, PlannerBackend::default())
737 .expect("Operation failed");
738
739 let executor = FftPlanExecutor::new(plan);
741
742 let input = vec![
744 Complex64::new(1.0, 0.0),
745 Complex64::new(0.0, 0.0),
746 Complex64::new(0.0, 0.0),
747 Complex64::new(0.0, 0.0),
748 Complex64::new(0.0, 0.0),
749 Complex64::new(0.0, 0.0),
750 Complex64::new(0.0, 0.0),
751 Complex64::new(0.0, 0.0),
752 ];
753 let mut output = vec![Complex64::default(); 8];
754
755 executor
757 .execute(&input, &mut output)
758 .expect("Operation failed");
759
760 for val in &output {
763 let magnitude = (val.re.powi(2) + val.im.powi(2)).sqrt();
765 assert!((magnitude - 1.0).abs() < 1e-10);
766 }
767 }
768
769 #[test]
770 fn test_plan_builder() {
771 let builder = PlanBuilder::new()
772 .shape(&[16])
773 .forward(true)
774 .strategy(PlanningStrategy::AlwaysNew)
775 .measure_performance(true);
776
777 let plan = builder.build().expect("Operation failed");
778
779 assert_eq!(plan.shape(), &[16]);
780 }
781
782 #[test]
783 fn test_serialization() {
784 let temp_dir = tempdir().expect("Operation failed");
786 let db_path = temp_dir.path().join("test_plan_db.json");
787
788 let mut config = PlanningConfig::default();
790 config.serialized_db_path = Some(db_path.to_str().expect("Operation failed").to_string());
791 config.strategy = PlanningStrategy::SerializedFirst;
792
793 let mut planner = AdvancedFftPlanner::with_config(config);
794
795 let shape = vec![32];
797 let _ = planner
798 .plan_fft(&shape, true, PlannerBackend::default())
799 .expect("Operation failed");
800
801 planner.save_plans().expect("Operation failed");
803
804 assert!(db_path.exists());
806 }
807
808 #[test]
809 fn test_global_planner() {
810 let planner = get_global_planner();
812
813 let mut planner_guard = planner.lock().expect("Operation failed");
815 let shape = vec![64];
816 let plan = planner_guard
817 .plan_fft(&shape, true, PlannerBackend::default())
818 .expect("Operation failed");
819
820 assert_eq!(plan.shape(), &shape);
821 }
822
823 #[test]
824 fn test_ahead_of_time_planning() {
825 let temp_dir = tempdir().expect("Operation failed");
827 let db_path = temp_dir.path().join("ahead_of_time.json");
828
829 let sizes = [8, 16, 32, 64];
831 plan_ahead_of_time(&sizes, Some(db_path.to_str().expect("Operation failed")))
832 .expect("Operation failed");
833
834 assert!(db_path.exists());
836 }
837}