1use std::time::{Duration, SystemTime};
4
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
10pub struct CacheKey {
11 pub tool_name: String,
12 pub arguments: String,
13}
14
15impl CacheKey {
16 pub fn new(tool_name: String, arguments: Value) -> Self {
17 let normalized = normalize_json(&arguments);
18 Self {
19 tool_name,
20 arguments: normalized,
21 }
22 }
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct CacheEntry {
28 pub result: Value,
29 pub timestamp: SystemTime,
30 pub ttl: Duration,
31 pub hit_count: u64,
32}
33
34impl CacheEntry {
35 pub fn new(result: Value, ttl: Duration) -> Self {
36 Self {
37 result,
38 timestamp: SystemTime::now(),
39 ttl,
40 hit_count: 0,
41 }
42 }
43
44 pub fn is_expired(&self) -> bool {
45 match self.timestamp.elapsed() {
46 Ok(elapsed) => elapsed > self.ttl,
47 Err(_) => true,
48 }
49 }
50
51 pub fn hit(&mut self) {
52 self.hit_count += 1;
53 }
54}
55
56#[derive(Clone)]
58pub struct ToolCallCache {
59 entries: dashmap::DashMap<CacheKey, CacheEntry>,
60 default_ttl: Duration,
61 max_size: usize,
62 enable_cache: bool,
63}
64
65impl ToolCallCache {
66 pub fn new() -> Self {
67 Self {
68 entries: dashmap::DashMap::new(),
69 default_ttl: Duration::from_secs(300),
70 max_size: 1000,
71 enable_cache: true,
72 }
73 }
74
75 pub fn with_ttl(mut self, ttl: Duration) -> Self {
76 self.default_ttl = ttl;
77 self
78 }
79
80 pub fn with_max_size(mut self, size: usize) -> Self {
81 self.max_size = size;
82 self
83 }
84
85 pub fn with_enabled(mut self, enabled: bool) -> Self {
86 self.enable_cache = enabled;
87 self
88 }
89
90 pub fn get(&self, key: &CacheKey) -> Option<Value> {
91 if !self.enable_cache {
92 return None;
93 }
94
95 let expired = {
97 let entry = self.entries.get(key)?;
98 entry.is_expired()
99 };
100
101 if expired {
103 self.entries.remove(key);
104 return None;
105 }
106
107 let mut entry = self.entries.get_mut(key)?;
109 entry.hit();
110 Some(entry.result.clone())
111 }
112
113 pub fn insert(&self, key: CacheKey, result: Value, ttl: Option<Duration>) {
114 if !self.enable_cache {
115 return;
116 }
117
118 if self.entries.len() >= self.max_size {
119 self.evict_lru();
120 }
121
122 let entry = CacheEntry::new(result, ttl.unwrap_or(self.default_ttl));
123 self.entries.insert(key, entry);
124 }
125
126 pub fn insert_with_key(&self, tool_name: String, arguments: Value, result: Value) {
127 let key = CacheKey::new(tool_name, arguments);
128 self.insert(key, result, None);
129 }
130
131 pub fn clear(&self) {
132 self.entries.clear();
133 }
134
135 pub fn invalidate_tool(&self, tool_name: &str) {
136 self.entries.retain(|key, _| key.tool_name != tool_name);
137 }
138
139 pub fn stats(&self) -> CacheStats {
140 let mut total_hits = 0u64;
141 let mut expired_count = 0u64;
142
143 for entry in self.entries.iter() {
144 total_hits += entry.hit_count;
145 if entry.is_expired() {
146 expired_count += 1;
147 }
148 }
149
150 CacheStats {
151 total_entries: self.entries.len(),
152 total_hits,
153 expired_count,
154 hit_rate: if self.entries.is_empty() {
155 0.0
156 } else {
157 total_hits as f64 / self.entries.len() as f64
158 },
159 }
160 }
161
162 fn evict_lru(&self) {
163 let mut entries: Vec<_> = self
164 .entries
165 .iter()
166 .map(|entry| (entry.key().clone(), entry.value().timestamp))
167 .collect();
168
169 entries.sort_by_key(|a| a.1);
170
171 let remove_count = (self.max_size / 10).max(1);
172 for (key, _) in entries.into_iter().take(remove_count) {
173 self.entries.remove(&key);
174 }
175 }
176}
177
178impl Default for ToolCallCache {
179 fn default() -> Self {
180 Self::new()
181 }
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct CacheStats {
187 pub total_entries: usize,
188 pub total_hits: u64,
189 pub expired_count: u64,
190 pub hit_rate: f64,
191}
192
193fn normalize_json(value: &Value) -> String {
194 match value {
195 Value::Object(obj) => {
196 let mut normalized = serde_json::Map::new();
197 for (k, v) in obj {
198 let normalized_key = k.trim().to_lowercase();
199 let normalized_value = normalize_json_value(v);
200 normalized.insert(normalized_key, normalized_value);
201 }
202 serde_json::to_string(&normalized).unwrap_or_default()
203 },
204 Value::Array(arr) => {
205 let normalized: Vec<_> = arr.iter().map(normalize_json_value).collect();
206 serde_json::to_string(&normalized).unwrap_or_default()
207 },
208 Value::String(s) => s.clone(),
209 _ => serde_json::to_string(value).unwrap_or_default(),
210 }
211}
212
213fn normalize_json_value(value: &Value) -> Value {
214 match value {
215 Value::Object(obj) => {
216 let mut normalized = serde_json::Map::new();
217 for (k, v) in obj {
218 let normalized_key = k.trim().to_lowercase();
219 normalized.insert(normalized_key, normalize_json_value(v));
220 }
221 Value::Object(normalized)
222 },
223 Value::Array(arr) => {
224 let normalized: Vec<_> = arr.iter().map(normalize_json_value).collect();
225 Value::Array(normalized)
226 },
227 _ => value.clone(),
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234
235 #[test]
236 fn test_cache_key_new() {
237 let args = serde_json::json!({"city": "Shenzhen", "count": 5});
238 let key = CacheKey::new("test_tool".to_string(), args);
239 assert_eq!(key.tool_name, "test_tool");
240 assert!(key.arguments.contains("city"));
241 }
242
243 #[test]
244 fn test_cache_entry_expired() {
245 let entry = CacheEntry::new(
246 serde_json::json!({"result": "success"}),
247 Duration::from_secs(1),
248 );
249 assert!(!entry.is_expired());
250
251 let mut entry_mut = entry.clone();
252 entry_mut.timestamp = SystemTime::now() - Duration::from_secs(2);
253 assert!(entry_mut.is_expired());
254 }
255
256 #[test]
257 fn test_cache_hit() {
258 let mut entry = CacheEntry::new(
259 serde_json::json!({"result": "success"}),
260 Duration::from_secs(60),
261 );
262 entry.hit();
263 entry.hit();
264 assert_eq!(entry.hit_count, 2);
265 }
266
267 #[test]
268 fn test_cache_insert_get() {
269 let cache = ToolCallCache::new();
270 let args = serde_json::json!({"input": "test"});
271 let result = serde_json::json!({"output": "success"});
272
273 cache.insert_with_key("test_tool".to_string(), args.clone(), result.clone());
274
275 let key = CacheKey::new("test_tool".to_string(), args);
276 let cached = cache.get(&key);
277 assert!(cached.is_some());
278 assert_eq!(cached.unwrap(), result);
279 }
280
281 #[test]
282 fn test_cache_expiration() {
283 let cache = ToolCallCache::new().with_ttl(Duration::from_millis(10));
285 let args = serde_json::json!({"input": "test"});
286 let result = serde_json::json!({"output": "success"});
287
288 cache.insert_with_key("test_tool".to_string(), args.clone(), result.clone());
289
290 let key = CacheKey::new("test_tool".to_string(), args.clone());
291
292 assert!(cache.get(&key).is_some());
294
295 std::thread::sleep(Duration::from_millis(20));
297
298 assert!(cache.get(&key).is_none());
300 }
301
302 #[test]
303 fn test_cache_stats() {
304 let cache = ToolCallCache::new();
305 let args = serde_json::json!({"input": "test"});
306
307 cache.insert_with_key("tool_a".to_string(), args.clone(), serde_json::json!({}));
308 cache.insert_with_key("tool_b".to_string(), args.clone(), serde_json::json!({}));
309
310 let key = CacheKey::new("tool_a".to_string(), args.clone());
311 let _ = cache.get(&key);
312 let _ = cache.get(&key);
313
314 let stats = cache.stats();
315 assert_eq!(stats.total_entries, 2);
316 assert_eq!(stats.total_hits, 2);
317 }
318
319 #[test]
320 fn test_normalize_json() {
321 let obj = serde_json::json!({
322 "CITY": "Shenzhen",
323 "count": 5,
324 "Data": {"NAME": "test"}
325 });
326
327 let normalized = normalize_json(&obj);
328 let parsed: Value = serde_json::from_str(&normalized).unwrap();
329
330 if let Some(parsed_obj) = parsed.as_object() {
332 assert!(parsed_obj.contains_key("city"));
333 assert!(parsed_obj.contains_key("count"));
334 assert!(parsed_obj.contains_key("data"));
335 assert_eq!(parsed_obj.get("city"), Some(&serde_json::json!("Shenzhen")));
336 assert_eq!(parsed_obj.get("count"), Some(&serde_json::json!(5)));
337 }
338 }
339}