Skip to main content

scirs2_fft/
algorithm_selector.rs

1//! FFT Algorithm Auto-Selection Module
2//!
3//! This module provides intelligent automatic selection of FFT algorithms based on:
4//! - Input size characteristics (power-of-2, prime, smooth numbers)
5//! - Hardware capabilities (cache size, SIMD support, core count)
6//! - Memory constraints
7//! - Historical performance data
8//!
9//! # Features
10//!
11//! - **Input Analysis**: Detects optimal algorithm based on input size properties
12//! - **Cache-Aware Selection**: Considers L1/L2/L3 cache sizes for optimal performance
13//! - **Memory Optimization**: Selects memory-efficient algorithms for large inputs
14//! - **Hardware Detection**: Adapts to available SIMD instructions and core count
15//! - **Performance Profiling**: Tracks and learns from execution history
16//!
17//! # Example
18//!
19//! ```rust,no_run
20//! use scirs2_fft::algorithm_selector::{AlgorithmSelector, SelectionConfig};
21//!
22//! let selector = AlgorithmSelector::new();
23//! let recommendation = selector.select_algorithm(1024, true).expect("Selection failed");
24//! println!("Recommended: {:?}", recommendation.algorithm);
25//! ```
26
27use crate::error::{FFTError, FFTResult};
28use serde::{Deserialize, Serialize};
29use std::collections::HashMap;
30use std::sync::{Arc, RwLock};
31use std::time::{Duration, Instant};
32
33/// FFT Algorithm variants available for selection
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
35pub enum FftAlgorithm {
36    /// Standard Cooley-Tukey radix-2 FFT (optimal for power-of-2 sizes)
37    CooleyTukeyRadix2,
38    /// Radix-4 FFT (faster for sizes that are powers of 4)
39    Radix4,
40    /// Split-radix FFT (good balance of speed and memory)
41    SplitRadix,
42    /// Mixed-radix FFT (handles non-power-of-2 sizes efficiently)
43    #[default]
44    MixedRadix,
45    /// Bluestein's algorithm (handles prime and arbitrary sizes)
46    Bluestein,
47    /// Rader's algorithm (efficient for prime sizes)
48    Rader,
49    /// Winograd FFT (minimal multiplications)
50    Winograd,
51    /// Good-Thomas PFA (prime factor algorithm)
52    GoodThomas,
53    /// Streaming FFT (memory-efficient for very large inputs)
54    Streaming,
55    /// Cache-oblivious FFT (optimized cache behavior)
56    CacheOblivious,
57    /// In-place FFT (minimal memory overhead)
58    InPlace,
59    /// SIMD-optimized FFT
60    SimdOptimized,
61    /// Parallel FFT (multi-threaded)
62    Parallel,
63    /// Hybrid (combines multiple algorithms)
64    Hybrid,
65}
66
67impl std::fmt::Display for FftAlgorithm {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        match self {
70            Self::CooleyTukeyRadix2 => write!(f, "Cooley-Tukey Radix-2"),
71            Self::Radix4 => write!(f, "Radix-4"),
72            Self::SplitRadix => write!(f, "Split-Radix"),
73            Self::MixedRadix => write!(f, "Mixed-Radix"),
74            Self::Bluestein => write!(f, "Bluestein"),
75            Self::Rader => write!(f, "Rader"),
76            Self::Winograd => write!(f, "Winograd"),
77            Self::GoodThomas => write!(f, "Good-Thomas PFA"),
78            Self::Streaming => write!(f, "Streaming"),
79            Self::CacheOblivious => write!(f, "Cache-Oblivious"),
80            Self::InPlace => write!(f, "In-Place"),
81            Self::SimdOptimized => write!(f, "SIMD-Optimized"),
82            Self::Parallel => write!(f, "Parallel"),
83            Self::Hybrid => write!(f, "Hybrid"),
84        }
85    }
86}
87
88/// Input size characteristics
89#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
90pub enum SizeCharacteristic {
91    /// Size is a power of 2 (e.g., 256, 512, 1024)
92    PowerOf2,
93    /// Size is a power of 4 (e.g., 256, 1024, 4096)
94    PowerOf4,
95    /// Size is a prime number
96    Prime,
97    /// Size is a product of small primes (2, 3, 5, 7, 11)
98    Smooth,
99    /// Size is a product of coprime factors
100    Composite,
101    /// Size has large prime factors
102    HardSize,
103}
104
105/// Detected input characteristics for algorithm selection
106#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct InputCharacteristics {
108    /// Input size
109    pub size: usize,
110    /// Size characteristic type
111    pub size_type: SizeCharacteristic,
112    /// Whether size is a power of 2
113    pub is_power_of_2: bool,
114    /// Whether size is a power of 4
115    pub is_power_of_4: bool,
116    /// Whether size is prime
117    pub is_prime: bool,
118    /// Prime factorization (factor -> power)
119    pub prime_factors: HashMap<usize, usize>,
120    /// Largest prime factor
121    pub largest_prime_factor: usize,
122    /// Number of distinct prime factors
123    pub num_distinct_factors: usize,
124    /// Whether size is "smooth" (only small prime factors)
125    pub is_smooth: bool,
126    /// Maximum prime factor for smooth classification
127    pub smooth_bound: usize,
128    /// Estimated memory requirement in bytes
129    pub estimated_memory_bytes: usize,
130    /// Fits in L1 cache
131    pub fits_l1_cache: bool,
132    /// Fits in L2 cache
133    pub fits_l2_cache: bool,
134    /// Fits in L3 cache
135    pub fits_l3_cache: bool,
136}
137
138impl InputCharacteristics {
139    /// Analyze input size and return characteristics
140    pub fn analyze(size: usize, cache_info: &CacheInfo) -> Self {
141        let is_power_of_2 = size.is_power_of_two();
142        let is_power_of_4 = is_power_of_2 && (size.trailing_zeros() % 2 == 0);
143        let prime_factors = factorize(size);
144        let is_prime = prime_factors.len() == 1 && prime_factors.get(&size).copied() == Some(1);
145        let largest_prime_factor = *prime_factors.keys().max().unwrap_or(&1);
146        let num_distinct_factors = prime_factors.len();
147
148        // Check if smooth (only factors <= 11)
149        let smooth_bound = 11;
150        let is_smooth = prime_factors.keys().all(|&p| p <= smooth_bound);
151
152        // Estimate memory: complex64 = 16 bytes per element
153        let estimated_memory_bytes = size * 16;
154
155        // Determine size type
156        let size_type = if is_power_of_4 {
157            SizeCharacteristic::PowerOf4
158        } else if is_power_of_2 {
159            SizeCharacteristic::PowerOf2
160        } else if is_prime {
161            SizeCharacteristic::Prime
162        } else if is_smooth {
163            SizeCharacteristic::Smooth
164        } else if largest_prime_factor <= 1000 {
165            SizeCharacteristic::Composite
166        } else {
167            SizeCharacteristic::HardSize
168        };
169
170        Self {
171            size,
172            size_type,
173            is_power_of_2,
174            is_power_of_4,
175            is_prime,
176            prime_factors,
177            largest_prime_factor,
178            num_distinct_factors,
179            is_smooth,
180            smooth_bound,
181            estimated_memory_bytes,
182            fits_l1_cache: estimated_memory_bytes <= cache_info.l1_size,
183            fits_l2_cache: estimated_memory_bytes <= cache_info.l2_size,
184            fits_l3_cache: estimated_memory_bytes <= cache_info.l3_size,
185        }
186    }
187}
188
189/// Hardware information for algorithm selection
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct HardwareInfo {
192    /// Number of physical cores
193    pub num_cores: usize,
194    /// Number of logical processors (including hyperthreading)
195    pub num_logical_processors: usize,
196    /// Cache information
197    pub cache_info: CacheInfo,
198    /// SIMD capabilities
199    pub simd_capabilities: SimdCapabilities,
200    /// CPU architecture
201    pub architecture: String,
202    /// Available memory in bytes
203    pub available_memory: usize,
204}
205
206impl Default for HardwareInfo {
207    fn default() -> Self {
208        Self::detect()
209    }
210}
211
212impl HardwareInfo {
213    /// Detect hardware capabilities
214    pub fn detect() -> Self {
215        let num_cores = num_cpus::get_physical();
216        let num_logical_processors = num_cpus::get();
217
218        Self {
219            num_cores,
220            num_logical_processors,
221            cache_info: CacheInfo::detect(),
222            simd_capabilities: SimdCapabilities::detect(),
223            architecture: std::env::consts::ARCH.to_string(),
224            available_memory: estimate_available_memory(),
225        }
226    }
227}
228
229/// CPU cache information
230#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
231pub struct CacheInfo {
232    /// L1 data cache size in bytes
233    pub l1_size: usize,
234    /// L2 cache size in bytes
235    pub l2_size: usize,
236    /// L3 cache size in bytes
237    pub l3_size: usize,
238    /// Cache line size in bytes
239    pub cache_line_size: usize,
240}
241
242impl Default for CacheInfo {
243    fn default() -> Self {
244        Self::detect()
245    }
246}
247
248impl CacheInfo {
249    /// Detect cache sizes (uses conservative estimates if detection fails)
250    pub fn detect() -> Self {
251        // Conservative default estimates
252        // These are typical for modern desktop CPUs
253        Self {
254            l1_size: 32 * 1024,       // 32 KB
255            l2_size: 256 * 1024,      // 256 KB
256            l3_size: 8 * 1024 * 1024, // 8 MB
257            cache_line_size: 64,      // 64 bytes
258        }
259    }
260
261    /// Create with custom cache sizes
262    pub fn custom(l1: usize, l2: usize, l3: usize, line_size: usize) -> Self {
263        Self {
264            l1_size: l1,
265            l2_size: l2,
266            l3_size: l3,
267            cache_line_size: line_size,
268        }
269    }
270}
271
272/// SIMD capability detection
273#[derive(Debug, Clone, Serialize, Deserialize)]
274pub struct SimdCapabilities {
275    /// SSE support
276    pub has_sse: bool,
277    /// SSE2 support
278    pub has_sse2: bool,
279    /// SSE3 support
280    pub has_sse3: bool,
281    /// SSE4.1 support
282    pub has_sse4_1: bool,
283    /// SSE4.2 support
284    pub has_sse4_2: bool,
285    /// AVX support
286    pub has_avx: bool,
287    /// AVX2 support
288    pub has_avx2: bool,
289    /// AVX-512 support
290    pub has_avx512: bool,
291    /// FMA support
292    pub has_fma: bool,
293    /// ARM NEON support
294    pub has_neon: bool,
295    /// Vector register width in bits
296    pub vector_width: usize,
297}
298
299impl Default for SimdCapabilities {
300    fn default() -> Self {
301        Self::detect()
302    }
303}
304
305impl SimdCapabilities {
306    /// Detect SIMD capabilities
307    pub fn detect() -> Self {
308        let mut caps = Self {
309            has_sse: false,
310            has_sse2: false,
311            has_sse3: false,
312            has_sse4_1: false,
313            has_sse4_2: false,
314            has_avx: false,
315            has_avx2: false,
316            has_avx512: false,
317            has_fma: false,
318            has_neon: false,
319            vector_width: 64, // Scalar default
320        };
321
322        #[cfg(target_arch = "x86_64")]
323        {
324            #[cfg(target_feature = "sse")]
325            {
326                caps.has_sse = true;
327                caps.vector_width = 128;
328            }
329            #[cfg(target_feature = "sse2")]
330            {
331                caps.has_sse2 = true;
332            }
333            #[cfg(target_feature = "sse3")]
334            {
335                caps.has_sse3 = true;
336            }
337            #[cfg(target_feature = "sse4.1")]
338            {
339                caps.has_sse4_1 = true;
340            }
341            #[cfg(target_feature = "sse4.2")]
342            {
343                caps.has_sse4_2 = true;
344            }
345            #[cfg(target_feature = "avx")]
346            {
347                caps.has_avx = true;
348                caps.vector_width = 256;
349            }
350            #[cfg(target_feature = "avx2")]
351            {
352                caps.has_avx2 = true;
353            }
354            #[cfg(target_feature = "fma")]
355            {
356                caps.has_fma = true;
357            }
358        }
359
360        #[cfg(target_arch = "aarch64")]
361        {
362            #[cfg(target_feature = "neon")]
363            {
364                caps.has_neon = true;
365                caps.vector_width = 128;
366            }
367        }
368
369        caps
370    }
371
372    /// Check if SIMD is available
373    pub fn simd_available(&self) -> bool {
374        self.has_sse2 || self.has_avx || self.has_neon
375    }
376
377    /// Get optimal vector width for complex f64
378    pub fn optimal_complex_vector_count(&self) -> usize {
379        // Complex f64 = 16 bytes = 128 bits
380        self.vector_width / 128
381    }
382}
383
384/// Algorithm recommendation with metadata
385#[derive(Debug, Clone, Serialize, Deserialize)]
386pub struct AlgorithmRecommendation {
387    /// Recommended primary algorithm
388    pub algorithm: FftAlgorithm,
389    /// Fallback algorithm if primary is not suitable
390    pub fallback: Option<FftAlgorithm>,
391    /// Confidence score (0.0 to 1.0)
392    pub confidence: f64,
393    /// Estimated execution time (nanoseconds)
394    pub estimated_time_ns: Option<u64>,
395    /// Estimated memory usage (bytes)
396    pub estimated_memory_bytes: usize,
397    /// Reasoning for the selection
398    pub reasoning: Vec<String>,
399    /// Whether to use parallel execution
400    pub use_parallel: bool,
401    /// Recommended number of threads
402    pub recommended_threads: usize,
403    /// Whether to use SIMD optimization
404    pub use_simd: bool,
405    /// Whether to use in-place computation
406    pub use_inplace: bool,
407    /// Input characteristics that led to this recommendation
408    pub input_characteristics: InputCharacteristics,
409}
410
411/// Performance history entry
412#[derive(Debug, Clone, Serialize, Deserialize)]
413pub struct PerformanceEntry {
414    /// FFT size
415    pub size: usize,
416    /// Algorithm used
417    pub algorithm: FftAlgorithm,
418    /// Whether it was a forward transform
419    pub forward: bool,
420    /// Execution time in nanoseconds
421    pub execution_time_ns: u64,
422    /// Peak memory usage in bytes
423    pub peak_memory_bytes: usize,
424    /// Timestamp
425    pub timestamp: u64,
426    /// Hardware info hash for matching
427    pub hardware_hash: u64,
428}
429
430/// Configuration for algorithm selection
431#[derive(Debug, Clone)]
432pub struct SelectionConfig {
433    /// Prefer memory efficiency over speed
434    pub prefer_memory_efficiency: bool,
435    /// Maximum memory budget in bytes (0 = unlimited)
436    pub max_memory_bytes: usize,
437    /// Minimum parallel size threshold
438    pub min_parallel_size: usize,
439    /// Enable performance learning
440    pub enable_learning: bool,
441    /// Maximum threads to use
442    pub max_threads: usize,
443    /// Force specific algorithm (None = auto-select)
444    pub force_algorithm: Option<FftAlgorithm>,
445    /// Enable SIMD optimization
446    pub enable_simd: bool,
447    /// Prefer in-place computation
448    pub prefer_inplace: bool,
449    /// Cache-aware selection
450    pub cache_aware: bool,
451}
452
453impl Default for SelectionConfig {
454    fn default() -> Self {
455        Self {
456            prefer_memory_efficiency: false,
457            max_memory_bytes: 0,
458            min_parallel_size: 65536, // 64K elements
459            enable_learning: true,
460            max_threads: 0, // 0 = auto
461            force_algorithm: None,
462            enable_simd: true,
463            prefer_inplace: false,
464            cache_aware: true,
465        }
466    }
467}
468
469/// Performance history database
470#[derive(Debug, Default)]
471pub struct PerformanceHistory {
472    /// History entries indexed by (size, algorithm, forward)
473    entries: HashMap<(usize, FftAlgorithm, bool), Vec<PerformanceEntry>>,
474    /// Best known algorithm for each size
475    best_algorithms: HashMap<(usize, bool), FftAlgorithm>,
476}
477
478impl PerformanceHistory {
479    /// Create new empty history
480    pub fn new() -> Self {
481        Self::default()
482    }
483
484    /// Record a performance measurement
485    pub fn record(&mut self, entry: PerformanceEntry) {
486        let key = (entry.size, entry.algorithm, entry.forward);
487        self.entries.entry(key).or_default().push(entry.clone());
488
489        // Update best algorithm if this is faster
490        let size_key = (entry.size, entry.forward);
491        let should_update = match self.best_algorithms.get(&size_key) {
492            None => true,
493            Some(&best_algo) => {
494                let best_key = (entry.size, best_algo, entry.forward);
495                if let Some(best_entries) = self.entries.get(&best_key) {
496                    let best_avg = Self::average_time(best_entries);
497                    let current_avg = Self::average_time(std::slice::from_ref(&entry));
498                    current_avg < best_avg
499                } else {
500                    true
501                }
502            }
503        };
504
505        if should_update {
506            self.best_algorithms.insert(size_key, entry.algorithm);
507        }
508    }
509
510    /// Get best algorithm for a size
511    pub fn get_best(&self, size: usize, forward: bool) -> Option<FftAlgorithm> {
512        // First check exact match
513        if let Some(&algo) = self.best_algorithms.get(&(size, forward)) {
514            return Some(algo);
515        }
516
517        // Find closest size
518        let mut closest_size = 0;
519        let mut min_diff = usize::MAX;
520
521        for &(s, f) in self.best_algorithms.keys() {
522            if f == forward {
523                let diff = s.abs_diff(size);
524                if diff < min_diff {
525                    min_diff = diff;
526                    closest_size = s;
527                }
528            }
529        }
530
531        if closest_size > 0 {
532            self.best_algorithms.get(&(closest_size, forward)).copied()
533        } else {
534            None
535        }
536    }
537
538    /// Get average execution time for entries
539    fn average_time(entries: &[PerformanceEntry]) -> u64 {
540        if entries.is_empty() {
541            return u64::MAX;
542        }
543        entries.iter().map(|e| e.execution_time_ns).sum::<u64>() / entries.len() as u64
544    }
545
546    /// Get statistics for an algorithm at a size
547    pub fn get_stats(
548        &self,
549        size: usize,
550        algorithm: FftAlgorithm,
551        forward: bool,
552    ) -> Option<PerformanceStats> {
553        let key = (size, algorithm, forward);
554        self.entries.get(&key).map(|entries| {
555            let times: Vec<u64> = entries.iter().map(|e| e.execution_time_ns).collect();
556            let avg = Self::average_time(entries);
557            let min = times.iter().min().copied().unwrap_or(0);
558            let max = times.iter().max().copied().unwrap_or(0);
559
560            let variance = if times.len() > 1 {
561                let avg_f = avg as f64;
562                times
563                    .iter()
564                    .map(|&t| {
565                        let diff = t as f64 - avg_f;
566                        diff * diff
567                    })
568                    .sum::<f64>()
569                    / (times.len() - 1) as f64
570            } else {
571                0.0
572            };
573
574            PerformanceStats {
575                sample_count: times.len(),
576                avg_time_ns: avg,
577                min_time_ns: min,
578                max_time_ns: max,
579                std_dev_ns: variance.sqrt(),
580            }
581        })
582    }
583}
584
585/// Performance statistics
586#[derive(Debug, Clone)]
587pub struct PerformanceStats {
588    /// Number of samples
589    pub sample_count: usize,
590    /// Average execution time
591    pub avg_time_ns: u64,
592    /// Minimum execution time
593    pub min_time_ns: u64,
594    /// Maximum execution time
595    pub max_time_ns: u64,
596    /// Standard deviation
597    pub std_dev_ns: f64,
598}
599
600/// Main algorithm selector
601pub struct AlgorithmSelector {
602    /// Configuration
603    config: SelectionConfig,
604    /// Hardware information
605    hardware: HardwareInfo,
606    /// Performance history
607    history: Arc<RwLock<PerformanceHistory>>,
608}
609
610impl Default for AlgorithmSelector {
611    fn default() -> Self {
612        Self::new()
613    }
614}
615
616impl AlgorithmSelector {
617    /// Create a new algorithm selector with default configuration
618    pub fn new() -> Self {
619        Self::with_config(SelectionConfig::default())
620    }
621
622    /// Create a new algorithm selector with custom configuration
623    pub fn with_config(config: SelectionConfig) -> Self {
624        Self {
625            config,
626            hardware: HardwareInfo::detect(),
627            history: Arc::new(RwLock::new(PerformanceHistory::new())),
628        }
629    }
630
631    /// Select the best algorithm for the given input size
632    pub fn select_algorithm(
633        &self,
634        size: usize,
635        forward: bool,
636    ) -> FFTResult<AlgorithmRecommendation> {
637        // Check for forced algorithm
638        if let Some(algo) = self.config.force_algorithm {
639            let chars = InputCharacteristics::analyze(size, &self.hardware.cache_info);
640            return Ok(AlgorithmRecommendation {
641                algorithm: algo,
642                fallback: None,
643                confidence: 1.0,
644                estimated_time_ns: None,
645                estimated_memory_bytes: chars.estimated_memory_bytes,
646                reasoning: vec!["Algorithm forced by configuration".to_string()],
647                use_parallel: false,
648                recommended_threads: 1,
649                use_simd: self.config.enable_simd,
650                use_inplace: self.config.prefer_inplace,
651                input_characteristics: chars,
652            });
653        }
654
655        // Analyze input characteristics
656        let chars = InputCharacteristics::analyze(size, &self.hardware.cache_info);
657
658        // Check performance history first
659        if self.config.enable_learning {
660            if let Ok(history) = self.history.read() {
661                if let Some(best_algo) = history.get_best(size, forward) {
662                    let stats = history.get_stats(size, best_algo, forward);
663                    return Ok(self.build_recommendation(
664                        best_algo,
665                        &chars,
666                        0.95, // High confidence from learned data
667                        stats.as_ref().map(|s| s.avg_time_ns),
668                        vec!["Selected based on historical performance data".to_string()],
669                    ));
670                }
671            }
672        }
673
674        // Select based on input characteristics
675        let (algorithm, fallback, reasoning) = self.select_by_characteristics(&chars);
676
677        // Determine confidence based on how well we match the algorithm's optimal case
678        let confidence = self.calculate_confidence(&chars, algorithm);
679
680        // Estimate execution time (rough model)
681        let estimated_time = self.estimate_execution_time(size, algorithm);
682
683        Ok(self.build_recommendation(
684            algorithm,
685            &chars,
686            confidence,
687            Some(estimated_time),
688            reasoning,
689        ))
690    }
691
692    /// Select algorithm based on input characteristics
693    fn select_by_characteristics(
694        &self,
695        chars: &InputCharacteristics,
696    ) -> (FftAlgorithm, Option<FftAlgorithm>, Vec<String>) {
697        let mut reasoning = Vec::new();
698        let size = chars.size;
699
700        // Memory constraint check
701        if self.config.max_memory_bytes > 0
702            && chars.estimated_memory_bytes > self.config.max_memory_bytes
703        {
704            reasoning.push(format!(
705                "Memory constraint: {} bytes required, {} bytes available",
706                chars.estimated_memory_bytes, self.config.max_memory_bytes
707            ));
708            return (
709                FftAlgorithm::Streaming,
710                Some(FftAlgorithm::InPlace),
711                reasoning,
712            );
713        }
714
715        // Very large sizes - use streaming or parallel
716        if size > 16 * 1024 * 1024 {
717            reasoning.push(format!(
718                "Very large size ({}): using streaming for memory efficiency",
719                size
720            ));
721            return (
722                FftAlgorithm::Streaming,
723                Some(FftAlgorithm::Parallel),
724                reasoning,
725            );
726        }
727
728        // Cache-aware selection
729        if self.config.cache_aware {
730            if chars.fits_l1_cache {
731                reasoning.push("Data fits in L1 cache".to_string());
732            } else if chars.fits_l2_cache {
733                reasoning.push("Data fits in L2 cache".to_string());
734            } else if chars.fits_l3_cache {
735                reasoning
736                    .push("Data fits in L3 cache, using cache-oblivious algorithm".to_string());
737                if !chars.is_power_of_2 {
738                    return (
739                        FftAlgorithm::CacheOblivious,
740                        Some(FftAlgorithm::MixedRadix),
741                        reasoning,
742                    );
743                }
744            } else {
745                reasoning
746                    .push("Data exceeds L3 cache, considering streaming or parallel".to_string());
747            }
748        }
749
750        // Parallel threshold check
751        let use_parallel = size >= self.config.min_parallel_size && self.hardware.num_cores > 1;
752        if use_parallel {
753            reasoning.push(format!(
754                "Size {} exceeds parallel threshold {}, {} cores available",
755                size, self.config.min_parallel_size, self.hardware.num_cores
756            ));
757        }
758
759        // SIMD check
760        let use_simd = self.config.enable_simd && self.hardware.simd_capabilities.simd_available();
761        if use_simd {
762            reasoning.push("SIMD optimization enabled".to_string());
763        }
764
765        // Select based on size characteristics
766        match chars.size_type {
767            SizeCharacteristic::PowerOf4 => {
768                reasoning.push(format!(
769                    "Size {} is a power of 4: optimal for Radix-4",
770                    size
771                ));
772                if use_parallel {
773                    (
774                        FftAlgorithm::Parallel,
775                        Some(FftAlgorithm::Radix4),
776                        reasoning,
777                    )
778                } else if use_simd {
779                    (
780                        FftAlgorithm::SimdOptimized,
781                        Some(FftAlgorithm::Radix4),
782                        reasoning,
783                    )
784                } else {
785                    (
786                        FftAlgorithm::Radix4,
787                        Some(FftAlgorithm::CooleyTukeyRadix2),
788                        reasoning,
789                    )
790                }
791            }
792            SizeCharacteristic::PowerOf2 => {
793                reasoning.push(format!(
794                    "Size {} is a power of 2: optimal for Radix-2",
795                    size
796                ));
797                if use_parallel && size >= 262144 {
798                    (
799                        FftAlgorithm::Parallel,
800                        Some(FftAlgorithm::SplitRadix),
801                        reasoning,
802                    )
803                } else if use_simd {
804                    (
805                        FftAlgorithm::SimdOptimized,
806                        Some(FftAlgorithm::SplitRadix),
807                        reasoning,
808                    )
809                } else {
810                    (
811                        FftAlgorithm::SplitRadix,
812                        Some(FftAlgorithm::CooleyTukeyRadix2),
813                        reasoning,
814                    )
815                }
816            }
817            SizeCharacteristic::Prime => {
818                reasoning.push(format!("Size {} is prime: using Bluestein or Rader", size));
819                // Rader is better for small primes, Bluestein for large
820                if size < 1000 {
821                    (
822                        FftAlgorithm::Rader,
823                        Some(FftAlgorithm::Bluestein),
824                        reasoning,
825                    )
826                } else {
827                    (
828                        FftAlgorithm::Bluestein,
829                        Some(FftAlgorithm::MixedRadix),
830                        reasoning,
831                    )
832                }
833            }
834            SizeCharacteristic::Smooth => {
835                reasoning.push(format!(
836                    "Size {} is {}-smooth: good for mixed-radix",
837                    size, chars.smooth_bound
838                ));
839                if chars.num_distinct_factors <= 3 && are_coprime(&chars.prime_factors) {
840                    reasoning.push("Factors are coprime: Good-Thomas PFA applicable".to_string());
841                    (
842                        FftAlgorithm::GoodThomas,
843                        Some(FftAlgorithm::MixedRadix),
844                        reasoning,
845                    )
846                } else {
847                    (
848                        FftAlgorithm::MixedRadix,
849                        Some(FftAlgorithm::Bluestein),
850                        reasoning,
851                    )
852                }
853            }
854            SizeCharacteristic::Composite => {
855                reasoning.push(format!(
856                    "Size {} is composite with largest factor {}: using mixed-radix",
857                    size, chars.largest_prime_factor
858                ));
859                (
860                    FftAlgorithm::MixedRadix,
861                    Some(FftAlgorithm::Bluestein),
862                    reasoning,
863                )
864            }
865            SizeCharacteristic::HardSize => {
866                reasoning.push(format!(
867                    "Size {} has large prime factor {}: using Bluestein",
868                    size, chars.largest_prime_factor
869                ));
870                (
871                    FftAlgorithm::Bluestein,
872                    Some(FftAlgorithm::MixedRadix),
873                    reasoning,
874                )
875            }
876        }
877    }
878
879    /// Build a recommendation structure
880    fn build_recommendation(
881        &self,
882        algorithm: FftAlgorithm,
883        chars: &InputCharacteristics,
884        confidence: f64,
885        estimated_time_ns: Option<u64>,
886        reasoning: Vec<String>,
887    ) -> AlgorithmRecommendation {
888        let use_parallel =
889            chars.size >= self.config.min_parallel_size && self.hardware.num_cores > 1;
890        let recommended_threads = if use_parallel {
891            if self.config.max_threads > 0 {
892                self.config.max_threads.min(self.hardware.num_cores)
893            } else {
894                // Use sqrt(cores) for good parallelism without overhead
895                ((self.hardware.num_cores as f64).sqrt().ceil() as usize).max(2)
896            }
897        } else {
898            1
899        };
900
901        AlgorithmRecommendation {
902            algorithm,
903            fallback: None,
904            confidence,
905            estimated_time_ns,
906            estimated_memory_bytes: chars.estimated_memory_bytes,
907            reasoning,
908            use_parallel,
909            recommended_threads,
910            use_simd: self.config.enable_simd && self.hardware.simd_capabilities.simd_available(),
911            use_inplace: self.config.prefer_inplace,
912            input_characteristics: chars.clone(),
913        }
914    }
915
916    /// Calculate confidence score for algorithm selection
917    fn calculate_confidence(&self, chars: &InputCharacteristics, algorithm: FftAlgorithm) -> f64 {
918        let base_confidence = match (chars.size_type, algorithm) {
919            (SizeCharacteristic::PowerOf4, FftAlgorithm::Radix4) => 0.95,
920            (SizeCharacteristic::PowerOf4, FftAlgorithm::SimdOptimized) => 0.93,
921            (SizeCharacteristic::PowerOf2, FftAlgorithm::SplitRadix) => 0.92,
922            (SizeCharacteristic::PowerOf2, FftAlgorithm::CooleyTukeyRadix2) => 0.90,
923            (SizeCharacteristic::PowerOf2, FftAlgorithm::SimdOptimized) => 0.91,
924            (SizeCharacteristic::Prime, FftAlgorithm::Rader) => 0.85,
925            (SizeCharacteristic::Prime, FftAlgorithm::Bluestein) => 0.80,
926            (SizeCharacteristic::Smooth, FftAlgorithm::GoodThomas) => 0.88,
927            (SizeCharacteristic::Smooth, FftAlgorithm::MixedRadix) => 0.85,
928            (SizeCharacteristic::Composite, FftAlgorithm::MixedRadix) => 0.80,
929            (SizeCharacteristic::HardSize, FftAlgorithm::Bluestein) => 0.75,
930            _ => 0.70,
931        };
932
933        // Adjust based on cache fit
934        let cache_bonus: f64 = if chars.fits_l1_cache {
935            0.02
936        } else if chars.fits_l2_cache {
937            0.01
938        } else {
939            -0.02
940        };
941
942        (base_confidence + cache_bonus).clamp(0.0, 1.0)
943    }
944
945    /// Estimate execution time (rough model based on O(n log n))
946    fn estimate_execution_time(&self, size: usize, algorithm: FftAlgorithm) -> u64 {
947        if size == 0 {
948            return 0;
949        }
950
951        let n = size as f64;
952        let log_n = n.log2();
953        let base_ops = n * log_n;
954
955        // Algorithm-specific coefficients (nanoseconds per operation)
956        let coeff = match algorithm {
957            FftAlgorithm::Radix4 => 0.8,
958            FftAlgorithm::CooleyTukeyRadix2 => 1.0,
959            FftAlgorithm::SplitRadix => 0.9,
960            FftAlgorithm::SimdOptimized => 0.5,
961            FftAlgorithm::Parallel => 0.4,
962            FftAlgorithm::MixedRadix => 1.2,
963            FftAlgorithm::Bluestein => 2.0,
964            FftAlgorithm::Rader => 1.5,
965            FftAlgorithm::GoodThomas => 1.1,
966            FftAlgorithm::Winograd => 1.3,
967            FftAlgorithm::Streaming => 1.5,
968            FftAlgorithm::CacheOblivious => 1.1,
969            FftAlgorithm::InPlace => 1.0,
970            FftAlgorithm::Hybrid => 1.0,
971        };
972
973        (base_ops * coeff) as u64
974    }
975
976    /// Record performance measurement for learning
977    pub fn record_performance(&self, entry: PerformanceEntry) -> FFTResult<()> {
978        if !self.config.enable_learning {
979            return Ok(());
980        }
981
982        let mut history = self
983            .history
984            .write()
985            .map_err(|e| FFTError::ValueError(format!("Failed to acquire write lock: {e}")))?;
986
987        history.record(entry);
988        Ok(())
989    }
990
991    /// Run a benchmark for a specific size and algorithm
992    pub fn benchmark(
993        &self,
994        size: usize,
995        algorithm: FftAlgorithm,
996        forward: bool,
997    ) -> FFTResult<PerformanceEntry> {
998        use scirs2_core::numeric::Complex64;
999
1000        #[cfg(feature = "oxifft")]
1001        {
1002            use oxifft::{Complex as OxiComplex, Direction, Flags, Plan};
1003
1004            // Create test data
1005            let data: Vec<OxiComplex<f64>> = (0..size)
1006                .map(|i| OxiComplex::new(i as f64, (i * 2) as f64))
1007                .collect();
1008
1009            // Create plan
1010            let direction = if forward {
1011                Direction::Forward
1012            } else {
1013                Direction::Backward
1014            };
1015
1016            let plan = Plan::dft_1d(size, direction, Flags::ESTIMATE).ok_or_else(|| {
1017                FFTError::ComputationError(format!("Failed to create FFT plan for size {}", size))
1018            })?;
1019
1020            // Warm-up
1021            for _ in 0..3 {
1022                let mut warm_data = data.clone();
1023                let mut output = vec![OxiComplex::default(); size];
1024                plan.execute(&warm_data, &mut output);
1025            }
1026
1027            // Benchmark
1028            let mut input = data;
1029            let mut output = vec![OxiComplex::default(); size];
1030            let start = Instant::now();
1031            plan.execute(&input, &mut output);
1032            let elapsed = start.elapsed();
1033
1034            Ok(PerformanceEntry {
1035                size,
1036                algorithm,
1037                forward,
1038                execution_time_ns: elapsed.as_nanos() as u64,
1039                peak_memory_bytes: size * 16, // Complex64 = 16 bytes
1040                timestamp: std::time::SystemTime::now()
1041                    .duration_since(std::time::UNIX_EPOCH)
1042                    .unwrap_or(Duration::ZERO)
1043                    .as_secs(),
1044                hardware_hash: 0, // Simplified for now
1045            })
1046        }
1047
1048        #[cfg(all(feature = "rustfft-backend", not(feature = "oxifft")))]
1049        {
1050            use rustfft::FftPlanner;
1051
1052            // Create test data
1053            let mut data: Vec<Complex64> = (0..size)
1054                .map(|i| Complex64::new(i as f64, (i * 2) as f64))
1055                .collect();
1056
1057            // Create planner
1058            let mut planner = FftPlanner::new();
1059            let fft = if forward {
1060                planner.plan_fft_forward(size)
1061            } else {
1062                planner.plan_fft_inverse(size)
1063            };
1064
1065            // Warm-up
1066            for _ in 0..3 {
1067                fft.process(&mut data.clone());
1068            }
1069
1070            // Benchmark
1071            let start = Instant::now();
1072            fft.process(&mut data);
1073            let elapsed = start.elapsed();
1074
1075            Ok(PerformanceEntry {
1076                size,
1077                algorithm,
1078                forward,
1079                execution_time_ns: elapsed.as_nanos() as u64,
1080                peak_memory_bytes: size * 16, // Complex64 = 16 bytes
1081                timestamp: std::time::SystemTime::now()
1082                    .duration_since(std::time::UNIX_EPOCH)
1083                    .unwrap_or(Duration::ZERO)
1084                    .as_secs(),
1085                hardware_hash: 0, // Simplified for now
1086            })
1087        }
1088
1089        #[cfg(not(any(feature = "oxifft", feature = "rustfft-backend")))]
1090        {
1091            Err(FFTError::ValueError(
1092                "No FFT backend available for benchmarking. Enable either 'oxifft' or 'rustfft-backend' feature.".to_string()
1093            ))
1094        }
1095    }
1096
1097    /// Get configuration
1098    pub fn config(&self) -> &SelectionConfig {
1099        &self.config
1100    }
1101
1102    /// Get hardware info
1103    pub fn hardware(&self) -> &HardwareInfo {
1104        &self.hardware
1105    }
1106
1107    /// Get performance history
1108    pub fn history(&self) -> Arc<RwLock<PerformanceHistory>> {
1109        self.history.clone()
1110    }
1111}
1112
1113// Helper functions
1114
1115/// Factorize a number into prime factors
1116fn factorize(mut n: usize) -> HashMap<usize, usize> {
1117    let mut factors = HashMap::new();
1118
1119    if n <= 1 {
1120        return factors;
1121    }
1122
1123    // Check for 2
1124    let mut count = 0;
1125    while n % 2 == 0 {
1126        count += 1;
1127        n /= 2;
1128    }
1129    if count > 0 {
1130        factors.insert(2, count);
1131    }
1132
1133    // Check odd factors up to sqrt(n)
1134    let mut i = 3;
1135    while i * i <= n {
1136        let mut count = 0;
1137        while n % i == 0 {
1138            count += 1;
1139            n /= i;
1140        }
1141        if count > 0 {
1142            factors.insert(i, count);
1143        }
1144        i += 2;
1145    }
1146
1147    // If n is still > 1, it's a prime factor
1148    if n > 1 {
1149        factors.insert(n, 1);
1150    }
1151
1152    factors
1153}
1154
1155/// Check if factors are pairwise coprime
1156fn are_coprime(factors: &HashMap<usize, usize>) -> bool {
1157    // All prime factors are coprime by definition
1158    // We check if the product of prime powers are coprime
1159    let powers: Vec<usize> = factors.iter().map(|(&p, &e)| p.pow(e as u32)).collect();
1160
1161    for i in 0..powers.len() {
1162        for j in (i + 1)..powers.len() {
1163            if gcd(powers[i], powers[j]) != 1 {
1164                return false;
1165            }
1166        }
1167    }
1168    true
1169}
1170
1171/// Greatest common divisor
1172fn gcd(mut a: usize, mut b: usize) -> usize {
1173    while b != 0 {
1174        let t = b;
1175        b = a % b;
1176        a = t;
1177    }
1178    a
1179}
1180
1181/// Estimate available memory
1182fn estimate_available_memory() -> usize {
1183    // Conservative estimate: assume 1 GB available
1184    // In a real implementation, this would query the OS
1185    1024 * 1024 * 1024
1186}
1187
1188#[cfg(test)]
1189mod tests {
1190    use super::*;
1191
1192    #[test]
1193    fn test_factorize() {
1194        let factors = factorize(12);
1195        assert_eq!(factors.get(&2), Some(&2));
1196        assert_eq!(factors.get(&3), Some(&1));
1197
1198        let factors = factorize(1024);
1199        assert_eq!(factors.get(&2), Some(&10));
1200        assert_eq!(factors.len(), 1);
1201
1202        let factors = factorize(17);
1203        assert_eq!(factors.get(&17), Some(&1));
1204        assert_eq!(factors.len(), 1);
1205    }
1206
1207    #[test]
1208    fn test_input_characteristics_power_of_2() {
1209        let cache_info = CacheInfo::default();
1210        let chars = InputCharacteristics::analyze(1024, &cache_info);
1211
1212        assert!(chars.is_power_of_2);
1213        assert!(chars.is_power_of_4);
1214        assert!(!chars.is_prime);
1215        assert!(chars.is_smooth);
1216        assert_eq!(chars.size_type, SizeCharacteristic::PowerOf4);
1217    }
1218
1219    #[test]
1220    fn test_input_characteristics_prime() {
1221        let cache_info = CacheInfo::default();
1222        let chars = InputCharacteristics::analyze(1009, &cache_info);
1223
1224        assert!(!chars.is_power_of_2);
1225        assert!(chars.is_prime);
1226        assert_eq!(chars.largest_prime_factor, 1009);
1227        assert_eq!(chars.size_type, SizeCharacteristic::Prime);
1228    }
1229
1230    #[test]
1231    fn test_input_characteristics_smooth() {
1232        let cache_info = CacheInfo::default();
1233        let chars = InputCharacteristics::analyze(360, &cache_info); // 2^3 * 3^2 * 5
1234
1235        assert!(!chars.is_power_of_2);
1236        assert!(!chars.is_prime);
1237        assert!(chars.is_smooth);
1238        assert_eq!(chars.size_type, SizeCharacteristic::Smooth);
1239    }
1240
1241    #[test]
1242    fn test_algorithm_selector_power_of_2() {
1243        let selector = AlgorithmSelector::new();
1244        let rec = selector
1245            .select_algorithm(1024, true)
1246            .expect("Selection failed");
1247
1248        // Power of 4 should recommend Radix-4 or SIMD optimized
1249        assert!(
1250            matches!(
1251                rec.algorithm,
1252                FftAlgorithm::Radix4 | FftAlgorithm::SimdOptimized | FftAlgorithm::Parallel
1253            ),
1254            "Expected Radix-4 or SIMD for power of 4, got {:?}",
1255            rec.algorithm
1256        );
1257        assert!(rec.confidence > 0.8);
1258    }
1259
1260    #[test]
1261    fn test_algorithm_selector_prime() {
1262        let selector = AlgorithmSelector::new();
1263        let rec = selector
1264            .select_algorithm(1009, true)
1265            .expect("Selection failed");
1266
1267        // Prime should recommend Rader or Bluestein
1268        assert!(
1269            matches!(rec.algorithm, FftAlgorithm::Rader | FftAlgorithm::Bluestein),
1270            "Expected Rader or Bluestein for prime, got {:?}",
1271            rec.algorithm
1272        );
1273    }
1274
1275    #[test]
1276    fn test_algorithm_selector_large_size() {
1277        let selector = AlgorithmSelector::new();
1278        let rec = selector
1279            .select_algorithm(16 * 1024 * 1024 + 1, true)
1280            .expect("Selection failed");
1281
1282        // Very large size should recommend streaming
1283        assert_eq!(rec.algorithm, FftAlgorithm::Streaming);
1284        assert!(rec.reasoning.iter().any(|r| r.contains("streaming")));
1285    }
1286
1287    #[test]
1288    fn test_hardware_detection() {
1289        let hw = HardwareInfo::detect();
1290        assert!(hw.num_cores > 0);
1291        assert!(hw.cache_info.l1_size > 0);
1292    }
1293
1294    #[test]
1295    fn test_simd_capabilities() {
1296        let caps = SimdCapabilities::detect();
1297        // Just check it doesn't panic
1298        let _ = caps.simd_available();
1299        let _ = caps.optimal_complex_vector_count();
1300    }
1301
1302    #[test]
1303    fn test_performance_history() {
1304        let mut history = PerformanceHistory::new();
1305
1306        let entry = PerformanceEntry {
1307            size: 1024,
1308            algorithm: FftAlgorithm::Radix4,
1309            forward: true,
1310            execution_time_ns: 1000,
1311            peak_memory_bytes: 16384,
1312            timestamp: 0,
1313            hardware_hash: 0,
1314        };
1315
1316        history.record(entry);
1317        assert_eq!(history.get_best(1024, true), Some(FftAlgorithm::Radix4));
1318    }
1319
1320    #[test]
1321    #[cfg(any(feature = "oxifft", feature = "rustfft-backend"))]
1322    fn test_benchmark() {
1323        let selector = AlgorithmSelector::new();
1324        let result = selector.benchmark(256, FftAlgorithm::MixedRadix, true);
1325
1326        assert!(result.is_ok());
1327        let entry = result.expect("Benchmark failed");
1328        assert_eq!(entry.size, 256);
1329        assert!(entry.execution_time_ns > 0);
1330    }
1331
1332    #[test]
1333    fn test_selection_config_forced_algorithm() {
1334        let config = SelectionConfig {
1335            force_algorithm: Some(FftAlgorithm::Bluestein),
1336            ..Default::default()
1337        };
1338        let selector = AlgorithmSelector::with_config(config);
1339        let rec = selector
1340            .select_algorithm(1024, true)
1341            .expect("Selection failed");
1342
1343        assert_eq!(rec.algorithm, FftAlgorithm::Bluestein);
1344        assert_eq!(rec.confidence, 1.0);
1345    }
1346
1347    #[test]
1348    fn test_memory_constraint() {
1349        let config = SelectionConfig {
1350            max_memory_bytes: 1024, // Very small
1351            ..Default::default()
1352        };
1353        let selector = AlgorithmSelector::with_config(config);
1354        let rec = selector
1355            .select_algorithm(1024, true)
1356            .expect("Selection failed");
1357
1358        // Should select streaming due to memory constraint
1359        assert_eq!(rec.algorithm, FftAlgorithm::Streaming);
1360    }
1361}