1use std::collections::{HashMap, HashSet};
5use std::sync::LazyLock;
6use std::time::{Duration, Instant};
7
8use zeph_common::ToolName;
9
10use crate::executor::ToolOutput;
11
12static NON_CACHEABLE_TOOLS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
18 HashSet::from([
19 "bash", "memory_save", "memory_search", "scheduler", "write", ])
25});
26
27#[must_use]
32pub fn is_cacheable(tool_name: &str) -> bool {
33 if tool_name.starts_with("mcp_") {
34 return false;
35 }
36 !NON_CACHEABLE_TOOLS.contains(tool_name)
37}
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41pub struct CacheKey {
42 pub tool_name: ToolName,
43 pub args_hash: u64,
44}
45
46impl CacheKey {
47 #[must_use]
48 pub fn new(tool_name: impl Into<ToolName>, args_hash: u64) -> Self {
49 Self {
50 tool_name: tool_name.into(),
51 args_hash,
52 }
53 }
54}
55
56#[derive(Debug, Clone)]
58pub struct CacheEntry {
59 pub output: ToolOutput,
60 pub inserted_at: Instant,
61}
62
63impl CacheEntry {
64 fn is_expired(&self, ttl: Duration) -> bool {
65 self.inserted_at.elapsed() > ttl
66 }
67}
68
69#[derive(Debug)]
78pub struct ToolResultCache {
79 entries: HashMap<CacheKey, CacheEntry>,
80 ttl: Option<Duration>,
82 enabled: bool,
83 hits: u64,
84 misses: u64,
85}
86
87impl ToolResultCache {
88 #[must_use]
92 pub fn new(enabled: bool, ttl: Option<Duration>) -> Self {
93 Self {
94 entries: HashMap::new(),
95 ttl,
96 enabled,
97 hits: 0,
98 misses: 0,
99 }
100 }
101
102 pub fn get(&mut self, key: &CacheKey) -> Option<ToolOutput> {
106 if !self.enabled {
107 return None;
108 }
109 if let Some(entry) = self.entries.get(key) {
110 if self.ttl.is_some_and(|ttl| entry.is_expired(ttl)) {
111 self.entries.remove(key);
112 return None;
113 }
114 let output = entry.output.clone();
115 self.hits += 1;
116 return Some(output);
117 }
118 self.misses += 1;
119 None
120 }
121
122 pub fn put(&mut self, key: CacheKey, output: ToolOutput) {
124 if !self.enabled {
125 return;
126 }
127 self.entries.insert(
128 key,
129 CacheEntry {
130 output,
131 inserted_at: Instant::now(),
132 },
133 );
134 }
135
136 pub fn clear(&mut self) {
138 self.entries.clear();
139 self.hits = 0;
140 self.misses = 0;
141 }
142
143 #[must_use]
145 pub fn len(&self) -> usize {
146 self.entries.len()
147 }
148
149 #[must_use]
151 pub fn is_empty(&self) -> bool {
152 self.entries.is_empty()
153 }
154
155 #[must_use]
157 pub fn hits(&self) -> u64 {
158 self.hits
159 }
160
161 #[must_use]
163 pub fn misses(&self) -> u64 {
164 self.misses
165 }
166
167 #[must_use]
169 pub fn is_enabled(&self) -> bool {
170 self.enabled
171 }
172
173 #[must_use]
175 pub fn ttl_secs(&self) -> u64 {
176 self.ttl.map_or(0, |d| d.as_secs())
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use crate::ToolName;
184
185 fn make_output(summary: &str) -> ToolOutput {
186 ToolOutput {
187 tool_name: ToolName::new("test"),
188 summary: summary.to_owned(),
189 blocks_executed: 1,
190 filter_stats: None,
191 diff: None,
192 streamed: false,
193 terminal_id: None,
194 locations: None,
195 raw_response: None,
196 claim_source: None,
197 }
198 }
199
200 fn key(name: &str, hash: u64) -> CacheKey {
201 CacheKey::new(name, hash)
202 }
203
204 #[test]
205 fn miss_on_empty_cache() {
206 let mut cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
207 assert!(cache.get(&key("read", 1)).is_none());
208 assert_eq!(cache.misses(), 1);
209 assert_eq!(cache.hits(), 0);
210 }
211
212 #[test]
213 fn put_then_get_returns_cached() {
214 let mut cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
215 let out = make_output("file contents");
216 cache.put(key("read", 42), out.clone());
217 let result = cache.get(&key("read", 42));
218 assert!(result.is_some());
219 assert_eq!(result.unwrap().summary, "file contents");
220 assert_eq!(cache.hits(), 1);
221 assert_eq!(cache.misses(), 0);
222 }
223
224 #[test]
225 fn different_hash_is_miss() {
226 let mut cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
227 cache.put(key("read", 1), make_output("a"));
228 assert!(cache.get(&key("read", 2)).is_none());
229 }
230
231 #[test]
232 fn different_tool_name_is_miss() {
233 let mut cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
234 cache.put(key("read", 1), make_output("a"));
235 assert!(cache.get(&key("write", 1)).is_none());
236 }
237
238 #[test]
239 fn ttl_none_never_expires() {
240 let mut cache = ToolResultCache::new(true, None);
241 cache.put(key("read", 1), make_output("content"));
242 assert!(cache.get(&key("read", 1)).is_some());
244 assert_eq!(cache.hits(), 1);
245 }
246
247 #[test]
248 fn ttl_zero_duration_expires_immediately() {
249 let mut cache = ToolResultCache::new(true, Some(Duration::ZERO));
252 cache.put(key("read", 1), make_output("content"));
253 let result = cache.get(&key("read", 1));
254 assert!(
256 result.is_none(),
257 "Duration::ZERO entry must expire on first get()"
258 );
259 assert_eq!(cache.len(), 0, "expired entry must be removed from map");
260 }
261
262 #[test]
263 fn ttl_expired_returns_none() {
264 let mut cache = ToolResultCache::new(true, Some(Duration::from_millis(1)));
265 cache.put(key("read", 1), make_output("content"));
266 std::thread::sleep(Duration::from_millis(10));
267 assert!(cache.get(&key("read", 1)).is_none());
268 assert_eq!(cache.len(), 0);
270 }
271
272 #[test]
273 fn clear_removes_all_and_resets_counters() {
274 let mut cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
275 cache.put(key("read", 1), make_output("a"));
276 cache.put(key("web_scrape", 2), make_output("b"));
277 cache.get(&key("read", 1));
279 cache.get(&key("missing", 99));
280 assert_eq!(cache.hits(), 1);
281 assert_eq!(cache.misses(), 1);
282
283 cache.clear();
284 assert_eq!(cache.len(), 0);
285 assert_eq!(cache.hits(), 0);
286 assert_eq!(cache.misses(), 0);
287 assert!(cache.get(&key("read", 1)).is_none());
288 }
289
290 #[test]
291 fn disabled_cache_always_misses() {
292 let mut cache = ToolResultCache::new(false, Some(Duration::from_secs(300)));
293 cache.put(key("read", 1), make_output("content"));
294 assert!(cache.get(&key("read", 1)).is_none());
296 assert_eq!(cache.len(), 0);
297 assert_eq!(cache.misses(), 0);
299 }
300
301 #[test]
302 fn is_cacheable_returns_false_for_deny_list() {
303 assert!(!is_cacheable("bash"));
304 assert!(!is_cacheable("memory_save"));
305 assert!(!is_cacheable("memory_search"));
306 assert!(!is_cacheable("scheduler"));
307 assert!(!is_cacheable("write"));
308 }
309
310 #[test]
311 fn is_cacheable_returns_false_for_mcp_prefix() {
312 assert!(!is_cacheable("mcp_github_list_issues"));
313 assert!(!is_cacheable("mcp_send_email"));
314 assert!(!is_cacheable("mcp_"));
315 }
316
317 #[test]
318 fn is_cacheable_returns_true_for_read_only_tools() {
319 assert!(is_cacheable("read"));
320 assert!(is_cacheable("web_scrape"));
321 assert!(is_cacheable("search_code"));
322 assert!(is_cacheable("load_skill"));
323 assert!(is_cacheable("diagnostics"));
324 }
325
326 #[test]
327 fn counter_increments_correctly() {
328 let mut cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
329 cache.put(key("read", 1), make_output("a"));
330 cache.put(key("read", 2), make_output("b"));
331
332 cache.get(&key("read", 1)); cache.get(&key("read", 1)); cache.get(&key("read", 99)); assert_eq!(cache.hits(), 2);
337 assert_eq!(cache.misses(), 1);
338 }
339
340 #[test]
341 fn ttl_secs_returns_zero_for_none() {
342 let cache = ToolResultCache::new(true, None);
343 assert_eq!(cache.ttl_secs(), 0);
344 }
345
346 #[test]
347 fn ttl_secs_returns_seconds_for_some() {
348 let cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
349 assert_eq!(cache.ttl_secs(), 300);
350 }
351}