1#![allow(clippy::doc_markdown)]
4
5use alloc::collections::VecDeque;
34use alloc::string::String;
35use alloc::vec::Vec;
36
37use spg_storage::Value;
38
39pub const DEFAULT_MAX_ENTRIES: usize = 1024;
42
43pub const DEFAULT_MAX_BYTES: usize = 16 * 1024 * 1024;
46
47#[derive(Debug, Clone, PartialEq)]
52pub struct CacheKey {
53 pub subquery_repr: String,
54 pub outer_values: Vec<Value>,
55}
56
57pub type GroupMap = (
60 spg_sql::ast::ColumnName,
61 alloc::collections::BTreeMap<String, Value>,
62);
63
64pub type ExprPlan = (
75 usize,
76 alloc::vec::Vec<Option<alloc::rc::Rc<GroupMap>>>,
77 spg_sql::ast::Expr,
78);
79
80#[derive(Debug, Clone)]
81pub struct MemoizeCache {
82 entries: VecDeque<(CacheKey, Value)>,
87 pub group_maps: alloc::collections::BTreeMap<String, Option<alloc::rc::Rc<GroupMap>>>,
93 pub expr_plans: alloc::collections::BTreeMap<usize, ExprPlan>,
95 max_entries: usize,
96 max_bytes: usize,
97 current_bytes: usize,
98 pub hit_count: u64,
99 pub miss_count: u64,
100}
101
102impl Default for MemoizeCache {
103 fn default() -> Self {
104 Self::new()
105 }
106}
107
108impl MemoizeCache {
109 pub fn new() -> Self {
110 Self {
111 entries: VecDeque::with_capacity(DEFAULT_MAX_ENTRIES),
112 max_entries: DEFAULT_MAX_ENTRIES,
113 max_bytes: DEFAULT_MAX_BYTES,
114 current_bytes: 0,
115 hit_count: 0,
116 miss_count: 0,
117 group_maps: alloc::collections::BTreeMap::new(),
118 expr_plans: alloc::collections::BTreeMap::new(),
119 }
120 }
121
122 pub const fn with_max_entries(mut self, n: usize) -> Self {
123 self.max_entries = n;
124 self
125 }
126
127 pub const fn with_max_bytes(mut self, b: usize) -> Self {
128 self.max_bytes = b;
129 self
130 }
131
132 pub fn len(&self) -> usize {
133 self.entries.len()
134 }
135
136 pub fn is_empty(&self) -> bool {
137 self.entries.is_empty()
138 }
139
140 pub fn get(&mut self, key: &CacheKey) -> Option<Value> {
144 let pos = self.entries.iter().position(|(k, _)| k == key);
145 if let Some(p) = pos {
146 let (k, v) = self.entries.remove(p)?;
147 self.entries.push_front((k, v.clone()));
148 self.hit_count += 1;
149 Some(v)
150 } else {
151 self.miss_count += 1;
152 None
153 }
154 }
155
156 pub fn insert(&mut self, key: CacheKey, value: Value) {
160 let entry_bytes = approx_bytes(&key) + approx_value_bytes(&value);
161 while !self.entries.is_empty()
162 && (self.entries.len() >= self.max_entries
163 || self.current_bytes + entry_bytes > self.max_bytes)
164 {
165 let Some((k, v)) = self.entries.pop_back() else {
166 break;
167 };
168 self.current_bytes = self
169 .current_bytes
170 .saturating_sub(approx_bytes(&k) + approx_value_bytes(&v));
171 }
172 self.current_bytes = self.current_bytes.saturating_add(entry_bytes);
173 self.entries.push_front((key, value));
174 }
175}
176
177fn approx_bytes(key: &CacheKey) -> usize {
178 key.subquery_repr.len()
179 + key
180 .outer_values
181 .iter()
182 .map(approx_value_bytes)
183 .sum::<usize>()
184 + 16
185}
186
187fn approx_value_bytes(v: &Value) -> usize {
188 match v {
189 Value::Null | Value::Bool(_) | Value::SmallInt(_) => 1,
190 Value::Int(_) => 4,
191 Value::BigInt(_) | Value::Float(_) => 8,
192 Value::Date(_) | Value::Timestamp(_) => 8,
193 Value::Interval { .. } => 16,
194 Value::Numeric { .. } => 16,
195 Value::Text(s) | Value::Json(s) => s.len(),
196 Value::Vector(v) => v.len() * 4,
197 Value::Sq8Vector(q) => q.bytes.len() + 8,
198 Value::HalfVector(h) => h.dim() * 2,
199 _ => 16,
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 fn key(repr: &str, outer: &[Value]) -> CacheKey {
209 CacheKey {
210 subquery_repr: repr.into(),
211 outer_values: outer.to_vec(),
212 }
213 }
214
215 #[test]
216 fn empty_cache_misses_everything() {
217 let mut c = MemoizeCache::new();
218 let k = key("SELECT 1", &[Value::Int(1)]);
219 assert!(c.get(&k).is_none());
220 assert_eq!(c.miss_count, 1);
221 assert_eq!(c.hit_count, 0);
222 }
223
224 #[test]
225 fn insert_then_get_hits() {
226 let mut c = MemoizeCache::new();
227 let k = key("SELECT 1", &[Value::Int(1)]);
228 c.insert(k.clone(), Value::BigInt(42));
229 let v = c.get(&k);
230 assert_eq!(v, Some(Value::BigInt(42)));
231 assert_eq!(c.hit_count, 1);
232 }
233
234 #[test]
235 fn repeated_outer_key_hits_after_first_insert() {
236 let mut c = MemoizeCache::new();
237 let repr = "SELECT MAX(x) FROM y WHERE y.k = outer.k";
238 for i in 0..100 {
239 let k = key(repr, &[Value::Int(i % 5)]);
240 if c.get(&k).is_none() {
241 c.insert(k, Value::BigInt(i64::from(i)));
242 }
243 }
244 assert_eq!(c.miss_count, 5);
246 assert_eq!(c.hit_count, 95);
247 }
248
249 #[test]
250 fn lru_eviction_at_max_entries() {
251 let mut c = MemoizeCache::new().with_max_entries(3);
252 for i in 0..5 {
253 let k = key("q", &[Value::Int(i)]);
254 c.insert(k, Value::BigInt(i64::from(i)));
255 }
256 assert!(c.len() <= 3, "len={}", c.len());
257 assert!(c.get(&key("q", &[Value::Int(4)])).is_some());
259 assert!(c.get(&key("q", &[Value::Int(3)])).is_some());
260 assert!(c.get(&key("q", &[Value::Int(2)])).is_some());
261 assert!(c.get(&key("q", &[Value::Int(0)])).is_none());
263 }
264
265 #[test]
266 fn lru_eviction_at_max_bytes() {
267 let mut c = MemoizeCache::new().with_max_bytes(128);
268 for i in 0..10 {
270 let big_str = alloc::string::String::from_iter(core::iter::repeat_n('x', 64));
271 c.insert(key("q", &[Value::Int(i)]), Value::Text(big_str));
272 }
273 assert!(c.len() < 10, "len={}", c.len());
274 }
275
276 #[test]
277 fn distinct_subquery_reprs_dont_collide() {
278 let mut c = MemoizeCache::new();
279 let k1 = key("SELECT 1", &[Value::Int(1)]);
280 let k2 = key("SELECT 2", &[Value::Int(1)]);
281 c.insert(k1.clone(), Value::BigInt(10));
282 c.insert(k2.clone(), Value::BigInt(20));
283 assert_eq!(c.get(&k1), Some(Value::BigInt(10)));
284 assert_eq!(c.get(&k2), Some(Value::BigInt(20)));
285 }
286
287 #[test]
288 fn miss_then_hit_bumps_promotes_to_lru_front() {
289 let mut c = MemoizeCache::new().with_max_entries(3);
290 c.insert(key("q", &[Value::Int(0)]), Value::BigInt(0));
291 c.insert(key("q", &[Value::Int(1)]), Value::BigInt(1));
292 c.insert(key("q", &[Value::Int(2)]), Value::BigInt(2));
293 let _ = c.get(&key("q", &[Value::Int(0)]));
295 c.insert(key("q", &[Value::Int(3)]), Value::BigInt(3));
297 assert!(c.get(&key("q", &[Value::Int(0)])).is_some());
298 assert!(c.get(&key("q", &[Value::Int(1)])).is_none());
299 }
300}