1#![allow(clippy::doc_markdown)]
4
5use alloc::collections::VecDeque;
34use alloc::string::String;
35use alloc::vec::Vec;
36
37use spg_storage::Value;
38
39pub const DEFAULT_MAX_ENTRIES: usize = 1024;
42
43pub const DEFAULT_MAX_BYTES: usize = 16 * 1024 * 1024;
46
47#[derive(Debug, Clone, PartialEq)]
52pub struct CacheKey {
53 pub subquery_repr: String,
54 pub outer_values: Vec<Value>,
55}
56
57pub type GroupMap = (
60 spg_sql::ast::ColumnName,
61 alloc::collections::BTreeMap<String, Value>,
62);
63
64pub type ExprPlan = (
75 usize,
76 alloc::vec::Vec<Option<alloc::rc::Rc<GroupMap>>>,
77 spg_sql::ast::Expr,
78);
79
80pub type ExistsSet = (
87 alloc::vec::Vec<spg_sql::ast::ColumnName>,
88 alloc::collections::BTreeSet<String>,
89);
90
91#[derive(Debug, Clone)]
97pub enum InListSet {
98 Int(alloc::collections::BTreeSet<i64>),
99 Text(alloc::collections::BTreeSet<String>),
100}
101
102#[derive(Debug, Clone)]
103pub struct InListSetEntry {
104 pub set: InListSet,
105 pub has_null: bool,
108}
109
110#[derive(Debug, Clone)]
111pub struct MemoizeCache {
112 entries: VecDeque<(CacheKey, Value)>,
117 pub group_maps: alloc::collections::BTreeMap<String, Option<alloc::rc::Rc<GroupMap>>>,
123 pub exists_sets: alloc::collections::BTreeMap<String, Option<alloc::rc::Rc<ExistsSet>>>,
127 pub expr_plans: alloc::collections::BTreeMap<usize, ExprPlan>,
129 pub in_sets: alloc::collections::BTreeMap<usize, Option<InListSetEntry>>,
134 pub has_subquery: alloc::collections::BTreeMap<usize, bool>,
139 max_entries: usize,
140 max_bytes: usize,
141 current_bytes: usize,
142 pub hit_count: u64,
143 pub miss_count: u64,
144}
145
146impl Default for MemoizeCache {
147 fn default() -> Self {
148 Self::new()
149 }
150}
151
152impl MemoizeCache {
153 pub fn new() -> Self {
154 Self {
155 entries: VecDeque::with_capacity(DEFAULT_MAX_ENTRIES),
156 max_entries: DEFAULT_MAX_ENTRIES,
157 max_bytes: DEFAULT_MAX_BYTES,
158 current_bytes: 0,
159 hit_count: 0,
160 miss_count: 0,
161 group_maps: alloc::collections::BTreeMap::new(),
162 exists_sets: alloc::collections::BTreeMap::new(),
163 expr_plans: alloc::collections::BTreeMap::new(),
164 in_sets: alloc::collections::BTreeMap::new(),
165 has_subquery: alloc::collections::BTreeMap::new(),
166 }
167 }
168
169 pub const fn with_max_entries(mut self, n: usize) -> Self {
170 self.max_entries = n;
171 self
172 }
173
174 pub const fn with_max_bytes(mut self, b: usize) -> Self {
175 self.max_bytes = b;
176 self
177 }
178
179 pub fn len(&self) -> usize {
180 self.entries.len()
181 }
182
183 pub fn is_empty(&self) -> bool {
184 self.entries.is_empty()
185 }
186
187 pub fn get(&mut self, key: &CacheKey) -> Option<Value> {
191 let pos = self.entries.iter().position(|(k, _)| k == key);
192 if let Some(p) = pos {
193 let (k, v) = self.entries.remove(p)?;
194 self.entries.push_front((k, v.clone()));
195 self.hit_count += 1;
196 Some(v)
197 } else {
198 self.miss_count += 1;
199 None
200 }
201 }
202
203 pub fn insert(&mut self, key: CacheKey, value: Value) {
207 let entry_bytes = approx_bytes(&key) + approx_value_bytes(&value);
208 while !self.entries.is_empty()
209 && (self.entries.len() >= self.max_entries
210 || self.current_bytes + entry_bytes > self.max_bytes)
211 {
212 let Some((k, v)) = self.entries.pop_back() else {
213 break;
214 };
215 self.current_bytes = self
216 .current_bytes
217 .saturating_sub(approx_bytes(&k) + approx_value_bytes(&v));
218 }
219 self.current_bytes = self.current_bytes.saturating_add(entry_bytes);
220 self.entries.push_front((key, value));
221 }
222}
223
224fn approx_bytes(key: &CacheKey) -> usize {
225 key.subquery_repr.len()
226 + key
227 .outer_values
228 .iter()
229 .map(approx_value_bytes)
230 .sum::<usize>()
231 + 16
232}
233
234fn approx_value_bytes(v: &Value) -> usize {
235 match v {
236 Value::Null | Value::Bool(_) | Value::SmallInt(_) => 1,
237 Value::Int(_) => 4,
238 Value::BigInt(_) | Value::Float(_) => 8,
239 Value::Date(_) | Value::Timestamp(_) => 8,
240 Value::Interval { .. } => 16,
241 Value::Numeric { .. } => 16,
242 Value::Text(s) | Value::Json(s) => s.len(),
243 Value::Vector(v) => v.len() * 4,
244 Value::Sq8Vector(q) => q.bytes.len() + 8,
245 Value::HalfVector(h) => h.dim() * 2,
246 _ => 16,
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 fn key(repr: &str, outer: &[Value]) -> CacheKey {
256 CacheKey {
257 subquery_repr: repr.into(),
258 outer_values: outer.to_vec(),
259 }
260 }
261
262 #[test]
263 fn empty_cache_misses_everything() {
264 let mut c = MemoizeCache::new();
265 let k = key("SELECT 1", &[Value::Int(1)]);
266 assert!(c.get(&k).is_none());
267 assert_eq!(c.miss_count, 1);
268 assert_eq!(c.hit_count, 0);
269 }
270
271 #[test]
272 fn insert_then_get_hits() {
273 let mut c = MemoizeCache::new();
274 let k = key("SELECT 1", &[Value::Int(1)]);
275 c.insert(k.clone(), Value::BigInt(42));
276 let v = c.get(&k);
277 assert_eq!(v, Some(Value::BigInt(42)));
278 assert_eq!(c.hit_count, 1);
279 }
280
281 #[test]
282 fn repeated_outer_key_hits_after_first_insert() {
283 let mut c = MemoizeCache::new();
284 let repr = "SELECT MAX(x) FROM y WHERE y.k = outer.k";
285 for i in 0..100 {
286 let k = key(repr, &[Value::Int(i % 5)]);
287 if c.get(&k).is_none() {
288 c.insert(k, Value::BigInt(i64::from(i)));
289 }
290 }
291 assert_eq!(c.miss_count, 5);
293 assert_eq!(c.hit_count, 95);
294 }
295
296 #[test]
297 fn lru_eviction_at_max_entries() {
298 let mut c = MemoizeCache::new().with_max_entries(3);
299 for i in 0..5 {
300 let k = key("q", &[Value::Int(i)]);
301 c.insert(k, Value::BigInt(i64::from(i)));
302 }
303 assert!(c.len() <= 3, "len={}", c.len());
304 assert!(c.get(&key("q", &[Value::Int(4)])).is_some());
306 assert!(c.get(&key("q", &[Value::Int(3)])).is_some());
307 assert!(c.get(&key("q", &[Value::Int(2)])).is_some());
308 assert!(c.get(&key("q", &[Value::Int(0)])).is_none());
310 }
311
312 #[test]
313 fn lru_eviction_at_max_bytes() {
314 let mut c = MemoizeCache::new().with_max_bytes(128);
315 for i in 0..10 {
317 let big_str = alloc::string::String::from_iter(core::iter::repeat_n('x', 64));
318 c.insert(key("q", &[Value::Int(i)]), Value::Text(big_str));
319 }
320 assert!(c.len() < 10, "len={}", c.len());
321 }
322
323 #[test]
324 fn distinct_subquery_reprs_dont_collide() {
325 let mut c = MemoizeCache::new();
326 let k1 = key("SELECT 1", &[Value::Int(1)]);
327 let k2 = key("SELECT 2", &[Value::Int(1)]);
328 c.insert(k1.clone(), Value::BigInt(10));
329 c.insert(k2.clone(), Value::BigInt(20));
330 assert_eq!(c.get(&k1), Some(Value::BigInt(10)));
331 assert_eq!(c.get(&k2), Some(Value::BigInt(20)));
332 }
333
334 #[test]
335 fn miss_then_hit_bumps_promotes_to_lru_front() {
336 let mut c = MemoizeCache::new().with_max_entries(3);
337 c.insert(key("q", &[Value::Int(0)]), Value::BigInt(0));
338 c.insert(key("q", &[Value::Int(1)]), Value::BigInt(1));
339 c.insert(key("q", &[Value::Int(2)]), Value::BigInt(2));
340 let _ = c.get(&key("q", &[Value::Int(0)]));
342 c.insert(key("q", &[Value::Int(3)]), Value::BigInt(3));
344 assert!(c.get(&key("q", &[Value::Int(0)])).is_some());
345 assert!(c.get(&key("q", &[Value::Int(1)])).is_none());
346 }
347}