1use rustfft::FftPlanner;
12use scirs2_core::numeric::Complex64;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::fs::{self, File};
16use std::io::{BufReader, BufWriter};
17use std::path::{Path, PathBuf};
18use std::time::Instant;
19
20use crate::error::{FFTError, FFTResult};
21use crate::plan_serialization::PlanSerializationManager;
22
23#[derive(Debug, Clone)]
25pub struct SizeRange {
26 pub min: usize,
28 pub max: usize,
30 pub step: SizeStep,
32}
33
34#[derive(Debug, Clone)]
36pub enum SizeStep {
37 Linear(usize),
39 Exponential(f64),
41 PowersOfTwo,
43 Custom(Vec<usize>),
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
49pub enum FftVariant {
50 Standard,
52 InPlace,
54 Cached,
56 SplitRadix,
58}
59
60#[derive(Debug, Clone)]
62pub struct AutoTuneConfig {
63 pub sizes: SizeRange,
65 pub repetitions: usize,
67 pub warmup: usize,
69 pub variants: Vec<FftVariant>,
71 pub database_path: PathBuf,
73}
74
75impl Default for AutoTuneConfig {
76 fn default() -> Self {
77 Self {
78 sizes: SizeRange {
79 min: 16,
80 max: 8192,
81 step: SizeStep::PowersOfTwo,
82 },
83 repetitions: 10,
84 warmup: 3,
85 variants: vec![FftVariant::Standard, FftVariant::Cached],
86 database_path: PathBuf::from(".fft_tuning_db.json"),
87 }
88 }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct BenchmarkResult {
94 pub size: usize,
96 pub variant: FftVariant,
98 pub forward: bool,
100 pub avg_time_ns: u64,
102 pub min_time_ns: u64,
104 pub std_dev_ns: f64,
106 pub system_info: SystemInfo,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct SystemInfo {
113 pub cpu_model: String,
115 pub num_cores: usize,
117 pub architecture: String,
119 pub cpu_features: Vec<String>,
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct TuningDatabase {
126 pub results: Vec<BenchmarkResult>,
128 pub last_updated: u64,
130 pub best_algorithms: HashMap<(usize, bool), FftVariant>,
132}
133
134pub struct AutoTuner {
136 config: AutoTuneConfig,
138 database: TuningDatabase,
140 enabled: bool,
142}
143
144impl Default for AutoTuner {
145 fn default() -> Self {
146 Self::with_config(AutoTuneConfig::default())
147 }
148}
149
150impl AutoTuner {
151 pub fn new() -> Self {
153 Self::default()
154 }
155
156 pub fn with_config(config: AutoTuneConfig) -> Self {
158 let database =
159 Self::load_database(&config.database_path).unwrap_or_else(|_| TuningDatabase {
160 results: Vec::new(),
161 last_updated: std::time::SystemTime::now()
162 .duration_since(std::time::UNIX_EPOCH)
163 .unwrap_or_default()
164 .as_secs(),
165 best_algorithms: HashMap::new(),
166 });
167
168 Self {
169 config,
170 database,
171 enabled: true,
172 }
173 }
174
175 fn load_database(path: &Path) -> FFTResult<TuningDatabase> {
177 if !path.exists() {
178 return Err(FFTError::IOError(format!(
179 "Tuning database file not found: {}",
180 path.display()
181 )));
182 }
183
184 let file = File::open(path)
185 .map_err(|e| FFTError::IOError(format!("Failed to open tuning database: {e}")))?;
186
187 let reader = BufReader::new(file);
188 let database: TuningDatabase = serde_json::from_reader(reader)
189 .map_err(|e| FFTError::ValueError(format!("Failed to parse tuning database: {e}")))?;
190
191 Ok(database)
192 }
193
194 pub fn save_database(&self) -> FFTResult<()> {
196 if let Some(parent) = self.config.database_path.parent() {
198 fs::create_dir_all(parent).map_err(|e| {
199 FFTError::IOError(format!(
200 "Failed to create directory for tuning database: {e}"
201 ))
202 })?;
203 }
204
205 let file = File::create(&self.config.database_path).map_err(|e| {
206 FFTError::IOError(format!("Failed to create tuning database file: {e}"))
207 })?;
208
209 let writer = BufWriter::new(file);
210 serde_json::to_writer_pretty(writer, &self.database)
211 .map_err(|e| FFTError::IOError(format!("Failed to serialize tuning database: {e}")))?;
212
213 Ok(())
214 }
215
216 pub fn set_enabled(&mut self, enabled: bool) {
218 self.enabled = enabled;
219 }
220
221 pub fn is_enabled(&self) -> bool {
223 self.enabled
224 }
225
226 pub fn run_benchmarks(&mut self) -> FFTResult<()> {
228 if !self.enabled {
229 return Ok(());
230 }
231
232 let sizes = self.generate_sizes();
233 let mut results = Vec::new();
234
235 for size in sizes {
236 for &variant in &self.config.variants {
237 let forward_result = self.benchmark_variant(size, variant, true)?;
239 results.push(forward_result);
240
241 let inverse_result = self.benchmark_variant(size, variant, false)?;
243 results.push(inverse_result);
244 }
245 }
246
247 self.database.results.extend(results);
249 self.update_best_algorithms();
250 self.save_database()?;
251
252 Ok(())
253 }
254
255 fn generate_sizes(&self) -> Vec<usize> {
257 let mut sizes = Vec::new();
258
259 match &self.config.sizes.step {
260 SizeStep::Linear(step) => {
261 let mut size = self.config.sizes.min;
262 while size <= self.config.sizes.max {
263 sizes.push(size);
264 size += step;
265 }
266 }
267 SizeStep::Exponential(factor) => {
268 let mut size = self.config.sizes.min as f64;
269 while size <= self.config.sizes.max as f64 {
270 sizes.push(size as usize);
271 size *= factor;
272 }
273 }
274 SizeStep::PowersOfTwo => {
275 let mut size = 1;
276 while size < self.config.sizes.min {
277 size *= 2;
278 }
279 while size <= self.config.sizes.max {
280 sizes.push(size);
281 size *= 2;
282 }
283 }
284 SizeStep::Custom(custom_sizes) => {
285 for &size in custom_sizes {
286 if size >= self.config.sizes.min && size <= self.config.sizes.max {
287 sizes.push(size);
288 }
289 }
290 }
291 }
292
293 sizes
294 }
295
296 fn benchmark_variant(
298 &self,
299 size: usize,
300 variant: FftVariant,
301 forward: bool,
302 ) -> FFTResult<BenchmarkResult> {
303 let mut data = vec![Complex64::new(0.0, 0.0); size];
305 for (i, val) in data.iter_mut().enumerate().take(size) {
306 *val = Complex64::new(i as f64, (i * 2) as f64);
307 }
308
309 for _ in 0..self.config.warmup {
311 match variant {
312 FftVariant::Standard => {
313 let mut planner = FftPlanner::new();
314 let fft = if forward {
315 planner.plan_fft_forward(size)
316 } else {
317 planner.plan_fft_inverse(size)
318 };
319 let mut buffer = data.clone();
320 fft.process(&mut buffer);
321 }
322 FftVariant::InPlace => {
323 let mut planner = FftPlanner::new();
324 let fft = if forward {
325 planner.plan_fft_forward(size)
326 } else {
327 planner.plan_fft_inverse(size)
328 };
329 let mut buffer = data.clone();
331 let mut scratch = vec![Complex64::new(0.0, 0.0); fft.get_inplace_scratch_len()];
332 fft.process_with_scratch(&mut buffer, &mut scratch);
333 }
334 FftVariant::Cached => {
335 let manager = PlanSerializationManager::new(&self.config.database_path);
337 let plan_info = manager.create_plan_info(size, forward);
338 let (_, time) = crate::plan_serialization::create_and_time_plan(size, forward);
339 manager.record_plan_usage(&plan_info, time).unwrap_or(());
340 }
341 FftVariant::SplitRadix => {
342 let mut planner = FftPlanner::new();
345 let fft = if forward {
346 planner.plan_fft_forward(size)
347 } else {
348 planner.plan_fft_inverse(size)
349 };
350 let mut buffer = data.clone();
351 fft.process(&mut buffer);
352 }
353 }
354 }
355
356 let mut times = Vec::with_capacity(self.config.repetitions);
358
359 for _ in 0..self.config.repetitions {
360 let start = Instant::now();
361
362 match variant {
363 FftVariant::Standard => {
364 let mut planner = FftPlanner::new();
365 let fft = if forward {
366 planner.plan_fft_forward(size)
367 } else {
368 planner.plan_fft_inverse(size)
369 };
370 let mut buffer = data.clone();
371 fft.process(&mut buffer);
372 }
373 FftVariant::InPlace => {
374 let mut planner = FftPlanner::new();
375 let fft = if forward {
376 planner.plan_fft_forward(size)
377 } else {
378 planner.plan_fft_inverse(size)
379 };
380 let mut buffer = data.clone();
382 let mut scratch = vec![Complex64::new(0.0, 0.0); fft.get_inplace_scratch_len()];
383 fft.process_with_scratch(&mut buffer, &mut scratch);
384 }
385 FftVariant::Cached => {
386 let mut planner = FftPlanner::new();
388 let fft = if forward {
389 planner.plan_fft_forward(size)
390 } else {
391 planner.plan_fft_inverse(size)
392 };
393 let mut buffer = data.clone();
394 fft.process(&mut buffer);
395 }
396 FftVariant::SplitRadix => {
397 let mut planner = FftPlanner::new();
399 let fft = if forward {
400 planner.plan_fft_forward(size)
401 } else {
402 planner.plan_fft_inverse(size)
403 };
404 let mut buffer = data.clone();
405 fft.process(&mut buffer);
406 }
407 }
408
409 let elapsed = start.elapsed();
410 times.push(elapsed.as_nanos() as u64);
411 }
412
413 let avg_time = times.iter().sum::<u64>() / times.len() as u64;
415 let min_time = *times.iter().min().unwrap_or(&0);
416
417 let variance = times
419 .iter()
420 .map(|&t| {
421 let diff = t as f64 - avg_time as f64;
422 diff * diff
423 })
424 .sum::<f64>()
425 / times.len() as f64;
426 let std_dev = variance.sqrt();
427
428 Ok(BenchmarkResult {
429 size,
430 variant,
431 forward,
432 avg_time_ns: avg_time,
433 min_time_ns: min_time,
434 std_dev_ns: std_dev,
435 system_info: self.detect_system_info(),
436 })
437 }
438
439 fn detect_system_info(&self) -> SystemInfo {
441 SystemInfo {
444 cpu_model: String::from("Unknown"),
445 num_cores: num_cpus::get(),
446 architecture: std::env::consts::ARCH.to_string(),
447 cpu_features: detect_cpu_features(),
448 }
449 }
450
451 fn update_best_algorithms(&mut self) {
453 self.database.best_algorithms.clear();
455
456 let mut grouped: HashMap<(usize, bool), Vec<&BenchmarkResult>> = HashMap::new();
458 for result in &self.database.results {
459 grouped
460 .entry((result.size, result.forward))
461 .or_default()
462 .push(result);
463 }
464
465 for ((size, forward), results) in grouped {
467 if let Some(best) = results.iter().min_by_key(|r| r.avg_time_ns) {
468 self.database
469 .best_algorithms
470 .insert((size, forward), best.variant);
471 }
472 }
473 }
474
475 pub fn get_best_variant(&self, size: usize, forward: bool) -> FftVariant {
477 if !self.enabled {
478 return FftVariant::Standard;
479 }
480
481 if let Some(&variant) = self.database.best_algorithms.get(&(size, forward)) {
483 return variant;
484 }
485
486 let mut closest_size = 0;
488 let mut min_diff = usize::MAX;
489
490 for &(s, f) in self.database.best_algorithms.keys() {
491 if f == forward {
492 let diff = s.abs_diff(size);
493 if diff < min_diff {
494 min_diff = diff;
495 closest_size = s;
496 }
497 }
498 }
499
500 if closest_size > 0 {
501 if let Some(&variant) = self.database.best_algorithms.get(&(closest_size, forward)) {
502 return variant;
503 }
504 }
505
506 FftVariant::Standard
508 }
509
510 pub fn run_optimal_fft<T>(
512 &self,
513 input: &[T],
514 size: Option<usize>,
515 forward: bool,
516 ) -> FFTResult<Vec<Complex64>>
517 where
518 T: Clone + Into<Complex64>,
519 {
520 let actual_size = size.unwrap_or(input.len());
521 let variant = self.get_best_variant(actual_size, forward);
522
523 let mut buffer: Vec<Complex64> = input.iter().map(|x| x.clone().into()).collect();
525 if buffer.len() < actual_size {
527 buffer.resize(actual_size, Complex64::new(0.0, 0.0));
528 }
529
530 match variant {
531 FftVariant::Standard => {
532 let mut planner = FftPlanner::new();
533 let fft = if forward {
534 planner.plan_fft_forward(actual_size)
535 } else {
536 planner.plan_fft_inverse(actual_size)
537 };
538 fft.process(&mut buffer);
539 }
540 FftVariant::InPlace => {
541 let mut planner = FftPlanner::new();
542 let fft = if forward {
543 planner.plan_fft_forward(actual_size)
544 } else {
545 planner.plan_fft_inverse(actual_size)
546 };
547 let mut scratch = vec![Complex64::new(0.0, 0.0); fft.get_inplace_scratch_len()];
548 fft.process_with_scratch(&mut buffer, &mut scratch);
549 }
550 FftVariant::Cached => {
551 let (plan_, _) =
554 crate::plan_serialization::create_and_time_plan(actual_size, forward);
555 plan_.process(&mut buffer);
556 }
557 FftVariant::SplitRadix => {
558 let mut planner = FftPlanner::new();
560 let fft = if forward {
561 planner.plan_fft_forward(actual_size)
562 } else {
563 planner.plan_fft_inverse(actual_size)
564 };
565 fft.process(&mut buffer);
566 }
567 }
568
569 if !forward {
571 let scale = 1.0 / (actual_size as f64);
572 for val in &mut buffer {
573 *val *= scale;
574 }
575 }
576
577 Ok(buffer)
578 }
579}
580
581#[allow(dead_code)]
583fn detect_cpu_features() -> Vec<String> {
584 let mut features = Vec::new();
585
586 #[cfg(target_arch = "x86_64")]
588 {
589 #[cfg(target_feature = "sse")]
590 features.push("sse".to_string());
591
592 #[cfg(target_feature = "sse2")]
593 features.push("sse2".to_string());
594
595 #[cfg(target_feature = "sse3")]
596 features.push("sse3".to_string());
597
598 #[cfg(target_feature = "sse4.1")]
599 features.push("sse4.1".to_string());
600
601 #[cfg(target_feature = "sse4.2")]
602 features.push("sse4.2".to_string());
603
604 #[cfg(target_feature = "avx")]
605 features.push("avx".to_string());
606
607 #[cfg(target_feature = "avx2")]
608 features.push("avx2".to_string());
609
610 #[cfg(target_feature = "fma")]
611 features.push("fma".to_string());
612 }
613
614 #[cfg(target_arch = "aarch64")]
616 {
617 #[cfg(target_feature = "neon")]
618 features.push("neon".to_string());
619 }
620
621 features
624}
625
626#[cfg(test)]
627mod tests {
628 use super::*;
629 use tempfile::tempdir;
630
631 #[test]
632 fn test_size_generation() {
633 let config = AutoTuneConfig {
635 sizes: SizeRange {
636 min: 8,
637 max: 64,
638 step: SizeStep::PowersOfTwo,
639 },
640 ..Default::default()
641 };
642 let tuner = AutoTuner::with_config(config);
643 let sizes = tuner.generate_sizes();
644 assert_eq!(sizes, vec![8, 16, 32, 64]);
645
646 let config = AutoTuneConfig {
648 sizes: SizeRange {
649 min: 10,
650 max: 30,
651 step: SizeStep::Linear(5),
652 },
653 ..Default::default()
654 };
655 let tuner = AutoTuner::with_config(config);
656 let sizes = tuner.generate_sizes();
657 assert_eq!(sizes, vec![10, 15, 20, 25, 30]);
658
659 let config = AutoTuneConfig {
661 sizes: SizeRange {
662 min: 10,
663 max: 100,
664 step: SizeStep::Exponential(2.0),
665 },
666 ..Default::default()
667 };
668 let tuner = AutoTuner::with_config(config);
669 let sizes = tuner.generate_sizes();
670 assert_eq!(sizes, vec![10, 20, 40, 80]);
671
672 let config = AutoTuneConfig {
674 sizes: SizeRange {
675 min: 10,
676 max: 100,
677 step: SizeStep::Custom(vec![5, 15, 25, 50, 150]),
678 },
679 ..Default::default()
680 };
681 let tuner = AutoTuner::with_config(config);
682 let sizes = tuner.generate_sizes();
683 assert_eq!(sizes, vec![15, 25, 50]);
684 }
685
686 #[test]
687 fn test_auto_tuner_basic() {
688 let temp_dir = tempdir().unwrap();
690 let db_path = temp_dir.path().join("test_tuning_db.json");
691
692 let config = AutoTuneConfig {
694 sizes: SizeRange {
695 min: 16,
696 max: 32,
697 step: SizeStep::PowersOfTwo,
698 },
699 repetitions: 2,
700 warmup: 1,
701 variants: vec![FftVariant::Standard, FftVariant::InPlace],
702 database_path: db_path.clone(),
703 };
704
705 let mut tuner = AutoTuner::with_config(config);
706
707 match tuner.run_benchmarks() {
709 Ok(_) => {
710 assert!(db_path.exists());
712
713 let variant = tuner.get_best_variant(16, true);
715 assert!(matches!(
716 variant,
717 FftVariant::Standard | FftVariant::InPlace
718 ));
719 }
720 Err(e) => {
721 println!("Benchmark failed: {e}");
723 }
724 }
725 }
726}