turbomcp_client/middleware/
cache.rs1use super::request::{McpRequest, McpResponse};
22use futures_util::future::BoxFuture;
23use parking_lot::RwLock;
24use serde_json::Value;
25use std::collections::HashMap;
26use std::hash::{Hash, Hasher};
27use std::sync::Arc;
28use std::sync::atomic::{AtomicU64, Ordering};
29use std::task::{Context, Poll};
30use std::time::{Duration, Instant};
31use tower_layer::Layer;
32use tower_service::Service;
33use turbomcp_protocol::McpError;
34
35#[derive(Debug, Clone)]
37pub struct CacheConfig {
38 pub max_entries: usize,
40 pub ttl: Duration,
42 pub cache_methods: Vec<String>,
44 pub exclude_methods: Vec<String>,
46}
47
48impl Default for CacheConfig {
49 fn default() -> Self {
50 Self {
51 max_entries: 1000,
52 ttl: Duration::from_secs(300), cache_methods: Vec::new(),
54 exclude_methods: vec![
55 "tools/call".to_string(),
57 "sampling/createMessage".to_string(),
58 "notifications/".to_string(),
60 ],
61 }
62 }
63}
64
65impl CacheConfig {
66 fn should_cache(&self, method: &str) -> bool {
68 for excluded in &self.exclude_methods {
70 if method.starts_with(excluded) || method == excluded {
71 return false;
72 }
73 }
74
75 if !self.cache_methods.is_empty() {
77 return self
78 .cache_methods
79 .iter()
80 .any(|m| method.starts_with(m) || method == m);
81 }
82
83 method.starts_with("resources/")
85 || method.starts_with("prompts/")
86 || method == "tools/list"
87 || method == "resources/list"
88 || method == "prompts/list"
89 }
90}
91
92#[derive(Debug, Clone)]
94struct CacheEntry {
95 data: Value,
96 created: Instant,
97 last_accessed: Instant,
98 access_count: u64,
99}
100
101impl CacheEntry {
102 fn new(data: Value) -> Self {
103 let now = Instant::now();
104 Self {
105 data,
106 created: now,
107 last_accessed: now,
108 access_count: 0,
109 }
110 }
111
112 fn is_expired(&self, ttl: Duration) -> bool {
113 self.created.elapsed() > ttl
114 }
115
116 fn access(&mut self) -> &Value {
117 self.last_accessed = Instant::now();
118 self.access_count += 1;
119 &self.data
120 }
121}
122
123#[derive(Debug, Clone, Default)]
125pub struct CacheStats {
126 pub hits: u64,
128 pub misses: u64,
130 pub evictions: u64,
132 pub expirations: u64,
134 pub current_entries: usize,
136}
137
138#[derive(Debug)]
140pub struct Cache {
141 config: CacheConfig,
142 entries: RwLock<HashMap<String, CacheEntry>>,
143 hits: AtomicU64,
144 misses: AtomicU64,
145 evictions: AtomicU64,
146 expirations: AtomicU64,
147}
148
149impl Cache {
150 #[must_use]
152 pub fn new(config: CacheConfig) -> Self {
153 Self {
154 config,
155 entries: RwLock::new(HashMap::new()),
156 hits: AtomicU64::new(0),
157 misses: AtomicU64::new(0),
158 evictions: AtomicU64::new(0),
159 expirations: AtomicU64::new(0),
160 }
161 }
162
163 fn cache_key(req: &McpRequest) -> String {
165 let mut hasher = std::collections::hash_map::DefaultHasher::new();
166
167 req.method().hash(&mut hasher);
168 if let Some(params) = req.params() {
169 params.to_string().hash(&mut hasher);
170 }
171
172 format!("{}:{:x}", req.method(), hasher.finish())
173 }
174
175 pub fn should_cache(&self, method: &str) -> bool {
177 self.config.should_cache(method)
178 }
179
180 pub fn get(&self, key: &str) -> Option<Value> {
182 let mut entries = self.entries.write();
183
184 if let Some(entry) = entries.get_mut(key) {
185 if entry.is_expired(self.config.ttl) {
186 entries.remove(key);
187 self.expirations.fetch_add(1, Ordering::Relaxed);
188 self.misses.fetch_add(1, Ordering::Relaxed);
189 return None;
190 }
191
192 self.hits.fetch_add(1, Ordering::Relaxed);
193 return Some(entry.access().clone());
194 }
195
196 self.misses.fetch_add(1, Ordering::Relaxed);
197 None
198 }
199
200 pub fn put(&self, key: String, value: Value) {
202 let mut entries = self.entries.write();
203
204 if entries.len() >= self.config.max_entries {
206 self.evict_lru(&mut entries);
207 }
208
209 entries.insert(key, CacheEntry::new(value));
210 }
211
212 fn evict_lru(&self, entries: &mut HashMap<String, CacheEntry>) {
214 let mut to_evict: Vec<_> = entries
216 .iter()
217 .map(|(k, v)| (k.clone(), v.last_accessed))
218 .collect();
219
220 to_evict.sort_by_key(|(_, accessed)| *accessed);
221
222 let evict_count = (entries.len() / 10).max(1);
224 for (key, _) in to_evict.into_iter().take(evict_count) {
225 entries.remove(&key);
226 self.evictions.fetch_add(1, Ordering::Relaxed);
227 }
228 }
229
230 #[must_use]
232 pub fn stats(&self) -> CacheStats {
233 CacheStats {
234 hits: self.hits.load(Ordering::Relaxed),
235 misses: self.misses.load(Ordering::Relaxed),
236 evictions: self.evictions.load(Ordering::Relaxed),
237 expirations: self.expirations.load(Ordering::Relaxed),
238 current_entries: self.entries.read().len(),
239 }
240 }
241
242 pub fn clear(&self) {
244 self.entries.write().clear();
245 }
246
247 pub fn cleanup(&self) {
249 let mut entries = self.entries.write();
250 let ttl = self.config.ttl;
251
252 let expired: Vec<_> = entries
253 .iter()
254 .filter(|(_, e)| e.is_expired(ttl))
255 .map(|(k, _)| k.clone())
256 .collect();
257
258 for key in expired {
259 entries.remove(&key);
260 self.expirations.fetch_add(1, Ordering::Relaxed);
261 }
262 }
263}
264
265impl Default for Cache {
266 fn default() -> Self {
267 Self::new(CacheConfig::default())
268 }
269}
270
271#[derive(Debug, Clone)]
273pub struct CacheLayer {
274 cache: Arc<Cache>,
275}
276
277impl CacheLayer {
278 #[must_use]
280 pub fn new(config: CacheConfig) -> Self {
281 Self {
282 cache: Arc::new(Cache::new(config)),
283 }
284 }
285
286 #[must_use]
288 pub fn with_cache(cache: Arc<Cache>) -> Self {
289 Self { cache }
290 }
291
292 #[must_use]
294 pub fn cache(&self) -> &Arc<Cache> {
295 &self.cache
296 }
297}
298
299impl Default for CacheLayer {
300 fn default() -> Self {
301 Self::new(CacheConfig::default())
302 }
303}
304
305impl<S> Layer<S> for CacheLayer {
306 type Service = CacheService<S>;
307
308 fn layer(&self, inner: S) -> Self::Service {
309 CacheService {
310 inner,
311 cache: Arc::clone(&self.cache),
312 }
313 }
314}
315
316#[derive(Debug, Clone)]
318pub struct CacheService<S> {
319 inner: S,
320 cache: Arc<Cache>,
321}
322
323impl<S> CacheService<S> {
324 pub fn inner(&self) -> &S {
326 &self.inner
327 }
328
329 pub fn inner_mut(&mut self) -> &mut S {
331 &mut self.inner
332 }
333
334 pub fn cache(&self) -> &Arc<Cache> {
336 &self.cache
337 }
338}
339
340impl<S> Service<McpRequest> for CacheService<S>
341where
342 S: Service<McpRequest, Response = McpResponse> + Clone + Send + 'static,
343 S::Future: Send,
344 S::Error: Into<McpError>,
345{
346 type Response = McpResponse;
347 type Error = McpError;
348 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
349
350 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
351 self.inner.poll_ready(cx).map_err(Into::into)
352 }
353
354 fn call(&mut self, req: McpRequest) -> Self::Future {
355 let method = req.method().to_string();
356 let cache = Arc::clone(&self.cache);
357
358 if !cache.should_cache(&method) {
360 let mut inner = self.inner.clone();
361 std::mem::swap(&mut self.inner, &mut inner);
362 return Box::pin(async move { inner.call(req).await.map_err(Into::into) });
363 }
364
365 let cache_key = Cache::cache_key(&req);
366
367 if let Some(cached_value) = cache.get(&cache_key) {
369 return Box::pin(async move {
370 Ok(McpResponse {
371 result: Some(cached_value),
372 error: None,
373 metadata: {
374 let mut m = HashMap::new();
375 m.insert("cache.hit".to_string(), serde_json::json!(true));
376 m
377 },
378 duration: Duration::ZERO,
379 })
380 });
381 }
382
383 let mut inner = self.inner.clone();
385 std::mem::swap(&mut self.inner, &mut inner);
386
387 Box::pin(async move {
388 let start = Instant::now();
389 let result = inner.call(req).await.map_err(Into::into)?;
390
391 if result.is_success()
393 && let Some(ref data) = result.result
394 {
395 cache.put(cache_key, data.clone());
396 }
397
398 let mut response = result;
399 response.insert_metadata("cache.hit", serde_json::json!(false));
400 response.duration = start.elapsed();
401
402 Ok(response)
403 })
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410 use serde_json::json;
411 use turbomcp_protocol::MessageId;
412 use turbomcp_protocol::jsonrpc::{JsonRpcRequest, JsonRpcVersion};
413
414 fn test_request(method: &str) -> McpRequest {
415 McpRequest::new(JsonRpcRequest {
416 jsonrpc: JsonRpcVersion,
417 id: MessageId::from("test-1"),
418 method: method.to_string(),
419 params: Some(json!({"key": "value"})),
420 })
421 }
422
423 #[test]
424 fn test_cache_config_defaults() {
425 let config = CacheConfig::default();
426
427 assert!(config.should_cache("resources/list"));
429 assert!(config.should_cache("resources/read"));
430 assert!(config.should_cache("prompts/list"));
431 assert!(config.should_cache("tools/list"));
432
433 assert!(!config.should_cache("tools/call"));
435 assert!(!config.should_cache("sampling/createMessage"));
436 }
437
438 #[test]
439 fn test_cache_put_get() {
440 let cache = Cache::default();
441
442 let key = "test:123".to_string();
443 let value = json!({"result": "test"});
444
445 cache.put(key.clone(), value.clone());
446
447 let retrieved = cache.get(&key);
448 assert!(retrieved.is_some());
449 assert_eq!(retrieved.unwrap(), value);
450 }
451
452 #[test]
453 fn test_cache_miss() {
454 let cache = Cache::default();
455
456 let retrieved = cache.get("nonexistent");
457 assert!(retrieved.is_none());
458
459 let stats = cache.stats();
460 assert_eq!(stats.misses, 1);
461 assert_eq!(stats.hits, 0);
462 }
463
464 #[test]
465 fn test_cache_expiration() {
466 let config = CacheConfig {
467 ttl: Duration::from_millis(1),
468 ..Default::default()
469 };
470 let cache = Cache::new(config);
471
472 let key = "test:456".to_string();
473 cache.put(key.clone(), json!({"data": "test"}));
474
475 std::thread::sleep(Duration::from_millis(5));
477
478 let retrieved = cache.get(&key);
479 assert!(retrieved.is_none());
480
481 let stats = cache.stats();
482 assert_eq!(stats.expirations, 1);
483 }
484
485 #[test]
486 fn test_cache_eviction() {
487 let config = CacheConfig {
488 max_entries: 2,
489 ttl: Duration::from_secs(300),
490 ..Default::default()
491 };
492 let cache = Cache::new(config);
493
494 cache.put("key1".to_string(), json!(1));
495 cache.put("key2".to_string(), json!(2));
496 cache.put("key3".to_string(), json!(3)); let stats = cache.stats();
499 assert!(stats.evictions > 0);
500 assert!(stats.current_entries <= 2);
501 }
502
503 #[test]
504 fn test_cache_key_generation() {
505 let req1 = test_request("resources/read");
506 let req2 = test_request("resources/read");
507 let req3 = test_request("resources/list");
508
509 assert_eq!(Cache::cache_key(&req1), Cache::cache_key(&req2));
511
512 assert_ne!(Cache::cache_key(&req1), Cache::cache_key(&req3));
514 }
515
516 #[tokio::test]
517 async fn test_cache_service() {
518 use tower::ServiceExt;
519
520 let cache = Arc::new(Cache::default());
521 let call_count = Arc::new(AtomicU64::new(0));
522 let call_count_clone = Arc::clone(&call_count);
523
524 let mock_service = tower::service_fn(move |_req: McpRequest| {
525 let count = Arc::clone(&call_count_clone);
526 async move {
527 count.fetch_add(1, Ordering::Relaxed);
528 Ok::<_, McpError>(McpResponse::success(
529 json!({"result": "data"}),
530 Duration::from_millis(10),
531 ))
532 }
533 });
534
535 let mut service = CacheLayer::with_cache(Arc::clone(&cache)).layer(mock_service);
536
537 let request = test_request("resources/list");
538
539 let response = service
541 .ready()
542 .await
543 .unwrap()
544 .call(request.clone())
545 .await
546 .unwrap();
547 assert!(response.is_success());
548 assert_eq!(call_count.load(Ordering::Relaxed), 1);
549
550 let mut service = CacheLayer::with_cache(Arc::clone(&cache)).layer(tower::service_fn(
552 |_req: McpRequest| async {
553 panic!("Inner service should not be called on cache hit");
554 #[allow(unreachable_code)]
555 Ok::<_, McpError>(McpResponse::success(json!({}), Duration::ZERO))
556 },
557 ));
558
559 let response = service.ready().await.unwrap().call(request).await.unwrap();
560 assert!(response.is_success());
561 assert_eq!(response.get_metadata("cache.hit"), Some(&json!(true)));
562 }
563}