1use anyhow::Result;
4use chrono::{DateTime, Utc};
5use parking_lot::RwLock;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10use tracing::{debug, info, warn};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct MCPCacheEntry<T> {
15 pub tool_name: String,
16 pub parameters: HashMap<String, serde_json::Value>,
17 pub result: T,
18 pub cached_at: DateTime<Utc>,
19 pub expires_at: DateTime<Utc>,
20 pub access_count: u64,
21 pub last_accessed: DateTime<Utc>,
22 pub cache_key: String,
23 pub result_size_bytes: usize,
24 pub compression_ratio: f64,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct MCPCacheConfig {
30 pub max_entries: usize,
32 pub ttl: Duration,
34 pub tti: Duration,
36 pub enable_compression: bool,
38 pub compression_threshold: usize,
40 pub max_result_size: usize,
42 pub enable_cache_warming: bool,
44 pub warming_interval: Duration,
46}
47
48impl Default for MCPCacheConfig {
49 fn default() -> Self {
50 Self {
51 max_entries: 1000,
52 ttl: Duration::from_secs(3600), tti: Duration::from_secs(300), enable_compression: true,
55 compression_threshold: 1024, max_result_size: 10 * 1024 * 1024, enable_cache_warming: true,
58 warming_interval: Duration::from_secs(60), }
60 }
61}
62
63#[derive(Debug, Clone, Default, Serialize, Deserialize)]
65pub struct MCPCacheStats {
66 pub total_entries: u64,
67 pub hits: u64,
68 pub misses: u64,
69 pub hit_rate: f64,
70 pub total_size_bytes: u64,
71 pub compressed_entries: u64,
72 pub uncompressed_entries: u64,
73 pub evictions: u64,
74 pub warming_entries: u64,
75 pub average_access_time_ms: f64,
76}
77
78impl MCPCacheStats {
79 pub fn calculate_hit_rate(&mut self) {
80 let total = self.hits + self.misses;
81 self.hit_rate = if total > 0 {
82 #[allow(clippy::cast_precision_loss)]
83 {
84 self.hits as f64 / total as f64
85 }
86 } else {
87 0.0
88 };
89 }
90}
91
92pub struct MCPCacheMiddleware<T> {
94 cache: Arc<RwLock<HashMap<String, MCPCacheEntry<T>>>>,
96 config: MCPCacheConfig,
98 stats: Arc<RwLock<MCPCacheStats>>,
100 warming_entries: Arc<RwLock<HashMap<String, u32>>>,
102 warming_task: Option<tokio::task::JoinHandle<()>>,
104}
105
106impl<T> MCPCacheMiddleware<T>
107where
108 T: Clone + Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static,
109{
110 #[must_use]
112 pub fn new(config: &MCPCacheConfig) -> Self {
113 let mut middleware = Self {
114 cache: Arc::new(RwLock::new(HashMap::new())),
115 config: config.clone(),
116 stats: Arc::new(RwLock::new(MCPCacheStats::default())),
117 warming_entries: Arc::new(RwLock::new(HashMap::new())),
118 warming_task: None,
119 };
120
121 if config.enable_cache_warming {
123 middleware.start_cache_warming();
124 }
125
126 middleware
127 }
128
129 #[must_use]
131 pub fn new_default() -> Self {
132 Self::new(&MCPCacheConfig::default())
133 }
134
135 pub async fn execute_tool<F, Fut>(
144 &self,
145 tool_name: &str,
146 parameters: HashMap<String, serde_json::Value>,
147 tool_executor: F,
148 ) -> Result<T>
149 where
150 F: FnOnce(HashMap<String, serde_json::Value>) -> Fut,
151 Fut: std::future::Future<Output = Result<T>>,
152 {
153 let cache_key = Self::generate_cache_key(tool_name, ¶meters);
154
155 if let Some(cached_entry) = self.get_cached_entry(&cache_key) {
157 if !cached_entry.is_expired() && !cached_entry.is_idle(self.config.tti) {
158 self.record_hit();
159 debug!(
160 "MCP cache hit for tool: {} with key: {}",
161 tool_name, cache_key
162 );
163 return Ok(cached_entry.result);
164 }
165 }
166
167 self.record_miss();
169 let start_time = std::time::Instant::now();
170
171 let result = tool_executor(parameters.clone()).await?;
172 let execution_time = start_time.elapsed();
173
174 let result_size = Self::calculate_result_size(&result);
176 if result_size > self.config.max_result_size {
177 warn!("MCP tool result too large to cache: {} bytes", result_size);
178 return Ok(result);
179 }
180
181 self.cache_result(
183 tool_name,
184 parameters,
185 result.clone(),
186 &cache_key,
187 result_size,
188 );
189
190 debug!(
191 "MCP tool executed and cached: {} ({}ms, {} bytes)",
192 tool_name,
193 execution_time.as_millis(),
194 result_size
195 );
196
197 Ok(result)
198 }
199
200 #[must_use]
202 pub fn get_cached_result(
203 &self,
204 tool_name: &str,
205 parameters: &HashMap<String, serde_json::Value>,
206 ) -> Option<T> {
207 let cache_key = Self::generate_cache_key(tool_name, parameters);
208
209 if let Some(cached_entry) = self.get_cached_entry(&cache_key) {
210 if !cached_entry.is_expired() && !cached_entry.is_idle(self.config.tti) {
211 self.record_hit();
212 return Some(cached_entry.result);
213 }
214 }
215
216 self.record_miss();
217 None
218 }
219
220 pub fn invalidate_tool(&self, tool_name: &str) {
222 let mut cache = self.cache.write();
223 let keys_to_remove: Vec<String> = cache
224 .iter()
225 .filter(|(_, entry)| entry.tool_name == tool_name)
226 .map(|(key, _)| key.clone())
227 .collect();
228
229 let count = keys_to_remove.len();
230 for key in keys_to_remove {
231 cache.remove(&key);
232 }
233
234 debug!(
235 "Invalidated {} cache entries for tool: {}",
236 count, tool_name
237 );
238 }
239
240 pub fn invalidate_all(&self) {
242 let mut cache = self.cache.write();
243 cache.clear();
244 info!("Invalidated all MCP cache entries");
245 }
246
247 #[must_use]
249 pub fn get_stats(&self) -> MCPCacheStats {
250 let mut stats = self.stats.read().clone();
251 stats.calculate_hit_rate();
252 stats
253 }
254
255 #[must_use]
257 pub fn get_cache_size(&self) -> usize {
258 let cache = self.cache.read();
259 cache.values().map(|entry| entry.result_size_bytes).sum()
260 }
261
262 #[must_use]
264 #[allow(clippy::cast_precision_loss)]
265 pub fn get_utilization(&self) -> f64 {
266 let current_size = self.get_cache_size();
267 let max_size = self.config.max_entries * self.config.max_result_size;
268 (current_size as f64 / max_size as f64) * 100.0
269 }
270
271 fn generate_cache_key(
273 tool_name: &str,
274 parameters: &HashMap<String, serde_json::Value>,
275 ) -> String {
276 use std::collections::hash_map::DefaultHasher;
277 use std::hash::{Hash, Hasher};
278
279 let mut key_parts = vec![tool_name.to_string()];
280
281 let mut sorted_params: Vec<_> = parameters.iter().collect();
283 sorted_params.sort_by_key(|(k, _)| *k);
284
285 for (param_name, param_value) in sorted_params {
286 key_parts.push(format!("{param_name}:{param_value}"));
287 }
288
289 let mut hasher = DefaultHasher::new();
291 key_parts.join("|").hash(&mut hasher);
292 format!("mcp:{}:{}", tool_name, hasher.finish())
293 }
294
295 fn get_cached_entry(&self, cache_key: &str) -> Option<MCPCacheEntry<T>> {
297 let mut cache = self.cache.write();
298 if let Some(entry) = cache.get_mut(cache_key) {
299 entry.access_count += 1;
300 entry.last_accessed = Utc::now();
301 Some(entry.clone())
302 } else {
303 None
304 }
305 }
306
307 fn cache_result(
309 &self,
310 tool_name: &str,
311 parameters: HashMap<String, serde_json::Value>,
312 result: T,
313 cache_key: &str,
314 result_size: usize,
315 ) {
316 let now = Utc::now();
317 let expires_at = now + chrono::Duration::from_std(self.config.ttl).unwrap_or_default();
318
319 let entry = MCPCacheEntry {
320 tool_name: tool_name.to_string(),
321 parameters,
322 result,
323 cached_at: now,
324 expires_at,
325 access_count: 0,
326 last_accessed: now,
327 cache_key: cache_key.to_string(),
328 result_size_bytes: result_size,
329 compression_ratio: 1.0, };
331
332 self.evict_if_needed();
334
335 let mut cache = self.cache.write();
336 cache.insert(cache_key.to_string(), entry);
337
338 {
340 let mut stats = self.stats.write();
341 stats.total_entries += 1;
342 stats.total_size_bytes += result_size as u64;
343 }
344 }
345
346 fn calculate_result_size(result: &T) -> usize {
348 serde_json::to_vec(result).map_or(0, |bytes| bytes.len())
349 }
350
351 fn evict_if_needed(&self) {
353 let mut cache = self.cache.write();
354
355 if cache.len() >= self.config.max_entries {
356 let mut entries: Vec<_> = cache
358 .iter()
359 .map(|(k, v)| (k.clone(), v.last_accessed))
360 .collect();
361 entries.sort_by_key(|(_, last_accessed)| *last_accessed);
362
363 let entries_to_remove = cache.len() - self.config.max_entries + 1;
364 for (key, _) in entries.iter().take(entries_to_remove) {
365 cache.remove(key);
366 }
367
368 {
370 let mut stats = self.stats.write();
371 stats.evictions += entries_to_remove as u64;
372 }
373 }
374 }
375
376 fn start_cache_warming(&mut self) {
378 let warming_entries = Arc::clone(&self.warming_entries);
379 let warming_interval = self.config.warming_interval;
380
381 let handle = tokio::spawn(async move {
382 let mut interval = tokio::time::interval(warming_interval);
383 loop {
384 interval.tick().await;
385
386 let entries_count = {
389 let entries = warming_entries.read();
390 entries.len()
391 };
392
393 if entries_count > 0 {
394 debug!("MCP cache warming {} entries", entries_count);
395 }
396 }
397 });
398
399 self.warming_task = Some(handle);
400 }
401
402 fn record_hit(&self) {
404 let mut stats = self.stats.write();
405 stats.hits += 1;
406 }
407
408 fn record_miss(&self) {
410 let mut stats = self.stats.write();
411 stats.misses += 1;
412 }
413}
414
415impl<T> MCPCacheEntry<T> {
416 pub fn is_expired(&self) -> bool {
418 Utc::now() > self.expires_at
419 }
420
421 pub fn is_idle(&self, tti: Duration) -> bool {
423 let now = Utc::now();
424 let idle_duration = now - self.last_accessed;
425 idle_duration > chrono::Duration::from_std(tti).unwrap_or_default()
426 }
427}
428
429impl<T> Drop for MCPCacheMiddleware<T> {
430 fn drop(&mut self) {
431 if let Some(handle) = self.warming_task.take() {
432 handle.abort();
433 }
434 }
435}
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440 use std::collections::HashMap;
441
442 #[tokio::test]
443 async fn test_mcp_cache_basic_operations() {
444 let middleware = MCPCacheMiddleware::<String>::new_default();
445
446 let mut parameters = HashMap::new();
447 parameters.insert(
448 "query".to_string(),
449 serde_json::Value::String("test".to_string()),
450 );
451
452 let result1 = middleware
454 .execute_tool("test_tool", parameters.clone(), |_| async {
455 Ok("test_result".to_string())
456 })
457 .await
458 .unwrap();
459
460 assert_eq!(result1, "test_result");
461
462 let result2 = middleware
464 .execute_tool("test_tool", parameters, |_| async {
465 panic!("Should not execute on cache hit")
466 })
467 .await
468 .unwrap();
469
470 assert_eq!(result2, "test_result");
471
472 let stats = middleware.get_stats();
473 assert_eq!(stats.hits, 1);
474 assert_eq!(stats.misses, 1);
475 assert!((stats.hit_rate - 0.5).abs() < 1e-9);
476 }
477
478 #[tokio::test]
479 async fn test_mcp_cache_invalidation() {
480 let middleware = MCPCacheMiddleware::<String>::new_default();
481
482 let mut parameters = HashMap::new();
483 parameters.insert(
484 "query".to_string(),
485 serde_json::Value::String("test".to_string()),
486 );
487
488 middleware
490 .execute_tool("test_tool", parameters.clone(), |_| async {
491 Ok("test_result".to_string())
492 })
493 .await
494 .unwrap();
495
496 let cached = middleware.get_cached_result("test_tool", ¶meters);
498 assert!(cached.is_some());
499
500 middleware.invalidate_tool("test_tool");
502
503 let cached = middleware.get_cached_result("test_tool", ¶meters);
505 assert!(cached.is_none());
506 }
507
508 #[tokio::test]
509 async fn test_mcp_cache_key_generation() {
510 let _middleware = MCPCacheMiddleware::<String>::new_default();
511
512 let mut params1 = HashMap::new();
513 params1.insert("a".to_string(), serde_json::Value::String("1".to_string()));
514 params1.insert("b".to_string(), serde_json::Value::String("2".to_string()));
515
516 let mut params2 = HashMap::new();
517 params2.insert("b".to_string(), serde_json::Value::String("2".to_string()));
518 params2.insert("a".to_string(), serde_json::Value::String("1".to_string()));
519
520 let key1 = MCPCacheMiddleware::<String>::generate_cache_key("test_tool", ¶ms1);
522 let key2 = MCPCacheMiddleware::<String>::generate_cache_key("test_tool", ¶ms2);
523 assert_eq!(key1, key2);
524 }
525}