1#[cfg(feature = "oxifft")]
12use crate::oxifft_plan_cache;
13#[cfg(feature = "oxifft")]
14use oxifft::{Complex as OxiComplex, Direction};
15#[cfg(feature = "rustfft-backend")]
16use rustfft::FftPlanner;
17use scirs2_core::numeric::Complex64;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::fs::{self, File};
21use std::io::{BufReader, BufWriter};
22use std::path::{Path, PathBuf};
23use std::time::Instant;
24
25use crate::error::{FFTError, FFTResult};
26use crate::plan_serialization::PlanSerializationManager;
27
28#[derive(Debug, Clone)]
30pub struct SizeRange {
31 pub min: usize,
33 pub max: usize,
35 pub step: SizeStep,
37}
38
39#[derive(Debug, Clone)]
41pub enum SizeStep {
42 Linear(usize),
44 Exponential(f64),
46 PowersOfTwo,
48 Custom(Vec<usize>),
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
54pub enum FftVariant {
55 Standard,
57 InPlace,
59 Cached,
61 SplitRadix,
63}
64
65#[derive(Debug, Clone)]
67pub struct AutoTuneConfig {
68 pub sizes: SizeRange,
70 pub repetitions: usize,
72 pub warmup: usize,
74 pub variants: Vec<FftVariant>,
76 pub database_path: PathBuf,
78}
79
80impl Default for AutoTuneConfig {
81 fn default() -> Self {
82 Self {
83 sizes: SizeRange {
84 min: 16,
85 max: 8192,
86 step: SizeStep::PowersOfTwo,
87 },
88 repetitions: 10,
89 warmup: 3,
90 variants: vec![FftVariant::Standard, FftVariant::Cached],
91 database_path: PathBuf::from(".fft_tuning_db.json"),
92 }
93 }
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct BenchmarkResult {
99 pub size: usize,
101 pub variant: FftVariant,
103 pub forward: bool,
105 pub avg_time_ns: u64,
107 pub min_time_ns: u64,
109 pub std_dev_ns: f64,
111 pub system_info: SystemInfo,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct SystemInfo {
118 pub cpu_model: String,
120 pub num_cores: usize,
122 pub architecture: String,
124 pub cpu_features: Vec<String>,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct TuningDatabase {
131 pub results: Vec<BenchmarkResult>,
133 pub last_updated: u64,
135 pub best_algorithms: HashMap<(usize, bool), FftVariant>,
137}
138
139pub struct AutoTuner {
141 config: AutoTuneConfig,
143 database: TuningDatabase,
145 enabled: bool,
147}
148
149impl Default for AutoTuner {
150 fn default() -> Self {
151 Self::with_config(AutoTuneConfig::default())
152 }
153}
154
155impl AutoTuner {
156 pub fn new() -> Self {
158 Self::default()
159 }
160
161 pub fn with_config(config: AutoTuneConfig) -> Self {
163 let database =
164 Self::load_database(&config.database_path).unwrap_or_else(|_| TuningDatabase {
165 results: Vec::new(),
166 last_updated: std::time::SystemTime::now()
167 .duration_since(std::time::UNIX_EPOCH)
168 .unwrap_or_default()
169 .as_secs(),
170 best_algorithms: HashMap::new(),
171 });
172
173 Self {
174 config,
175 database,
176 enabled: true,
177 }
178 }
179
180 fn load_database(path: &Path) -> FFTResult<TuningDatabase> {
182 if !path.exists() {
183 return Err(FFTError::IOError(format!(
184 "Tuning database file not found: {}",
185 path.display()
186 )));
187 }
188
189 let file = File::open(path)
190 .map_err(|e| FFTError::IOError(format!("Failed to open tuning database: {e}")))?;
191
192 let reader = BufReader::new(file);
193 let database: TuningDatabase = serde_json::from_reader(reader)
194 .map_err(|e| FFTError::ValueError(format!("Failed to parse tuning database: {e}")))?;
195
196 Ok(database)
197 }
198
199 pub fn save_database(&self) -> FFTResult<()> {
201 if let Some(parent) = self.config.database_path.parent() {
203 fs::create_dir_all(parent).map_err(|e| {
204 FFTError::IOError(format!(
205 "Failed to create directory for tuning database: {e}"
206 ))
207 })?;
208 }
209
210 let file = File::create(&self.config.database_path).map_err(|e| {
211 FFTError::IOError(format!("Failed to create tuning database file: {e}"))
212 })?;
213
214 let writer = BufWriter::new(file);
215 serde_json::to_writer_pretty(writer, &self.database)
216 .map_err(|e| FFTError::IOError(format!("Failed to serialize tuning database: {e}")))?;
217
218 Ok(())
219 }
220
221 pub fn set_enabled(&mut self, enabled: bool) {
223 self.enabled = enabled;
224 }
225
226 pub fn is_enabled(&self) -> bool {
228 self.enabled
229 }
230
231 pub fn run_benchmarks(&mut self) -> FFTResult<()> {
233 if !self.enabled {
234 return Ok(());
235 }
236
237 let sizes = self.generate_sizes();
238 let mut results = Vec::new();
239
240 for size in sizes {
241 for &variant in &self.config.variants {
242 let forward_result = self.benchmark_variant(size, variant, true)?;
244 results.push(forward_result);
245
246 let inverse_result = self.benchmark_variant(size, variant, false)?;
248 results.push(inverse_result);
249 }
250 }
251
252 self.database.results.extend(results);
254 self.update_best_algorithms();
255 self.save_database()?;
256
257 Ok(())
258 }
259
260 fn generate_sizes(&self) -> Vec<usize> {
262 let mut sizes = Vec::new();
263
264 match &self.config.sizes.step {
265 SizeStep::Linear(step) => {
266 let mut size = self.config.sizes.min;
267 while size <= self.config.sizes.max {
268 sizes.push(size);
269 size += step;
270 }
271 }
272 SizeStep::Exponential(factor) => {
273 let mut size = self.config.sizes.min as f64;
274 while size <= self.config.sizes.max as f64 {
275 sizes.push(size as usize);
276 size *= factor;
277 }
278 }
279 SizeStep::PowersOfTwo => {
280 let mut size = 1;
281 while size < self.config.sizes.min {
282 size *= 2;
283 }
284 while size <= self.config.sizes.max {
285 sizes.push(size);
286 size *= 2;
287 }
288 }
289 SizeStep::Custom(custom_sizes) => {
290 for &size in custom_sizes {
291 if size >= self.config.sizes.min && size <= self.config.sizes.max {
292 sizes.push(size);
293 }
294 }
295 }
296 }
297
298 sizes
299 }
300
301 fn benchmark_variant(
303 &self,
304 size: usize,
305 variant: FftVariant,
306 forward: bool,
307 ) -> FFTResult<BenchmarkResult> {
308 let mut data = vec![Complex64::new(0.0, 0.0); size];
310 for (i, val) in data.iter_mut().enumerate().take(size) {
311 *val = Complex64::new(i as f64, (i * 2) as f64);
312 }
313
314 for _ in 0..self.config.warmup {
316 match variant {
317 FftVariant::Standard => {
318 #[cfg(feature = "oxifft")]
319 {
320 let input_oxi: Vec<OxiComplex<f64>> =
321 data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
322 let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
323
324 let direction = if forward {
325 Direction::Forward
326 } else {
327 Direction::Backward
328 };
329 let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
330 }
331
332 #[cfg(not(feature = "oxifft"))]
333 {
334 #[cfg(feature = "rustfft-backend")]
335 {
336 let mut planner = FftPlanner::new();
337 let fft = if forward {
338 planner.plan_fft_forward(size)
339 } else {
340 planner.plan_fft_inverse(size)
341 };
342 let mut buffer = data.clone();
343 fft.process(&mut buffer);
344 }
345 }
346 }
347 FftVariant::InPlace => {
348 #[cfg(feature = "oxifft")]
349 {
350 let input_oxi: Vec<OxiComplex<f64>> =
351 data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
352 let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
353
354 let direction = if forward {
355 Direction::Forward
356 } else {
357 Direction::Backward
358 };
359 let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
360 }
361
362 #[cfg(not(feature = "oxifft"))]
363 {
364 #[cfg(feature = "rustfft-backend")]
365 {
366 let mut planner = FftPlanner::new();
367 let fft = if forward {
368 planner.plan_fft_forward(size)
369 } else {
370 planner.plan_fft_inverse(size)
371 };
372 let mut buffer = data.clone();
374 let mut scratch =
375 vec![Complex64::new(0.0, 0.0); fft.get_inplace_scratch_len()];
376 fft.process_with_scratch(&mut buffer, &mut scratch);
377 }
378 }
379 }
380 FftVariant::Cached => {
381 let manager = PlanSerializationManager::new(&self.config.database_path);
383 let plan_info = manager.create_plan_info(size, forward);
384 let (_, time) = crate::plan_serialization::create_and_time_plan(size, forward);
385 manager.record_plan_usage(&plan_info, time).unwrap_or(());
386 }
387 FftVariant::SplitRadix => {
388 #[cfg(feature = "oxifft")]
389 {
390 let input_oxi: Vec<OxiComplex<f64>> =
392 data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
393 let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
394
395 let direction = if forward {
396 Direction::Forward
397 } else {
398 Direction::Backward
399 };
400 let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
401 }
402
403 #[cfg(not(feature = "oxifft"))]
404 {
405 #[cfg(feature = "rustfft-backend")]
406 {
407 let mut planner = FftPlanner::new();
410 let fft = if forward {
411 planner.plan_fft_forward(size)
412 } else {
413 planner.plan_fft_inverse(size)
414 };
415 let mut buffer = data.clone();
416 fft.process(&mut buffer);
417 }
418 }
419 }
420 }
421 }
422
423 let mut times = Vec::with_capacity(self.config.repetitions);
425
426 for _ in 0..self.config.repetitions {
427 let start = Instant::now();
428
429 match variant {
430 FftVariant::Standard => {
431 #[cfg(feature = "oxifft")]
432 {
433 let input_oxi: Vec<OxiComplex<f64>> =
434 data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
435 let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
436
437 let direction = if forward {
438 Direction::Forward
439 } else {
440 Direction::Backward
441 };
442 let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
443 }
444
445 #[cfg(not(feature = "oxifft"))]
446 {
447 #[cfg(feature = "rustfft-backend")]
448 {
449 let mut planner = FftPlanner::new();
450 let fft = if forward {
451 planner.plan_fft_forward(size)
452 } else {
453 planner.plan_fft_inverse(size)
454 };
455 let mut buffer = data.clone();
456 fft.process(&mut buffer);
457 }
458 }
459 }
460 FftVariant::InPlace => {
461 #[cfg(feature = "oxifft")]
462 {
463 let input_oxi: Vec<OxiComplex<f64>> =
464 data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
465 let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
466
467 let direction = if forward {
468 Direction::Forward
469 } else {
470 Direction::Backward
471 };
472 let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
473 }
474
475 #[cfg(not(feature = "oxifft"))]
476 {
477 #[cfg(feature = "rustfft-backend")]
478 {
479 let mut planner = FftPlanner::new();
480 let fft = if forward {
481 planner.plan_fft_forward(size)
482 } else {
483 planner.plan_fft_inverse(size)
484 };
485 let mut buffer = data.clone();
487 let mut scratch =
488 vec![Complex64::new(0.0, 0.0); fft.get_inplace_scratch_len()];
489 fft.process_with_scratch(&mut buffer, &mut scratch);
490 }
491 }
492 }
493 FftVariant::Cached => {
494 #[cfg(feature = "oxifft")]
495 {
496 let input_oxi: Vec<OxiComplex<f64>> =
497 data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
498 let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
499
500 let direction = if forward {
501 Direction::Forward
502 } else {
503 Direction::Backward
504 };
505 let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
506 }
507
508 #[cfg(not(feature = "oxifft"))]
509 {
510 #[cfg(feature = "rustfft-backend")]
511 {
512 let mut planner = FftPlanner::new();
514 let fft = if forward {
515 planner.plan_fft_forward(size)
516 } else {
517 planner.plan_fft_inverse(size)
518 };
519 let mut buffer = data.clone();
520 fft.process(&mut buffer);
521 }
522 }
523 }
524 FftVariant::SplitRadix => {
525 #[cfg(feature = "oxifft")]
526 {
527 let input_oxi: Vec<OxiComplex<f64>> =
528 data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
529 let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
530
531 let direction = if forward {
532 Direction::Forward
533 } else {
534 Direction::Backward
535 };
536 let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
537 }
538
539 #[cfg(not(feature = "oxifft"))]
540 {
541 #[cfg(feature = "rustfft-backend")]
542 {
543 let mut planner = FftPlanner::new();
545 let fft = if forward {
546 planner.plan_fft_forward(size)
547 } else {
548 planner.plan_fft_inverse(size)
549 };
550 let mut buffer = data.clone();
551 fft.process(&mut buffer);
552 }
553 }
554 }
555 }
556
557 let elapsed = start.elapsed();
558 times.push(elapsed.as_nanos() as u64);
559 }
560
561 let avg_time = times.iter().sum::<u64>() / times.len() as u64;
563 let min_time = *times.iter().min().unwrap_or(&0);
564
565 let variance = times
567 .iter()
568 .map(|&t| {
569 let diff = t as f64 - avg_time as f64;
570 diff * diff
571 })
572 .sum::<f64>()
573 / times.len() as f64;
574 let std_dev = variance.sqrt();
575
576 Ok(BenchmarkResult {
577 size,
578 variant,
579 forward,
580 avg_time_ns: avg_time,
581 min_time_ns: min_time,
582 std_dev_ns: std_dev,
583 system_info: self.detect_system_info(),
584 })
585 }
586
587 fn detect_system_info(&self) -> SystemInfo {
589 SystemInfo {
592 cpu_model: String::from("Unknown"),
593 num_cores: num_cpus::get(),
594 architecture: std::env::consts::ARCH.to_string(),
595 cpu_features: detect_cpu_features(),
596 }
597 }
598
599 fn update_best_algorithms(&mut self) {
601 self.database.best_algorithms.clear();
603
604 let mut grouped: HashMap<(usize, bool), Vec<&BenchmarkResult>> = HashMap::new();
606 for result in &self.database.results {
607 grouped
608 .entry((result.size, result.forward))
609 .or_default()
610 .push(result);
611 }
612
613 for ((size, forward), results) in grouped {
615 if let Some(best) = results.iter().min_by_key(|r| r.avg_time_ns) {
616 self.database
617 .best_algorithms
618 .insert((size, forward), best.variant);
619 }
620 }
621 }
622
623 pub fn get_best_variant(&self, size: usize, forward: bool) -> FftVariant {
625 if !self.enabled {
626 return FftVariant::Standard;
627 }
628
629 if let Some(&variant) = self.database.best_algorithms.get(&(size, forward)) {
631 return variant;
632 }
633
634 let mut closest_size = 0;
636 let mut min_diff = usize::MAX;
637
638 for &(s, f) in self.database.best_algorithms.keys() {
639 if f == forward {
640 let diff = s.abs_diff(size);
641 if diff < min_diff {
642 min_diff = diff;
643 closest_size = s;
644 }
645 }
646 }
647
648 if closest_size > 0 {
649 if let Some(&variant) = self.database.best_algorithms.get(&(closest_size, forward)) {
650 return variant;
651 }
652 }
653
654 FftVariant::Standard
656 }
657
658 pub fn run_optimal_fft<T>(
660 &self,
661 input: &[T],
662 size: Option<usize>,
663 forward: bool,
664 ) -> FFTResult<Vec<Complex64>>
665 where
666 T: Clone + Into<Complex64>,
667 {
668 let actual_size = size.unwrap_or(input.len());
669 let variant = self.get_best_variant(actual_size, forward);
670
671 let mut buffer: Vec<Complex64> = input.iter().map(|x| x.clone().into()).collect();
673 if buffer.len() < actual_size {
675 buffer.resize(actual_size, Complex64::new(0.0, 0.0));
676 }
677
678 #[cfg(feature = "oxifft")]
679 {
680 let input_oxi: Vec<OxiComplex<f64>> =
681 buffer.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
682 let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); actual_size];
683
684 let direction = if forward {
685 Direction::Forward
686 } else {
687 Direction::Backward
688 };
689 oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction)?;
690
691 for (i, val) in output.iter().enumerate() {
693 buffer[i] = Complex64::new(val.re, val.im);
694 }
695 }
696
697 #[cfg(not(feature = "oxifft"))]
698 {
699 #[cfg(feature = "rustfft-backend")]
700 {
701 match variant {
702 FftVariant::Standard => {
703 let mut planner = FftPlanner::new();
704 let fft = if forward {
705 planner.plan_fft_forward(actual_size)
706 } else {
707 planner.plan_fft_inverse(actual_size)
708 };
709 fft.process(&mut buffer);
710 }
711 FftVariant::InPlace => {
712 let mut planner = FftPlanner::new();
713 let fft = if forward {
714 planner.plan_fft_forward(actual_size)
715 } else {
716 planner.plan_fft_inverse(actual_size)
717 };
718 let mut scratch =
719 vec![Complex64::new(0.0, 0.0); fft.get_inplace_scratch_len()];
720 fft.process_with_scratch(&mut buffer, &mut scratch);
721 }
722 FftVariant::Cached => {
723 let (plan_, _) =
726 crate::plan_serialization::create_and_time_plan(actual_size, forward);
727 plan_.process(&mut buffer);
728 }
729 FftVariant::SplitRadix => {
730 let mut planner = FftPlanner::new();
732 let fft = if forward {
733 planner.plan_fft_forward(actual_size)
734 } else {
735 planner.plan_fft_inverse(actual_size)
736 };
737 fft.process(&mut buffer);
738 }
739 }
740 }
741
742 #[cfg(not(feature = "rustfft-backend"))]
743 {
744 return Err(FFTError::ComputationError(
745 "No FFT backend available. Enable either 'oxifft' or 'rustfft-backend' feature.".to_string()
746 ));
747 }
748 }
749
750 if !forward {
752 let scale = 1.0 / (actual_size as f64);
753 for val in &mut buffer {
754 *val *= scale;
755 }
756 }
757
758 Ok(buffer)
759 }
760}
761
762#[allow(dead_code)]
764fn detect_cpu_features() -> Vec<String> {
765 let mut features = Vec::new();
766
767 #[cfg(target_arch = "x86_64")]
769 {
770 #[cfg(target_feature = "sse")]
771 features.push("sse".to_string());
772
773 #[cfg(target_feature = "sse2")]
774 features.push("sse2".to_string());
775
776 #[cfg(target_feature = "sse3")]
777 features.push("sse3".to_string());
778
779 #[cfg(target_feature = "sse4.1")]
780 features.push("sse4.1".to_string());
781
782 #[cfg(target_feature = "sse4.2")]
783 features.push("sse4.2".to_string());
784
785 #[cfg(target_feature = "avx")]
786 features.push("avx".to_string());
787
788 #[cfg(target_feature = "avx2")]
789 features.push("avx2".to_string());
790
791 #[cfg(target_feature = "fma")]
792 features.push("fma".to_string());
793 }
794
795 #[cfg(target_arch = "aarch64")]
797 {
798 #[cfg(target_feature = "neon")]
799 features.push("neon".to_string());
800 }
801
802 features
805}
806
807pub struct IntegratedAutoSelector {
813 selector: crate::algorithm_selector::AlgorithmSelector,
815 tuner: AutoTuner,
817 prefer_learned: bool,
819}
820
821impl Default for IntegratedAutoSelector {
822 fn default() -> Self {
823 Self::new()
824 }
825}
826
827impl IntegratedAutoSelector {
828 pub fn new() -> Self {
830 Self {
831 selector: crate::algorithm_selector::AlgorithmSelector::new(),
832 tuner: AutoTuner::new(),
833 prefer_learned: true,
834 }
835 }
836
837 pub fn with_config(
839 selector_config: crate::algorithm_selector::SelectionConfig,
840 tuner_config: AutoTuneConfig,
841 prefer_learned: bool,
842 ) -> Self {
843 Self {
844 selector: crate::algorithm_selector::AlgorithmSelector::with_config(selector_config),
845 tuner: AutoTuner::with_config(tuner_config),
846 prefer_learned,
847 }
848 }
849
850 pub fn select(&self, size: usize, forward: bool) -> FFTResult<SelectionResult> {
852 if self.prefer_learned && self.tuner.is_enabled() {
854 let variant = self.tuner.get_best_variant(size, forward);
855 if variant != FftVariant::Standard {
856 return Ok(SelectionResult {
858 algorithm: variant_to_algorithm(variant),
859 variant,
860 source: SelectionSource::Learned,
861 confidence: 0.9,
862 recommendation: self.selector.select_algorithm(size, forward).ok(),
863 });
864 }
865 }
866
867 let recommendation = self.selector.select_algorithm(size, forward)?;
869 let variant = algorithm_to_variant(recommendation.algorithm);
870
871 Ok(SelectionResult {
872 algorithm: recommendation.algorithm,
873 variant,
874 source: SelectionSource::Characteristic,
875 confidence: recommendation.confidence,
876 recommendation: Some(recommendation),
877 })
878 }
879
880 pub fn auto_tune(&mut self, sizes: &[usize]) -> FFTResult<()> {
882 if sizes.is_empty() {
884 return Ok(());
885 }
886
887 let min = *sizes.iter().min().unwrap_or(&16);
888 let max = *sizes.iter().max().unwrap_or(&8192);
889
890 let config = AutoTuneConfig {
891 sizes: SizeRange {
892 min,
893 max,
894 step: SizeStep::Custom(sizes.to_vec()),
895 },
896 ..Default::default()
897 };
898
899 self.tuner = AutoTuner::with_config(config);
900 self.tuner.run_benchmarks()
901 }
902
903 pub fn execute<T>(
905 &self,
906 input: &[T],
907 size: Option<usize>,
908 forward: bool,
909 ) -> FFTResult<Vec<Complex64>>
910 where
911 T: Clone + Into<Complex64>,
912 {
913 let actual_size = size.unwrap_or(input.len());
914 let selection = self.select(actual_size, forward)?;
915
916 self.tuner.run_optimal_fft(input, size, forward)
918 }
919
920 pub fn selector(&self) -> &crate::algorithm_selector::AlgorithmSelector {
922 &self.selector
923 }
924
925 pub fn tuner(&self) -> &AutoTuner {
927 &self.tuner
928 }
929}
930
931#[derive(Debug, Clone)]
933pub struct SelectionResult {
934 pub algorithm: crate::algorithm_selector::FftAlgorithm,
936 pub variant: FftVariant,
938 pub source: SelectionSource,
940 pub confidence: f64,
942 pub recommendation: Option<crate::algorithm_selector::AlgorithmRecommendation>,
944}
945
946#[derive(Debug, Clone, Copy, PartialEq, Eq)]
948pub enum SelectionSource {
949 Learned,
951 Characteristic,
953 Forced,
955 Default,
957}
958
959fn variant_to_algorithm(variant: FftVariant) -> crate::algorithm_selector::FftAlgorithm {
961 use crate::algorithm_selector::FftAlgorithm;
962 match variant {
963 FftVariant::Standard => FftAlgorithm::MixedRadix,
964 FftVariant::InPlace => FftAlgorithm::InPlace,
965 FftVariant::Cached => FftAlgorithm::MixedRadix,
966 FftVariant::SplitRadix => FftAlgorithm::SplitRadix,
967 }
968}
969
970fn algorithm_to_variant(algorithm: crate::algorithm_selector::FftAlgorithm) -> FftVariant {
972 use crate::algorithm_selector::FftAlgorithm;
973 match algorithm {
974 FftAlgorithm::SplitRadix => FftVariant::SplitRadix,
975 FftAlgorithm::InPlace => FftVariant::InPlace,
976 _ => FftVariant::Standard,
977 }
978}
979
980pub fn auto_select_algorithm(size: usize, forward: bool) -> FFTResult<SelectionResult> {
1004 let selector = IntegratedAutoSelector::new();
1005 selector.select(size, forward)
1006}
1007
1008pub fn auto_fft<T>(input: &[T], size: Option<usize>, forward: bool) -> FFTResult<Vec<Complex64>>
1032where
1033 T: Clone + Into<Complex64>,
1034{
1035 let selector = IntegratedAutoSelector::new();
1036 selector.execute(input, size, forward)
1037}
1038
1039#[cfg(test)]
1040mod tests {
1041 use super::*;
1042 use tempfile::tempdir;
1043
1044 #[test]
1045 fn test_size_generation() {
1046 let config = AutoTuneConfig {
1048 sizes: SizeRange {
1049 min: 8,
1050 max: 64,
1051 step: SizeStep::PowersOfTwo,
1052 },
1053 ..Default::default()
1054 };
1055 let tuner = AutoTuner::with_config(config);
1056 let sizes = tuner.generate_sizes();
1057 assert_eq!(sizes, vec![8, 16, 32, 64]);
1058
1059 let config = AutoTuneConfig {
1061 sizes: SizeRange {
1062 min: 10,
1063 max: 30,
1064 step: SizeStep::Linear(5),
1065 },
1066 ..Default::default()
1067 };
1068 let tuner = AutoTuner::with_config(config);
1069 let sizes = tuner.generate_sizes();
1070 assert_eq!(sizes, vec![10, 15, 20, 25, 30]);
1071
1072 let config = AutoTuneConfig {
1074 sizes: SizeRange {
1075 min: 10,
1076 max: 100,
1077 step: SizeStep::Exponential(2.0),
1078 },
1079 ..Default::default()
1080 };
1081 let tuner = AutoTuner::with_config(config);
1082 let sizes = tuner.generate_sizes();
1083 assert_eq!(sizes, vec![10, 20, 40, 80]);
1084
1085 let config = AutoTuneConfig {
1087 sizes: SizeRange {
1088 min: 10,
1089 max: 100,
1090 step: SizeStep::Custom(vec![5, 15, 25, 50, 150]),
1091 },
1092 ..Default::default()
1093 };
1094 let tuner = AutoTuner::with_config(config);
1095 let sizes = tuner.generate_sizes();
1096 assert_eq!(sizes, vec![15, 25, 50]);
1097 }
1098
1099 #[test]
1100 fn test_auto_tuner_basic() {
1101 let temp_dir = tempdir().expect("Operation failed");
1103 let db_path = temp_dir.path().join("test_tuning_db.json");
1104
1105 let config = AutoTuneConfig {
1107 sizes: SizeRange {
1108 min: 16,
1109 max: 32,
1110 step: SizeStep::PowersOfTwo,
1111 },
1112 repetitions: 2,
1113 warmup: 1,
1114 variants: vec![FftVariant::Standard, FftVariant::InPlace],
1115 database_path: db_path.clone(),
1116 };
1117
1118 let mut tuner = AutoTuner::with_config(config);
1119
1120 match tuner.run_benchmarks() {
1122 Ok(_) => {
1123 assert!(db_path.exists());
1125
1126 let variant = tuner.get_best_variant(16, true);
1128 assert!(matches!(
1129 variant,
1130 FftVariant::Standard | FftVariant::InPlace
1131 ));
1132 }
1133 Err(e) => {
1134 println!("Benchmark failed: {e}");
1136 }
1137 }
1138 }
1139}