Skip to main content

haystack_server/session/
working_set.rs

1use std::collections::HashMap;
2
3/// Cached result of a filter evaluation.
4#[derive(Debug, Clone)]
5struct CachedResult {
6    /// Entity IDs matching the filter.
7    entity_ids: Vec<String>,
8    /// Graph version when this result was cached.
9    graph_version: u64,
10    /// Sum of connector cache versions when cached.
11    connector_versions_sum: u64,
12}
13
14/// Per-session cache of filter → entity ID results.
15///
16/// Results are invalidated when graph_version or connector cache versions change.
17/// Uses simple LRU eviction when capacity is exceeded.
18pub struct WorkingSetCache {
19    entries: HashMap<String, CachedResult>,
20    capacity: usize,
21    /// Tracks insertion order for LRU eviction.
22    order: Vec<String>,
23    hits: u64,
24    misses: u64,
25}
26
27impl WorkingSetCache {
28    /// Create a new cache with the given capacity.
29    pub fn new(capacity: usize) -> Self {
30        Self {
31            entries: HashMap::new(),
32            capacity,
33            order: Vec::new(),
34            hits: 0,
35            misses: 0,
36        }
37    }
38
39    /// Look up a cached filter result. Returns None if not cached or stale.
40    pub fn get(
41        &mut self,
42        filter: &str,
43        current_graph_version: u64,
44        current_connector_versions_sum: u64,
45    ) -> Option<&[String]> {
46        if let Some(entry) = self.entries.get(filter) {
47            if entry.graph_version == current_graph_version
48                && entry.connector_versions_sum == current_connector_versions_sum
49            {
50                self.hits += 1;
51                // Move to end of LRU order
52                if let Some(pos) = self.order.iter().position(|k| k == filter) {
53                    let key = self.order.remove(pos);
54                    self.order.push(key);
55                }
56                return Some(&self.entries.get(filter).unwrap().entity_ids);
57            }
58            // Stale — remove
59            self.entries.remove(filter);
60            self.order.retain(|k| k != filter);
61        }
62        self.misses += 1;
63        None
64    }
65
66    /// Insert a filter result into the cache.
67    pub fn insert(
68        &mut self,
69        filter: String,
70        entity_ids: Vec<String>,
71        graph_version: u64,
72        connector_versions_sum: u64,
73    ) {
74        // Evict if at capacity
75        while self.entries.len() >= self.capacity && !self.order.is_empty() {
76            let oldest = self.order.remove(0);
77            self.entries.remove(&oldest);
78        }
79
80        self.order.retain(|k| k != &filter);
81        self.order.push(filter.clone());
82        self.entries.insert(
83            filter,
84            CachedResult {
85                entity_ids,
86                graph_version,
87                connector_versions_sum,
88            },
89        );
90    }
91
92    /// Clear all cached entries.
93    pub fn clear(&mut self) {
94        self.entries.clear();
95        self.order.clear();
96    }
97
98    /// Get cache statistics.
99    pub fn stats(&self) -> (u64, u64) {
100        (self.hits, self.misses)
101    }
102
103    /// Number of cached entries.
104    pub fn len(&self) -> usize {
105        self.entries.len()
106    }
107
108    /// Whether the cache is empty.
109    pub fn is_empty(&self) -> bool {
110        self.entries.is_empty()
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn cache_hit_on_same_version() {
120        let mut cache = WorkingSetCache::new(16);
121        cache.insert("site".into(), vec!["s1".into(), "s2".into()], 1, 10);
122        let result = cache.get("site", 1, 10);
123        assert_eq!(
124            result,
125            Some(vec!["s1".to_string(), "s2".to_string()].as_slice())
126        );
127    }
128
129    #[test]
130    fn cache_miss_on_graph_version_change() {
131        let mut cache = WorkingSetCache::new(16);
132        cache.insert("site".into(), vec!["s1".into()], 1, 10);
133        assert!(cache.get("site", 2, 10).is_none());
134    }
135
136    #[test]
137    fn cache_miss_on_connector_version_change() {
138        let mut cache = WorkingSetCache::new(16);
139        cache.insert("site".into(), vec!["s1".into()], 1, 10);
140        assert!(cache.get("site", 1, 11).is_none());
141    }
142
143    #[test]
144    fn lru_eviction() {
145        let mut cache = WorkingSetCache::new(2);
146        cache.insert("a".into(), vec!["1".into()], 1, 0);
147        cache.insert("b".into(), vec!["2".into()], 1, 0);
148        // Access "a" to make it most-recently-used
149        cache.get("a", 1, 0);
150        // Insert "c" — should evict "b" (least recently used)
151        cache.insert("c".into(), vec!["3".into()], 1, 0);
152        assert!(cache.get("b", 1, 0).is_none());
153        assert!(cache.get("a", 1, 0).is_some());
154        assert!(cache.get("c", 1, 0).is_some());
155    }
156
157    #[test]
158    fn clear_empties_cache() {
159        let mut cache = WorkingSetCache::new(16);
160        cache.insert("site".into(), vec!["s1".into()], 1, 0);
161        assert!(!cache.is_empty());
162        cache.clear();
163        assert!(cache.is_empty());
164        assert_eq!(cache.len(), 0);
165    }
166
167    #[test]
168    fn stats_tracking() {
169        let mut cache = WorkingSetCache::new(16);
170        cache.insert("site".into(), vec!["s1".into()], 1, 0);
171        cache.get("site", 1, 0); // hit
172        cache.get("site", 1, 0); // hit
173        cache.get("missing", 1, 0); // miss
174        let (hits, misses) = cache.stats();
175        assert_eq!(hits, 2);
176        assert_eq!(misses, 1);
177    }
178}