1use std::collections::{HashMap, HashSet, VecDeque};
5use std::sync::LazyLock;
6use std::time::{Duration, Instant};
7
8use zeph_common::ToolName;
9
10use crate::executor::ToolOutput;
11
12static NON_CACHEABLE_TOOLS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
18 HashSet::from([
19 "bash", "memory_save", "memory_search", "scheduler", "write", ])
25});
26
27#[must_use]
32pub fn is_cacheable(tool_name: &str) -> bool {
33 if tool_name.starts_with("mcp_") {
34 return false;
35 }
36 !NON_CACHEABLE_TOOLS.contains(tool_name)
37}
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41pub struct CacheKey {
42 pub tool_name: ToolName,
43 pub args_hash: u64,
44}
45
46impl CacheKey {
47 #[must_use]
48 pub fn new(tool_name: impl Into<ToolName>, args_hash: u64) -> Self {
49 Self {
50 tool_name: tool_name.into(),
51 args_hash,
52 }
53 }
54}
55
56#[derive(Debug, Clone)]
58pub struct CacheEntry {
59 pub output: ToolOutput,
60 pub inserted_at: Instant,
61}
62
63impl CacheEntry {
64 fn is_expired(&self, ttl: Duration) -> bool {
65 self.inserted_at.elapsed() > ttl
66 }
67}
68
69const MAX_CACHE_ENTRIES: usize = 512;
75
76#[derive(Debug)]
86pub struct ToolResultCache {
87 entries: HashMap<CacheKey, CacheEntry>,
88 insertion_order: VecDeque<CacheKey>,
90 ttl: Option<Duration>,
92 enabled: bool,
93 hits: u64,
94 misses: u64,
95}
96
97impl ToolResultCache {
98 #[must_use]
102 pub fn new(enabled: bool, ttl: Option<Duration>) -> Self {
103 Self {
104 entries: HashMap::new(),
105 insertion_order: VecDeque::new(),
106 ttl,
107 enabled,
108 hits: 0,
109 misses: 0,
110 }
111 }
112
113 pub fn get(&mut self, key: &CacheKey) -> Option<ToolOutput> {
117 if !self.enabled {
118 return None;
119 }
120 if let Some(entry) = self.entries.get(key) {
121 if self.ttl.is_some_and(|ttl| entry.is_expired(ttl)) {
122 self.entries.remove(key);
123 return None;
124 }
125 let output = entry.output.clone();
126 self.hits += 1;
127 return Some(output);
128 }
129 self.misses += 1;
130 None
131 }
132
133 pub fn put(&mut self, key: CacheKey, output: ToolOutput) {
138 if !self.enabled {
139 return;
140 }
141 if self.entries.len() >= MAX_CACHE_ENTRIES
142 && let Some(oldest_key) = self.insertion_order.pop_front()
143 {
144 self.entries.remove(&oldest_key);
145 tracing::debug!(
146 tool = %oldest_key.tool_name,
147 args_hash = oldest_key.args_hash,
148 "tool cache: evicted oldest entry (LRU cap {})",
149 MAX_CACHE_ENTRIES
150 );
151 }
152 self.insertion_order.push_back(key.clone());
153 self.entries.insert(
154 key,
155 CacheEntry {
156 output,
157 inserted_at: Instant::now(),
158 },
159 );
160 }
161
162 pub fn clear(&mut self) {
164 self.entries.clear();
165 self.insertion_order.clear();
166 self.hits = 0;
167 self.misses = 0;
168 }
169
170 #[must_use]
172 pub fn len(&self) -> usize {
173 self.entries.len()
174 }
175
176 #[must_use]
178 pub fn is_empty(&self) -> bool {
179 self.entries.is_empty()
180 }
181
182 #[must_use]
184 pub fn hits(&self) -> u64 {
185 self.hits
186 }
187
188 #[must_use]
190 pub fn misses(&self) -> u64 {
191 self.misses
192 }
193
194 #[must_use]
196 pub fn is_enabled(&self) -> bool {
197 self.enabled
198 }
199
200 #[must_use]
202 pub fn ttl_secs(&self) -> u64 {
203 self.ttl.map_or(0, |d| d.as_secs())
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use crate::ToolName;
211
212 fn make_output(summary: &str) -> ToolOutput {
213 ToolOutput {
214 tool_name: ToolName::new("test"),
215 summary: summary.to_owned(),
216 blocks_executed: 1,
217 filter_stats: None,
218 diff: None,
219 streamed: false,
220 terminal_id: None,
221 locations: None,
222 raw_response: None,
223 claim_source: None,
224 }
225 }
226
227 fn key(name: &str, hash: u64) -> CacheKey {
228 CacheKey::new(name, hash)
229 }
230
231 #[test]
232 fn miss_on_empty_cache() {
233 let mut cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
234 assert!(cache.get(&key("read", 1)).is_none());
235 assert_eq!(cache.misses(), 1);
236 assert_eq!(cache.hits(), 0);
237 }
238
239 #[test]
240 fn put_then_get_returns_cached() {
241 let mut cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
242 let out = make_output("file contents");
243 cache.put(key("read", 42), out.clone());
244 let result = cache.get(&key("read", 42));
245 assert!(result.is_some());
246 assert_eq!(result.unwrap().summary, "file contents");
247 assert_eq!(cache.hits(), 1);
248 assert_eq!(cache.misses(), 0);
249 }
250
251 #[test]
252 fn different_hash_is_miss() {
253 let mut cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
254 cache.put(key("read", 1), make_output("a"));
255 assert!(cache.get(&key("read", 2)).is_none());
256 }
257
258 #[test]
259 fn different_tool_name_is_miss() {
260 let mut cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
261 cache.put(key("read", 1), make_output("a"));
262 assert!(cache.get(&key("write", 1)).is_none());
263 }
264
265 #[test]
266 fn ttl_none_never_expires() {
267 let mut cache = ToolResultCache::new(true, None);
268 cache.put(key("read", 1), make_output("content"));
269 assert!(cache.get(&key("read", 1)).is_some());
271 assert_eq!(cache.hits(), 1);
272 }
273
274 #[test]
275 fn ttl_zero_duration_expires_immediately() {
276 let mut cache = ToolResultCache::new(true, Some(Duration::ZERO));
279 cache.put(key("read", 1), make_output("content"));
280 let result = cache.get(&key("read", 1));
281 assert!(
283 result.is_none(),
284 "Duration::ZERO entry must expire on first get()"
285 );
286 assert_eq!(cache.len(), 0, "expired entry must be removed from map");
287 }
288
289 #[test]
290 fn ttl_expired_returns_none() {
291 let mut cache = ToolResultCache::new(true, Some(Duration::from_millis(1)));
292 cache.put(key("read", 1), make_output("content"));
293 std::thread::sleep(Duration::from_millis(10));
294 assert!(cache.get(&key("read", 1)).is_none());
295 assert_eq!(cache.len(), 0);
297 }
298
299 #[test]
300 fn clear_removes_all_and_resets_counters() {
301 let mut cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
302 cache.put(key("read", 1), make_output("a"));
303 cache.put(key("web_scrape", 2), make_output("b"));
304 cache.get(&key("read", 1));
306 cache.get(&key("missing", 99));
307 assert_eq!(cache.hits(), 1);
308 assert_eq!(cache.misses(), 1);
309
310 cache.clear();
311 assert_eq!(cache.len(), 0);
312 assert_eq!(cache.hits(), 0);
313 assert_eq!(cache.misses(), 0);
314 assert!(cache.get(&key("read", 1)).is_none());
315 }
316
317 #[test]
318 fn disabled_cache_always_misses() {
319 let mut cache = ToolResultCache::new(false, Some(Duration::from_secs(300)));
320 cache.put(key("read", 1), make_output("content"));
321 assert!(cache.get(&key("read", 1)).is_none());
323 assert_eq!(cache.len(), 0);
324 assert_eq!(cache.misses(), 0);
326 }
327
328 #[test]
329 fn is_cacheable_returns_false_for_deny_list() {
330 assert!(!is_cacheable("bash"));
331 assert!(!is_cacheable("memory_save"));
332 assert!(!is_cacheable("memory_search"));
333 assert!(!is_cacheable("scheduler"));
334 assert!(!is_cacheable("write"));
335 }
336
337 #[test]
338 fn is_cacheable_returns_false_for_mcp_prefix() {
339 assert!(!is_cacheable("mcp_github_list_issues"));
340 assert!(!is_cacheable("mcp_send_email"));
341 assert!(!is_cacheable("mcp_"));
342 }
343
344 #[test]
345 fn is_cacheable_returns_true_for_read_only_tools() {
346 assert!(is_cacheable("read"));
347 assert!(is_cacheable("web_scrape"));
348 assert!(is_cacheable("search_code"));
349 assert!(is_cacheable("load_skill"));
350 assert!(is_cacheable("diagnostics"));
351 }
352
353 #[test]
354 fn counter_increments_correctly() {
355 let mut cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
356 cache.put(key("read", 1), make_output("a"));
357 cache.put(key("read", 2), make_output("b"));
358
359 cache.get(&key("read", 1)); cache.get(&key("read", 1)); cache.get(&key("read", 99)); assert_eq!(cache.hits(), 2);
364 assert_eq!(cache.misses(), 1);
365 }
366
367 #[test]
368 fn ttl_secs_returns_zero_for_none() {
369 let cache = ToolResultCache::new(true, None);
370 assert_eq!(cache.ttl_secs(), 0);
371 }
372
373 #[test]
374 fn ttl_secs_returns_seconds_for_some() {
375 let cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
376 assert_eq!(cache.ttl_secs(), 300);
377 }
378
379 #[test]
380 fn lru_eviction_at_capacity() {
381 let mut cache = ToolResultCache::new(true, None);
382 for i in 0..MAX_CACHE_ENTRIES {
384 cache.put(key("read", i as u64), make_output("v"));
385 }
386 assert_eq!(cache.len(), MAX_CACHE_ENTRIES);
387 cache.put(key("read", MAX_CACHE_ENTRIES as u64), make_output("new"));
389 assert_eq!(cache.len(), MAX_CACHE_ENTRIES, "size must stay at cap");
390 assert!(
391 cache.get(&key("read", 0)).is_none(),
392 "oldest entry must be evicted"
393 );
394 assert!(
395 cache.get(&key("read", MAX_CACHE_ENTRIES as u64)).is_some(),
396 "new entry must be present"
397 );
398 }
399}