Skip to main content

zai_rs/toolkits/
cache.rs

1//! Tool call result cache with intelligent invalidation
2
3use std::time::{Duration, SystemTime};
4
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8/// Cache key for tool calls
9#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
10pub struct CacheKey {
11    pub tool_name: String,
12    pub arguments: String,
13}
14
15impl CacheKey {
16    pub fn new(tool_name: String, arguments: Value) -> Self {
17        let normalized = normalize_json(&arguments);
18        Self {
19            tool_name,
20            arguments: normalized,
21        }
22    }
23}
24
25/// Cache entry with TTL
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct CacheEntry {
28    pub result: Value,
29    pub timestamp: SystemTime,
30    pub ttl: Duration,
31    pub hit_count: u64,
32}
33
34impl CacheEntry {
35    pub fn new(result: Value, ttl: Duration) -> Self {
36        Self {
37            result,
38            timestamp: SystemTime::now(),
39            ttl,
40            hit_count: 0,
41        }
42    }
43
44    pub fn is_expired(&self) -> bool {
45        match self.timestamp.elapsed() {
46            Ok(elapsed) => elapsed > self.ttl,
47            Err(_) => true,
48        }
49    }
50
51    pub fn hit(&mut self) {
52        self.hit_count += 1;
53    }
54}
55
56/// Intelligent tool call result cache
57#[derive(Clone)]
58pub struct ToolCallCache {
59    entries: dashmap::DashMap<CacheKey, CacheEntry>,
60    default_ttl: Duration,
61    max_size: usize,
62    enable_cache: bool,
63}
64
65impl ToolCallCache {
66    pub fn new() -> Self {
67        Self {
68            entries: dashmap::DashMap::new(),
69            default_ttl: Duration::from_secs(300),
70            max_size: 1000,
71            enable_cache: true,
72        }
73    }
74
75    pub fn with_ttl(mut self, ttl: Duration) -> Self {
76        self.default_ttl = ttl;
77        self
78    }
79
80    pub fn with_max_size(mut self, size: usize) -> Self {
81        self.max_size = size;
82        self
83    }
84
85    pub fn with_enabled(mut self, enabled: bool) -> Self {
86        self.enable_cache = enabled;
87        self
88    }
89
90    pub fn get(&self, key: &CacheKey) -> Option<Value> {
91        if !self.enable_cache {
92            return None;
93        }
94
95        // First check if entry exists and is expired without holding the lock
96        let expired = {
97            let entry = self.entries.get(key)?;
98            entry.is_expired()
99        };
100
101        // If expired, remove it and return None
102        if expired {
103            self.entries.remove(key);
104            return None;
105        }
106
107        // Get mut reference for hit counting and result cloning
108        let mut entry = self.entries.get_mut(key)?;
109        entry.hit();
110        Some(entry.result.clone())
111    }
112
113    pub fn insert(&self, key: CacheKey, result: Value, ttl: Option<Duration>) {
114        if !self.enable_cache {
115            return;
116        }
117
118        if self.entries.len() >= self.max_size {
119            self.evict_lru();
120        }
121
122        let entry = CacheEntry::new(result, ttl.unwrap_or(self.default_ttl));
123        self.entries.insert(key, entry);
124    }
125
126    pub fn insert_with_key(&self, tool_name: String, arguments: Value, result: Value) {
127        let key = CacheKey::new(tool_name, arguments);
128        self.insert(key, result, None);
129    }
130
131    pub fn clear(&self) {
132        self.entries.clear();
133    }
134
135    pub fn invalidate_tool(&self, tool_name: &str) {
136        self.entries.retain(|key, _| key.tool_name != tool_name);
137    }
138
139    pub fn stats(&self) -> CacheStats {
140        let mut total_hits = 0u64;
141        let mut expired_count = 0u64;
142
143        for entry in self.entries.iter() {
144            total_hits += entry.hit_count;
145            if entry.is_expired() {
146                expired_count += 1;
147            }
148        }
149
150        CacheStats {
151            total_entries: self.entries.len(),
152            total_hits,
153            expired_count,
154            hit_rate: if self.entries.is_empty() {
155                0.0
156            } else {
157                total_hits as f64 / self.entries.len() as f64
158            },
159        }
160    }
161
162    fn evict_lru(&self) {
163        let mut entries: Vec<_> = self
164            .entries
165            .iter()
166            .map(|entry| (entry.key().clone(), entry.value().timestamp))
167            .collect();
168
169        entries.sort_by_key(|a| a.1);
170
171        let remove_count = (self.max_size / 10).max(1);
172        for (key, _) in entries.into_iter().take(remove_count) {
173            self.entries.remove(&key);
174        }
175    }
176}
177
178impl Default for ToolCallCache {
179    fn default() -> Self {
180        Self::new()
181    }
182}
183
184/// Cache statistics
185#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct CacheStats {
187    pub total_entries: usize,
188    pub total_hits: u64,
189    pub expired_count: u64,
190    pub hit_rate: f64,
191}
192
193fn normalize_json(value: &Value) -> String {
194    match value {
195        Value::Object(obj) => {
196            let mut normalized = serde_json::Map::new();
197            for (k, v) in obj {
198                let normalized_key = k.trim().to_lowercase();
199                let normalized_value = normalize_json_value(v);
200                normalized.insert(normalized_key, normalized_value);
201            }
202            serde_json::to_string(&normalized).unwrap_or_default()
203        },
204        Value::Array(arr) => {
205            let normalized: Vec<_> = arr.iter().map(normalize_json_value).collect();
206            serde_json::to_string(&normalized).unwrap_or_default()
207        },
208        Value::String(s) => s.clone(),
209        _ => serde_json::to_string(value).unwrap_or_default(),
210    }
211}
212
213fn normalize_json_value(value: &Value) -> Value {
214    match value {
215        Value::Object(obj) => {
216            let mut normalized = serde_json::Map::new();
217            for (k, v) in obj {
218                let normalized_key = k.trim().to_lowercase();
219                normalized.insert(normalized_key, normalize_json_value(v));
220            }
221            Value::Object(normalized)
222        },
223        Value::Array(arr) => {
224            let normalized: Vec<_> = arr.iter().map(normalize_json_value).collect();
225            Value::Array(normalized)
226        },
227        _ => value.clone(),
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    #[test]
236    fn test_cache_key_new() {
237        let args = serde_json::json!({"city": "Shenzhen", "count": 5});
238        let key = CacheKey::new("test_tool".to_string(), args);
239        assert_eq!(key.tool_name, "test_tool");
240        assert!(key.arguments.contains("city"));
241    }
242
243    #[test]
244    fn test_cache_entry_expired() {
245        let entry = CacheEntry::new(
246            serde_json::json!({"result": "success"}),
247            Duration::from_secs(1),
248        );
249        assert!(!entry.is_expired());
250
251        let mut entry_mut = entry.clone();
252        entry_mut.timestamp = SystemTime::now() - Duration::from_secs(2);
253        assert!(entry_mut.is_expired());
254    }
255
256    #[test]
257    fn test_cache_hit() {
258        let mut entry = CacheEntry::new(
259            serde_json::json!({"result": "success"}),
260            Duration::from_secs(60),
261        );
262        entry.hit();
263        entry.hit();
264        assert_eq!(entry.hit_count, 2);
265    }
266
267    #[test]
268    fn test_cache_insert_get() {
269        let cache = ToolCallCache::new();
270        let args = serde_json::json!({"input": "test"});
271        let result = serde_json::json!({"output": "success"});
272
273        cache.insert_with_key("test_tool".to_string(), args.clone(), result.clone());
274
275        let key = CacheKey::new("test_tool".to_string(), args);
276        let cached = cache.get(&key);
277        assert!(cached.is_some());
278        assert_eq!(cached.unwrap(), result);
279    }
280
281    #[test]
282    fn test_cache_expiration() {
283        // Test expiration with short TTL and sleep
284        let cache = ToolCallCache::new().with_ttl(Duration::from_millis(10));
285        let args = serde_json::json!({"input": "test"});
286        let result = serde_json::json!({"output": "success"});
287
288        cache.insert_with_key("test_tool".to_string(), args.clone(), result.clone());
289
290        let key = CacheKey::new("test_tool".to_string(), args.clone());
291
292        // Entry should be cached initially
293        assert!(cache.get(&key).is_some());
294
295        // Wait for TTL to expire
296        std::thread::sleep(Duration::from_millis(20));
297
298        // Entry should be expired now
299        assert!(cache.get(&key).is_none());
300    }
301
302    #[test]
303    fn test_cache_stats() {
304        let cache = ToolCallCache::new();
305        let args = serde_json::json!({"input": "test"});
306
307        cache.insert_with_key("tool_a".to_string(), args.clone(), serde_json::json!({}));
308        cache.insert_with_key("tool_b".to_string(), args.clone(), serde_json::json!({}));
309
310        let key = CacheKey::new("tool_a".to_string(), args.clone());
311        let _ = cache.get(&key);
312        let _ = cache.get(&key);
313
314        let stats = cache.stats();
315        assert_eq!(stats.total_entries, 2);
316        assert_eq!(stats.total_hits, 2);
317    }
318
319    #[test]
320    fn test_normalize_json() {
321        let obj = serde_json::json!({
322            "CITY": "Shenzhen",
323            "count": 5,
324            "Data": {"NAME": "test"}
325        });
326
327        let normalized = normalize_json(&obj);
328        let parsed: Value = serde_json::from_str(&normalized).unwrap();
329
330        // Keys should be lowercase, but values should be preserved
331        if let Some(parsed_obj) = parsed.as_object() {
332            assert!(parsed_obj.contains_key("city"));
333            assert!(parsed_obj.contains_key("count"));
334            assert!(parsed_obj.contains_key("data"));
335            assert_eq!(parsed_obj.get("city"), Some(&serde_json::json!("Shenzhen")));
336            assert_eq!(parsed_obj.get("count"), Some(&serde_json::json!(5)));
337        }
338    }
339}