sqlitegraph/backend/native/v2/edge_cluster/
cache.rs1use super::cluster::EdgeCluster;
7use parking_lot::RwLock;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Instant;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub struct CacheKey {
15 pub node_id: i64,
16 pub direction: super::cluster_trace::Direction,
17}
18
19impl CacheKey {
20 pub fn new(node_id: i64, direction: super::cluster_trace::Direction) -> Self {
21 Self { node_id, direction }
22 }
23}
24
25#[derive(Debug, Clone)]
27pub struct AccessPatternTracker {
28 access_history: Vec<i64>,
30 max_history: usize,
32}
33
34impl AccessPatternTracker {
35 pub fn new(max_history: usize) -> Self {
36 Self {
37 access_history: Vec::with_capacity(max_history),
38 max_history,
39 }
40 }
41
42 pub fn record_access(&mut self, node_id: i64) -> AccessType {
44 self.access_history.push(node_id);
45 if self.access_history.len() > self.max_history {
46 self.access_history.remove(0);
47 }
48
49 if self.access_history.len() >= 2 {
51 let _last = self.access_history[self.access_history.len() - 2];
52 if self.is_traversal_pattern(node_id) {
54 return AccessType::Traversal;
55 }
56 }
57
58 AccessType::Lookup
59 }
60
61 fn is_traversal_pattern(&self, node_id: i64) -> bool {
63 self.access_history
66 .iter()
67 .filter(|&&id| id == node_id)
68 .count()
69 > 0
70 }
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
75pub enum AccessType {
76 Traversal,
78 Lookup,
80}
81
82#[derive(Debug, Clone)]
84pub struct CacheEntry {
85 pub data: Arc<EdgeCluster>,
87 pub access_count: u32,
89 pub last_access: Instant,
91 pub traversal_score: f64,
93 access_history: [Option<Instant>; 2],
95}
96
97impl CacheEntry {
98 pub fn new(data: Arc<EdgeCluster>) -> Self {
99 let now = Instant::now();
100 Self {
101 data,
102 access_count: 0,
103 last_access: now,
104 traversal_score: 0.0,
105 access_history: [None, None],
106 }
107 }
108
109 pub fn record_access(&mut self, access_type: AccessType) {
111 self.access_count += 1;
112 self.last_access = Instant::now();
113
114 self.access_history[1] = self.access_history[0];
116 self.access_history[0] = Some(Instant::now());
117
118 match access_type {
120 AccessType::Traversal => {
121 self.traversal_score += 1.0;
123 }
124 AccessType::Lookup => {
125 self.traversal_score += 0.1;
127 }
128 }
129 }
130
131 pub fn eviction_score(&self) -> f64 {
133 let recency_score = if let Some(most_recent) = self.access_history[0] {
135 1.0 / (most_recent.elapsed().as_secs_f64() + 1.0)
136 } else {
137 0.0
138 };
139
140 self.traversal_score * 10.0 + recency_score
141 }
142
143 pub fn is_high_degree(&self) -> bool {
145 self.data.edge_count() > 100
146 }
147}
148
149pub struct TraversalAwareCache {
151 entries: HashMap<CacheKey, CacheEntry>,
153 access_pattern: AccessPatternTracker,
155 max_capacity: usize,
157 stats: CacheStats,
159}
160
161#[derive(Debug, Clone, Default)]
163pub struct CacheStats {
164 pub hits: u64,
165 pub misses: u64,
166 pub traversals: u64,
167 pub lookups: u64,
168}
169
170impl TraversalAwareCache {
171 pub fn new(max_capacity: usize) -> Self {
173 Self {
174 entries: HashMap::with_capacity(max_capacity),
175 access_pattern: AccessPatternTracker::new(100),
176 max_capacity,
177 stats: CacheStats::default(),
178 }
179 }
180
181 pub fn get(&mut self, key: CacheKey) -> Option<Arc<EdgeCluster>> {
183 let access_type = self.access_pattern.record_access(key.node_id);
185
186 match access_type {
188 AccessType::Traversal => self.stats.traversals += 1,
189 AccessType::Lookup => self.stats.lookups += 1,
190 }
191
192 if let Some(entry) = self.entries.get_mut(&key) {
194 self.stats.hits += 1;
195 entry.record_access(access_type);
196 return Some(Arc::clone(&entry.data));
197 }
198
199 self.stats.misses += 1;
200 None
201 }
202
203 pub fn insert(&mut self, key: CacheKey, cluster: Arc<EdgeCluster>) {
205 if let Some(entry) = self.entries.get_mut(&key) {
207 entry.data = cluster;
208 entry.record_access(AccessType::Lookup);
209 return;
210 }
211
212 if self.entries.len() >= self.max_capacity {
214 self.evict_one();
215 }
216
217 let entry = CacheEntry::new(cluster);
219 self.entries.insert(key, entry);
220 }
221
222 pub fn remove(&mut self, key: &CacheKey) -> Option<Arc<EdgeCluster>> {
224 self.entries.remove(key).map(|entry| entry.data)
225 }
226
227 fn evict_one(&mut self) {
229 if self.entries.is_empty() {
230 return;
231 }
232
233 let mut worst_key = None;
235 let mut worst_score = f64::MAX;
236
237 for (key, entry) in &self.entries {
238 let score = entry.eviction_score();
239
240 let adjusted_score = if entry.is_high_degree() {
242 score * 2.0 } else {
244 score
245 };
246
247 if adjusted_score < worst_score {
248 worst_score = adjusted_score;
249 worst_key = Some(*key);
250 }
251 }
252
253 if let Some(key) = worst_key {
255 self.entries.remove(&key);
256 }
257 }
258
259 pub fn clear(&mut self) {
261 self.entries.clear();
262 }
263
264 pub fn len(&self) -> usize {
266 self.entries.len()
267 }
268
269 pub fn is_empty(&self) -> bool {
271 self.entries.is_empty()
272 }
273
274 pub fn stats(&self) -> &CacheStats {
276 &self.stats
277 }
278
279 pub fn hit_ratio(&self) -> f64 {
281 let total = self.stats.hits + self.stats.misses;
282 if total == 0 {
283 0.0
284 } else {
285 self.stats.hits as f64 / total as f64
286 }
287 }
288}
289
290pub struct ThreadSafeCache {
292 inner: Arc<RwLock<TraversalAwareCache>>,
293}
294
295impl ThreadSafeCache {
296 pub fn new(max_capacity: usize) -> Self {
298 Self {
299 inner: Arc::new(RwLock::new(TraversalAwareCache::new(max_capacity))),
300 }
301 }
302
303 pub fn get(&self, key: CacheKey) -> Option<Arc<EdgeCluster>> {
305 self.inner.write().get(key)
306 }
307
308 pub fn insert(&self, key: CacheKey, cluster: Arc<EdgeCluster>) {
310 self.inner.write().insert(key, cluster);
311 }
312
313 pub fn remove(&self, key: &CacheKey) -> Option<Arc<EdgeCluster>> {
315 self.inner.write().remove(key)
316 }
317
318 pub fn stats(&self) -> CacheStats {
320 self.inner.read().stats().clone()
321 }
322
323 pub fn hit_ratio(&self) -> f64 {
325 self.inner.read().hit_ratio()
326 }
327
328 pub fn inner(&self) -> &Arc<RwLock<TraversalAwareCache>> {
330 &self.inner
331 }
332}
333
334impl Clone for ThreadSafeCache {
335 fn clone(&self) -> Self {
336 Self {
337 inner: Arc::clone(&self.inner),
338 }
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345 use crate::backend::native::v2::edge_cluster::cluster_trace::Direction;
346
347 #[test]
348 fn test_cache_basics() {
349 let mut cache = TraversalAwareCache::new(3);
350
351 let key1 = CacheKey::new(1, Direction::Outgoing);
352 let key2 = CacheKey::new(2, Direction::Outgoing);
353
354 let cluster = Arc::new(
356 EdgeCluster::create_from_compact_edges(vec![], 1, Direction::Outgoing).unwrap(),
357 );
358
359 cache.insert(key1, Arc::clone(&cluster));
360 assert!(cache.get(key1).is_some());
361 assert!(cache.get(key2).is_none());
362
363 assert_eq!(cache.stats().hits, 1);
365 assert_eq!(cache.stats().misses, 1);
366 }
367
368 #[test]
369 fn test_cache_eviction() {
370 let mut cache = TraversalAwareCache::new(2);
371
372 let key1 = CacheKey::new(1, Direction::Outgoing);
373 let key2 = CacheKey::new(2, Direction::Outgoing);
374 let key3 = CacheKey::new(3, Direction::Outgoing);
375
376 let cluster = Arc::new(
377 EdgeCluster::create_from_compact_edges(vec![], 1, Direction::Outgoing).unwrap(),
378 );
379
380 cache.insert(key1, Arc::clone(&cluster));
382 cache.insert(key2, Arc::clone(&cluster));
383
384 cache.insert(key3, Arc::clone(&cluster));
386
387 assert_eq!(cache.len(), 2);
389 }
390
391 #[test]
392 fn test_hit_ratio() {
393 let mut cache = TraversalAwareCache::new(10);
394
395 let key1 = CacheKey::new(1, Direction::Outgoing);
396 let cluster = Arc::new(
397 EdgeCluster::create_from_compact_edges(vec![], 1, Direction::Outgoing).unwrap(),
398 );
399
400 cache.insert(key1, Arc::clone(&cluster));
401
402 for _ in 0..5 {
404 cache.get(key1);
405 }
406 for i in 2..7 {
407 cache.get(CacheKey::new(i, Direction::Outgoing));
408 }
409
410 assert!((cache.hit_ratio() - 0.5).abs() < 0.01);
411 }
412}