Skip to main content

turbomcp_client/middleware/
cache.rs

1//! Response caching middleware for MCP client.
2//!
3//! Tower Layer that caches successful responses with configurable TTL
4//! and LRU eviction. Supports method-based caching policies.
5//!
6//! ## Usage
7//!
8//! ```rust,ignore
9//! use turbomcp_client::middleware::{CacheLayer, CacheConfig};
10//! use tower::ServiceBuilder;
11//!
12//! let service = ServiceBuilder::new()
13//!     .layer(CacheLayer::new(CacheConfig {
14//!         max_entries: 1000,
15//!         ttl: Duration::from_secs(300),
16//!         ..Default::default()
17//!     }))
18//!     .service(inner_service);
19//! ```
20
21use super::request::{McpRequest, McpResponse};
22use futures_util::future::BoxFuture;
23use parking_lot::RwLock;
24use serde_json::Value;
25use std::collections::HashMap;
26use std::hash::{Hash, Hasher};
27use std::sync::Arc;
28use std::sync::atomic::{AtomicU64, Ordering};
29use std::task::{Context, Poll};
30use std::time::{Duration, Instant};
31use tower_layer::Layer;
32use tower_service::Service;
33use turbomcp_protocol::McpError;
34
35/// Cache configuration.
36#[derive(Debug, Clone)]
37pub struct CacheConfig {
38    /// Maximum number of cached entries
39    pub max_entries: usize,
40    /// Time-to-live for cached entries
41    pub ttl: Duration,
42    /// Methods to cache (empty = cache all cacheable methods)
43    pub cache_methods: Vec<String>,
44    /// Methods to never cache
45    pub exclude_methods: Vec<String>,
46}
47
48impl Default for CacheConfig {
49    fn default() -> Self {
50        Self {
51            max_entries: 1000,
52            ttl: Duration::from_secs(300), // 5 minutes
53            cache_methods: Vec::new(),
54            exclude_methods: vec![
55                // Mutations should not be cached
56                "tools/call".to_string(),
57                "sampling/createMessage".to_string(),
58                // Notifications
59                "notifications/".to_string(),
60            ],
61        }
62    }
63}
64
65impl CacheConfig {
66    /// Check if a method should be cached.
67    fn should_cache(&self, method: &str) -> bool {
68        // Check exclusions first
69        for excluded in &self.exclude_methods {
70            if method.starts_with(excluded) || method == excluded {
71                return false;
72            }
73        }
74
75        // If specific methods are configured, check membership
76        if !self.cache_methods.is_empty() {
77            return self
78                .cache_methods
79                .iter()
80                .any(|m| method.starts_with(m) || method == m);
81        }
82
83        // Default: cache read-like operations
84        method.starts_with("resources/")
85            || method.starts_with("prompts/")
86            || method == "tools/list"
87            || method == "resources/list"
88            || method == "prompts/list"
89    }
90}
91
92/// Cache entry with metadata.
93#[derive(Debug, Clone)]
94struct CacheEntry {
95    data: Value,
96    created: Instant,
97    last_accessed: Instant,
98    access_count: u64,
99}
100
101impl CacheEntry {
102    fn new(data: Value) -> Self {
103        let now = Instant::now();
104        Self {
105            data,
106            created: now,
107            last_accessed: now,
108            access_count: 0,
109        }
110    }
111
112    fn is_expired(&self, ttl: Duration) -> bool {
113        self.created.elapsed() > ttl
114    }
115
116    fn access(&mut self) -> &Value {
117        self.last_accessed = Instant::now();
118        self.access_count += 1;
119        &self.data
120    }
121}
122
123/// Cache statistics.
124#[derive(Debug, Clone, Default)]
125pub struct CacheStats {
126    /// Cache hits
127    pub hits: u64,
128    /// Cache misses
129    pub misses: u64,
130    /// Entries evicted due to size limit
131    pub evictions: u64,
132    /// Entries expired
133    pub expirations: u64,
134    /// Current entry count
135    pub current_entries: usize,
136}
137
138/// Thread-safe response cache.
139#[derive(Debug)]
140pub struct Cache {
141    config: CacheConfig,
142    entries: RwLock<HashMap<String, CacheEntry>>,
143    hits: AtomicU64,
144    misses: AtomicU64,
145    evictions: AtomicU64,
146    expirations: AtomicU64,
147}
148
149impl Cache {
150    /// Create a new cache with the given configuration.
151    #[must_use]
152    pub fn new(config: CacheConfig) -> Self {
153        Self {
154            config,
155            entries: RwLock::new(HashMap::new()),
156            hits: AtomicU64::new(0),
157            misses: AtomicU64::new(0),
158            evictions: AtomicU64::new(0),
159            expirations: AtomicU64::new(0),
160        }
161    }
162
163    /// Generate a cache key from request.
164    fn cache_key(req: &McpRequest) -> String {
165        let mut hasher = std::collections::hash_map::DefaultHasher::new();
166
167        req.method().hash(&mut hasher);
168        if let Some(params) = req.params() {
169            params.to_string().hash(&mut hasher);
170        }
171
172        format!("{}:{:x}", req.method(), hasher.finish())
173    }
174
175    /// Check if method should be cached.
176    pub fn should_cache(&self, method: &str) -> bool {
177        self.config.should_cache(method)
178    }
179
180    /// Get a cached value.
181    pub fn get(&self, key: &str) -> Option<Value> {
182        let mut entries = self.entries.write();
183
184        if let Some(entry) = entries.get_mut(key) {
185            if entry.is_expired(self.config.ttl) {
186                entries.remove(key);
187                self.expirations.fetch_add(1, Ordering::Relaxed);
188                self.misses.fetch_add(1, Ordering::Relaxed);
189                return None;
190            }
191
192            self.hits.fetch_add(1, Ordering::Relaxed);
193            return Some(entry.access().clone());
194        }
195
196        self.misses.fetch_add(1, Ordering::Relaxed);
197        None
198    }
199
200    /// Store a value in the cache.
201    pub fn put(&self, key: String, value: Value) {
202        let mut entries = self.entries.write();
203
204        // Evict if at capacity
205        if entries.len() >= self.config.max_entries {
206            self.evict_lru(&mut entries);
207        }
208
209        entries.insert(key, CacheEntry::new(value));
210    }
211
212    /// Evict least recently used entries.
213    fn evict_lru(&self, entries: &mut HashMap<String, CacheEntry>) {
214        // Find the oldest entries
215        let mut to_evict: Vec<_> = entries
216            .iter()
217            .map(|(k, v)| (k.clone(), v.last_accessed))
218            .collect();
219
220        to_evict.sort_by_key(|(_, accessed)| *accessed);
221
222        // Evict 10% of entries or at least 1
223        let evict_count = (entries.len() / 10).max(1);
224        for (key, _) in to_evict.into_iter().take(evict_count) {
225            entries.remove(&key);
226            self.evictions.fetch_add(1, Ordering::Relaxed);
227        }
228    }
229
230    /// Get cache statistics.
231    #[must_use]
232    pub fn stats(&self) -> CacheStats {
233        CacheStats {
234            hits: self.hits.load(Ordering::Relaxed),
235            misses: self.misses.load(Ordering::Relaxed),
236            evictions: self.evictions.load(Ordering::Relaxed),
237            expirations: self.expirations.load(Ordering::Relaxed),
238            current_entries: self.entries.read().len(),
239        }
240    }
241
242    /// Clear all cached entries.
243    pub fn clear(&self) {
244        self.entries.write().clear();
245    }
246
247    /// Remove expired entries.
248    pub fn cleanup(&self) {
249        let mut entries = self.entries.write();
250        let ttl = self.config.ttl;
251
252        let expired: Vec<_> = entries
253            .iter()
254            .filter(|(_, e)| e.is_expired(ttl))
255            .map(|(k, _)| k.clone())
256            .collect();
257
258        for key in expired {
259            entries.remove(&key);
260            self.expirations.fetch_add(1, Ordering::Relaxed);
261        }
262    }
263}
264
265impl Default for Cache {
266    fn default() -> Self {
267        Self::new(CacheConfig::default())
268    }
269}
270
271/// Tower Layer that adds response caching.
272#[derive(Debug, Clone)]
273pub struct CacheLayer {
274    cache: Arc<Cache>,
275}
276
277impl CacheLayer {
278    /// Create a new cache layer with the given configuration.
279    #[must_use]
280    pub fn new(config: CacheConfig) -> Self {
281        Self {
282            cache: Arc::new(Cache::new(config)),
283        }
284    }
285
286    /// Create a new cache layer with a shared cache.
287    #[must_use]
288    pub fn with_cache(cache: Arc<Cache>) -> Self {
289        Self { cache }
290    }
291
292    /// Get a reference to the cache.
293    #[must_use]
294    pub fn cache(&self) -> &Arc<Cache> {
295        &self.cache
296    }
297}
298
299impl Default for CacheLayer {
300    fn default() -> Self {
301        Self::new(CacheConfig::default())
302    }
303}
304
305impl<S> Layer<S> for CacheLayer {
306    type Service = CacheService<S>;
307
308    fn layer(&self, inner: S) -> Self::Service {
309        CacheService {
310            inner,
311            cache: Arc::clone(&self.cache),
312        }
313    }
314}
315
316/// Tower Service that caches responses.
317#[derive(Debug, Clone)]
318pub struct CacheService<S> {
319    inner: S,
320    cache: Arc<Cache>,
321}
322
323impl<S> CacheService<S> {
324    /// Get a reference to the inner service.
325    pub fn inner(&self) -> &S {
326        &self.inner
327    }
328
329    /// Get a mutable reference to the inner service.
330    pub fn inner_mut(&mut self) -> &mut S {
331        &mut self.inner
332    }
333
334    /// Get a reference to the cache.
335    pub fn cache(&self) -> &Arc<Cache> {
336        &self.cache
337    }
338}
339
340impl<S> Service<McpRequest> for CacheService<S>
341where
342    S: Service<McpRequest, Response = McpResponse> + Clone + Send + 'static,
343    S::Future: Send,
344    S::Error: Into<McpError>,
345{
346    type Response = McpResponse;
347    type Error = McpError;
348    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
349
350    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
351        self.inner.poll_ready(cx).map_err(Into::into)
352    }
353
354    fn call(&mut self, req: McpRequest) -> Self::Future {
355        let method = req.method().to_string();
356        let cache = Arc::clone(&self.cache);
357
358        // Check if this method should be cached
359        if !cache.should_cache(&method) {
360            let mut inner = self.inner.clone();
361            std::mem::swap(&mut self.inner, &mut inner);
362            return Box::pin(async move { inner.call(req).await.map_err(Into::into) });
363        }
364
365        let cache_key = Cache::cache_key(&req);
366
367        // Check cache
368        if let Some(cached_value) = cache.get(&cache_key) {
369            return Box::pin(async move {
370                Ok(McpResponse {
371                    result: Some(cached_value),
372                    error: None,
373                    metadata: {
374                        let mut m = HashMap::new();
375                        m.insert("cache.hit".to_string(), serde_json::json!(true));
376                        m
377                    },
378                    duration: Duration::ZERO,
379                })
380            });
381        }
382
383        // Cache miss - call inner service
384        let mut inner = self.inner.clone();
385        std::mem::swap(&mut self.inner, &mut inner);
386
387        Box::pin(async move {
388            let start = Instant::now();
389            let result = inner.call(req).await.map_err(Into::into)?;
390
391            // Cache successful responses
392            if result.is_success()
393                && let Some(ref data) = result.result
394            {
395                cache.put(cache_key, data.clone());
396            }
397
398            let mut response = result;
399            response.insert_metadata("cache.hit", serde_json::json!(false));
400            response.duration = start.elapsed();
401
402            Ok(response)
403        })
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410    use serde_json::json;
411    use turbomcp_protocol::MessageId;
412    use turbomcp_protocol::jsonrpc::{JsonRpcRequest, JsonRpcVersion};
413
414    fn test_request(method: &str) -> McpRequest {
415        McpRequest::new(JsonRpcRequest {
416            jsonrpc: JsonRpcVersion,
417            id: MessageId::from("test-1"),
418            method: method.to_string(),
419            params: Some(json!({"key": "value"})),
420        })
421    }
422
423    #[test]
424    fn test_cache_config_defaults() {
425        let config = CacheConfig::default();
426
427        // Should cache read operations
428        assert!(config.should_cache("resources/list"));
429        assert!(config.should_cache("resources/read"));
430        assert!(config.should_cache("prompts/list"));
431        assert!(config.should_cache("tools/list"));
432
433        // Should not cache mutations
434        assert!(!config.should_cache("tools/call"));
435        assert!(!config.should_cache("sampling/createMessage"));
436    }
437
438    #[test]
439    fn test_cache_put_get() {
440        let cache = Cache::default();
441
442        let key = "test:123".to_string();
443        let value = json!({"result": "test"});
444
445        cache.put(key.clone(), value.clone());
446
447        let retrieved = cache.get(&key);
448        assert!(retrieved.is_some());
449        assert_eq!(retrieved.unwrap(), value);
450    }
451
452    #[test]
453    fn test_cache_miss() {
454        let cache = Cache::default();
455
456        let retrieved = cache.get("nonexistent");
457        assert!(retrieved.is_none());
458
459        let stats = cache.stats();
460        assert_eq!(stats.misses, 1);
461        assert_eq!(stats.hits, 0);
462    }
463
464    #[test]
465    fn test_cache_expiration() {
466        let config = CacheConfig {
467            ttl: Duration::from_millis(1),
468            ..Default::default()
469        };
470        let cache = Cache::new(config);
471
472        let key = "test:456".to_string();
473        cache.put(key.clone(), json!({"data": "test"}));
474
475        // Wait for expiration
476        std::thread::sleep(Duration::from_millis(5));
477
478        let retrieved = cache.get(&key);
479        assert!(retrieved.is_none());
480
481        let stats = cache.stats();
482        assert_eq!(stats.expirations, 1);
483    }
484
485    #[test]
486    fn test_cache_eviction() {
487        let config = CacheConfig {
488            max_entries: 2,
489            ttl: Duration::from_secs(300),
490            ..Default::default()
491        };
492        let cache = Cache::new(config);
493
494        cache.put("key1".to_string(), json!(1));
495        cache.put("key2".to_string(), json!(2));
496        cache.put("key3".to_string(), json!(3)); // Should trigger eviction
497
498        let stats = cache.stats();
499        assert!(stats.evictions > 0);
500        assert!(stats.current_entries <= 2);
501    }
502
503    #[test]
504    fn test_cache_key_generation() {
505        let req1 = test_request("resources/read");
506        let req2 = test_request("resources/read");
507        let req3 = test_request("resources/list");
508
509        // Same method + params should have same key
510        assert_eq!(Cache::cache_key(&req1), Cache::cache_key(&req2));
511
512        // Different method should have different key
513        assert_ne!(Cache::cache_key(&req1), Cache::cache_key(&req3));
514    }
515
516    #[tokio::test]
517    async fn test_cache_service() {
518        use tower::ServiceExt;
519
520        let cache = Arc::new(Cache::default());
521        let call_count = Arc::new(AtomicU64::new(0));
522        let call_count_clone = Arc::clone(&call_count);
523
524        let mock_service = tower::service_fn(move |_req: McpRequest| {
525            let count = Arc::clone(&call_count_clone);
526            async move {
527                count.fetch_add(1, Ordering::Relaxed);
528                Ok::<_, McpError>(McpResponse::success(
529                    json!({"result": "data"}),
530                    Duration::from_millis(10),
531                ))
532            }
533        });
534
535        let mut service = CacheLayer::with_cache(Arc::clone(&cache)).layer(mock_service);
536
537        let request = test_request("resources/list");
538
539        // First call - cache miss
540        let response = service
541            .ready()
542            .await
543            .unwrap()
544            .call(request.clone())
545            .await
546            .unwrap();
547        assert!(response.is_success());
548        assert_eq!(call_count.load(Ordering::Relaxed), 1);
549
550        // Second call - cache hit
551        let mut service = CacheLayer::with_cache(Arc::clone(&cache)).layer(tower::service_fn(
552            |_req: McpRequest| async {
553                panic!("Inner service should not be called on cache hit");
554                #[allow(unreachable_code)]
555                Ok::<_, McpError>(McpResponse::success(json!({}), Duration::ZERO))
556            },
557        ));
558
559        let response = service.ready().await.unwrap().call(request).await.unwrap();
560        assert!(response.is_success());
561        assert_eq!(response.get_metadata("cache.hit"), Some(&json!(true)));
562    }
563}