1use rustfft::FftPlanner;
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex};
9use std::time::{Duration, Instant};
10
11#[derive(Clone, Debug, Hash, PartialEq, Eq)]
13struct PlanKey {
14 size: usize,
15 forward: bool,
16 }
18
19#[derive(Clone)]
21struct CachedPlan {
22 plan: Arc<dyn rustfft::Fft<f64>>,
23 last_used: Instant,
24 usage_count: usize,
25}
26
27pub struct PlanCache {
29 cache: Arc<Mutex<HashMap<PlanKey, CachedPlan>>>,
30 max_entries: usize,
31 max_age: Duration,
32 enabled: Arc<Mutex<bool>>,
33 hit_count: Arc<Mutex<u64>>,
34 miss_count: Arc<Mutex<u64>>,
35}
36
37impl PlanCache {
38 pub fn new() -> Self {
40 Self {
41 cache: Arc::new(Mutex::new(HashMap::new())),
42 max_entries: 128,
43 max_age: Duration::from_secs(3600), enabled: Arc::new(Mutex::new(true)),
45 hit_count: Arc::new(Mutex::new(0)),
46 miss_count: Arc::new(Mutex::new(0)),
47 }
48 }
49
50 pub fn with_config(max_entries: usize, max_age: Duration) -> Self {
52 Self {
53 cache: Arc::new(Mutex::new(HashMap::new())),
54 max_entries,
55 max_age,
56 enabled: Arc::new(Mutex::new(true)),
57 hit_count: Arc::new(Mutex::new(0)),
58 miss_count: Arc::new(Mutex::new(0)),
59 }
60 }
61
62 pub fn set_enabled(&self, enabled: bool) {
64 *self.enabled.lock().unwrap() = enabled;
65 }
66
67 pub fn is_enabled(&self) -> bool {
69 *self.enabled.lock().unwrap()
70 }
71
72 pub fn clear(&self) {
74 if let Ok(mut cache) = self.cache.lock() {
75 cache.clear();
76 }
77 }
78
79 pub fn get_stats(&self) -> CacheStats {
81 let hit_count = *self.hit_count.lock().unwrap();
82 let miss_count = *self.miss_count.lock().unwrap();
83 let total_requests = hit_count + miss_count;
84 let hit_rate = if total_requests > 0 {
85 hit_count as f64 / total_requests as f64
86 } else {
87 0.0
88 };
89
90 let size = self.cache.lock().map(|c| c.len()).unwrap_or(0);
91
92 CacheStats {
93 hit_count,
94 miss_count,
95 hit_rate,
96 size,
97 max_size: self.max_entries,
98 }
99 }
100
101 pub fn get_or_create_plan(
103 &self,
104 size: usize,
105 forward: bool,
106 planner: &mut FftPlanner<f64>,
107 ) -> Arc<dyn rustfft::Fft<f64>> {
108 if !*self.enabled.lock().unwrap() {
109 return if forward {
110 planner.plan_fft_forward(size)
111 } else {
112 planner.plan_fft_inverse(size)
113 };
114 }
115
116 let key = PlanKey { size, forward };
117
118 if let Ok(mut cache) = self.cache.lock() {
120 if let Some(cached) = cache.get_mut(&key) {
121 if cached.last_used.elapsed() <= self.max_age {
123 cached.last_used = Instant::now();
124 cached.usage_count += 1;
125 *self.hit_count.lock().unwrap() += 1;
126 return cached.plan.clone();
127 } else {
128 cache.remove(&key);
130 }
131 }
132 }
133
134 *self.miss_count.lock().unwrap() += 1;
136
137 let plan: Arc<dyn rustfft::Fft<f64>> = if forward {
138 planner.plan_fft_forward(size)
139 } else {
140 planner.plan_fft_inverse(size)
141 };
142
143 if let Ok(mut cache) = self.cache.lock() {
145 if cache.len() >= self.max_entries {
147 self.evict_old_entries(&mut cache);
148 }
149
150 cache.insert(
151 key,
152 CachedPlan {
153 plan: plan.clone(),
154 last_used: Instant::now(),
155 usage_count: 1,
156 },
157 );
158 }
159
160 plan
161 }
162
163 fn evict_old_entries(&self, cache: &mut HashMap<PlanKey, CachedPlan>) {
165 cache.retain(|_, v| v.last_used.elapsed() <= self.max_age);
167
168 while cache.len() >= self.max_entries {
170 if let Some((key_to_remove_, _)) = cache
171 .iter()
172 .min_by_key(|(_, v)| (v.last_used, v.usage_count))
173 .map(|(k, v)| (k.clone(), v.clone()))
174 {
175 cache.remove(&key_to_remove_);
176 } else {
177 break;
178 }
179 }
180 }
181
182 pub fn precompute_common_sizes(&self, sizes: &[usize], planner: &mut FftPlanner<f64>) {
184 for &size in sizes {
185 self.get_or_create_plan(size, true, planner);
187 self.get_or_create_plan(size, false, planner);
188 }
189 }
190}
191
192impl Default for PlanCache {
193 fn default() -> Self {
194 Self::new()
195 }
196}
197
198#[derive(Debug, Clone)]
200pub struct CacheStats {
201 pub hit_count: u64,
202 pub miss_count: u64,
203 pub hit_rate: f64,
204 pub size: usize,
205 pub max_size: usize,
206}
207
208impl std::fmt::Display for CacheStats {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 write!(
211 f,
212 "Cache Stats: {} hits, {} misses ({:.1}% hit rate), {}/{} entries",
213 self.hit_count,
214 self.miss_count,
215 self.hit_rate * 100.0,
216 self.size,
217 self.max_size
218 )
219 }
220}
221
222static GLOBAL_PLAN_CACHE: std::sync::OnceLock<PlanCache> = std::sync::OnceLock::new();
224
225#[allow(dead_code)]
227pub fn get_global_cache() -> &'static PlanCache {
228 GLOBAL_PLAN_CACHE.get_or_init(PlanCache::new)
229}
230
231#[allow(dead_code)]
233pub fn init_global_cache(max_entries: usize, max_age: Duration) -> Result<(), &'static str> {
234 GLOBAL_PLAN_CACHE
235 .set(PlanCache::with_config(max_entries, max_age))
236 .map_err(|_| "Global plan cache already initialized")
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 #[test]
244 fn test_plan_cache_basic() {
245 let cache = PlanCache::new();
246 let mut planner = FftPlanner::new();
247
248 let _plan1 = cache.get_or_create_plan(128, true, &mut planner);
250 let _plan2 = cache.get_or_create_plan(128, true, &mut planner);
251
252 let stats = cache.get_stats();
254 assert_eq!(stats.hit_count, 1);
255 assert_eq!(stats.miss_count, 1);
256 }
257
258 #[test]
259 fn test_cache_eviction() {
260 let cache = PlanCache::with_config(2, Duration::from_secs(3600));
261 let mut planner = FftPlanner::new();
262
263 cache.get_or_create_plan(64, true, &mut planner);
265 cache.get_or_create_plan(128, true, &mut planner);
266
267 cache.get_or_create_plan(256, true, &mut planner);
269
270 let stats = cache.get_stats();
271 assert_eq!(stats.size, 2);
272 }
273
274 #[test]
275 fn test_cache_disabled() {
276 let cache = PlanCache::new();
277 cache.set_enabled(false);
278
279 let mut planner = FftPlanner::new();
280
281 cache.get_or_create_plan(128, true, &mut planner);
283 cache.get_or_create_plan(128, true, &mut planner);
284
285 let stats = cache.get_stats();
287 assert_eq!(stats.hit_count, 0);
288 assert_eq!(stats.miss_count, 0); }
290}