unistore_cache/
cache.rs

1//! 【缓存核心】- LRU 缓存实现
2//!
3//! 职责:
4//! - 实现 LRU(最近最少使用)淘汰策略
5//! - 管理缓存条目的存取
6//! - 处理 TTL 过期
7
8use crate::config::CacheConfig;
9use crate::deps::*;
10use crate::entry::CacheEntry;
11use crate::stats::CacheStats;
12use tracing::debug;
13
14/// LRU 缓存
15pub struct Cache<K, V>
16where
17    K: Eq + Hash + Clone,
18{
19    /// 缓存数据
20    data: RwLock<HashMap<K, CacheEntry<V>>>,
21    /// 配置
22    config: CacheConfig,
23    /// 统计
24    stats: CacheStats,
25}
26
27impl<K, V> Cache<K, V>
28where
29    K: Eq + Hash + Clone,
30{
31    /// 创建新缓存
32    pub fn new(config: CacheConfig) -> Self {
33        Self {
34            data: RwLock::new(HashMap::with_capacity(config.max_capacity)),
35            config,
36            stats: CacheStats::new(),
37        }
38    }
39
40    /// 使用默认配置创建
41    pub fn with_defaults() -> Self {
42        Self::new(CacheConfig::default())
43    }
44
45    /// 插入条目
46    pub fn insert(&self, key: K, value: V) {
47        self.insert_with_ttl(key, value, self.config.default_ttl)
48    }
49
50    /// 插入条目(自定义 TTL)
51    pub fn insert_with_ttl(&self, key: K, value: V, ttl: Option<Duration>) {
52        let entry = CacheEntry::new(value, ttl);
53        let mut data = self.data.write();
54
55        // 检查是否需要淘汰
56        if data.len() >= self.config.max_capacity && !data.contains_key(&key) {
57            self.evict_lru(&mut data);
58        }
59
60        data.insert(key, entry);
61
62        if self.config.enable_stats {
63            self.stats.record_insert();
64        }
65    }
66
67    /// 获取值(如果存在且未过期)
68    pub fn get(&self, key: &K) -> Option<V>
69    where
70        V: Clone,
71    {
72        let mut data = self.data.write();
73
74        if let Some(entry) = data.get_mut(key) {
75            if entry.is_expired() {
76                data.remove(key);
77                if self.config.enable_stats {
78                    self.stats.record_expiration();
79                    self.stats.record_miss();
80                }
81                return None;
82            }
83
84            entry.touch();
85            if self.config.enable_stats {
86                self.stats.record_hit();
87            }
88            return Some(entry.value().clone());
89        }
90
91        if self.config.enable_stats {
92            self.stats.record_miss();
93        }
94        None
95    }
96
97    /// 获取值的引用(通过闭包访问)
98    pub fn with_value<R, F>(&self, key: &K, f: F) -> Option<R>
99    where
100        F: FnOnce(&V) -> R,
101    {
102        let mut data = self.data.write();
103
104        if let Some(entry) = data.get_mut(key) {
105            if entry.is_expired() {
106                data.remove(key);
107                if self.config.enable_stats {
108                    self.stats.record_expiration();
109                    self.stats.record_miss();
110                }
111                return None;
112            }
113
114            entry.touch();
115            if self.config.enable_stats {
116                self.stats.record_hit();
117            }
118            return Some(f(entry.value()));
119        }
120
121        if self.config.enable_stats {
122            self.stats.record_miss();
123        }
124        None
125    }
126
127    /// 检查键是否存在(不更新访问时间)
128    pub fn contains_key(&self, key: &K) -> bool {
129        let data = self.data.read();
130        if let Some(entry) = data.get(key) {
131            !entry.is_expired()
132        } else {
133            false
134        }
135    }
136
137    /// 删除条目
138    pub fn remove(&self, key: &K) -> Option<V> {
139        self.data.write().remove(key).map(|e| e.into_value())
140    }
141
142    /// 获取或插入
143    pub fn get_or_insert_with<F>(&self, key: K, f: F) -> V
144    where
145        V: Clone,
146        F: FnOnce() -> V,
147    {
148        // 先尝试读取
149        if let Some(value) = self.get(&key) {
150            return value;
151        }
152
153        // 需要插入
154        let value = f();
155        self.insert(key, value.clone());
156        value
157    }
158
159    /// 清空缓存
160    pub fn clear(&self) {
161        self.data.write().clear();
162    }
163
164    /// 获取当前大小
165    pub fn len(&self) -> usize {
166        self.data.read().len()
167    }
168
169    /// 检查是否为空
170    pub fn is_empty(&self) -> bool {
171        self.data.read().is_empty()
172    }
173
174    /// 获取统计快照
175    pub fn stats(&self) -> crate::stats::CacheStatsSnapshot {
176        self.stats.snapshot()
177    }
178
179    /// 重置统计
180    pub fn reset_stats(&self) {
181        self.stats.reset();
182    }
183
184    /// 手动清理过期条目
185    pub fn cleanup_expired(&self) -> usize {
186        let mut data = self.data.write();
187        let before = data.len();
188
189        data.retain(|_, entry| {
190            let expired = entry.is_expired();
191            if expired && self.config.enable_stats {
192                self.stats.record_expiration();
193            }
194            !expired
195        });
196
197        let removed = before - data.len();
198        if removed > 0 {
199            debug!(removed = removed, "清理过期缓存条目");
200        }
201        removed
202    }
203
204    /// LRU 淘汰(内部方法)
205    fn evict_lru(&self, data: &mut HashMap<K, CacheEntry<V>>) {
206        // 找到最旧的条目
207        let to_evict: Vec<K> = data
208            .iter()
209            .filter(|(_, entry)| entry.is_expired())
210            .map(|(k, _)| k.clone())
211            .take(self.config.eviction_batch_size)
212            .collect();
213
214        // 如果过期条目不够,按 LRU 淘汰
215        let mut to_evict = to_evict;
216        if to_evict.len() < self.config.eviction_batch_size {
217            let needed = self.config.eviction_batch_size - to_evict.len();
218            let mut entries: Vec<_> = data
219                .iter()
220                .filter(|(k, _)| !to_evict.contains(k))
221                .map(|(k, e)| (k.clone(), e.last_accessed()))
222                .collect();
223
224            entries.sort_by_key(|(_, t)| *t);
225            to_evict.extend(entries.into_iter().take(needed).map(|(k, _)| k));
226        }
227
228        // 执行淘汰
229        for key in to_evict {
230            data.remove(&key);
231            if self.config.enable_stats {
232                self.stats.record_eviction();
233            }
234        }
235    }
236
237    /// 获取所有键
238    pub fn keys(&self) -> Vec<K> {
239        self.data.read().keys().cloned().collect()
240    }
241
242    /// 批量插入
243    pub fn insert_many<I>(&self, items: I)
244    where
245        I: IntoIterator<Item = (K, V)>,
246    {
247        let mut data = self.data.write();
248        for (key, value) in items {
249            let entry = CacheEntry::new(value, self.config.default_ttl);
250            data.insert(key, entry);
251            if self.config.enable_stats {
252                self.stats.record_insert();
253            }
254        }
255    }
256
257    /// 更新已有条目的值(如果存在)
258    pub fn update<F>(&self, key: &K, f: F) -> bool
259    where
260        F: FnOnce(&mut V),
261    {
262        let mut data = self.data.write();
263        if let Some(entry) = data.get_mut(key) {
264            if entry.is_expired() {
265                data.remove(key);
266                if self.config.enable_stats {
267                    self.stats.record_expiration();
268                }
269                return false;
270            }
271            f(entry.value_mut());
272            entry.touch();
273            true
274        } else {
275            false
276        }
277    }
278}
279
280impl<K, V> Default for Cache<K, V>
281where
282    K: Eq + Hash + Clone,
283{
284    fn default() -> Self {
285        Self::with_defaults()
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[test]
294    fn test_insert_and_get() {
295        let cache: Cache<&str, i32> = Cache::with_defaults();
296        cache.insert("key1", 42);
297
298        assert_eq!(cache.get(&"key1"), Some(42));
299        assert_eq!(cache.get(&"key2"), None);
300    }
301
302    #[test]
303    fn test_ttl_expiration() {
304        let config = CacheConfig::default().default_ttl(Duration::from_millis(10));
305        let cache: Cache<&str, i32> = Cache::new(config);
306
307        cache.insert("key1", 42);
308        assert!(cache.contains_key(&"key1"));
309
310        std::thread::sleep(Duration::from_millis(20));
311        assert!(!cache.contains_key(&"key1"));
312    }
313
314    #[test]
315    fn test_remove() {
316        let cache: Cache<&str, i32> = Cache::with_defaults();
317        cache.insert("key1", 42);
318
319        assert_eq!(cache.remove(&"key1"), Some(42));
320        assert!(!cache.contains_key(&"key1"));
321    }
322
323    #[test]
324    fn test_get_or_insert() {
325        let cache: Cache<&str, i32> = Cache::with_defaults();
326
327        let value = cache.get_or_insert_with("key1", || 42);
328        assert_eq!(value, 42);
329
330        let value = cache.get_or_insert_with("key1", || 100);
331        assert_eq!(value, 42); // 应该返回已有值
332    }
333
334    #[test]
335    fn test_stats() {
336        let cache: Cache<&str, i32> = Cache::with_defaults();
337
338        cache.insert("key1", 42);
339        cache.get(&"key1");
340        cache.get(&"key2");
341
342        let stats = cache.stats();
343        assert_eq!(stats.hits, 1);
344        assert_eq!(stats.misses, 1);
345        assert_eq!(stats.inserts, 1);
346    }
347
348    #[test]
349    fn test_lru_eviction() {
350        let config = CacheConfig::default()
351            .max_capacity(3)
352            .no_ttl()
353            .eviction_batch_size(1);
354        let cache: Cache<i32, i32> = Cache::new(config);
355
356        cache.insert(1, 10);
357        std::thread::sleep(Duration::from_millis(1));
358        cache.insert(2, 20);
359        std::thread::sleep(Duration::from_millis(1));
360        cache.insert(3, 30);
361        std::thread::sleep(Duration::from_millis(1));
362
363        // 访问 key 1,使其成为最近使用
364        cache.get(&1);
365        std::thread::sleep(Duration::from_millis(1));
366
367        // 插入新条目,应该淘汰 key 2(最久未使用)
368        cache.insert(4, 40);
369
370        assert!(cache.contains_key(&1));
371        assert!(!cache.contains_key(&2)); // 被淘汰
372        assert!(cache.contains_key(&3));
373        assert!(cache.contains_key(&4));
374    }
375}