1use std::collections::{HashMap, HashSet};
5use std::sync::LazyLock;
6use std::time::{Duration, Instant};
7
8use crate::executor::ToolOutput;
9
10static NON_CACHEABLE_TOOLS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
16 HashSet::from([
17 "bash", "memory_save", "memory_search", "scheduler", "write", ])
23});
24
25#[must_use]
30pub fn is_cacheable(tool_name: &str) -> bool {
31 if tool_name.starts_with("mcp_") {
32 return false;
33 }
34 !NON_CACHEABLE_TOOLS.contains(tool_name)
35}
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39pub struct CacheKey {
40 pub tool_name: String,
41 pub args_hash: u64,
42}
43
44impl CacheKey {
45 #[must_use]
46 pub fn new(tool_name: impl Into<String>, args_hash: u64) -> Self {
47 Self {
48 tool_name: tool_name.into(),
49 args_hash,
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct CacheEntry {
57 pub output: ToolOutput,
58 pub inserted_at: Instant,
59}
60
61impl CacheEntry {
62 fn is_expired(&self, ttl: Duration) -> bool {
63 self.inserted_at.elapsed() > ttl
64 }
65}
66
67#[derive(Debug)]
76pub struct ToolResultCache {
77 entries: HashMap<CacheKey, CacheEntry>,
78 ttl: Option<Duration>,
80 enabled: bool,
81 hits: u64,
82 misses: u64,
83}
84
85impl ToolResultCache {
86 #[must_use]
90 pub fn new(enabled: bool, ttl: Option<Duration>) -> Self {
91 Self {
92 entries: HashMap::new(),
93 ttl,
94 enabled,
95 hits: 0,
96 misses: 0,
97 }
98 }
99
100 pub fn get(&mut self, key: &CacheKey) -> Option<ToolOutput> {
104 if !self.enabled {
105 return None;
106 }
107 if let Some(entry) = self.entries.get(key) {
108 if self.ttl.is_some_and(|ttl| entry.is_expired(ttl)) {
109 self.entries.remove(key);
110 return None;
111 }
112 let output = entry.output.clone();
113 self.hits += 1;
114 return Some(output);
115 }
116 self.misses += 1;
117 None
118 }
119
120 pub fn put(&mut self, key: CacheKey, output: ToolOutput) {
122 if !self.enabled {
123 return;
124 }
125 self.entries.insert(
126 key,
127 CacheEntry {
128 output,
129 inserted_at: Instant::now(),
130 },
131 );
132 }
133
134 pub fn clear(&mut self) {
136 self.entries.clear();
137 self.hits = 0;
138 self.misses = 0;
139 }
140
141 #[must_use]
143 pub fn len(&self) -> usize {
144 self.entries.len()
145 }
146
147 #[must_use]
149 pub fn is_empty(&self) -> bool {
150 self.entries.is_empty()
151 }
152
153 #[must_use]
155 pub fn hits(&self) -> u64 {
156 self.hits
157 }
158
159 #[must_use]
161 pub fn misses(&self) -> u64 {
162 self.misses
163 }
164
165 #[must_use]
167 pub fn is_enabled(&self) -> bool {
168 self.enabled
169 }
170
171 #[must_use]
173 pub fn ttl_secs(&self) -> u64 {
174 self.ttl.map_or(0, |d| d.as_secs())
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 fn make_output(summary: &str) -> ToolOutput {
183 ToolOutput {
184 tool_name: "test".to_owned(),
185 summary: summary.to_owned(),
186 blocks_executed: 1,
187 filter_stats: None,
188 diff: None,
189 streamed: false,
190 terminal_id: None,
191 locations: None,
192 raw_response: None,
193 claim_source: None,
194 }
195 }
196
197 fn key(name: &str, hash: u64) -> CacheKey {
198 CacheKey::new(name, hash)
199 }
200
201 #[test]
202 fn miss_on_empty_cache() {
203 let mut cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
204 assert!(cache.get(&key("read", 1)).is_none());
205 assert_eq!(cache.misses(), 1);
206 assert_eq!(cache.hits(), 0);
207 }
208
209 #[test]
210 fn put_then_get_returns_cached() {
211 let mut cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
212 let out = make_output("file contents");
213 cache.put(key("read", 42), out.clone());
214 let result = cache.get(&key("read", 42));
215 assert!(result.is_some());
216 assert_eq!(result.unwrap().summary, "file contents");
217 assert_eq!(cache.hits(), 1);
218 assert_eq!(cache.misses(), 0);
219 }
220
221 #[test]
222 fn different_hash_is_miss() {
223 let mut cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
224 cache.put(key("read", 1), make_output("a"));
225 assert!(cache.get(&key("read", 2)).is_none());
226 }
227
228 #[test]
229 fn different_tool_name_is_miss() {
230 let mut cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
231 cache.put(key("read", 1), make_output("a"));
232 assert!(cache.get(&key("write", 1)).is_none());
233 }
234
235 #[test]
236 fn ttl_none_never_expires() {
237 let mut cache = ToolResultCache::new(true, None);
238 cache.put(key("read", 1), make_output("content"));
239 assert!(cache.get(&key("read", 1)).is_some());
241 assert_eq!(cache.hits(), 1);
242 }
243
244 #[test]
245 fn ttl_zero_duration_expires_immediately() {
246 let mut cache = ToolResultCache::new(true, Some(Duration::ZERO));
249 cache.put(key("read", 1), make_output("content"));
250 let result = cache.get(&key("read", 1));
251 assert!(
253 result.is_none(),
254 "Duration::ZERO entry must expire on first get()"
255 );
256 assert_eq!(cache.len(), 0, "expired entry must be removed from map");
257 }
258
259 #[test]
260 fn ttl_expired_returns_none() {
261 let mut cache = ToolResultCache::new(true, Some(Duration::from_millis(1)));
262 cache.put(key("read", 1), make_output("content"));
263 std::thread::sleep(Duration::from_millis(10));
264 assert!(cache.get(&key("read", 1)).is_none());
265 assert_eq!(cache.len(), 0);
267 }
268
269 #[test]
270 fn clear_removes_all_and_resets_counters() {
271 let mut cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
272 cache.put(key("read", 1), make_output("a"));
273 cache.put(key("web_scrape", 2), make_output("b"));
274 cache.get(&key("read", 1));
276 cache.get(&key("missing", 99));
277 assert_eq!(cache.hits(), 1);
278 assert_eq!(cache.misses(), 1);
279
280 cache.clear();
281 assert_eq!(cache.len(), 0);
282 assert_eq!(cache.hits(), 0);
283 assert_eq!(cache.misses(), 0);
284 assert!(cache.get(&key("read", 1)).is_none());
285 }
286
287 #[test]
288 fn disabled_cache_always_misses() {
289 let mut cache = ToolResultCache::new(false, Some(Duration::from_secs(300)));
290 cache.put(key("read", 1), make_output("content"));
291 assert!(cache.get(&key("read", 1)).is_none());
293 assert_eq!(cache.len(), 0);
294 assert_eq!(cache.misses(), 0);
296 }
297
298 #[test]
299 fn is_cacheable_returns_false_for_deny_list() {
300 assert!(!is_cacheable("bash"));
301 assert!(!is_cacheable("memory_save"));
302 assert!(!is_cacheable("memory_search"));
303 assert!(!is_cacheable("scheduler"));
304 assert!(!is_cacheable("write"));
305 }
306
307 #[test]
308 fn is_cacheable_returns_false_for_mcp_prefix() {
309 assert!(!is_cacheable("mcp_github_list_issues"));
310 assert!(!is_cacheable("mcp_send_email"));
311 assert!(!is_cacheable("mcp_"));
312 }
313
314 #[test]
315 fn is_cacheable_returns_true_for_read_only_tools() {
316 assert!(is_cacheable("read"));
317 assert!(is_cacheable("web_scrape"));
318 assert!(is_cacheable("search_code"));
319 assert!(is_cacheable("load_skill"));
320 assert!(is_cacheable("diagnostics"));
321 }
322
323 #[test]
324 fn counter_increments_correctly() {
325 let mut cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
326 cache.put(key("read", 1), make_output("a"));
327 cache.put(key("read", 2), make_output("b"));
328
329 cache.get(&key("read", 1)); cache.get(&key("read", 1)); cache.get(&key("read", 99)); assert_eq!(cache.hits(), 2);
334 assert_eq!(cache.misses(), 1);
335 }
336
337 #[test]
338 fn ttl_secs_returns_zero_for_none() {
339 let cache = ToolResultCache::new(true, None);
340 assert_eq!(cache.ttl_secs(), 0);
341 }
342
343 #[test]
344 fn ttl_secs_returns_seconds_for_some() {
345 let cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
346 assert_eq!(cache.ttl_secs(), 300);
347 }
348}