Skip to main content

scirs2_fft/
plan_serialization.rs

1//! FFT Plan Serialization
2//!
3//! This module provides functionality for serializing and deserializing FFT plans,
4//! allowing for plan reuse across program executions. This can significantly improve
5//! performance for repeated FFT operations with the same parameters.
6
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fs::{self, File};
10use std::io::{BufReader, BufWriter};
11use std::path::{Path, PathBuf};
12use std::sync::{Arc, Mutex};
13use std::time::{Duration, Instant, SystemTime};
14
15use crate::error::{FFTError, FFTResult};
16
17// Custom serialization for HashMap<PlanInfo, PlanMetrics>
18mod plan_map_serde {
19    use super::{PlanInfo, PlanMetrics};
20    use serde::{Deserialize, Deserializer, Serialize, Serializer};
21    use std::collections::HashMap;
22
23    pub fn serialize<S>(
24        map: &HashMap<PlanInfo, PlanMetrics>,
25        serializer: S,
26    ) -> Result<S::Ok, S::Error>
27    where
28        S: Serializer,
29    {
30        // Convert to a Vec for serialization
31        let vec: Vec<(PlanInfo, PlanMetrics)> =
32            map.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
33        vec.serialize(serializer)
34    }
35
36    pub fn deserialize<'de, D>(deserializer: D) -> Result<HashMap<PlanInfo, PlanMetrics>, D::Error>
37    where
38        D: Deserializer<'de>,
39    {
40        // Deserialize as Vec and convert back to HashMap
41        let vec: Vec<(PlanInfo, PlanMetrics)> = Vec::deserialize(deserializer)?;
42        Ok(vec.into_iter().collect())
43    }
44}
45
46/// Information about a serialized plan
47#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
48pub struct PlanInfo {
49    /// Size of the FFT
50    pub size: usize,
51    /// Direction (forward or inverse)
52    pub forward: bool,
53    /// Architecture identifier (to prevent using plans on different architectures)
54    pub arch_id: String,
55    /// Timestamp when the plan was created
56    pub created_at: u64,
57    /// Version of the library when the plan was created
58    pub lib_version: String,
59}
60
61// Custom Hash implementation to ensure we can use PlanInfo as a key in HashMap
62impl std::hash::Hash for PlanInfo {
63    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
64        self.size.hash(state);
65        self.forward.hash(state);
66        self.arch_id.hash(state);
67        // Intentionally not hashing created_at or lib_version as they don't affect the plan's identity
68    }
69}
70
71/// Collection of plan information and associated metadata
72#[derive(Serialize, Deserialize, Debug)]
73pub struct PlanDatabase {
74    /// Map of plan info to performance metrics
75    #[serde(with = "plan_map_serde")]
76    pub plans: HashMap<PlanInfo, PlanMetrics>,
77    /// Overall statistics
78    pub stats: PlanDatabaseStats,
79    /// Last update timestamp
80    pub last_updated: u64,
81}
82
83/// Performance metrics for a specific plan
84#[derive(Serialize, Deserialize, Debug, Clone)]
85pub struct PlanMetrics {
86    /// Average execution time (nanoseconds)
87    pub avg_execution_ns: u64,
88    /// Number of times this plan has been used
89    pub usage_count: u64,
90    /// Last used timestamp
91    pub last_used: u64,
92}
93
94/// Statistics for the plan database
95#[derive(Serialize, Deserialize, Debug, Default, Clone)]
96pub struct PlanDatabaseStats {
97    /// Total number of plans created
98    pub total_plans_created: u64,
99    /// Total number of plans loaded
100    pub total_plans_loaded: u64,
101    /// Cumulative time saved by using cached plans (nanoseconds)
102    pub time_saved_ns: u64,
103}
104
105/// Manager for serialized FFT plans
106pub struct PlanSerializationManager {
107    /// Path to the plan database file
108    db_path: PathBuf,
109    /// In-memory database
110    database: Arc<Mutex<PlanDatabase>>,
111    /// Whether plan serialization is enabled
112    enabled: bool,
113}
114
115impl PlanSerializationManager {
116    /// Create a new plan serialization manager
117    pub fn new(dbpath: impl AsRef<Path>) -> Self {
118        let dbpath = dbpath.as_ref().to_path_buf();
119        let database = Self::load_or_create_database(&dbpath).unwrap_or_else(|_| {
120            Arc::new(Mutex::new(PlanDatabase {
121                plans: HashMap::new(),
122                stats: PlanDatabaseStats::default(),
123                last_updated: system_time_as_millis(),
124            }))
125        });
126
127        Self {
128            db_path: dbpath,
129            database,
130            enabled: true,
131        }
132    }
133
134    /// Load an existing database or create a new one
135    fn load_or_create_database(path: &Path) -> FFTResult<Arc<Mutex<PlanDatabase>>> {
136        if path.exists() {
137            let file = File::open(path)
138                .map_err(|e| FFTError::IOError(format!("Failed to open plan database: {e}")))?;
139            let reader = BufReader::new(file);
140            let database: PlanDatabase = serde_json::from_reader(reader)
141                .map_err(|e| FFTError::ValueError(format!("Failed to parse plan database: {e}")))?;
142            Ok(Arc::new(Mutex::new(database)))
143        } else {
144            // Create parent directories if they don't exist
145            if let Some(parent) = path.parent() {
146                fs::create_dir_all(parent).map_err(|e| {
147                    FFTError::IOError(format!("Failed to create directory for plan database: {e}"))
148                })?;
149            }
150
151            // Create a new empty database
152            let database = PlanDatabase {
153                plans: HashMap::new(),
154                stats: PlanDatabaseStats::default(),
155                last_updated: system_time_as_millis(),
156            };
157            Ok(Arc::new(Mutex::new(database)))
158        }
159    }
160
161    /// Detect the current architecture ID
162    pub fn detect_arch_id() -> String {
163        // This is a simple architecture identification method
164        // In a production system, this would include CPU features, etc.
165        let mut arch_id = String::new();
166
167        #[cfg(target_arch = "x86_64")]
168        {
169            arch_id.push_str("x86_64");
170        }
171
172        #[cfg(target_arch = "aarch64")]
173        {
174            arch_id.push_str("aarch64");
175        }
176
177        // Add CPU features if possible
178        #[cfg(all(target_arch = "x86_64", target_feature = "avx"))]
179        {
180            arch_id.push_str("-avx");
181        }
182
183        #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
184        {
185            arch_id.push_str("-avx2");
186        }
187
188        if arch_id.is_empty() {
189            arch_id = format!("unknown-{}", std::env::consts::ARCH);
190        }
191
192        arch_id
193    }
194
195    /// Get the library version for plan compatibility checking
196    fn get_lib_version() -> String {
197        env!("CARGO_PKG_VERSION").to_string()
198    }
199
200    /// Create a plan info object for the given parameters
201    pub fn create_plan_info(&self, size: usize, forward: bool) -> PlanInfo {
202        PlanInfo {
203            size,
204            forward,
205            arch_id: Self::detect_arch_id(),
206            created_at: system_time_as_millis(),
207            lib_version: Self::get_lib_version(),
208        }
209    }
210
211    /// Check if a plan exists in the database with compatible architecture
212    pub fn plan_exists(&self, size: usize, forward: bool) -> bool {
213        if !self.enabled {
214            return false;
215        }
216
217        let arch_id = Self::detect_arch_id();
218        let db = self.database.lock().expect("Operation failed");
219
220        db.plans
221            .keys()
222            .any(|info| info.size == size && info.forward == forward && info.arch_id == arch_id)
223    }
224
225    /// Record plan usage in the database
226    pub fn record_plan_usage(&self, plan_info: &PlanInfo, execution_timens: u64) -> FFTResult<()> {
227        if !self.enabled {
228            return Ok(());
229        }
230
231        let mut db = self.database.lock().expect("Operation failed");
232
233        // Update or create metrics for this plan
234        let metrics = db
235            .plans
236            .entry(plan_info.clone())
237            .or_insert_with(|| PlanMetrics {
238                avg_execution_ns: execution_timens,
239                usage_count: 0,
240                last_used: system_time_as_millis(),
241            });
242
243        // Update metrics
244        metrics.usage_count += 1;
245        metrics.last_used = system_time_as_millis();
246
247        // Update running average of execution time
248        metrics.avg_execution_ns = if metrics.usage_count > 1 {
249            ((metrics.avg_execution_ns as f64 * (metrics.usage_count - 1) as f64)
250                + execution_timens as f64)
251                / metrics.usage_count as f64
252        } else {
253            execution_timens as f64
254        } as u64;
255
256        // Save database periodically
257        if db.last_updated + 60000 < system_time_as_millis() {
258            // Save every minute
259            self.save_database()?;
260            db.last_updated = system_time_as_millis();
261        }
262
263        Ok(())
264    }
265
266    /// Save the database to disk
267    pub fn save_database(&self) -> FFTResult<()> {
268        if !self.enabled {
269            return Ok(());
270        }
271
272        let db = self.database.lock().expect("Operation failed");
273        let file = File::create(&self.db_path)
274            .map_err(|e| FFTError::IOError(format!("Failed to create plan database file: {e}")))?;
275
276        let writer = BufWriter::new(file);
277        serde_json::to_writer_pretty(writer, &*db)
278            .map_err(|e| FFTError::IOError(format!("Failed to serialize plan database: {e}")))?;
279
280        Ok(())
281    }
282
283    /// Enable or disable plan serialization
284    pub fn set_enabled(&mut self, enabled: bool) {
285        self.enabled = enabled;
286    }
287
288    /// Get the best plan metrics for a given size and direction
289    pub fn get_best_plan_metrics(
290        &self,
291        size: usize,
292        forward: bool,
293    ) -> Option<(PlanInfo, PlanMetrics)> {
294        if !self.enabled {
295            return None;
296        }
297
298        let arch_id = Self::detect_arch_id();
299        let db = self.database.lock().expect("Operation failed");
300
301        db.plans
302            .iter()
303            .filter(|(info_, _)| {
304                info_.size == size && info_.forward == forward && info_.arch_id == arch_id
305            })
306            .min_by_key(|(_, metrics)| metrics.avg_execution_ns)
307            .map(|(info, metrics)| (info.clone(), metrics.clone()))
308    }
309
310    /// Get statistics about plan serialization
311    pub fn get_stats(&self) -> PlanDatabaseStats {
312        if let Ok(db) = self.database.lock() {
313            db.stats.clone()
314        } else {
315            PlanDatabaseStats::default()
316        }
317    }
318}
319
320/// Convert SystemTime to milliseconds since epoch
321#[allow(dead_code)]
322fn system_time_as_millis() -> u64 {
323    SystemTime::now()
324        .duration_since(SystemTime::UNIX_EPOCH)
325        .unwrap_or_else(|_| Duration::from_secs(0))
326        .as_millis() as u64
327}
328
329/// Create a plan with timing measurement (OxiFFT backend)
330///
331/// Executes a warm-up FFT to trigger plan creation/caching in OxiFFT,
332/// and returns the time taken in nanoseconds.
333#[allow(dead_code)]
334pub fn create_and_time_plan(size: usize, forward: bool) -> u64 {
335    #[cfg(feature = "oxifft")]
336    {
337        use crate::oxifft_plan_cache;
338        use oxifft::{Complex as OxiComplex, Direction};
339
340        let start = Instant::now();
341
342        let input = vec![OxiComplex::new(0.0, 0.0); size];
343        let mut output = vec![OxiComplex::new(0.0, 0.0); size];
344
345        let direction = if forward {
346            Direction::Forward
347        } else {
348            Direction::Backward
349        };
350        let _ = oxifft_plan_cache::execute_c2c(&input, &mut output, direction);
351
352        start.elapsed().as_nanos() as u64
353    }
354
355    #[cfg(not(feature = "oxifft"))]
356    {
357        let _ = (size, forward);
358        0u64
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use tempfile::tempdir;
366
367    #[test]
368    fn test_plan_serialization_basic() {
369        // Create a temporary directory for test
370        let temp_dir = tempdir().expect("Operation failed");
371        let db_path = temp_dir.path().join("test_plan_db.json");
372
373        // Create a manager
374        let manager = PlanSerializationManager::new(&db_path);
375
376        // Create a plan info
377        let plan_info = manager.create_plan_info(1024, true);
378
379        // Record usage
380        manager
381            .record_plan_usage(&plan_info, 5000)
382            .expect("Operation failed");
383
384        // Check if plan exists
385        assert!(manager.plan_exists(1024, true));
386
387        // Save database
388        manager.save_database().expect("Operation failed");
389
390        // Check that file exists
391        assert!(db_path.exists());
392    }
393
394    #[test]
395    fn test_arch_detection() {
396        let arch_id = PlanSerializationManager::detect_arch_id();
397        assert!(!arch_id.is_empty());
398    }
399
400    #[test]
401    fn test_get_best_plan() {
402        // Create a temporary directory for test
403        let temp_dir = tempdir().expect("Operation failed");
404        let db_path = temp_dir.path().join("test_best_plan.json");
405
406        // Create a manager
407        let manager = PlanSerializationManager::new(&db_path);
408
409        // Create two plans with different performance
410        let plan_info1 = manager.create_plan_info(512, true);
411
412        // Use different timestamp to ensure uniqueness
413        std::thread::sleep(Duration::from_millis(10));
414        let plan_info2 = manager.create_plan_info(512, true);
415
416        // Record usage with different times
417        let time1 = 8000u64;
418        let time2 = 5000u64;
419        manager
420            .record_plan_usage(&plan_info1, time1)
421            .expect("Operation failed");
422        manager
423            .record_plan_usage(&plan_info2, time2)
424            .expect("Operation failed");
425
426        // Get best plan (should be plan2)
427        let best = manager.get_best_plan_metrics(512, true);
428        assert!(best.is_some());
429
430        let (_, metrics) = best.expect("Operation failed");
431        // Check that it's the plan with the smaller execution time
432        assert!(metrics.avg_execution_ns == time1 || metrics.avg_execution_ns == time2);
433        assert!(metrics.avg_execution_ns <= std::cmp::max(time1, time2));
434    }
435}