1use std::collections::HashMap;
15use std::hash::Hash;
16use std::sync::RwLock;
17use std::time::Duration;
18
19use tokio::time::Instant;
20use tracing::{debug, warn};
21
22#[derive(Debug, Clone)]
24struct CacheEntry<V> {
25 value: V,
26 inserted_at: Instant,
27 ttl: Duration,
28}
29
30impl<V> CacheEntry<V> {
31 fn new(value: V, ttl: Duration) -> Self {
33 Self {
34 value,
35 inserted_at: Instant::now(),
36 ttl,
37 }
38 }
39
40 fn is_expired(&self) -> bool {
42 self.inserted_at.elapsed() > self.ttl
43 }
44
45 fn is_stale(&self) -> bool {
48 self.inserted_at.elapsed() > (self.ttl * 3 / 4)
49 }
50
51 fn age(&self) -> Duration {
53 self.inserted_at.elapsed()
54 }
55}
56
57pub struct TtlCache<K, V> {
81 entries: RwLock<HashMap<K, CacheEntry<V>>>,
82 default_ttl: Duration,
83 max_capacity: usize,
86}
87
88const DEFAULT_MAX_CAPACITY: usize = 1024;
90
91impl<K, V> TtlCache<K, V>
92where
93 K: Eq + Hash + Clone + std::fmt::Debug,
94 V: Clone,
95{
96 pub fn new(default_ttl: Duration) -> Self {
98 Self {
99 entries: RwLock::new(HashMap::new()),
100 default_ttl,
101 max_capacity: DEFAULT_MAX_CAPACITY,
102 }
103 }
104
105 pub fn with_max_capacity(default_ttl: Duration, max_capacity: usize) -> Self {
107 Self {
108 entries: RwLock::new(HashMap::new()),
109 default_ttl,
110 max_capacity,
111 }
112 }
113
114 pub fn get(&self, key: &K) -> Option<V> {
119 let entries = match self.entries.read() {
120 Ok(guard) => guard,
121 Err(poisoned) => {
122 warn!("Cache read lock poisoned, recovering");
123 poisoned.into_inner()
124 }
125 };
126 let entry = entries.get(key)?;
127
128 if entry.is_expired() {
129 debug!(
130 hit = false,
131 ?key,
132 age_secs = entry.age().as_secs(),
133 "cache lookup (expired)"
134 );
135 None
136 } else {
137 debug!(hit = true, ?key, "cache lookup");
138 Some(entry.value.clone())
139 }
140 }
141
142 pub fn get_stale(&self, key: &K) -> Option<V> {
147 let entries = match self.entries.read() {
148 Ok(guard) => guard,
149 Err(poisoned) => {
150 warn!("Cache read lock poisoned, recovering");
151 poisoned.into_inner()
152 }
153 };
154 entries.get(key).map(|entry| {
155 if entry.is_expired() {
156 debug!(
157 ?key,
158 age_secs = entry.age().as_secs(),
159 "Serving stale cache entry"
160 );
161 }
162 entry.value.clone()
163 })
164 }
165
166 pub fn needs_refresh(&self, key: &K) -> bool {
170 let entries = match self.entries.read() {
171 Ok(guard) => guard,
172 Err(poisoned) => {
173 warn!("Cache read lock poisoned, recovering");
174 poisoned.into_inner()
175 }
176 };
177
178 entries.get(key).is_some_and(|entry| entry.is_stale())
179 }
180
181 pub fn insert(&self, key: K, value: V) {
183 self.insert_with_ttl(key, value, self.default_ttl);
184 }
185
186 pub fn insert_with_ttl(&self, key: K, value: V, ttl: Duration) {
191 let mut entries = match self.entries.write() {
192 Ok(guard) => guard,
193 Err(poisoned) => {
194 warn!("Cache write lock poisoned, recovering");
195 poisoned.into_inner()
196 }
197 };
198
199 if entries.len() >= self.max_capacity && !entries.contains_key(&key) {
201 let before = entries.len();
203 entries.retain(|_, entry| !entry.is_expired());
204 let removed = before - entries.len();
205 if removed > 0 {
206 debug!(removed, "Evicted expired entries to make room");
207 }
208
209 if entries.len() >= self.max_capacity {
211 if let Some(oldest_key) = entries
212 .iter()
213 .max_by_key(|(_, entry)| entry.age())
214 .map(|(k, _)| k.clone())
215 {
216 entries.remove(&oldest_key);
217 debug!(?oldest_key, "Evicted oldest entry to make room");
218 }
219 }
220 }
221
222 debug!(?key, ttl_secs = ttl.as_secs(), "Inserting cache entry");
223 entries.insert(key, CacheEntry::new(value, ttl));
224 }
225
226 pub fn remove(&self, key: &K) -> Option<V> {
228 let mut entries = match self.entries.write() {
229 Ok(guard) => guard,
230 Err(poisoned) => {
231 warn!("Cache write lock poisoned, recovering");
232 poisoned.into_inner()
233 }
234 };
235 entries.remove(key).map(|e| e.value)
236 }
237
238 pub fn cleanup(&self) {
242 let mut entries = match self.entries.write() {
243 Ok(guard) => guard,
244 Err(poisoned) => {
245 warn!("Cache write lock poisoned, recovering");
246 poisoned.into_inner()
247 }
248 };
249 let before = entries.len();
250 entries.retain(|_, entry| !entry.is_expired());
251 let removed = before - entries.len();
252 if removed > 0 {
253 debug!(removed, remaining = entries.len(), "Cache cleanup complete");
254 }
255 }
256
257 pub fn len(&self) -> usize {
259 match self.entries.read() {
260 Ok(entries) => entries.len(),
261 Err(poisoned) => {
262 warn!("Cache read lock poisoned, recovering");
263 poisoned.into_inner().len()
264 }
265 }
266 }
267
268 pub fn is_empty(&self) -> bool {
270 self.len() == 0
271 }
272
273 pub fn clear(&self) {
275 let mut entries = match self.entries.write() {
276 Ok(guard) => guard,
277 Err(poisoned) => {
278 warn!("Cache write lock poisoned, recovering");
279 poisoned.into_inner()
280 }
281 };
282 entries.clear();
283 }
284}
285
286pub struct SingleValueCache<V> {
291 entry: RwLock<Option<CacheEntry<V>>>,
292 ttl: Duration,
293}
294
295impl<V: Clone> SingleValueCache<V> {
296 pub fn new(ttl: Duration) -> Self {
298 Self {
299 entry: RwLock::new(None),
300 ttl,
301 }
302 }
303
304 pub fn get(&self) -> Option<V> {
306 let guard = match self.entry.read() {
307 Ok(guard) => guard,
308 Err(poisoned) => {
309 warn!("SingleValueCache read lock poisoned, recovering");
310 poisoned.into_inner()
311 }
312 };
313 let entry = guard.as_ref()?;
314
315 if entry.is_expired() {
316 None
317 } else {
318 Some(entry.value.clone())
319 }
320 }
321
322 pub fn get_stale(&self) -> Option<V> {
324 let guard = match self.entry.read() {
325 Ok(guard) => guard,
326 Err(poisoned) => {
327 warn!("SingleValueCache read lock poisoned, recovering");
328 poisoned.into_inner()
329 }
330 };
331 guard.as_ref().map(|e| e.value.clone())
332 }
333
334 pub fn needs_refresh(&self) -> bool {
336 let guard = match self.entry.read() {
337 Ok(guard) => guard,
338 Err(poisoned) => {
339 warn!("SingleValueCache read lock poisoned, recovering");
340 poisoned.into_inner()
341 }
342 };
343
344 match guard.as_ref() {
345 Some(e) => e.is_stale(),
346 None => true,
347 }
348 }
349
350 pub fn has_value(&self) -> bool {
352 let guard = match self.entry.read() {
353 Ok(guard) => guard,
354 Err(poisoned) => {
355 warn!("SingleValueCache read lock poisoned, recovering");
356 poisoned.into_inner()
357 }
358 };
359 guard.is_some()
360 }
361
362 pub fn set(&self, value: V) {
364 let mut guard = match self.entry.write() {
365 Ok(guard) => guard,
366 Err(poisoned) => {
367 warn!("SingleValueCache write lock poisoned, recovering");
368 poisoned.into_inner()
369 }
370 };
371 *guard = Some(CacheEntry::new(value, self.ttl));
372 }
373
374 pub fn clear(&self) {
376 let mut guard = match self.entry.write() {
377 Ok(guard) => guard,
378 Err(poisoned) => {
379 warn!("SingleValueCache write lock poisoned, recovering");
380 poisoned.into_inner()
381 }
382 };
383 *guard = None;
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390
391 #[test]
392 fn test_cache_insert_and_get() {
393 let cache: TtlCache<String, String> = TtlCache::new(Duration::from_secs(3600));
394
395 cache.insert("key".to_string(), "value".to_string());
396
397 assert_eq!(cache.get(&"key".to_string()), Some("value".to_string()));
398 }
399
400 #[test]
401 fn test_cache_get_missing_key() {
402 let cache: TtlCache<String, String> = TtlCache::new(Duration::from_secs(3600));
403
404 assert_eq!(cache.get(&"missing".to_string()), None);
405 }
406
407 #[test]
408 fn test_cache_expiration() {
409 let cache: TtlCache<String, String> = TtlCache::new(Duration::from_millis(10));
410
411 cache.insert("key".to_string(), "value".to_string());
412 assert_eq!(cache.get(&"key".to_string()), Some("value".to_string()));
413
414 std::thread::sleep(Duration::from_millis(20));
416
417 assert_eq!(cache.get(&"key".to_string()), None);
418 }
419
420 #[test]
421 fn test_cache_get_stale_after_expiration() {
422 let cache: TtlCache<String, String> = TtlCache::new(Duration::from_millis(10));
423
424 cache.insert("key".to_string(), "value".to_string());
425
426 std::thread::sleep(Duration::from_millis(20));
428
429 assert_eq!(cache.get(&"key".to_string()), None);
431 assert_eq!(
433 cache.get_stale(&"key".to_string()),
434 Some("value".to_string())
435 );
436 }
437
438 #[test]
439 fn test_cache_remove() {
440 let cache: TtlCache<String, String> = TtlCache::new(Duration::from_secs(3600));
441
442 cache.insert("key".to_string(), "value".to_string());
443 assert!(cache.get(&"key".to_string()).is_some());
444
445 cache.remove(&"key".to_string());
446 assert!(cache.get(&"key".to_string()).is_none());
447 }
448
449 #[test]
450 fn test_cache_cleanup() {
451 let cache: TtlCache<String, String> = TtlCache::new(Duration::from_millis(10));
452
453 cache.insert("key1".to_string(), "value1".to_string());
454 cache.insert("key2".to_string(), "value2".to_string());
455
456 std::thread::sleep(Duration::from_millis(20));
458
459 cache.insert_with_ttl(
461 "key3".to_string(),
462 "value3".to_string(),
463 Duration::from_secs(3600),
464 );
465
466 assert_eq!(cache.len(), 3);
467
468 cache.cleanup();
469
470 assert_eq!(cache.len(), 1);
472 assert_eq!(cache.get(&"key3".to_string()), Some("value3".to_string()));
473 }
474
475 #[test]
476 fn test_cache_clear() {
477 let cache: TtlCache<String, String> = TtlCache::new(Duration::from_secs(3600));
478
479 cache.insert("key1".to_string(), "value1".to_string());
480 cache.insert("key2".to_string(), "value2".to_string());
481
482 assert_eq!(cache.len(), 2);
483
484 cache.clear();
485
486 assert_eq!(cache.len(), 0);
487 assert!(cache.is_empty());
488 }
489
490 #[test]
491 fn test_single_value_cache() {
492 let cache: SingleValueCache<String> = SingleValueCache::new(Duration::from_secs(3600));
493
494 assert!(!cache.has_value());
495 assert!(cache.get().is_none());
496
497 cache.set("value".to_string());
498
499 assert!(cache.has_value());
500 assert_eq!(cache.get(), Some("value".to_string()));
501 }
502
503 #[test]
504 fn test_single_value_cache_expiration() {
505 let cache: SingleValueCache<String> = SingleValueCache::new(Duration::from_millis(10));
506
507 cache.set("value".to_string());
508 assert_eq!(cache.get(), Some("value".to_string()));
509
510 std::thread::sleep(Duration::from_millis(20));
512
513 assert!(cache.get().is_none());
514 assert_eq!(cache.get_stale(), Some("value".to_string()));
516 }
517
518 #[tokio::test(start_paused = true)]
519 async fn test_needs_refresh() {
520 let cache: TtlCache<String, String> = TtlCache::new(Duration::from_secs(1));
524 cache.insert("key".to_string(), "value".to_string());
525
526 assert!(!cache.needs_refresh(&"key".to_string()));
528
529 tokio::time::advance(Duration::from_millis(800)).await;
531 assert!(
532 cache.needs_refresh(&"key".to_string()),
533 "entry must be stale at t=800ms (>= 750ms threshold)"
534 );
535 assert!(
536 cache.get(&"key".to_string()).is_some(),
537 "entry must not be expired at t=800ms (< 1000ms TTL)"
538 );
539
540 tokio::time::advance(Duration::from_millis(300)).await;
542 assert!(
543 cache.get(&"key".to_string()).is_none(),
544 "entry must be expired at t=1100ms (> 1000ms TTL)"
545 );
546 }
547}