1use rustfft::FftPlanner;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::fs::{self, File};
11use std::io::{BufReader, BufWriter};
12use std::path::{Path, PathBuf};
13use std::sync::{Arc, Mutex};
14use std::time::{Duration, Instant, SystemTime};
15
16use crate::error::{FFTError, FFTResult};
17
18mod plan_map_serde {
20 use super::{PlanInfo, PlanMetrics};
21 use serde::{Deserialize, Deserializer, Serialize, Serializer};
22 use std::collections::HashMap;
23
24 pub fn serialize<S>(
25 map: &HashMap<PlanInfo, PlanMetrics>,
26 serializer: S,
27 ) -> Result<S::Ok, S::Error>
28 where
29 S: Serializer,
30 {
31 let vec: Vec<(PlanInfo, PlanMetrics)> =
33 map.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
34 vec.serialize(serializer)
35 }
36
37 pub fn deserialize<'de, D>(deserializer: D) -> Result<HashMap<PlanInfo, PlanMetrics>, D::Error>
38 where
39 D: Deserializer<'de>,
40 {
41 let vec: Vec<(PlanInfo, PlanMetrics)> = Vec::deserialize(deserializer)?;
43 Ok(vec.into_iter().collect())
44 }
45}
46
47#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
49pub struct PlanInfo {
50 pub size: usize,
52 pub forward: bool,
54 pub arch_id: String,
56 pub created_at: u64,
58 pub lib_version: String,
60}
61
62impl std::hash::Hash for PlanInfo {
64 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
65 self.size.hash(state);
66 self.forward.hash(state);
67 self.arch_id.hash(state);
68 }
70}
71
72#[derive(Serialize, Deserialize, Debug)]
74pub struct PlanDatabase {
75 #[serde(with = "plan_map_serde")]
77 pub plans: HashMap<PlanInfo, PlanMetrics>,
78 pub stats: PlanDatabaseStats,
80 pub last_updated: u64,
82}
83
84#[derive(Serialize, Deserialize, Debug, Clone)]
86pub struct PlanMetrics {
87 pub avg_execution_ns: u64,
89 pub usage_count: u64,
91 pub last_used: u64,
93}
94
95#[derive(Serialize, Deserialize, Debug, Default, Clone)]
97pub struct PlanDatabaseStats {
98 pub total_plans_created: u64,
100 pub total_plans_loaded: u64,
102 pub time_saved_ns: u64,
104}
105
106pub struct PlanSerializationManager {
108 db_path: PathBuf,
110 database: Arc<Mutex<PlanDatabase>>,
112 enabled: bool,
114}
115
116impl PlanSerializationManager {
117 pub fn new(dbpath: impl AsRef<Path>) -> Self {
119 let dbpath = dbpath.as_ref().to_path_buf();
120 let database = Self::load_or_create_database(&dbpath).unwrap_or_else(|_| {
121 Arc::new(Mutex::new(PlanDatabase {
122 plans: HashMap::new(),
123 stats: PlanDatabaseStats::default(),
124 last_updated: system_time_as_millis(),
125 }))
126 });
127
128 Self {
129 db_path: dbpath,
130 database,
131 enabled: true,
132 }
133 }
134
135 fn load_or_create_database(path: &Path) -> FFTResult<Arc<Mutex<PlanDatabase>>> {
137 if path.exists() {
138 let file = File::open(path)
139 .map_err(|e| FFTError::IOError(format!("Failed to open plan database: {e}")))?;
140 let reader = BufReader::new(file);
141 let database: PlanDatabase = serde_json::from_reader(reader)
142 .map_err(|e| FFTError::ValueError(format!("Failed to parse plan database: {e}")))?;
143 Ok(Arc::new(Mutex::new(database)))
144 } else {
145 if let Some(parent) = path.parent() {
147 fs::create_dir_all(parent).map_err(|e| {
148 FFTError::IOError(format!("Failed to create directory for plan database: {e}"))
149 })?;
150 }
151
152 let database = PlanDatabase {
154 plans: HashMap::new(),
155 stats: PlanDatabaseStats::default(),
156 last_updated: system_time_as_millis(),
157 };
158 Ok(Arc::new(Mutex::new(database)))
159 }
160 }
161
162 pub fn detect_arch_id() -> String {
164 let mut arch_id = String::new();
167
168 #[cfg(target_arch = "x86_64")]
169 {
170 arch_id.push_str("x86_64");
171 }
172
173 #[cfg(target_arch = "aarch64")]
174 {
175 arch_id.push_str("aarch64");
176 }
177
178 #[cfg(all(target_arch = "x86_64", target_feature = "avx"))]
180 {
181 arch_id.push_str("-avx");
182 }
183
184 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
185 {
186 arch_id.push_str("-avx2");
187 }
188
189 if arch_id.is_empty() {
190 arch_id = format!("unknown-{}", std::env::consts::ARCH);
191 }
192
193 arch_id
194 }
195
196 fn get_lib_version() -> String {
198 env!("CARGO_PKG_VERSION").to_string()
199 }
200
201 pub fn create_plan_info(&self, size: usize, forward: bool) -> PlanInfo {
203 PlanInfo {
204 size,
205 forward,
206 arch_id: Self::detect_arch_id(),
207 created_at: system_time_as_millis(),
208 lib_version: Self::get_lib_version(),
209 }
210 }
211
212 pub fn plan_exists(&self, size: usize, forward: bool) -> bool {
214 if !self.enabled {
215 return false;
216 }
217
218 let arch_id = Self::detect_arch_id();
219 let db = self.database.lock().unwrap();
220
221 db.plans
222 .keys()
223 .any(|info| info.size == size && info.forward == forward && info.arch_id == arch_id)
224 }
225
226 pub fn record_plan_usage(&self, plan_info: &PlanInfo, execution_timens: u64) -> FFTResult<()> {
228 if !self.enabled {
229 return Ok(());
230 }
231
232 let mut db = self.database.lock().unwrap();
233
234 let metrics = db
236 .plans
237 .entry(plan_info.clone())
238 .or_insert_with(|| PlanMetrics {
239 avg_execution_ns: execution_timens,
240 usage_count: 0,
241 last_used: system_time_as_millis(),
242 });
243
244 metrics.usage_count += 1;
246 metrics.last_used = system_time_as_millis();
247
248 metrics.avg_execution_ns = if metrics.usage_count > 1 {
250 ((metrics.avg_execution_ns as f64 * (metrics.usage_count - 1) as f64)
251 + execution_timens as f64)
252 / metrics.usage_count as f64
253 } else {
254 execution_timens as f64
255 } as u64;
256
257 if db.last_updated + 60000 < system_time_as_millis() {
259 self.save_database()?;
261 db.last_updated = system_time_as_millis();
262 }
263
264 Ok(())
265 }
266
267 pub fn save_database(&self) -> FFTResult<()> {
269 if !self.enabled {
270 return Ok(());
271 }
272
273 let db = self.database.lock().unwrap();
274 let file = File::create(&self.db_path)
275 .map_err(|e| FFTError::IOError(format!("Failed to create plan database file: {e}")))?;
276
277 let writer = BufWriter::new(file);
278 serde_json::to_writer_pretty(writer, &*db)
279 .map_err(|e| FFTError::IOError(format!("Failed to serialize plan database: {e}")))?;
280
281 Ok(())
282 }
283
284 pub fn set_enabled(&mut self, enabled: bool) {
286 self.enabled = enabled;
287 }
288
289 pub fn get_best_plan_metrics(
291 &self,
292 size: usize,
293 forward: bool,
294 ) -> Option<(PlanInfo, PlanMetrics)> {
295 if !self.enabled {
296 return None;
297 }
298
299 let arch_id = Self::detect_arch_id();
300 let db = self.database.lock().unwrap();
301
302 db.plans
303 .iter()
304 .filter(|(info_, _)| {
305 info_.size == size && info_.forward == forward && info_.arch_id == arch_id
306 })
307 .min_by_key(|(_, metrics)| metrics.avg_execution_ns)
308 .map(|(info, metrics)| (info.clone(), metrics.clone()))
309 }
310
311 pub fn get_stats(&self) -> PlanDatabaseStats {
313 if let Ok(db) = self.database.lock() {
314 db.stats.clone()
315 } else {
316 PlanDatabaseStats::default()
317 }
318 }
319}
320
321#[allow(dead_code)]
323fn system_time_as_millis() -> u64 {
324 SystemTime::now()
325 .duration_since(SystemTime::UNIX_EPOCH)
326 .unwrap_or_else(|_| Duration::from_secs(0))
327 .as_millis() as u64
328}
329
330#[allow(dead_code)]
332pub fn create_and_time_plan(size: usize, forward: bool) -> (Arc<dyn rustfft::Fft<f64>>, u64) {
333 let start = Instant::now();
334 let mut planner = FftPlanner::new();
335 let plan = if forward {
336 planner.plan_fft_forward(size)
337 } else {
338 planner.plan_fft_inverse(size)
339 };
340 let elapsed_ns = start.elapsed().as_nanos() as u64;
341
342 (plan, elapsed_ns)
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348 use tempfile::tempdir;
349
350 #[test]
351 fn test_plan_serialization_basic() {
352 let temp_dir = tempdir().unwrap();
354 let db_path = temp_dir.path().join("test_plan_db.json");
355
356 let manager = PlanSerializationManager::new(&db_path);
358
359 let plan_info = manager.create_plan_info(1024, true);
361
362 manager.record_plan_usage(&plan_info, 5000).unwrap();
364
365 assert!(manager.plan_exists(1024, true));
367
368 manager.save_database().unwrap();
370
371 assert!(db_path.exists());
373 }
374
375 #[test]
376 fn test_arch_detection() {
377 let arch_id = PlanSerializationManager::detect_arch_id();
378 assert!(!arch_id.is_empty());
379 }
380
381 #[test]
382 fn test_get_best_plan() {
383 let temp_dir = tempdir().unwrap();
385 let db_path = temp_dir.path().join("test_best_plan.json");
386
387 let manager = PlanSerializationManager::new(&db_path);
389
390 let plan_info1 = manager.create_plan_info(512, true);
392
393 std::thread::sleep(Duration::from_millis(10));
395 let plan_info2 = manager.create_plan_info(512, true);
396
397 let time1 = 8000u64;
399 let time2 = 5000u64;
400 manager.record_plan_usage(&plan_info1, time1).unwrap();
401 manager.record_plan_usage(&plan_info2, time2).unwrap();
402
403 let best = manager.get_best_plan_metrics(512, true);
405 assert!(best.is_some());
406
407 let (_, metrics) = best.unwrap();
408 assert!(metrics.avg_execution_ns == time1 || metrics.avg_execution_ns == time2);
410 assert!(metrics.avg_execution_ns <= std::cmp::max(time1, time2));
411 }
412}