things3_core/
mcp_cache_middleware.rs

1//! Caching middleware for MCP (Model Context Protocol) tool results
2
3use anyhow::Result;
4use chrono::{DateTime, Utc};
5use parking_lot::RwLock;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10use tracing::{debug, info, warn};
11
12/// MCP tool result cache entry
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct MCPCacheEntry<T> {
15    pub tool_name: String,
16    pub parameters: HashMap<String, serde_json::Value>,
17    pub result: T,
18    pub cached_at: DateTime<Utc>,
19    pub expires_at: DateTime<Utc>,
20    pub access_count: u64,
21    pub last_accessed: DateTime<Utc>,
22    pub cache_key: String,
23    pub result_size_bytes: usize,
24    pub compression_ratio: f64,
25}
26
27/// MCP cache configuration
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct MCPCacheConfig {
30    /// Maximum number of cached results
31    pub max_entries: usize,
32    /// Time to live for cache entries
33    pub ttl: Duration,
34    /// Time to idle for cache entries
35    pub tti: Duration,
36    /// Enable compression for large results
37    pub enable_compression: bool,
38    /// Compression threshold in bytes
39    pub compression_threshold: usize,
40    /// Maximum result size to cache
41    pub max_result_size: usize,
42    /// Enable cache warming for frequently used tools
43    pub enable_cache_warming: bool,
44    /// Cache warming interval
45    pub warming_interval: Duration,
46}
47
48impl Default for MCPCacheConfig {
49    fn default() -> Self {
50        Self {
51            max_entries: 1000,
52            ttl: Duration::from_secs(3600), // 1 hour
53            tti: Duration::from_secs(300),  // 5 minutes
54            enable_compression: true,
55            compression_threshold: 1024,       // 1KB
56            max_result_size: 10 * 1024 * 1024, // 10MB
57            enable_cache_warming: true,
58            warming_interval: Duration::from_secs(60), // 1 minute
59        }
60    }
61}
62
63/// MCP cache statistics
64#[derive(Debug, Clone, Default, Serialize, Deserialize)]
65pub struct MCPCacheStats {
66    pub total_entries: u64,
67    pub hits: u64,
68    pub misses: u64,
69    pub hit_rate: f64,
70    pub total_size_bytes: u64,
71    pub compressed_entries: u64,
72    pub uncompressed_entries: u64,
73    pub evictions: u64,
74    pub warming_entries: u64,
75    pub average_access_time_ms: f64,
76}
77
78impl MCPCacheStats {
79    pub fn calculate_hit_rate(&mut self) {
80        let total = self.hits + self.misses;
81        self.hit_rate = if total > 0 {
82            #[allow(clippy::cast_precision_loss)]
83            {
84                self.hits as f64 / total as f64
85            }
86        } else {
87            0.0
88        };
89    }
90}
91
92/// MCP tool cache middleware
93pub struct MCPCacheMiddleware<T> {
94    /// Cache entries by tool name and parameters
95    cache: Arc<RwLock<HashMap<String, MCPCacheEntry<T>>>>,
96    /// Configuration
97    config: MCPCacheConfig,
98    /// Statistics
99    stats: Arc<RwLock<MCPCacheStats>>,
100    /// Cache warming entries (key -> priority)
101    warming_entries: Arc<RwLock<HashMap<String, u32>>>,
102    /// Cache warming task handle
103    warming_task: Option<tokio::task::JoinHandle<()>>,
104}
105
106impl<T> MCPCacheMiddleware<T>
107where
108    T: Clone + Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static,
109{
110    /// Create a new MCP cache middleware
111    #[must_use]
112    pub fn new(config: &MCPCacheConfig) -> Self {
113        let mut middleware = Self {
114            cache: Arc::new(RwLock::new(HashMap::new())),
115            config: config.clone(),
116            stats: Arc::new(RwLock::new(MCPCacheStats::default())),
117            warming_entries: Arc::new(RwLock::new(HashMap::new())),
118            warming_task: None,
119        };
120
121        // Start cache warming task if enabled
122        if config.enable_cache_warming {
123            middleware.start_cache_warming();
124        }
125
126        middleware
127    }
128
129    /// Create a new middleware with default configuration
130    #[must_use]
131    pub fn new_default() -> Self {
132        Self::new(&MCPCacheConfig::default())
133    }
134
135    /// Execute a tool with caching
136    ///
137    /// # Errors
138    ///
139    /// This function will return an error if:
140    /// - Tool execution fails
141    /// - Cache operations fail
142    /// - Serialization/deserialization fails
143    pub async fn execute_tool<F, Fut>(
144        &self,
145        tool_name: &str,
146        parameters: HashMap<String, serde_json::Value>,
147        tool_executor: F,
148    ) -> Result<T>
149    where
150        F: FnOnce(HashMap<String, serde_json::Value>) -> Fut,
151        Fut: std::future::Future<Output = Result<T>>,
152    {
153        let cache_key = Self::generate_cache_key(tool_name, &parameters);
154
155        // Check cache first
156        if let Some(cached_entry) = self.get_cached_entry(&cache_key) {
157            if !cached_entry.is_expired() && !cached_entry.is_idle(self.config.tti) {
158                self.record_hit();
159                debug!(
160                    "MCP cache hit for tool: {} with key: {}",
161                    tool_name, cache_key
162                );
163                return Ok(cached_entry.result);
164            }
165        }
166
167        // Cache miss - execute tool
168        self.record_miss();
169        let start_time = std::time::Instant::now();
170
171        let result = tool_executor(parameters.clone()).await?;
172        let execution_time = start_time.elapsed();
173
174        // Check if result is too large to cache
175        let result_size = Self::calculate_result_size(&result);
176        if result_size > self.config.max_result_size {
177            warn!("MCP tool result too large to cache: {} bytes", result_size);
178            return Ok(result);
179        }
180
181        // Cache the result
182        self.cache_result(
183            tool_name,
184            parameters,
185            result.clone(),
186            &cache_key,
187            result_size,
188        );
189
190        debug!(
191            "MCP tool executed and cached: {} ({}ms, {} bytes)",
192            tool_name,
193            execution_time.as_millis(),
194            result_size
195        );
196
197        Ok(result)
198    }
199
200    /// Get a cached result without executing the tool
201    #[must_use]
202    pub fn get_cached_result(
203        &self,
204        tool_name: &str,
205        parameters: &HashMap<String, serde_json::Value>,
206    ) -> Option<T> {
207        let cache_key = Self::generate_cache_key(tool_name, parameters);
208
209        if let Some(cached_entry) = self.get_cached_entry(&cache_key) {
210            if !cached_entry.is_expired() && !cached_entry.is_idle(self.config.tti) {
211                self.record_hit();
212                return Some(cached_entry.result);
213            }
214        }
215
216        self.record_miss();
217        None
218    }
219
220    /// Invalidate cache entries for a specific tool
221    pub fn invalidate_tool(&self, tool_name: &str) {
222        let mut cache = self.cache.write();
223        let keys_to_remove: Vec<String> = cache
224            .iter()
225            .filter(|(_, entry)| entry.tool_name == tool_name)
226            .map(|(key, _)| key.clone())
227            .collect();
228
229        let count = keys_to_remove.len();
230        for key in keys_to_remove {
231            cache.remove(&key);
232        }
233
234        debug!(
235            "Invalidated {} cache entries for tool: {}",
236            count, tool_name
237        );
238    }
239
240    /// Invalidate all cache entries
241    pub fn invalidate_all(&self) {
242        let mut cache = self.cache.write();
243        cache.clear();
244        info!("Invalidated all MCP cache entries");
245    }
246
247    /// Get cache statistics
248    #[must_use]
249    pub fn get_stats(&self) -> MCPCacheStats {
250        let mut stats = self.stats.read().clone();
251        stats.calculate_hit_rate();
252        stats
253    }
254
255    /// Get cache size in bytes
256    #[must_use]
257    pub fn get_cache_size(&self) -> usize {
258        let cache = self.cache.read();
259        cache.values().map(|entry| entry.result_size_bytes).sum()
260    }
261
262    /// Get cache utilization percentage
263    #[must_use]
264    #[allow(clippy::cast_precision_loss)]
265    pub fn get_utilization(&self) -> f64 {
266        let current_size = self.get_cache_size();
267        let max_size = self.config.max_entries * self.config.max_result_size;
268        (current_size as f64 / max_size as f64) * 100.0
269    }
270
271    /// Generate cache key from tool name and parameters
272    fn generate_cache_key(
273        tool_name: &str,
274        parameters: &HashMap<String, serde_json::Value>,
275    ) -> String {
276        use std::collections::hash_map::DefaultHasher;
277        use std::hash::{Hash, Hasher};
278
279        let mut key_parts = vec![tool_name.to_string()];
280
281        // Sort parameters for consistent key generation
282        let mut sorted_params: Vec<_> = parameters.iter().collect();
283        sorted_params.sort_by_key(|(k, _)| *k);
284
285        for (param_name, param_value) in sorted_params {
286            key_parts.push(format!("{param_name}:{param_value}"));
287        }
288
289        // Use a hash of the key parts to keep it manageable
290        let mut hasher = DefaultHasher::new();
291        key_parts.join("|").hash(&mut hasher);
292        format!("mcp:{}:{}", tool_name, hasher.finish())
293    }
294
295    /// Get a cached entry
296    fn get_cached_entry(&self, cache_key: &str) -> Option<MCPCacheEntry<T>> {
297        let mut cache = self.cache.write();
298        if let Some(entry) = cache.get_mut(cache_key) {
299            entry.access_count += 1;
300            entry.last_accessed = Utc::now();
301            Some(entry.clone())
302        } else {
303            None
304        }
305    }
306
307    /// Cache a tool result
308    fn cache_result(
309        &self,
310        tool_name: &str,
311        parameters: HashMap<String, serde_json::Value>,
312        result: T,
313        cache_key: &str,
314        result_size: usize,
315    ) {
316        let now = Utc::now();
317        let expires_at = now + chrono::Duration::from_std(self.config.ttl).unwrap_or_default();
318
319        let entry = MCPCacheEntry {
320            tool_name: tool_name.to_string(),
321            parameters,
322            result,
323            cached_at: now,
324            expires_at,
325            access_count: 0,
326            last_accessed: now,
327            cache_key: cache_key.to_string(),
328            result_size_bytes: result_size,
329            compression_ratio: 1.0, // TODO: Implement compression
330        };
331
332        // Check if we need to evict entries
333        self.evict_if_needed();
334
335        let mut cache = self.cache.write();
336        cache.insert(cache_key.to_string(), entry);
337
338        // Update statistics
339        {
340            let mut stats = self.stats.write();
341            stats.total_entries += 1;
342            stats.total_size_bytes += result_size as u64;
343        }
344    }
345
346    /// Calculate result size in bytes
347    fn calculate_result_size(result: &T) -> usize {
348        serde_json::to_vec(result).map_or(0, |bytes| bytes.len())
349    }
350
351    /// Evict entries if cache is full
352    fn evict_if_needed(&self) {
353        let mut cache = self.cache.write();
354
355        if cache.len() >= self.config.max_entries {
356            // Remove oldest entries (LRU)
357            let mut entries: Vec<_> = cache
358                .iter()
359                .map(|(k, v)| (k.clone(), v.last_accessed))
360                .collect();
361            entries.sort_by_key(|(_, last_accessed)| *last_accessed);
362
363            let entries_to_remove = cache.len() - self.config.max_entries + 1;
364            for (key, _) in entries.iter().take(entries_to_remove) {
365                cache.remove(key);
366            }
367
368            // Update statistics
369            {
370                let mut stats = self.stats.write();
371                stats.evictions += entries_to_remove as u64;
372            }
373        }
374    }
375
376    /// Start cache warming background task
377    fn start_cache_warming(&mut self) {
378        let warming_entries = Arc::clone(&self.warming_entries);
379        let warming_interval = self.config.warming_interval;
380
381        let handle = tokio::spawn(async move {
382            let mut interval = tokio::time::interval(warming_interval);
383            loop {
384                interval.tick().await;
385
386                // In a real implementation, you would warm frequently accessed entries
387                // by calling the appropriate tool executors
388                let entries_count = {
389                    let entries = warming_entries.read();
390                    entries.len()
391                };
392
393                if entries_count > 0 {
394                    debug!("MCP cache warming {} entries", entries_count);
395                }
396            }
397        });
398
399        self.warming_task = Some(handle);
400    }
401
402    /// Record a cache hit
403    fn record_hit(&self) {
404        let mut stats = self.stats.write();
405        stats.hits += 1;
406    }
407
408    /// Record a cache miss
409    fn record_miss(&self) {
410        let mut stats = self.stats.write();
411        stats.misses += 1;
412    }
413}
414
415impl<T> MCPCacheEntry<T> {
416    /// Check if the cache entry is expired
417    pub fn is_expired(&self) -> bool {
418        Utc::now() > self.expires_at
419    }
420
421    /// Check if the cache entry is idle
422    pub fn is_idle(&self, tti: Duration) -> bool {
423        let now = Utc::now();
424        let idle_duration = now - self.last_accessed;
425        idle_duration > chrono::Duration::from_std(tti).unwrap_or_default()
426    }
427}
428
429impl<T> Drop for MCPCacheMiddleware<T> {
430    fn drop(&mut self) {
431        if let Some(handle) = self.warming_task.take() {
432            handle.abort();
433        }
434    }
435}
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440    use std::collections::HashMap;
441
442    #[tokio::test]
443    async fn test_mcp_cache_basic_operations() {
444        let middleware = MCPCacheMiddleware::<String>::new_default();
445
446        let mut parameters = HashMap::new();
447        parameters.insert(
448            "query".to_string(),
449            serde_json::Value::String("test".to_string()),
450        );
451
452        // First call - should be a cache miss
453        let result1 = middleware
454            .execute_tool("test_tool", parameters.clone(), |_| async {
455                Ok("test_result".to_string())
456            })
457            .await
458            .unwrap();
459
460        assert_eq!(result1, "test_result");
461
462        // Second call - should be a cache hit
463        let result2 = middleware
464            .execute_tool("test_tool", parameters, |_| async {
465                panic!("Should not execute on cache hit")
466            })
467            .await
468            .unwrap();
469
470        assert_eq!(result2, "test_result");
471
472        let stats = middleware.get_stats();
473        assert_eq!(stats.hits, 1);
474        assert_eq!(stats.misses, 1);
475        assert!((stats.hit_rate - 0.5).abs() < 1e-9);
476    }
477
478    #[tokio::test]
479    async fn test_mcp_cache_invalidation() {
480        let middleware = MCPCacheMiddleware::<String>::new_default();
481
482        let mut parameters = HashMap::new();
483        parameters.insert(
484            "query".to_string(),
485            serde_json::Value::String("test".to_string()),
486        );
487
488        // Cache a result
489        middleware
490            .execute_tool("test_tool", parameters.clone(), |_| async {
491                Ok("test_result".to_string())
492            })
493            .await
494            .unwrap();
495
496        // Verify it's cached
497        let cached = middleware.get_cached_result("test_tool", &parameters);
498        assert!(cached.is_some());
499
500        // Invalidate the tool
501        middleware.invalidate_tool("test_tool");
502
503        // Verify it's no longer cached
504        let cached = middleware.get_cached_result("test_tool", &parameters);
505        assert!(cached.is_none());
506    }
507
508    #[tokio::test]
509    async fn test_mcp_cache_key_generation() {
510        let _middleware = MCPCacheMiddleware::<String>::new_default();
511
512        let mut params1 = HashMap::new();
513        params1.insert("a".to_string(), serde_json::Value::String("1".to_string()));
514        params1.insert("b".to_string(), serde_json::Value::String("2".to_string()));
515
516        let mut params2 = HashMap::new();
517        params2.insert("b".to_string(), serde_json::Value::String("2".to_string()));
518        params2.insert("a".to_string(), serde_json::Value::String("1".to_string()));
519
520        // Same parameters in different order should generate same key
521        let key1 = MCPCacheMiddleware::<String>::generate_cache_key("test_tool", &params1);
522        let key2 = MCPCacheMiddleware::<String>::generate_cache_key("test_tool", &params2);
523        assert_eq!(key1, key2);
524    }
525}