Skip to main content

scirs2_fft/
auto_tuning.rs

1//! Auto-tuning for hardware-specific FFT optimizations
2//!
3//! This module provides functionality to automatically tune FFT parameters
4//! for optimal performance on the current hardware. It includes:
5//!
6//! - Benchmarking different FFT configurations
7//! - Selecting optimal parameters based on timing results
8//! - Persisting tuning results for future use
9//! - Detecting CPU features and adapting algorithms accordingly
10
11#[cfg(feature = "oxifft")]
12use crate::oxifft_plan_cache;
13#[cfg(feature = "oxifft")]
14use oxifft::{Complex as OxiComplex, Direction};
15use scirs2_core::numeric::Complex64;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::fs::{self, File};
19use std::io::{BufReader, BufWriter};
20use std::path::{Path, PathBuf};
21use std::time::Instant;
22
23use crate::error::{FFTError, FFTResult};
24use crate::plan_serialization::PlanSerializationManager;
25
26/// A range of FFT sizes to benchmark
27#[derive(Debug, Clone)]
28pub struct SizeRange {
29    /// Minimum size to test
30    pub min: usize,
31    /// Maximum size to test
32    pub max: usize,
33    /// Step between sizes (can be multiplication factor)
34    pub step: SizeStep,
35}
36
37/// Step type for size range
38#[derive(Debug, Clone)]
39pub enum SizeStep {
40    /// Add a constant value
41    Linear(usize),
42    /// Multiply by a factor
43    Exponential(f64),
44    /// Use powers of two
45    PowersOfTwo,
46    /// Use specific sizes
47    Custom(Vec<usize>),
48}
49
50/// FFT algorithm variant to benchmark
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
52pub enum FftVariant {
53    /// Standard FFT
54    Standard,
55    /// In-place FFT
56    InPlace,
57    /// Cached-plan FFT
58    Cached,
59    /// Split-radix FFT
60    SplitRadix,
61}
62
63/// Configuration for auto-tuning
64#[derive(Debug, Clone)]
65pub struct AutoTuneConfig {
66    /// Sizes to benchmark
67    pub sizes: SizeRange,
68    /// Number of repetitions per test
69    pub repetitions: usize,
70    /// Warm-up iterations (not timed)
71    pub warmup: usize,
72    /// FFT variants to test
73    pub variants: Vec<FftVariant>,
74    /// Path to save tuning results
75    pub database_path: PathBuf,
76}
77
78impl Default for AutoTuneConfig {
79    fn default() -> Self {
80        Self {
81            sizes: SizeRange {
82                min: 16,
83                max: 8192,
84                step: SizeStep::PowersOfTwo,
85            },
86            repetitions: 10,
87            warmup: 3,
88            variants: vec![FftVariant::Standard, FftVariant::Cached],
89            database_path: PathBuf::from(".fft_tuning_db.json"),
90        }
91    }
92}
93
94/// Results from a single benchmark
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct BenchmarkResult {
97    /// FFT size
98    pub size: usize,
99    /// FFT variant
100    pub variant: FftVariant,
101    /// Whether this is forward or inverse FFT
102    pub forward: bool,
103    /// Average execution time in nanoseconds
104    pub avg_time_ns: u64,
105    /// Minimum execution time in nanoseconds
106    pub min_time_ns: u64,
107    /// Standard deviation in nanoseconds
108    pub std_dev_ns: f64,
109    /// System information when the benchmark was run
110    pub system_info: SystemInfo,
111}
112
113/// System information for result matching
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct SystemInfo {
116    /// CPU model
117    pub cpu_model: String,
118    /// Number of cores
119    pub num_cores: usize,
120    /// Architecture
121    pub architecture: String,
122    /// CPU features (SIMD instruction sets, etc.)
123    pub cpu_features: Vec<String>,
124}
125
126/// Database of tuning results
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct TuningDatabase {
129    /// Benchmark results
130    pub results: Vec<BenchmarkResult>,
131    /// Last updated timestamp
132    pub last_updated: u64,
133    /// Best algorithm for each size
134    pub best_algorithms: HashMap<(usize, bool), FftVariant>,
135}
136
137/// Auto-tuning manager
138pub struct AutoTuner {
139    /// Configuration
140    config: AutoTuneConfig,
141    /// Database of results
142    database: TuningDatabase,
143    /// Whether to use tuning
144    enabled: bool,
145}
146
147impl Default for AutoTuner {
148    fn default() -> Self {
149        Self::with_config(AutoTuneConfig::default())
150    }
151}
152
153impl AutoTuner {
154    /// Create a new auto-tuner with default configuration
155    pub fn new() -> Self {
156        Self::default()
157    }
158
159    /// Create a new auto-tuner with custom configuration
160    pub fn with_config(config: AutoTuneConfig) -> Self {
161        let database =
162            Self::load_database(&config.database_path).unwrap_or_else(|_| TuningDatabase {
163                results: Vec::new(),
164                last_updated: std::time::SystemTime::now()
165                    .duration_since(std::time::UNIX_EPOCH)
166                    .unwrap_or_default()
167                    .as_secs(),
168                best_algorithms: HashMap::new(),
169            });
170
171        Self {
172            config,
173            database,
174            enabled: true,
175        }
176    }
177
178    /// Load the tuning database from disk
179    fn load_database(path: &Path) -> FFTResult<TuningDatabase> {
180        if !path.exists() {
181            return Err(FFTError::IOError(format!(
182                "Tuning database file not found: {}",
183                path.display()
184            )));
185        }
186
187        let file = File::open(path)
188            .map_err(|e| FFTError::IOError(format!("Failed to open tuning database: {e}")))?;
189
190        let reader = BufReader::new(file);
191        let database: TuningDatabase = serde_json::from_reader(reader)
192            .map_err(|e| FFTError::ValueError(format!("Failed to parse tuning database: {e}")))?;
193
194        Ok(database)
195    }
196
197    /// Save the tuning database to disk
198    pub fn save_database(&self) -> FFTResult<()> {
199        // Create parent directories if they don't exist
200        if let Some(parent) = self.config.database_path.parent() {
201            fs::create_dir_all(parent).map_err(|e| {
202                FFTError::IOError(format!(
203                    "Failed to create directory for tuning database: {e}"
204                ))
205            })?;
206        }
207
208        let file = File::create(&self.config.database_path).map_err(|e| {
209            FFTError::IOError(format!("Failed to create tuning database file: {e}"))
210        })?;
211
212        let writer = BufWriter::new(file);
213        serde_json::to_writer_pretty(writer, &self.database)
214            .map_err(|e| FFTError::IOError(format!("Failed to serialize tuning database: {e}")))?;
215
216        Ok(())
217    }
218
219    /// Enable or disable auto-tuning
220    pub fn set_enabled(&mut self, enabled: bool) {
221        self.enabled = enabled;
222    }
223
224    /// Check if auto-tuning is enabled
225    pub fn is_enabled(&self) -> bool {
226        self.enabled
227    }
228
229    /// Run benchmarks for all configured FFT variants and sizes
230    pub fn run_benchmarks(&mut self) -> FFTResult<()> {
231        if !self.enabled {
232            return Ok(());
233        }
234
235        let sizes = self.generate_sizes();
236        let mut results = Vec::new();
237
238        for size in sizes {
239            for &variant in &self.config.variants {
240                // Benchmark forward transform
241                let forward_result = self.benchmark_variant(size, variant, true)?;
242                results.push(forward_result);
243
244                // Benchmark inverse transform
245                let inverse_result = self.benchmark_variant(size, variant, false)?;
246                results.push(inverse_result);
247            }
248        }
249
250        // Update database
251        self.database.results.extend(results);
252        self.update_best_algorithms();
253        self.save_database()?;
254
255        Ok(())
256    }
257
258    /// Generate the list of sizes to benchmark
259    fn generate_sizes(&self) -> Vec<usize> {
260        let mut sizes = Vec::new();
261
262        match &self.config.sizes.step {
263            SizeStep::Linear(step) => {
264                let mut size = self.config.sizes.min;
265                while size <= self.config.sizes.max {
266                    sizes.push(size);
267                    size += step;
268                }
269            }
270            SizeStep::Exponential(factor) => {
271                let mut size = self.config.sizes.min as f64;
272                while size <= self.config.sizes.max as f64 {
273                    sizes.push(size as usize);
274                    size *= factor;
275                }
276            }
277            SizeStep::PowersOfTwo => {
278                let mut size = 1;
279                while size < self.config.sizes.min {
280                    size *= 2;
281                }
282                while size <= self.config.sizes.max {
283                    sizes.push(size);
284                    size *= 2;
285                }
286            }
287            SizeStep::Custom(custom_sizes) => {
288                for &size in custom_sizes {
289                    if size >= self.config.sizes.min && size <= self.config.sizes.max {
290                        sizes.push(size);
291                    }
292                }
293            }
294        }
295
296        sizes
297    }
298
299    /// Benchmark a specific FFT variant for a given size
300    fn benchmark_variant(
301        &self,
302        size: usize,
303        variant: FftVariant,
304        forward: bool,
305    ) -> FFTResult<BenchmarkResult> {
306        // Create test data
307        let mut data = vec![Complex64::new(0.0, 0.0); size];
308        for (i, val) in data.iter_mut().enumerate().take(size) {
309            *val = Complex64::new(i as f64, (i * 2) as f64);
310        }
311
312        // Warm-up phase
313        for _ in 0..self.config.warmup {
314            match variant {
315                FftVariant::Standard => {
316                    #[cfg(feature = "oxifft")]
317                    {
318                        let input_oxi: Vec<OxiComplex<f64>> =
319                            data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
320                        let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
321
322                        let direction = if forward {
323                            Direction::Forward
324                        } else {
325                            Direction::Backward
326                        };
327                        let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
328                    }
329
330                    #[cfg(not(feature = "oxifft"))]
331                    {
332                        #[cfg(feature = "rustfft-backend")]
333                        {
334                            let mut planner = FftPlanner::new();
335                            let fft = if forward {
336                                planner.plan_fft_forward(size)
337                            } else {
338                                planner.plan_fft_inverse(size)
339                            };
340                            let mut buffer = data.clone();
341                            fft.process(&mut buffer);
342                        }
343                    }
344                }
345                FftVariant::InPlace => {
346                    #[cfg(feature = "oxifft")]
347                    {
348                        let input_oxi: Vec<OxiComplex<f64>> =
349                            data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
350                        let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
351
352                        let direction = if forward {
353                            Direction::Forward
354                        } else {
355                            Direction::Backward
356                        };
357                        let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
358                    }
359
360                    #[cfg(not(feature = "oxifft"))]
361                    {
362                        #[cfg(feature = "rustfft-backend")]
363                        {
364                            let mut planner = FftPlanner::new();
365                            let fft = if forward {
366                                planner.plan_fft_forward(size)
367                            } else {
368                                planner.plan_fft_inverse(size)
369                            };
370                            // Use in-place processing with scratch buffer
371                            let mut buffer = data.clone();
372                            let mut scratch =
373                                vec![Complex64::new(0.0, 0.0); fft.get_inplace_scratch_len()];
374                            fft.process_with_scratch(&mut buffer, &mut scratch);
375                        }
376                    }
377                }
378                FftVariant::Cached => {
379                    // Create a plan via the serialization manager
380                    let manager = PlanSerializationManager::new(&self.config.database_path);
381                    let plan_info = manager.create_plan_info(size, forward);
382                    let time = crate::plan_serialization::create_and_time_plan(size, forward);
383                    manager.record_plan_usage(&plan_info, time).unwrap_or(());
384                }
385                FftVariant::SplitRadix => {
386                    #[cfg(feature = "oxifft")]
387                    {
388                        // For now, use OxiFFT's standard algorithm
389                        let input_oxi: Vec<OxiComplex<f64>> =
390                            data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
391                        let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
392
393                        let direction = if forward {
394                            Direction::Forward
395                        } else {
396                            Direction::Backward
397                        };
398                        let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
399                    }
400
401                    #[cfg(not(feature = "oxifft"))]
402                    {
403                        #[cfg(feature = "rustfft-backend")]
404                        {
405                            // For now, this is just an example variant
406                            // In a real implementation, we'd use a specific split-radix algorithm
407                            let mut planner = FftPlanner::new();
408                            let fft = if forward {
409                                planner.plan_fft_forward(size)
410                            } else {
411                                planner.plan_fft_inverse(size)
412                            };
413                            let mut buffer = data.clone();
414                            fft.process(&mut buffer);
415                        }
416                    }
417                }
418            }
419        }
420
421        // Timing phase
422        let mut times = Vec::with_capacity(self.config.repetitions);
423
424        for _ in 0..self.config.repetitions {
425            let start = Instant::now();
426
427            match variant {
428                FftVariant::Standard => {
429                    #[cfg(feature = "oxifft")]
430                    {
431                        let input_oxi: Vec<OxiComplex<f64>> =
432                            data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
433                        let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
434
435                        let direction = if forward {
436                            Direction::Forward
437                        } else {
438                            Direction::Backward
439                        };
440                        let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
441                    }
442
443                    #[cfg(not(feature = "oxifft"))]
444                    {
445                        #[cfg(feature = "rustfft-backend")]
446                        {
447                            let mut planner = FftPlanner::new();
448                            let fft = if forward {
449                                planner.plan_fft_forward(size)
450                            } else {
451                                planner.plan_fft_inverse(size)
452                            };
453                            let mut buffer = data.clone();
454                            fft.process(&mut buffer);
455                        }
456                    }
457                }
458                FftVariant::InPlace => {
459                    #[cfg(feature = "oxifft")]
460                    {
461                        let input_oxi: Vec<OxiComplex<f64>> =
462                            data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
463                        let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
464
465                        let direction = if forward {
466                            Direction::Forward
467                        } else {
468                            Direction::Backward
469                        };
470                        let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
471                    }
472
473                    #[cfg(not(feature = "oxifft"))]
474                    {
475                        #[cfg(feature = "rustfft-backend")]
476                        {
477                            let mut planner = FftPlanner::new();
478                            let fft = if forward {
479                                planner.plan_fft_forward(size)
480                            } else {
481                                planner.plan_fft_inverse(size)
482                            };
483                            // Use in-place processing with scratch buffer
484                            let mut buffer = data.clone();
485                            let mut scratch =
486                                vec![Complex64::new(0.0, 0.0); fft.get_inplace_scratch_len()];
487                            fft.process_with_scratch(&mut buffer, &mut scratch);
488                        }
489                    }
490                }
491                FftVariant::Cached => {
492                    #[cfg(feature = "oxifft")]
493                    {
494                        let input_oxi: Vec<OxiComplex<f64>> =
495                            data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
496                        let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
497
498                        let direction = if forward {
499                            Direction::Forward
500                        } else {
501                            Direction::Backward
502                        };
503                        let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
504                    }
505
506                    #[cfg(not(feature = "oxifft"))]
507                    {
508                        #[cfg(feature = "rustfft-backend")]
509                        {
510                            // Use the plan cache
511                            let mut planner = FftPlanner::new();
512                            let fft = if forward {
513                                planner.plan_fft_forward(size)
514                            } else {
515                                planner.plan_fft_inverse(size)
516                            };
517                            let mut buffer = data.clone();
518                            fft.process(&mut buffer);
519                        }
520                    }
521                }
522                FftVariant::SplitRadix => {
523                    #[cfg(feature = "oxifft")]
524                    {
525                        let input_oxi: Vec<OxiComplex<f64>> =
526                            data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
527                        let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
528
529                        let direction = if forward {
530                            Direction::Forward
531                        } else {
532                            Direction::Backward
533                        };
534                        let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
535                    }
536
537                    #[cfg(not(feature = "oxifft"))]
538                    {
539                        #[cfg(feature = "rustfft-backend")]
540                        {
541                            // Placeholder for split-radix implementation
542                            let mut planner = FftPlanner::new();
543                            let fft = if forward {
544                                planner.plan_fft_forward(size)
545                            } else {
546                                planner.plan_fft_inverse(size)
547                            };
548                            let mut buffer = data.clone();
549                            fft.process(&mut buffer);
550                        }
551                    }
552                }
553            }
554
555            let elapsed = start.elapsed();
556            times.push(elapsed.as_nanos() as u64);
557        }
558
559        // Calculate statistics
560        let avg_time = times.iter().sum::<u64>() / times.len() as u64;
561        let min_time = *times.iter().min().unwrap_or(&0);
562
563        // Calculate standard deviation
564        let variance = times
565            .iter()
566            .map(|&t| {
567                let diff = t as f64 - avg_time as f64;
568                diff * diff
569            })
570            .sum::<f64>()
571            / times.len() as f64;
572        let std_dev = variance.sqrt();
573
574        Ok(BenchmarkResult {
575            size,
576            variant,
577            forward,
578            avg_time_ns: avg_time,
579            min_time_ns: min_time,
580            std_dev_ns: std_dev,
581            system_info: self.detect_system_info(),
582        })
583    }
584
585    /// Detect system information for result matching
586    fn detect_system_info(&self) -> SystemInfo {
587        // This is a simplified version - a real implementation would
588        // detect actual CPU model, features, etc.
589        SystemInfo {
590            cpu_model: String::from("Unknown"),
591            num_cores: num_cpus::get(),
592            architecture: std::env::consts::ARCH.to_string(),
593            cpu_features: detect_cpu_features(),
594        }
595    }
596
597    /// Update the best algorithms based on benchmark results
598    fn update_best_algorithms(&mut self) {
599        // Clear existing best algorithms
600        self.database.best_algorithms.clear();
601
602        // Group results by size and direction
603        let mut grouped: HashMap<(usize, bool), Vec<&BenchmarkResult>> = HashMap::new();
604        for result in &self.database.results {
605            grouped
606                .entry((result.size, result.forward))
607                .or_default()
608                .push(result);
609        }
610
611        // Find the best algorithm for each group
612        for ((size, forward), results) in grouped {
613            if let Some(best) = results.iter().min_by_key(|r| r.avg_time_ns) {
614                self.database
615                    .best_algorithms
616                    .insert((size, forward), best.variant);
617            }
618        }
619    }
620
621    /// Get the best FFT variant for the given size and direction
622    pub fn get_best_variant(&self, size: usize, forward: bool) -> FftVariant {
623        if !self.enabled {
624            return FftVariant::Standard;
625        }
626
627        // Look for exact size match
628        if let Some(&variant) = self.database.best_algorithms.get(&(size, forward)) {
629            return variant;
630        }
631
632        // Look for closest size match
633        let mut closest_size = 0;
634        let mut min_diff = usize::MAX;
635
636        for &(s, f) in self.database.best_algorithms.keys() {
637            if f == forward {
638                let diff = s.abs_diff(size);
639                if diff < min_diff {
640                    min_diff = diff;
641                    closest_size = s;
642                }
643            }
644        }
645
646        if closest_size > 0 {
647            if let Some(&variant) = self.database.best_algorithms.get(&(closest_size, forward)) {
648                return variant;
649            }
650        }
651
652        // Default to standard FFT if no match
653        FftVariant::Standard
654    }
655
656    /// Run FFT with optimal algorithm selection
657    pub fn run_optimal_fft<T>(
658        &self,
659        input: &[T],
660        size: Option<usize>,
661        forward: bool,
662    ) -> FFTResult<Vec<Complex64>>
663    where
664        T: Clone + Into<Complex64>,
665    {
666        let actual_size = size.unwrap_or(input.len());
667        let variant = self.get_best_variant(actual_size, forward);
668
669        // Convert input to complex
670        let mut buffer: Vec<Complex64> = input.iter().map(|x| x.clone().into()).collect();
671        // Pad if necessary
672        if buffer.len() < actual_size {
673            buffer.resize(actual_size, Complex64::new(0.0, 0.0));
674        }
675
676        #[cfg(feature = "oxifft")]
677        {
678            let input_oxi: Vec<OxiComplex<f64>> =
679                buffer.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
680            let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); actual_size];
681
682            let direction = if forward {
683                Direction::Forward
684            } else {
685                Direction::Backward
686            };
687            oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction)?;
688
689            // Copy result back to buffer
690            for (i, val) in output.iter().enumerate() {
691                buffer[i] = Complex64::new(val.re, val.im);
692            }
693        }
694
695        #[cfg(not(feature = "oxifft"))]
696        {
697            #[cfg(feature = "rustfft-backend")]
698            {
699                match variant {
700                    FftVariant::Standard => {
701                        let mut planner = FftPlanner::new();
702                        let fft = if forward {
703                            planner.plan_fft_forward(actual_size)
704                        } else {
705                            planner.plan_fft_inverse(actual_size)
706                        };
707                        fft.process(&mut buffer);
708                    }
709                    FftVariant::InPlace => {
710                        let mut planner = FftPlanner::new();
711                        let fft = if forward {
712                            planner.plan_fft_forward(actual_size)
713                        } else {
714                            planner.plan_fft_inverse(actual_size)
715                        };
716                        let mut scratch =
717                            vec![Complex64::new(0.0, 0.0); fft.get_inplace_scratch_len()];
718                        fft.process_with_scratch(&mut buffer, &mut scratch);
719                    }
720                    FftVariant::Cached => {
721                        // Use the plan cache via PlanSerializationManager
722                        // Create a plan directly - manager is not needed here
723                        let (plan_, _) =
724                            crate::plan_serialization::create_and_time_plan(actual_size, forward);
725                        plan_.process(&mut buffer);
726                    }
727                    FftVariant::SplitRadix => {
728                        // Placeholder for split-radix FFT
729                        let mut planner = FftPlanner::new();
730                        let fft = if forward {
731                            planner.plan_fft_forward(actual_size)
732                        } else {
733                            planner.plan_fft_inverse(actual_size)
734                        };
735                        fft.process(&mut buffer);
736                    }
737                }
738            }
739
740            {
741                return Err(FFTError::ComputationError(
742                    "No FFT backend available. Enable either 'oxifft' or 'rustfft-backend' feature.".to_string()
743                ));
744            }
745        }
746
747        // Scale inverse FFT by 1/N if required
748        if !forward {
749            let scale = 1.0 / (actual_size as f64);
750            for val in &mut buffer {
751                *val *= scale;
752            }
753        }
754
755        Ok(buffer)
756    }
757}
758
759/// Detect CPU features for result matching
760#[allow(dead_code)]
761fn detect_cpu_features() -> Vec<String> {
762    let mut features = Vec::new();
763
764    // Target-specific feature detection
765    #[cfg(target_arch = "x86_64")]
766    {
767        #[cfg(target_feature = "sse")]
768        features.push("sse".to_string());
769
770        #[cfg(target_feature = "sse2")]
771        features.push("sse2".to_string());
772
773        #[cfg(target_feature = "sse3")]
774        features.push("sse3".to_string());
775
776        #[cfg(target_feature = "sse4.1")]
777        features.push("sse4.1".to_string());
778
779        #[cfg(target_feature = "sse4.2")]
780        features.push("sse4.2".to_string());
781
782        #[cfg(target_feature = "avx")]
783        features.push("avx".to_string());
784
785        #[cfg(target_feature = "avx2")]
786        features.push("avx2".to_string());
787
788        #[cfg(target_feature = "fma")]
789        features.push("fma".to_string());
790    }
791
792    // ARM-specific features
793    #[cfg(target_arch = "aarch64")]
794    {
795        #[cfg(target_feature = "neon")]
796        features.push("neon".to_string());
797    }
798
799    // Add more architecture-specific features if needed
800
801    features
802}
803
804// ============================================================================
805// Enhanced Auto-Selection (v0.2.0)
806// ============================================================================
807
808/// Integrated auto-selection that combines algorithm selection with auto-tuning
809pub struct IntegratedAutoSelector {
810    /// Algorithm selector for input-characteristic based selection
811    selector: crate::algorithm_selector::AlgorithmSelector,
812    /// Auto-tuner for performance-based selection
813    tuner: AutoTuner,
814    /// Whether to prefer learned performance data
815    prefer_learned: bool,
816}
817
818impl Default for IntegratedAutoSelector {
819    fn default() -> Self {
820        Self::new()
821    }
822}
823
824impl IntegratedAutoSelector {
825    /// Create a new integrated auto-selector
826    pub fn new() -> Self {
827        Self {
828            selector: crate::algorithm_selector::AlgorithmSelector::new(),
829            tuner: AutoTuner::new(),
830            prefer_learned: true,
831        }
832    }
833
834    /// Create with custom configuration
835    pub fn with_config(
836        selector_config: crate::algorithm_selector::SelectionConfig,
837        tuner_config: AutoTuneConfig,
838        prefer_learned: bool,
839    ) -> Self {
840        Self {
841            selector: crate::algorithm_selector::AlgorithmSelector::with_config(selector_config),
842            tuner: AutoTuner::with_config(tuner_config),
843            prefer_learned,
844        }
845    }
846
847    /// Select the best algorithm for the given size
848    pub fn select(&self, size: usize, forward: bool) -> FFTResult<SelectionResult> {
849        // First, check if we have learned performance data
850        if self.prefer_learned && self.tuner.is_enabled() {
851            let variant = self.tuner.get_best_variant(size, forward);
852            if variant != FftVariant::Standard {
853                // We have learned data, use it
854                return Ok(SelectionResult {
855                    algorithm: variant_to_algorithm(variant),
856                    variant,
857                    source: SelectionSource::Learned,
858                    confidence: 0.9,
859                    recommendation: self.selector.select_algorithm(size, forward).ok(),
860                });
861            }
862        }
863
864        // Fall back to input-characteristic based selection
865        let recommendation = self.selector.select_algorithm(size, forward)?;
866        let variant = algorithm_to_variant(recommendation.algorithm);
867
868        Ok(SelectionResult {
869            algorithm: recommendation.algorithm,
870            variant,
871            source: SelectionSource::Characteristic,
872            confidence: recommendation.confidence,
873            recommendation: Some(recommendation),
874        })
875    }
876
877    /// Run auto-tuning for a range of sizes
878    pub fn auto_tune(&mut self, sizes: &[usize]) -> FFTResult<()> {
879        // Generate size range from provided sizes
880        if sizes.is_empty() {
881            return Ok(());
882        }
883
884        let min = *sizes.iter().min().unwrap_or(&16);
885        let max = *sizes.iter().max().unwrap_or(&8192);
886
887        let config = AutoTuneConfig {
888            sizes: SizeRange {
889                min,
890                max,
891                step: SizeStep::Custom(sizes.to_vec()),
892            },
893            ..Default::default()
894        };
895
896        self.tuner = AutoTuner::with_config(config);
897        self.tuner.run_benchmarks()
898    }
899
900    /// Execute FFT with optimal algorithm
901    pub fn execute<T>(
902        &self,
903        input: &[T],
904        size: Option<usize>,
905        forward: bool,
906    ) -> FFTResult<Vec<Complex64>>
907    where
908        T: Clone + Into<Complex64>,
909    {
910        let actual_size = size.unwrap_or(input.len());
911        let selection = self.select(actual_size, forward)?;
912
913        // Use the tuner's run_optimal_fft which handles the actual execution
914        self.tuner.run_optimal_fft(input, size, forward)
915    }
916
917    /// Get the algorithm selector
918    pub fn selector(&self) -> &crate::algorithm_selector::AlgorithmSelector {
919        &self.selector
920    }
921
922    /// Get the auto-tuner
923    pub fn tuner(&self) -> &AutoTuner {
924        &self.tuner
925    }
926}
927
928/// Result of algorithm selection
929#[derive(Debug, Clone)]
930pub struct SelectionResult {
931    /// Selected algorithm
932    pub algorithm: crate::algorithm_selector::FftAlgorithm,
933    /// Corresponding FFT variant
934    pub variant: FftVariant,
935    /// Source of the selection
936    pub source: SelectionSource,
937    /// Confidence in the selection
938    pub confidence: f64,
939    /// Full recommendation (if available)
940    pub recommendation: Option<crate::algorithm_selector::AlgorithmRecommendation>,
941}
942
943/// Source of algorithm selection
944#[derive(Debug, Clone, Copy, PartialEq, Eq)]
945pub enum SelectionSource {
946    /// Selected based on learned performance data
947    Learned,
948    /// Selected based on input characteristics
949    Characteristic,
950    /// Forced by configuration
951    Forced,
952    /// Default fallback
953    Default,
954}
955
956/// Convert FftVariant to FftAlgorithm
957fn variant_to_algorithm(variant: FftVariant) -> crate::algorithm_selector::FftAlgorithm {
958    use crate::algorithm_selector::FftAlgorithm;
959    match variant {
960        FftVariant::Standard => FftAlgorithm::MixedRadix,
961        FftVariant::InPlace => FftAlgorithm::InPlace,
962        FftVariant::Cached => FftAlgorithm::MixedRadix,
963        FftVariant::SplitRadix => FftAlgorithm::SplitRadix,
964    }
965}
966
967/// Convert FftAlgorithm to FftVariant
968fn algorithm_to_variant(algorithm: crate::algorithm_selector::FftAlgorithm) -> FftVariant {
969    use crate::algorithm_selector::FftAlgorithm;
970    match algorithm {
971        FftAlgorithm::SplitRadix => FftVariant::SplitRadix,
972        FftAlgorithm::InPlace => FftVariant::InPlace,
973        _ => FftVariant::Standard,
974    }
975}
976
977/// Auto-select the best FFT algorithm for the given input
978///
979/// This is a convenience function that uses the integrated auto-selector
980/// to determine the optimal algorithm based on input characteristics and
981/// learned performance data.
982///
983/// # Arguments
984///
985/// * `size` - FFT size
986/// * `forward` - Whether this is a forward (true) or inverse (false) transform
987///
988/// # Returns
989///
990/// The recommended algorithm and metadata
991///
992/// # Example
993///
994/// ```rust
995/// use scirs2_fft::auto_tuning::auto_select_algorithm;
996///
997/// let result = auto_select_algorithm(1024, true).expect("Selection failed");
998/// println!("Recommended: {:?}", result.algorithm);
999/// ```
1000pub fn auto_select_algorithm(size: usize, forward: bool) -> FFTResult<SelectionResult> {
1001    let selector = IntegratedAutoSelector::new();
1002    selector.select(size, forward)
1003}
1004
1005/// Execute FFT with automatic algorithm selection
1006///
1007/// This function automatically selects the best algorithm based on
1008/// input characteristics and executes the FFT.
1009///
1010/// # Arguments
1011///
1012/// * `input` - Input data
1013/// * `size` - Optional FFT size (if different from input length)
1014/// * `forward` - Whether this is a forward (true) or inverse (false) transform
1015///
1016/// # Returns
1017///
1018/// The FFT result as a vector of complex numbers
1019///
1020/// # Example
1021///
1022/// ```rust
1023/// use scirs2_fft::auto_tuning::auto_fft;
1024///
1025/// let signal = vec![1.0, 2.0, 3.0, 4.0];
1026/// let spectrum = auto_fft(&signal, None, true).expect("FFT failed");
1027/// ```
1028pub fn auto_fft<T>(input: &[T], size: Option<usize>, forward: bool) -> FFTResult<Vec<Complex64>>
1029where
1030    T: Clone + Into<Complex64>,
1031{
1032    let selector = IntegratedAutoSelector::new();
1033    selector.execute(input, size, forward)
1034}
1035
1036#[cfg(test)]
1037mod tests {
1038    use super::*;
1039    use tempfile::tempdir;
1040
1041    #[test]
1042    fn test_size_generation() {
1043        // Test powers of two
1044        let config = AutoTuneConfig {
1045            sizes: SizeRange {
1046                min: 8,
1047                max: 64,
1048                step: SizeStep::PowersOfTwo,
1049            },
1050            ..Default::default()
1051        };
1052        let tuner = AutoTuner::with_config(config);
1053        let sizes = tuner.generate_sizes();
1054        assert_eq!(sizes, vec![8, 16, 32, 64]);
1055
1056        // Test linear steps
1057        let config = AutoTuneConfig {
1058            sizes: SizeRange {
1059                min: 10,
1060                max: 30,
1061                step: SizeStep::Linear(5),
1062            },
1063            ..Default::default()
1064        };
1065        let tuner = AutoTuner::with_config(config);
1066        let sizes = tuner.generate_sizes();
1067        assert_eq!(sizes, vec![10, 15, 20, 25, 30]);
1068
1069        // Test exponential steps
1070        let config = AutoTuneConfig {
1071            sizes: SizeRange {
1072                min: 10,
1073                max: 100,
1074                step: SizeStep::Exponential(2.0),
1075            },
1076            ..Default::default()
1077        };
1078        let tuner = AutoTuner::with_config(config);
1079        let sizes = tuner.generate_sizes();
1080        assert_eq!(sizes, vec![10, 20, 40, 80]);
1081
1082        // Test custom sizes
1083        let config = AutoTuneConfig {
1084            sizes: SizeRange {
1085                min: 10,
1086                max: 100,
1087                step: SizeStep::Custom(vec![5, 15, 25, 50, 150]),
1088            },
1089            ..Default::default()
1090        };
1091        let tuner = AutoTuner::with_config(config);
1092        let sizes = tuner.generate_sizes();
1093        assert_eq!(sizes, vec![15, 25, 50]);
1094    }
1095
1096    #[test]
1097    fn test_auto_tuner_basic() {
1098        // Create a temporary directory for test
1099        let temp_dir = tempdir().expect("Operation failed");
1100        let db_path = temp_dir.path().join("test_tuning_db.json");
1101
1102        // Create configuration with minimal benchmarking
1103        let config = AutoTuneConfig {
1104            sizes: SizeRange {
1105                min: 16,
1106                max: 32,
1107                step: SizeStep::PowersOfTwo,
1108            },
1109            repetitions: 2,
1110            warmup: 1,
1111            variants: vec![FftVariant::Standard, FftVariant::InPlace],
1112            database_path: db_path.clone(),
1113        };
1114
1115        let mut tuner = AutoTuner::with_config(config);
1116
1117        // Run minimal benchmarks (this is fast enough for a test)
1118        match tuner.run_benchmarks() {
1119            Ok(_) => {
1120                // Verify database file was created
1121                assert!(db_path.exists());
1122
1123                // Test getting a best variant
1124                let variant = tuner.get_best_variant(16, true);
1125                assert!(matches!(
1126                    variant,
1127                    FftVariant::Standard | FftVariant::InPlace
1128                ));
1129            }
1130            Err(e) => {
1131                // Benchmark may fail in some environments, just log and continue
1132                println!("Benchmark failed: {e}");
1133            }
1134        }
1135    }
1136}