scirs2_fft/
auto_tuning.rs

1//! Auto-tuning for hardware-specific FFT optimizations
2//!
3//! This module provides functionality to automatically tune FFT parameters
4//! for optimal performance on the current hardware. It includes:
5//!
6//! - Benchmarking different FFT configurations
7//! - Selecting optimal parameters based on timing results
8//! - Persisting tuning results for future use
9//! - Detecting CPU features and adapting algorithms accordingly
10
11use rustfft::FftPlanner;
12use scirs2_core::numeric::Complex64;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::fs::{self, File};
16use std::io::{BufReader, BufWriter};
17use std::path::{Path, PathBuf};
18use std::time::Instant;
19
20use crate::error::{FFTError, FFTResult};
21use crate::plan_serialization::PlanSerializationManager;
22
23/// A range of FFT sizes to benchmark
24#[derive(Debug, Clone)]
25pub struct SizeRange {
26    /// Minimum size to test
27    pub min: usize,
28    /// Maximum size to test
29    pub max: usize,
30    /// Step between sizes (can be multiplication factor)
31    pub step: SizeStep,
32}
33
34/// Step type for size range
35#[derive(Debug, Clone)]
36pub enum SizeStep {
37    /// Add a constant value
38    Linear(usize),
39    /// Multiply by a factor
40    Exponential(f64),
41    /// Use powers of two
42    PowersOfTwo,
43    /// Use specific sizes
44    Custom(Vec<usize>),
45}
46
47/// FFT algorithm variant to benchmark
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
49pub enum FftVariant {
50    /// Standard FFT
51    Standard,
52    /// In-place FFT
53    InPlace,
54    /// Cached-plan FFT
55    Cached,
56    /// Split-radix FFT
57    SplitRadix,
58}
59
60/// Configuration for auto-tuning
61#[derive(Debug, Clone)]
62pub struct AutoTuneConfig {
63    /// Sizes to benchmark
64    pub sizes: SizeRange,
65    /// Number of repetitions per test
66    pub repetitions: usize,
67    /// Warm-up iterations (not timed)
68    pub warmup: usize,
69    /// FFT variants to test
70    pub variants: Vec<FftVariant>,
71    /// Path to save tuning results
72    pub database_path: PathBuf,
73}
74
75impl Default for AutoTuneConfig {
76    fn default() -> Self {
77        Self {
78            sizes: SizeRange {
79                min: 16,
80                max: 8192,
81                step: SizeStep::PowersOfTwo,
82            },
83            repetitions: 10,
84            warmup: 3,
85            variants: vec![FftVariant::Standard, FftVariant::Cached],
86            database_path: PathBuf::from(".fft_tuning_db.json"),
87        }
88    }
89}
90
91/// Results from a single benchmark
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct BenchmarkResult {
94    /// FFT size
95    pub size: usize,
96    /// FFT variant
97    pub variant: FftVariant,
98    /// Whether this is forward or inverse FFT
99    pub forward: bool,
100    /// Average execution time in nanoseconds
101    pub avg_time_ns: u64,
102    /// Minimum execution time in nanoseconds
103    pub min_time_ns: u64,
104    /// Standard deviation in nanoseconds
105    pub std_dev_ns: f64,
106    /// System information when the benchmark was run
107    pub system_info: SystemInfo,
108}
109
110/// System information for result matching
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct SystemInfo {
113    /// CPU model
114    pub cpu_model: String,
115    /// Number of cores
116    pub num_cores: usize,
117    /// Architecture
118    pub architecture: String,
119    /// CPU features (SIMD instruction sets, etc.)
120    pub cpu_features: Vec<String>,
121}
122
123/// Database of tuning results
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct TuningDatabase {
126    /// Benchmark results
127    pub results: Vec<BenchmarkResult>,
128    /// Last updated timestamp
129    pub last_updated: u64,
130    /// Best algorithm for each size
131    pub best_algorithms: HashMap<(usize, bool), FftVariant>,
132}
133
134/// Auto-tuning manager
135pub struct AutoTuner {
136    /// Configuration
137    config: AutoTuneConfig,
138    /// Database of results
139    database: TuningDatabase,
140    /// Whether to use tuning
141    enabled: bool,
142}
143
144impl Default for AutoTuner {
145    fn default() -> Self {
146        Self::with_config(AutoTuneConfig::default())
147    }
148}
149
150impl AutoTuner {
151    /// Create a new auto-tuner with default configuration
152    pub fn new() -> Self {
153        Self::default()
154    }
155
156    /// Create a new auto-tuner with custom configuration
157    pub fn with_config(config: AutoTuneConfig) -> Self {
158        let database =
159            Self::load_database(&config.database_path).unwrap_or_else(|_| TuningDatabase {
160                results: Vec::new(),
161                last_updated: std::time::SystemTime::now()
162                    .duration_since(std::time::UNIX_EPOCH)
163                    .unwrap_or_default()
164                    .as_secs(),
165                best_algorithms: HashMap::new(),
166            });
167
168        Self {
169            config,
170            database,
171            enabled: true,
172        }
173    }
174
175    /// Load the tuning database from disk
176    fn load_database(path: &Path) -> FFTResult<TuningDatabase> {
177        if !path.exists() {
178            return Err(FFTError::IOError(format!(
179                "Tuning database file not found: {}",
180                path.display()
181            )));
182        }
183
184        let file = File::open(path)
185            .map_err(|e| FFTError::IOError(format!("Failed to open tuning database: {e}")))?;
186
187        let reader = BufReader::new(file);
188        let database: TuningDatabase = serde_json::from_reader(reader)
189            .map_err(|e| FFTError::ValueError(format!("Failed to parse tuning database: {e}")))?;
190
191        Ok(database)
192    }
193
194    /// Save the tuning database to disk
195    pub fn save_database(&self) -> FFTResult<()> {
196        // Create parent directories if they don't exist
197        if let Some(parent) = self.config.database_path.parent() {
198            fs::create_dir_all(parent).map_err(|e| {
199                FFTError::IOError(format!(
200                    "Failed to create directory for tuning database: {e}"
201                ))
202            })?;
203        }
204
205        let file = File::create(&self.config.database_path).map_err(|e| {
206            FFTError::IOError(format!("Failed to create tuning database file: {e}"))
207        })?;
208
209        let writer = BufWriter::new(file);
210        serde_json::to_writer_pretty(writer, &self.database)
211            .map_err(|e| FFTError::IOError(format!("Failed to serialize tuning database: {e}")))?;
212
213        Ok(())
214    }
215
216    /// Enable or disable auto-tuning
217    pub fn set_enabled(&mut self, enabled: bool) {
218        self.enabled = enabled;
219    }
220
221    /// Check if auto-tuning is enabled
222    pub fn is_enabled(&self) -> bool {
223        self.enabled
224    }
225
226    /// Run benchmarks for all configured FFT variants and sizes
227    pub fn run_benchmarks(&mut self) -> FFTResult<()> {
228        if !self.enabled {
229            return Ok(());
230        }
231
232        let sizes = self.generate_sizes();
233        let mut results = Vec::new();
234
235        for size in sizes {
236            for &variant in &self.config.variants {
237                // Benchmark forward transform
238                let forward_result = self.benchmark_variant(size, variant, true)?;
239                results.push(forward_result);
240
241                // Benchmark inverse transform
242                let inverse_result = self.benchmark_variant(size, variant, false)?;
243                results.push(inverse_result);
244            }
245        }
246
247        // Update database
248        self.database.results.extend(results);
249        self.update_best_algorithms();
250        self.save_database()?;
251
252        Ok(())
253    }
254
255    /// Generate the list of sizes to benchmark
256    fn generate_sizes(&self) -> Vec<usize> {
257        let mut sizes = Vec::new();
258
259        match &self.config.sizes.step {
260            SizeStep::Linear(step) => {
261                let mut size = self.config.sizes.min;
262                while size <= self.config.sizes.max {
263                    sizes.push(size);
264                    size += step;
265                }
266            }
267            SizeStep::Exponential(factor) => {
268                let mut size = self.config.sizes.min as f64;
269                while size <= self.config.sizes.max as f64 {
270                    sizes.push(size as usize);
271                    size *= factor;
272                }
273            }
274            SizeStep::PowersOfTwo => {
275                let mut size = 1;
276                while size < self.config.sizes.min {
277                    size *= 2;
278                }
279                while size <= self.config.sizes.max {
280                    sizes.push(size);
281                    size *= 2;
282                }
283            }
284            SizeStep::Custom(custom_sizes) => {
285                for &size in custom_sizes {
286                    if size >= self.config.sizes.min && size <= self.config.sizes.max {
287                        sizes.push(size);
288                    }
289                }
290            }
291        }
292
293        sizes
294    }
295
296    /// Benchmark a specific FFT variant for a given size
297    fn benchmark_variant(
298        &self,
299        size: usize,
300        variant: FftVariant,
301        forward: bool,
302    ) -> FFTResult<BenchmarkResult> {
303        // Create test data
304        let mut data = vec![Complex64::new(0.0, 0.0); size];
305        for (i, val) in data.iter_mut().enumerate().take(size) {
306            *val = Complex64::new(i as f64, (i * 2) as f64);
307        }
308
309        // Warm-up phase
310        for _ in 0..self.config.warmup {
311            match variant {
312                FftVariant::Standard => {
313                    let mut planner = FftPlanner::new();
314                    let fft = if forward {
315                        planner.plan_fft_forward(size)
316                    } else {
317                        planner.plan_fft_inverse(size)
318                    };
319                    let mut buffer = data.clone();
320                    fft.process(&mut buffer);
321                }
322                FftVariant::InPlace => {
323                    let mut planner = FftPlanner::new();
324                    let fft = if forward {
325                        planner.plan_fft_forward(size)
326                    } else {
327                        planner.plan_fft_inverse(size)
328                    };
329                    // Use in-place processing with scratch buffer
330                    let mut buffer = data.clone();
331                    let mut scratch = vec![Complex64::new(0.0, 0.0); fft.get_inplace_scratch_len()];
332                    fft.process_with_scratch(&mut buffer, &mut scratch);
333                }
334                FftVariant::Cached => {
335                    // Create a plan via the serialization manager
336                    let manager = PlanSerializationManager::new(&self.config.database_path);
337                    let plan_info = manager.create_plan_info(size, forward);
338                    let (_, time) = crate::plan_serialization::create_and_time_plan(size, forward);
339                    manager.record_plan_usage(&plan_info, time).unwrap_or(());
340                }
341                FftVariant::SplitRadix => {
342                    // For now, this is just an example variant
343                    // In a real implementation, we'd use a specific split-radix algorithm
344                    let mut planner = FftPlanner::new();
345                    let fft = if forward {
346                        planner.plan_fft_forward(size)
347                    } else {
348                        planner.plan_fft_inverse(size)
349                    };
350                    let mut buffer = data.clone();
351                    fft.process(&mut buffer);
352                }
353            }
354        }
355
356        // Timing phase
357        let mut times = Vec::with_capacity(self.config.repetitions);
358
359        for _ in 0..self.config.repetitions {
360            let start = Instant::now();
361
362            match variant {
363                FftVariant::Standard => {
364                    let mut planner = FftPlanner::new();
365                    let fft = if forward {
366                        planner.plan_fft_forward(size)
367                    } else {
368                        planner.plan_fft_inverse(size)
369                    };
370                    let mut buffer = data.clone();
371                    fft.process(&mut buffer);
372                }
373                FftVariant::InPlace => {
374                    let mut planner = FftPlanner::new();
375                    let fft = if forward {
376                        planner.plan_fft_forward(size)
377                    } else {
378                        planner.plan_fft_inverse(size)
379                    };
380                    // Use in-place processing with scratch buffer
381                    let mut buffer = data.clone();
382                    let mut scratch = vec![Complex64::new(0.0, 0.0); fft.get_inplace_scratch_len()];
383                    fft.process_with_scratch(&mut buffer, &mut scratch);
384                }
385                FftVariant::Cached => {
386                    // Use the plan cache
387                    let mut planner = FftPlanner::new();
388                    let fft = if forward {
389                        planner.plan_fft_forward(size)
390                    } else {
391                        planner.plan_fft_inverse(size)
392                    };
393                    let mut buffer = data.clone();
394                    fft.process(&mut buffer);
395                }
396                FftVariant::SplitRadix => {
397                    // Placeholder for split-radix implementation
398                    let mut planner = FftPlanner::new();
399                    let fft = if forward {
400                        planner.plan_fft_forward(size)
401                    } else {
402                        planner.plan_fft_inverse(size)
403                    };
404                    let mut buffer = data.clone();
405                    fft.process(&mut buffer);
406                }
407            }
408
409            let elapsed = start.elapsed();
410            times.push(elapsed.as_nanos() as u64);
411        }
412
413        // Calculate statistics
414        let avg_time = times.iter().sum::<u64>() / times.len() as u64;
415        let min_time = *times.iter().min().unwrap_or(&0);
416
417        // Calculate standard deviation
418        let variance = times
419            .iter()
420            .map(|&t| {
421                let diff = t as f64 - avg_time as f64;
422                diff * diff
423            })
424            .sum::<f64>()
425            / times.len() as f64;
426        let std_dev = variance.sqrt();
427
428        Ok(BenchmarkResult {
429            size,
430            variant,
431            forward,
432            avg_time_ns: avg_time,
433            min_time_ns: min_time,
434            std_dev_ns: std_dev,
435            system_info: self.detect_system_info(),
436        })
437    }
438
439    /// Detect system information for result matching
440    fn detect_system_info(&self) -> SystemInfo {
441        // This is a simplified version - a real implementation would
442        // detect actual CPU model, features, etc.
443        SystemInfo {
444            cpu_model: String::from("Unknown"),
445            num_cores: num_cpus::get(),
446            architecture: std::env::consts::ARCH.to_string(),
447            cpu_features: detect_cpu_features(),
448        }
449    }
450
451    /// Update the best algorithms based on benchmark results
452    fn update_best_algorithms(&mut self) {
453        // Clear existing best algorithms
454        self.database.best_algorithms.clear();
455
456        // Group results by size and direction
457        let mut grouped: HashMap<(usize, bool), Vec<&BenchmarkResult>> = HashMap::new();
458        for result in &self.database.results {
459            grouped
460                .entry((result.size, result.forward))
461                .or_default()
462                .push(result);
463        }
464
465        // Find the best algorithm for each group
466        for ((size, forward), results) in grouped {
467            if let Some(best) = results.iter().min_by_key(|r| r.avg_time_ns) {
468                self.database
469                    .best_algorithms
470                    .insert((size, forward), best.variant);
471            }
472        }
473    }
474
475    /// Get the best FFT variant for the given size and direction
476    pub fn get_best_variant(&self, size: usize, forward: bool) -> FftVariant {
477        if !self.enabled {
478            return FftVariant::Standard;
479        }
480
481        // Look for exact size match
482        if let Some(&variant) = self.database.best_algorithms.get(&(size, forward)) {
483            return variant;
484        }
485
486        // Look for closest size match
487        let mut closest_size = 0;
488        let mut min_diff = usize::MAX;
489
490        for &(s, f) in self.database.best_algorithms.keys() {
491            if f == forward {
492                let diff = s.abs_diff(size);
493                if diff < min_diff {
494                    min_diff = diff;
495                    closest_size = s;
496                }
497            }
498        }
499
500        if closest_size > 0 {
501            if let Some(&variant) = self.database.best_algorithms.get(&(closest_size, forward)) {
502                return variant;
503            }
504        }
505
506        // Default to standard FFT if no match
507        FftVariant::Standard
508    }
509
510    /// Run FFT with optimal algorithm selection
511    pub fn run_optimal_fft<T>(
512        &self,
513        input: &[T],
514        size: Option<usize>,
515        forward: bool,
516    ) -> FFTResult<Vec<Complex64>>
517    where
518        T: Clone + Into<Complex64>,
519    {
520        let actual_size = size.unwrap_or(input.len());
521        let variant = self.get_best_variant(actual_size, forward);
522
523        // Convert input to complex
524        let mut buffer: Vec<Complex64> = input.iter().map(|x| x.clone().into()).collect();
525        // Pad if necessary
526        if buffer.len() < actual_size {
527            buffer.resize(actual_size, Complex64::new(0.0, 0.0));
528        }
529
530        match variant {
531            FftVariant::Standard => {
532                let mut planner = FftPlanner::new();
533                let fft = if forward {
534                    planner.plan_fft_forward(actual_size)
535                } else {
536                    planner.plan_fft_inverse(actual_size)
537                };
538                fft.process(&mut buffer);
539            }
540            FftVariant::InPlace => {
541                let mut planner = FftPlanner::new();
542                let fft = if forward {
543                    planner.plan_fft_forward(actual_size)
544                } else {
545                    planner.plan_fft_inverse(actual_size)
546                };
547                let mut scratch = vec![Complex64::new(0.0, 0.0); fft.get_inplace_scratch_len()];
548                fft.process_with_scratch(&mut buffer, &mut scratch);
549            }
550            FftVariant::Cached => {
551                // Use the plan cache via PlanSerializationManager
552                // Create a plan directly - manager is not needed here
553                let (plan_, _) =
554                    crate::plan_serialization::create_and_time_plan(actual_size, forward);
555                plan_.process(&mut buffer);
556            }
557            FftVariant::SplitRadix => {
558                // Placeholder for split-radix FFT
559                let mut planner = FftPlanner::new();
560                let fft = if forward {
561                    planner.plan_fft_forward(actual_size)
562                } else {
563                    planner.plan_fft_inverse(actual_size)
564                };
565                fft.process(&mut buffer);
566            }
567        }
568
569        // Scale inverse FFT by 1/N if required
570        if !forward {
571            let scale = 1.0 / (actual_size as f64);
572            for val in &mut buffer {
573                *val *= scale;
574            }
575        }
576
577        Ok(buffer)
578    }
579}
580
581/// Detect CPU features for result matching
582#[allow(dead_code)]
583fn detect_cpu_features() -> Vec<String> {
584    let mut features = Vec::new();
585
586    // Target-specific feature detection
587    #[cfg(target_arch = "x86_64")]
588    {
589        #[cfg(target_feature = "sse")]
590        features.push("sse".to_string());
591
592        #[cfg(target_feature = "sse2")]
593        features.push("sse2".to_string());
594
595        #[cfg(target_feature = "sse3")]
596        features.push("sse3".to_string());
597
598        #[cfg(target_feature = "sse4.1")]
599        features.push("sse4.1".to_string());
600
601        #[cfg(target_feature = "sse4.2")]
602        features.push("sse4.2".to_string());
603
604        #[cfg(target_feature = "avx")]
605        features.push("avx".to_string());
606
607        #[cfg(target_feature = "avx2")]
608        features.push("avx2".to_string());
609
610        #[cfg(target_feature = "fma")]
611        features.push("fma".to_string());
612    }
613
614    // ARM-specific features
615    #[cfg(target_arch = "aarch64")]
616    {
617        #[cfg(target_feature = "neon")]
618        features.push("neon".to_string());
619    }
620
621    // Add more architecture-specific features if needed
622
623    features
624}
625
626#[cfg(test)]
627mod tests {
628    use super::*;
629    use tempfile::tempdir;
630
631    #[test]
632    fn test_size_generation() {
633        // Test powers of two
634        let config = AutoTuneConfig {
635            sizes: SizeRange {
636                min: 8,
637                max: 64,
638                step: SizeStep::PowersOfTwo,
639            },
640            ..Default::default()
641        };
642        let tuner = AutoTuner::with_config(config);
643        let sizes = tuner.generate_sizes();
644        assert_eq!(sizes, vec![8, 16, 32, 64]);
645
646        // Test linear steps
647        let config = AutoTuneConfig {
648            sizes: SizeRange {
649                min: 10,
650                max: 30,
651                step: SizeStep::Linear(5),
652            },
653            ..Default::default()
654        };
655        let tuner = AutoTuner::with_config(config);
656        let sizes = tuner.generate_sizes();
657        assert_eq!(sizes, vec![10, 15, 20, 25, 30]);
658
659        // Test exponential steps
660        let config = AutoTuneConfig {
661            sizes: SizeRange {
662                min: 10,
663                max: 100,
664                step: SizeStep::Exponential(2.0),
665            },
666            ..Default::default()
667        };
668        let tuner = AutoTuner::with_config(config);
669        let sizes = tuner.generate_sizes();
670        assert_eq!(sizes, vec![10, 20, 40, 80]);
671
672        // Test custom sizes
673        let config = AutoTuneConfig {
674            sizes: SizeRange {
675                min: 10,
676                max: 100,
677                step: SizeStep::Custom(vec![5, 15, 25, 50, 150]),
678            },
679            ..Default::default()
680        };
681        let tuner = AutoTuner::with_config(config);
682        let sizes = tuner.generate_sizes();
683        assert_eq!(sizes, vec![15, 25, 50]);
684    }
685
686    #[test]
687    fn test_auto_tuner_basic() {
688        // Create a temporary directory for test
689        let temp_dir = tempdir().unwrap();
690        let db_path = temp_dir.path().join("test_tuning_db.json");
691
692        // Create configuration with minimal benchmarking
693        let config = AutoTuneConfig {
694            sizes: SizeRange {
695                min: 16,
696                max: 32,
697                step: SizeStep::PowersOfTwo,
698            },
699            repetitions: 2,
700            warmup: 1,
701            variants: vec![FftVariant::Standard, FftVariant::InPlace],
702            database_path: db_path.clone(),
703        };
704
705        let mut tuner = AutoTuner::with_config(config);
706
707        // Run minimal benchmarks (this is fast enough for a test)
708        match tuner.run_benchmarks() {
709            Ok(_) => {
710                // Verify database file was created
711                assert!(db_path.exists());
712
713                // Test getting a best variant
714                let variant = tuner.get_best_variant(16, true);
715                assert!(matches!(
716                    variant,
717                    FftVariant::Standard | FftVariant::InPlace
718                ));
719            }
720            Err(e) => {
721                // Benchmark may fail in some environments, just log and continue
722                println!("Benchmark failed: {e}");
723            }
724        }
725    }
726}