scirs2_fft/
plan_serialization.rs

1//! FFT Plan Serialization
2//!
3//! This module provides functionality for serializing and deserializing FFT plans,
4//! allowing for plan reuse across program executions. This can significantly improve
5//! performance for repeated FFT operations with the same parameters.
6
7use rustfft::FftPlanner;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::fs::{self, File};
11use std::io::{BufReader, BufWriter};
12use std::path::{Path, PathBuf};
13use std::sync::{Arc, Mutex};
14use std::time::{Duration, Instant, SystemTime};
15
16use crate::error::{FFTError, FFTResult};
17
18// Custom serialization for HashMap<PlanInfo, PlanMetrics>
19mod plan_map_serde {
20    use super::{PlanInfo, PlanMetrics};
21    use serde::{Deserialize, Deserializer, Serialize, Serializer};
22    use std::collections::HashMap;
23
24    pub fn serialize<S>(
25        map: &HashMap<PlanInfo, PlanMetrics>,
26        serializer: S,
27    ) -> Result<S::Ok, S::Error>
28    where
29        S: Serializer,
30    {
31        // Convert to a Vec for serialization
32        let vec: Vec<(PlanInfo, PlanMetrics)> =
33            map.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
34        vec.serialize(serializer)
35    }
36
37    pub fn deserialize<'de, D>(deserializer: D) -> Result<HashMap<PlanInfo, PlanMetrics>, D::Error>
38    where
39        D: Deserializer<'de>,
40    {
41        // Deserialize as Vec and convert back to HashMap
42        let vec: Vec<(PlanInfo, PlanMetrics)> = Vec::deserialize(deserializer)?;
43        Ok(vec.into_iter().collect())
44    }
45}
46
47/// Information about a serialized plan
48#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
49pub struct PlanInfo {
50    /// Size of the FFT
51    pub size: usize,
52    /// Direction (forward or inverse)
53    pub forward: bool,
54    /// Architecture identifier (to prevent using plans on different architectures)
55    pub arch_id: String,
56    /// Timestamp when the plan was created
57    pub created_at: u64,
58    /// Version of the library when the plan was created
59    pub lib_version: String,
60}
61
62// Custom Hash implementation to ensure we can use PlanInfo as a key in HashMap
63impl std::hash::Hash for PlanInfo {
64    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
65        self.size.hash(state);
66        self.forward.hash(state);
67        self.arch_id.hash(state);
68        // Intentionally not hashing created_at or lib_version as they don't affect the plan's identity
69    }
70}
71
72/// Collection of plan information and associated metadata
73#[derive(Serialize, Deserialize, Debug)]
74pub struct PlanDatabase {
75    /// Map of plan info to performance metrics
76    #[serde(with = "plan_map_serde")]
77    pub plans: HashMap<PlanInfo, PlanMetrics>,
78    /// Overall statistics
79    pub stats: PlanDatabaseStats,
80    /// Last update timestamp
81    pub last_updated: u64,
82}
83
84/// Performance metrics for a specific plan
85#[derive(Serialize, Deserialize, Debug, Clone)]
86pub struct PlanMetrics {
87    /// Average execution time (nanoseconds)
88    pub avg_execution_ns: u64,
89    /// Number of times this plan has been used
90    pub usage_count: u64,
91    /// Last used timestamp
92    pub last_used: u64,
93}
94
95/// Statistics for the plan database
96#[derive(Serialize, Deserialize, Debug, Default, Clone)]
97pub struct PlanDatabaseStats {
98    /// Total number of plans created
99    pub total_plans_created: u64,
100    /// Total number of plans loaded
101    pub total_plans_loaded: u64,
102    /// Cumulative time saved by using cached plans (nanoseconds)
103    pub time_saved_ns: u64,
104}
105
106/// Manager for serialized FFT plans
107pub struct PlanSerializationManager {
108    /// Path to the plan database file
109    db_path: PathBuf,
110    /// In-memory database
111    database: Arc<Mutex<PlanDatabase>>,
112    /// Whether plan serialization is enabled
113    enabled: bool,
114}
115
116impl PlanSerializationManager {
117    /// Create a new plan serialization manager
118    pub fn new(dbpath: impl AsRef<Path>) -> Self {
119        let dbpath = dbpath.as_ref().to_path_buf();
120        let database = Self::load_or_create_database(&dbpath).unwrap_or_else(|_| {
121            Arc::new(Mutex::new(PlanDatabase {
122                plans: HashMap::new(),
123                stats: PlanDatabaseStats::default(),
124                last_updated: system_time_as_millis(),
125            }))
126        });
127
128        Self {
129            db_path: dbpath,
130            database,
131            enabled: true,
132        }
133    }
134
135    /// Load an existing database or create a new one
136    fn load_or_create_database(path: &Path) -> FFTResult<Arc<Mutex<PlanDatabase>>> {
137        if path.exists() {
138            let file = File::open(path)
139                .map_err(|e| FFTError::IOError(format!("Failed to open plan database: {e}")))?;
140            let reader = BufReader::new(file);
141            let database: PlanDatabase = serde_json::from_reader(reader)
142                .map_err(|e| FFTError::ValueError(format!("Failed to parse plan database: {e}")))?;
143            Ok(Arc::new(Mutex::new(database)))
144        } else {
145            // Create parent directories if they don't exist
146            if let Some(parent) = path.parent() {
147                fs::create_dir_all(parent).map_err(|e| {
148                    FFTError::IOError(format!("Failed to create directory for plan database: {e}"))
149                })?;
150            }
151
152            // Create a new empty database
153            let database = PlanDatabase {
154                plans: HashMap::new(),
155                stats: PlanDatabaseStats::default(),
156                last_updated: system_time_as_millis(),
157            };
158            Ok(Arc::new(Mutex::new(database)))
159        }
160    }
161
162    /// Detect the current architecture ID
163    pub fn detect_arch_id() -> String {
164        // This is a simple architecture identification method
165        // In a production system, this would include CPU features, etc.
166        let mut arch_id = String::new();
167
168        #[cfg(target_arch = "x86_64")]
169        {
170            arch_id.push_str("x86_64");
171        }
172
173        #[cfg(target_arch = "aarch64")]
174        {
175            arch_id.push_str("aarch64");
176        }
177
178        // Add CPU features if possible
179        #[cfg(all(target_arch = "x86_64", target_feature = "avx"))]
180        {
181            arch_id.push_str("-avx");
182        }
183
184        #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
185        {
186            arch_id.push_str("-avx2");
187        }
188
189        if arch_id.is_empty() {
190            arch_id = format!("unknown-{}", std::env::consts::ARCH);
191        }
192
193        arch_id
194    }
195
196    /// Get the library version for plan compatibility checking
197    fn get_lib_version() -> String {
198        env!("CARGO_PKG_VERSION").to_string()
199    }
200
201    /// Create a plan info object for the given parameters
202    pub fn create_plan_info(&self, size: usize, forward: bool) -> PlanInfo {
203        PlanInfo {
204            size,
205            forward,
206            arch_id: Self::detect_arch_id(),
207            created_at: system_time_as_millis(),
208            lib_version: Self::get_lib_version(),
209        }
210    }
211
212    /// Check if a plan exists in the database with compatible architecture
213    pub fn plan_exists(&self, size: usize, forward: bool) -> bool {
214        if !self.enabled {
215            return false;
216        }
217
218        let arch_id = Self::detect_arch_id();
219        let db = self.database.lock().unwrap();
220
221        db.plans
222            .keys()
223            .any(|info| info.size == size && info.forward == forward && info.arch_id == arch_id)
224    }
225
226    /// Record plan usage in the database
227    pub fn record_plan_usage(&self, plan_info: &PlanInfo, execution_timens: u64) -> FFTResult<()> {
228        if !self.enabled {
229            return Ok(());
230        }
231
232        let mut db = self.database.lock().unwrap();
233
234        // Update or create metrics for this plan
235        let metrics = db
236            .plans
237            .entry(plan_info.clone())
238            .or_insert_with(|| PlanMetrics {
239                avg_execution_ns: execution_timens,
240                usage_count: 0,
241                last_used: system_time_as_millis(),
242            });
243
244        // Update metrics
245        metrics.usage_count += 1;
246        metrics.last_used = system_time_as_millis();
247
248        // Update running average of execution time
249        metrics.avg_execution_ns = if metrics.usage_count > 1 {
250            ((metrics.avg_execution_ns as f64 * (metrics.usage_count - 1) as f64)
251                + execution_timens as f64)
252                / metrics.usage_count as f64
253        } else {
254            execution_timens as f64
255        } as u64;
256
257        // Save database periodically
258        if db.last_updated + 60000 < system_time_as_millis() {
259            // Save every minute
260            self.save_database()?;
261            db.last_updated = system_time_as_millis();
262        }
263
264        Ok(())
265    }
266
267    /// Save the database to disk
268    pub fn save_database(&self) -> FFTResult<()> {
269        if !self.enabled {
270            return Ok(());
271        }
272
273        let db = self.database.lock().unwrap();
274        let file = File::create(&self.db_path)
275            .map_err(|e| FFTError::IOError(format!("Failed to create plan database file: {e}")))?;
276
277        let writer = BufWriter::new(file);
278        serde_json::to_writer_pretty(writer, &*db)
279            .map_err(|e| FFTError::IOError(format!("Failed to serialize plan database: {e}")))?;
280
281        Ok(())
282    }
283
284    /// Enable or disable plan serialization
285    pub fn set_enabled(&mut self, enabled: bool) {
286        self.enabled = enabled;
287    }
288
289    /// Get the best plan metrics for a given size and direction
290    pub fn get_best_plan_metrics(
291        &self,
292        size: usize,
293        forward: bool,
294    ) -> Option<(PlanInfo, PlanMetrics)> {
295        if !self.enabled {
296            return None;
297        }
298
299        let arch_id = Self::detect_arch_id();
300        let db = self.database.lock().unwrap();
301
302        db.plans
303            .iter()
304            .filter(|(info_, _)| {
305                info_.size == size && info_.forward == forward && info_.arch_id == arch_id
306            })
307            .min_by_key(|(_, metrics)| metrics.avg_execution_ns)
308            .map(|(info, metrics)| (info.clone(), metrics.clone()))
309    }
310
311    /// Get statistics about plan serialization
312    pub fn get_stats(&self) -> PlanDatabaseStats {
313        if let Ok(db) = self.database.lock() {
314            db.stats.clone()
315        } else {
316            PlanDatabaseStats::default()
317        }
318    }
319}
320
321/// Convert SystemTime to milliseconds since epoch
322#[allow(dead_code)]
323fn system_time_as_millis() -> u64 {
324    SystemTime::now()
325        .duration_since(SystemTime::UNIX_EPOCH)
326        .unwrap_or_else(|_| Duration::from_secs(0))
327        .as_millis() as u64
328}
329
330/// Create a plan with timing measurement
331#[allow(dead_code)]
332pub fn create_and_time_plan(size: usize, forward: bool) -> (Arc<dyn rustfft::Fft<f64>>, u64) {
333    let start = Instant::now();
334    let mut planner = FftPlanner::new();
335    let plan = if forward {
336        planner.plan_fft_forward(size)
337    } else {
338        planner.plan_fft_inverse(size)
339    };
340    let elapsed_ns = start.elapsed().as_nanos() as u64;
341
342    (plan, elapsed_ns)
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348    use tempfile::tempdir;
349
350    #[test]
351    fn test_plan_serialization_basic() {
352        // Create a temporary directory for test
353        let temp_dir = tempdir().unwrap();
354        let db_path = temp_dir.path().join("test_plan_db.json");
355
356        // Create a manager
357        let manager = PlanSerializationManager::new(&db_path);
358
359        // Create a plan info
360        let plan_info = manager.create_plan_info(1024, true);
361
362        // Record usage
363        manager.record_plan_usage(&plan_info, 5000).unwrap();
364
365        // Check if plan exists
366        assert!(manager.plan_exists(1024, true));
367
368        // Save database
369        manager.save_database().unwrap();
370
371        // Check that file exists
372        assert!(db_path.exists());
373    }
374
375    #[test]
376    fn test_arch_detection() {
377        let arch_id = PlanSerializationManager::detect_arch_id();
378        assert!(!arch_id.is_empty());
379    }
380
381    #[test]
382    fn test_get_best_plan() {
383        // Create a temporary directory for test
384        let temp_dir = tempdir().unwrap();
385        let db_path = temp_dir.path().join("test_best_plan.json");
386
387        // Create a manager
388        let manager = PlanSerializationManager::new(&db_path);
389
390        // Create two plans with different performance
391        let plan_info1 = manager.create_plan_info(512, true);
392
393        // Use different timestamp to ensure uniqueness
394        std::thread::sleep(Duration::from_millis(10));
395        let plan_info2 = manager.create_plan_info(512, true);
396
397        // Record usage with different times
398        let time1 = 8000u64;
399        let time2 = 5000u64;
400        manager.record_plan_usage(&plan_info1, time1).unwrap();
401        manager.record_plan_usage(&plan_info2, time2).unwrap();
402
403        // Get best plan (should be plan2)
404        let best = manager.get_best_plan_metrics(512, true);
405        assert!(best.is_some());
406
407        let (_, metrics) = best.unwrap();
408        // Check that it's the plan with the smaller execution time
409        assert!(metrics.avg_execution_ns == time1 || metrics.avg_execution_ns == time2);
410        assert!(metrics.avg_execution_ns <= std::cmp::max(time1, time2));
411    }
412}