1#![allow(clippy::doc_markdown)]
4
5use alloc::collections::VecDeque;
34use alloc::string::String;
35use alloc::vec::Vec;
36
37use spg_storage::Value;
38
39pub const DEFAULT_MAX_ENTRIES: usize = 1024;
42
43pub const DEFAULT_MAX_BYTES: usize = 16 * 1024 * 1024;
46
47#[derive(Debug, Clone, PartialEq)]
52pub struct CacheKey {
53 pub subquery_repr: String,
54 pub outer_values: Vec<Value>,
55}
56
57#[derive(Debug, Clone)]
58pub struct MemoizeCache {
59 entries: VecDeque<(CacheKey, Value)>,
64 max_entries: usize,
65 max_bytes: usize,
66 current_bytes: usize,
67 pub hit_count: u64,
68 pub miss_count: u64,
69}
70
71impl Default for MemoizeCache {
72 fn default() -> Self {
73 Self::new()
74 }
75}
76
77impl MemoizeCache {
78 pub fn new() -> Self {
79 Self {
80 entries: VecDeque::with_capacity(DEFAULT_MAX_ENTRIES),
81 max_entries: DEFAULT_MAX_ENTRIES,
82 max_bytes: DEFAULT_MAX_BYTES,
83 current_bytes: 0,
84 hit_count: 0,
85 miss_count: 0,
86 }
87 }
88
89 pub const fn with_max_entries(mut self, n: usize) -> Self {
90 self.max_entries = n;
91 self
92 }
93
94 pub const fn with_max_bytes(mut self, b: usize) -> Self {
95 self.max_bytes = b;
96 self
97 }
98
99 pub fn len(&self) -> usize {
100 self.entries.len()
101 }
102
103 pub fn is_empty(&self) -> bool {
104 self.entries.is_empty()
105 }
106
107 pub fn get(&mut self, key: &CacheKey) -> Option<Value> {
111 let pos = self.entries.iter().position(|(k, _)| k == key);
112 if let Some(p) = pos {
113 let (k, v) = self.entries.remove(p)?;
114 self.entries.push_front((k, v.clone()));
115 self.hit_count += 1;
116 Some(v)
117 } else {
118 self.miss_count += 1;
119 None
120 }
121 }
122
123 pub fn insert(&mut self, key: CacheKey, value: Value) {
127 let entry_bytes = approx_bytes(&key) + approx_value_bytes(&value);
128 while !self.entries.is_empty()
129 && (self.entries.len() >= self.max_entries
130 || self.current_bytes + entry_bytes > self.max_bytes)
131 {
132 let Some((k, v)) = self.entries.pop_back() else {
133 break;
134 };
135 self.current_bytes = self
136 .current_bytes
137 .saturating_sub(approx_bytes(&k) + approx_value_bytes(&v));
138 }
139 self.current_bytes = self.current_bytes.saturating_add(entry_bytes);
140 self.entries.push_front((key, value));
141 }
142}
143
144fn approx_bytes(key: &CacheKey) -> usize {
145 key.subquery_repr.len()
146 + key
147 .outer_values
148 .iter()
149 .map(approx_value_bytes)
150 .sum::<usize>()
151 + 16
152}
153
154fn approx_value_bytes(v: &Value) -> usize {
155 match v {
156 Value::Null | Value::Bool(_) | Value::SmallInt(_) => 1,
157 Value::Int(_) => 4,
158 Value::BigInt(_) | Value::Float(_) => 8,
159 Value::Date(_) | Value::Timestamp(_) => 8,
160 Value::Interval { .. } => 16,
161 Value::Numeric { .. } => 16,
162 Value::Text(s) | Value::Json(s) => s.len(),
163 Value::Vector(v) => v.len() * 4,
164 Value::Sq8Vector(q) => q.bytes.len() + 8,
165 Value::HalfVector(h) => h.dim() * 2,
166 _ => 16,
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 fn key(repr: &str, outer: &[Value]) -> CacheKey {
176 CacheKey {
177 subquery_repr: repr.into(),
178 outer_values: outer.to_vec(),
179 }
180 }
181
182 #[test]
183 fn empty_cache_misses_everything() {
184 let mut c = MemoizeCache::new();
185 let k = key("SELECT 1", &[Value::Int(1)]);
186 assert!(c.get(&k).is_none());
187 assert_eq!(c.miss_count, 1);
188 assert_eq!(c.hit_count, 0);
189 }
190
191 #[test]
192 fn insert_then_get_hits() {
193 let mut c = MemoizeCache::new();
194 let k = key("SELECT 1", &[Value::Int(1)]);
195 c.insert(k.clone(), Value::BigInt(42));
196 let v = c.get(&k);
197 assert_eq!(v, Some(Value::BigInt(42)));
198 assert_eq!(c.hit_count, 1);
199 }
200
201 #[test]
202 fn repeated_outer_key_hits_after_first_insert() {
203 let mut c = MemoizeCache::new();
204 let repr = "SELECT MAX(x) FROM y WHERE y.k = outer.k";
205 for i in 0..100 {
206 let k = key(repr, &[Value::Int(i % 5)]);
207 if c.get(&k).is_none() {
208 c.insert(k, Value::BigInt(i64::from(i)));
209 }
210 }
211 assert_eq!(c.miss_count, 5);
213 assert_eq!(c.hit_count, 95);
214 }
215
216 #[test]
217 fn lru_eviction_at_max_entries() {
218 let mut c = MemoizeCache::new().with_max_entries(3);
219 for i in 0..5 {
220 let k = key("q", &[Value::Int(i)]);
221 c.insert(k, Value::BigInt(i64::from(i)));
222 }
223 assert!(c.len() <= 3, "len={}", c.len());
224 assert!(c.get(&key("q", &[Value::Int(4)])).is_some());
226 assert!(c.get(&key("q", &[Value::Int(3)])).is_some());
227 assert!(c.get(&key("q", &[Value::Int(2)])).is_some());
228 assert!(c.get(&key("q", &[Value::Int(0)])).is_none());
230 }
231
232 #[test]
233 fn lru_eviction_at_max_bytes() {
234 let mut c = MemoizeCache::new().with_max_bytes(128);
235 for i in 0..10 {
237 let big_str = alloc::string::String::from_iter(core::iter::repeat_n('x', 64));
238 c.insert(key("q", &[Value::Int(i)]), Value::Text(big_str));
239 }
240 assert!(c.len() < 10, "len={}", c.len());
241 }
242
243 #[test]
244 fn distinct_subquery_reprs_dont_collide() {
245 let mut c = MemoizeCache::new();
246 let k1 = key("SELECT 1", &[Value::Int(1)]);
247 let k2 = key("SELECT 2", &[Value::Int(1)]);
248 c.insert(k1.clone(), Value::BigInt(10));
249 c.insert(k2.clone(), Value::BigInt(20));
250 assert_eq!(c.get(&k1), Some(Value::BigInt(10)));
251 assert_eq!(c.get(&k2), Some(Value::BigInt(20)));
252 }
253
254 #[test]
255 fn miss_then_hit_bumps_promotes_to_lru_front() {
256 let mut c = MemoizeCache::new().with_max_entries(3);
257 c.insert(key("q", &[Value::Int(0)]), Value::BigInt(0));
258 c.insert(key("q", &[Value::Int(1)]), Value::BigInt(1));
259 c.insert(key("q", &[Value::Int(2)]), Value::BigInt(2));
260 let _ = c.get(&key("q", &[Value::Int(0)]));
262 c.insert(key("q", &[Value::Int(3)]), Value::BigInt(3));
264 assert!(c.get(&key("q", &[Value::Int(0)])).is_some());
265 assert!(c.get(&key("q", &[Value::Int(1)])).is_none());
266 }
267}