1use std::collections::{HashMap, HashSet};
5use std::sync::LazyLock;
6use std::time::{Duration, Instant};
7
8use crate::executor::ToolOutput;
9
10static NON_CACHEABLE_TOOLS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
16 HashSet::from([
17 "bash", "memory_save", "memory_search", "scheduler", "write", ])
23});
24
25#[must_use]
30pub fn is_cacheable(tool_name: &str) -> bool {
31 if tool_name.starts_with("mcp_") {
32 return false;
33 }
34 !NON_CACHEABLE_TOOLS.contains(tool_name)
35}
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39pub struct CacheKey {
40 pub tool_name: String,
41 pub args_hash: u64,
42}
43
44impl CacheKey {
45 #[must_use]
46 pub fn new(tool_name: impl Into<String>, args_hash: u64) -> Self {
47 Self {
48 tool_name: tool_name.into(),
49 args_hash,
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct CacheEntry {
57 pub output: ToolOutput,
58 pub inserted_at: Instant,
59}
60
61impl CacheEntry {
62 fn is_expired(&self, ttl: Duration) -> bool {
63 self.inserted_at.elapsed() > ttl
64 }
65}
66
67#[derive(Debug)]
76pub struct ToolResultCache {
77 entries: HashMap<CacheKey, CacheEntry>,
78 ttl: Option<Duration>,
80 enabled: bool,
81 hits: u64,
82 misses: u64,
83}
84
85impl ToolResultCache {
86 #[must_use]
90 pub fn new(enabled: bool, ttl: Option<Duration>) -> Self {
91 Self {
92 entries: HashMap::new(),
93 ttl,
94 enabled,
95 hits: 0,
96 misses: 0,
97 }
98 }
99
100 pub fn get(&mut self, key: &CacheKey) -> Option<ToolOutput> {
104 if !self.enabled {
105 return None;
106 }
107 if let Some(entry) = self.entries.get(key) {
108 if self.ttl.is_some_and(|ttl| entry.is_expired(ttl)) {
109 self.entries.remove(key);
110 return None;
111 }
112 let output = entry.output.clone();
113 self.hits += 1;
114 return Some(output);
115 }
116 self.misses += 1;
117 None
118 }
119
120 pub fn put(&mut self, key: CacheKey, output: ToolOutput) {
122 if !self.enabled {
123 return;
124 }
125 self.entries.insert(
126 key,
127 CacheEntry {
128 output,
129 inserted_at: Instant::now(),
130 },
131 );
132 }
133
134 pub fn clear(&mut self) {
136 self.entries.clear();
137 self.hits = 0;
138 self.misses = 0;
139 }
140
141 #[must_use]
143 pub fn len(&self) -> usize {
144 self.entries.len()
145 }
146
147 #[must_use]
149 pub fn is_empty(&self) -> bool {
150 self.entries.is_empty()
151 }
152
153 #[must_use]
155 pub fn hits(&self) -> u64 {
156 self.hits
157 }
158
159 #[must_use]
161 pub fn misses(&self) -> u64 {
162 self.misses
163 }
164
165 #[must_use]
167 pub fn is_enabled(&self) -> bool {
168 self.enabled
169 }
170
171 #[must_use]
173 pub fn ttl_secs(&self) -> u64 {
174 self.ttl.map_or(0, |d| d.as_secs())
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 fn make_output(summary: &str) -> ToolOutput {
183 ToolOutput {
184 tool_name: "test".to_owned(),
185 summary: summary.to_owned(),
186 blocks_executed: 1,
187 filter_stats: None,
188 diff: None,
189 streamed: false,
190 terminal_id: None,
191 locations: None,
192 raw_response: None,
193 }
194 }
195
196 fn key(name: &str, hash: u64) -> CacheKey {
197 CacheKey::new(name, hash)
198 }
199
200 #[test]
201 fn miss_on_empty_cache() {
202 let mut cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
203 assert!(cache.get(&key("read", 1)).is_none());
204 assert_eq!(cache.misses(), 1);
205 assert_eq!(cache.hits(), 0);
206 }
207
208 #[test]
209 fn put_then_get_returns_cached() {
210 let mut cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
211 let out = make_output("file contents");
212 cache.put(key("read", 42), out.clone());
213 let result = cache.get(&key("read", 42));
214 assert!(result.is_some());
215 assert_eq!(result.unwrap().summary, "file contents");
216 assert_eq!(cache.hits(), 1);
217 assert_eq!(cache.misses(), 0);
218 }
219
220 #[test]
221 fn different_hash_is_miss() {
222 let mut cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
223 cache.put(key("read", 1), make_output("a"));
224 assert!(cache.get(&key("read", 2)).is_none());
225 }
226
227 #[test]
228 fn different_tool_name_is_miss() {
229 let mut cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
230 cache.put(key("read", 1), make_output("a"));
231 assert!(cache.get(&key("write", 1)).is_none());
232 }
233
234 #[test]
235 fn ttl_none_never_expires() {
236 let mut cache = ToolResultCache::new(true, None);
237 cache.put(key("read", 1), make_output("content"));
238 assert!(cache.get(&key("read", 1)).is_some());
240 assert_eq!(cache.hits(), 1);
241 }
242
243 #[test]
244 fn ttl_zero_duration_expires_immediately() {
245 let mut cache = ToolResultCache::new(true, Some(Duration::ZERO));
248 cache.put(key("read", 1), make_output("content"));
249 let result = cache.get(&key("read", 1));
250 assert!(
252 result.is_none(),
253 "Duration::ZERO entry must expire on first get()"
254 );
255 assert_eq!(cache.len(), 0, "expired entry must be removed from map");
256 }
257
258 #[test]
259 fn ttl_expired_returns_none() {
260 let mut cache = ToolResultCache::new(true, Some(Duration::from_millis(1)));
261 cache.put(key("read", 1), make_output("content"));
262 std::thread::sleep(Duration::from_millis(10));
263 assert!(cache.get(&key("read", 1)).is_none());
264 assert_eq!(cache.len(), 0);
266 }
267
268 #[test]
269 fn clear_removes_all_and_resets_counters() {
270 let mut cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
271 cache.put(key("read", 1), make_output("a"));
272 cache.put(key("web_scrape", 2), make_output("b"));
273 cache.get(&key("read", 1));
275 cache.get(&key("missing", 99));
276 assert_eq!(cache.hits(), 1);
277 assert_eq!(cache.misses(), 1);
278
279 cache.clear();
280 assert_eq!(cache.len(), 0);
281 assert_eq!(cache.hits(), 0);
282 assert_eq!(cache.misses(), 0);
283 assert!(cache.get(&key("read", 1)).is_none());
284 }
285
286 #[test]
287 fn disabled_cache_always_misses() {
288 let mut cache = ToolResultCache::new(false, Some(Duration::from_secs(300)));
289 cache.put(key("read", 1), make_output("content"));
290 assert!(cache.get(&key("read", 1)).is_none());
292 assert_eq!(cache.len(), 0);
293 assert_eq!(cache.misses(), 0);
295 }
296
297 #[test]
298 fn is_cacheable_returns_false_for_deny_list() {
299 assert!(!is_cacheable("bash"));
300 assert!(!is_cacheable("memory_save"));
301 assert!(!is_cacheable("memory_search"));
302 assert!(!is_cacheable("scheduler"));
303 assert!(!is_cacheable("write"));
304 }
305
306 #[test]
307 fn is_cacheable_returns_false_for_mcp_prefix() {
308 assert!(!is_cacheable("mcp_github_list_issues"));
309 assert!(!is_cacheable("mcp_send_email"));
310 assert!(!is_cacheable("mcp_"));
311 }
312
313 #[test]
314 fn is_cacheable_returns_true_for_read_only_tools() {
315 assert!(is_cacheable("read"));
316 assert!(is_cacheable("web_scrape"));
317 assert!(is_cacheable("search_code"));
318 assert!(is_cacheable("load_skill"));
319 assert!(is_cacheable("diagnostics"));
320 }
321
322 #[test]
323 fn counter_increments_correctly() {
324 let mut cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
325 cache.put(key("read", 1), make_output("a"));
326 cache.put(key("read", 2), make_output("b"));
327
328 cache.get(&key("read", 1)); cache.get(&key("read", 1)); cache.get(&key("read", 99)); assert_eq!(cache.hits(), 2);
333 assert_eq!(cache.misses(), 1);
334 }
335
336 #[test]
337 fn ttl_secs_returns_zero_for_none() {
338 let cache = ToolResultCache::new(true, None);
339 assert_eq!(cache.ttl_secs(), 0);
340 }
341
342 #[test]
343 fn ttl_secs_returns_seconds_for_some() {
344 let cache = ToolResultCache::new(true, Some(Duration::from_secs(300)));
345 assert_eq!(cache.ttl_secs(), 300);
346 }
347}