Skip to main content

vtcode_core/mcp/
tool_discovery_cache.rs

1//! Tool discovery caching system for MCP to avoid redundant tool searches
2//!
3//! This module provides multi-level caching for MCP tool discovery with
4//! bloom filters for fast negative lookups and LRU cache for positive results.
5
6use lru::LruCache;
7use rustc_hash::FxHashMap;
8use std::num::NonZeroUsize;
9use std::sync::{Arc, RwLock};
10use std::time::{Duration, Instant};
11use tracing::error;
12
13use super::McpToolInfo;
14use super::tool_discovery::DetailLevel;
15
16/// Bloom filter for fast negative lookups (tool doesn't exist)
17#[derive(Clone)]
18pub struct BloomFilter {
19    /// Bit array for the filter
20    bits: Vec<bool>,
21    /// Number of hash functions
22    num_hashes: usize,
23    /// Size of the bit array
24    size: usize,
25}
26
27impl BloomFilter {
28    pub fn new(expected_items: usize, false_positive_rate: f64) -> Self {
29        let size = Self::optimal_size(expected_items, false_positive_rate);
30        let num_hashes = Self::optimal_num_hashes(size, expected_items);
31
32        Self {
33            bits: vec![false; size],
34            num_hashes,
35            size,
36        }
37    }
38
39    /// Add an item to the bloom filter
40    pub fn insert(&mut self, item: &str) {
41        for i in 0..self.num_hashes {
42            let hash = self.hash(item, i);
43            let index = hash % self.size;
44            self.bits[index] = true;
45        }
46    }
47
48    /// Check if an item might be in the set
49    pub fn contains(&self, item: &str) -> bool {
50        for i in 0..self.num_hashes {
51            let hash = self.hash(item, i);
52            let index = hash % self.size;
53            if !self.bits[index] {
54                return false;
55            }
56        }
57        true
58    }
59
60    /// Clear the bloom filter
61    pub fn clear(&mut self) {
62        self.bits.fill(false);
63    }
64
65    /// Calculate optimal size for bloom filter
66    fn optimal_size(expected_items: usize, false_positive_rate: f64) -> usize {
67        let size = -(expected_items as f64 * false_positive_rate.ln() / (2.0_f64.ln().powi(2)));
68        size.ceil() as usize
69    }
70
71    /// Calculate optimal number of hash functions
72    fn optimal_num_hashes(size: usize, expected_items: usize) -> usize {
73        let num_hashes = (size as f64 / expected_items as f64) * 2.0_f64.ln();
74        num_hashes.ceil() as usize
75    }
76
77    /// Simple hash function for bloom filter
78    fn hash(&self, item: &str, seed: usize) -> usize {
79        use std::collections::hash_map::DefaultHasher;
80        use std::hash::{Hash, Hasher};
81
82        let mut hasher = DefaultHasher::new();
83        item.hash(&mut hasher);
84        seed.hash(&mut hasher);
85        hasher.finish() as usize
86    }
87}
88
89/// Cache key for tool discovery results
90#[derive(Debug, Clone, Hash, PartialEq, Eq)]
91struct ToolDiscoveryCacheKey {
92    provider_name: String,
93    keyword: String,
94    detail_level: DetailLevel,
95}
96
97/// Cached tool discovery result (internal cache entry)
98#[derive(Clone)]
99struct CachedToolDiscoveryEntry {
100    // OPTIMIZATION: Use Arc to avoid cloning large vectors on cache hits
101    results: Arc<Vec<ToolDiscoveryResult>>,
102    timestamp: Instant,
103}
104
105struct DiscoveryCacheInner {
106    bloom_filter: BloomFilter,
107    detailed_cache: LruCache<ToolDiscoveryCacheKey, CachedToolDiscoveryEntry>,
108    all_tools_cache: FxHashMap<String, Vec<McpToolInfo>>,
109    last_refresh: FxHashMap<String, Instant>,
110}
111
112/// Cached tool discovery result (matches actual API)
113#[derive(Debug, Clone)]
114pub struct ToolDiscoveryResult {
115    pub tool: McpToolInfo,
116    pub relevance_score: f64,
117    pub detail_level: DetailLevel,
118}
119
120/// Multi-level caching system for tool discovery
121pub struct ToolDiscoveryCache {
122    inner: Arc<RwLock<DiscoveryCacheInner>>,
123    /// Cache configuration
124    config: CacheConfig,
125}
126
127#[derive(Clone)]
128struct CacheConfig {
129    /// Maximum age for cached entries
130    max_age: Duration,
131    /// Maximum age for provider tool lists
132    provider_refresh_interval: Duration,
133    /// Expected number of tools for bloom filter sizing
134    expected_tool_count: usize,
135    /// Acceptable false positive rate for bloom filter
136    false_positive_rate: f64,
137}
138
139impl ToolDiscoveryCache {
140    pub fn new(capacity: usize) -> Self {
141        let config = CacheConfig {
142            max_age: Duration::from_secs(300),                  // 5 minutes
143            provider_refresh_interval: Duration::from_secs(60), // 1 minute
144            expected_tool_count: 1000,
145            false_positive_rate: 0.01, // 1% false positive rate
146        };
147
148        let bloom_filter = BloomFilter::new(config.expected_tool_count, config.false_positive_rate);
149        let cache_size = NonZeroUsize::new(capacity).or(NonZeroUsize::new(100));
150
151        Self {
152            inner: Arc::new(RwLock::new(DiscoveryCacheInner {
153                bloom_filter,
154                detailed_cache: LruCache::new(cache_size.unwrap_or(NonZeroUsize::MIN)),
155                all_tools_cache: FxHashMap::default(),
156                last_refresh: FxHashMap::default(),
157            })),
158            config,
159        }
160    }
161
162    /// Check if a tool might exist (fast negative lookup)
163    pub fn might_have_tool(&self, tool_name: &str) -> bool {
164        match self.inner.read() {
165            Ok(inner) => inner.bloom_filter.contains(tool_name),
166            Err(_) => {
167                tracing::warn!("Bloom filter lock poisoned, assuming tool might exist");
168                true
169            }
170        }
171    }
172
173    /// Get cached tool discovery results
174    pub fn get_cached_discovery(
175        &self,
176        provider_name: &str,
177        keyword: &str,
178        detail_level: DetailLevel,
179    ) -> Option<Arc<Vec<ToolDiscoveryResult>>> {
180        // OPTIMIZATION: Use to_owned() for explicit String allocation
181        let key = ToolDiscoveryCacheKey {
182            provider_name: provider_name.to_owned(),
183            keyword: keyword.to_owned(),
184            detail_level,
185        };
186
187        let mut inner = match self.inner.write() {
188            Ok(inner) => inner,
189            Err(e) => {
190                tracing::error!("Detailed cache lock poisoned: {}", e);
191                return None;
192            }
193        };
194
195        if let Some(cached) = inner.detailed_cache.get(&key) {
196            // Check if the cached entry is still fresh
197            if cached.timestamp.elapsed() < self.config.max_age {
198                return Some(Arc::clone(&cached.results));
199            } else {
200                // Entry is stale, remove it
201                inner.detailed_cache.pop(&key);
202            }
203        }
204
205        None
206    }
207
208    /// Cache tool discovery results
209    pub fn cache_discovery(
210        &self,
211        provider_name: &str,
212        keyword: &str,
213        detail_level: DetailLevel,
214        results: Vec<ToolDiscoveryResult>,
215    ) {
216        self.cache_discovery_shared(provider_name, keyword, detail_level, Arc::new(results));
217    }
218
219    fn cache_discovery_shared(
220        &self,
221        provider_name: &str,
222        keyword: &str,
223        detail_level: DetailLevel,
224        results: Arc<Vec<ToolDiscoveryResult>>,
225    ) {
226        // OPTIMIZATION: Use to_owned() for explicit String allocation
227        let key = ToolDiscoveryCacheKey {
228            provider_name: provider_name.to_owned(),
229            keyword: keyword.to_owned(),
230            detail_level,
231        };
232
233        let cached = CachedToolDiscoveryEntry {
234            // OPTIMIZATION: Wrap in Arc once, share across cache hits
235            results: Arc::clone(&results),
236            timestamp: Instant::now(),
237        };
238
239        let Ok(mut inner) = self.inner.write() else {
240            tracing::error!("Failed to acquire discovery cache lock for writing");
241            return;
242        };
243
244        inner.detailed_cache.put(key, cached);
245
246        for result in results.iter() {
247            inner.bloom_filter.insert(&result.tool.name);
248        }
249    }
250
251    /// Get all cached tools for a provider (with refresh checking)
252    pub fn get_all_tools(
253        &self,
254        provider_name: &str,
255        refresh_if_stale: bool,
256    ) -> Option<Vec<McpToolInfo>> {
257        let inner = match self.inner.read() {
258            Ok(inner) => inner,
259            Err(e) => {
260                error!("Discovery cache lock poisoned: {}", e);
261                return None;
262            }
263        };
264
265        let should_refresh = if let Some(last) = inner.last_refresh.get(provider_name) {
266            last.elapsed() > self.config.provider_refresh_interval
267        } else {
268            true
269        };
270
271        if should_refresh && refresh_if_stale {
272            return None; // Signal that refresh is needed
273        }
274
275        inner.all_tools_cache.get(provider_name).cloned()
276    }
277
278    /// Cache all tools for a provider
279    pub fn cache_all_tools(&self, provider_name: &str, tools: Vec<McpToolInfo>) {
280        let mut inner = match self.inner.write() {
281            Ok(inner) => inner,
282            Err(e) => {
283                tracing::error!("Discovery cache lock poisoned: {}", e);
284                return;
285            }
286        };
287
288        inner
289            .all_tools_cache
290            .insert(provider_name.to_owned(), tools.clone());
291        inner
292            .last_refresh
293            .insert(provider_name.to_owned(), Instant::now());
294
295        // Update bloom filter with all tool names
296        inner.bloom_filter.clear(); // Clear and rebuild for accuracy
297
298        let all_tool_names: Vec<String> = inner
299            .all_tools_cache
300            .values()
301            .flat_map(|provider_tools| provider_tools.iter().map(|tool| tool.name.clone()))
302            .collect();
303
304        for tool_name in all_tool_names {
305            inner.bloom_filter.insert(&tool_name);
306        }
307    }
308
309    /// Cache a single tool result (for read-only tools)
310    pub fn cache_tool_result(&self, _cache_key: String, _result: serde_json::Value) {
311        // This would be implemented for caching individual tool execution results
312        // For now, we'll just store it in a simple cache
313        // In a full implementation, this would use a separate cache with different TTL
314    }
315
316    /// Clear all caches
317    pub fn clear(&self) {
318        if let Ok(mut inner) = self.inner.write() {
319            inner.bloom_filter.clear();
320            inner.detailed_cache.clear();
321            inner.all_tools_cache.clear();
322            inner.last_refresh.clear();
323        }
324    }
325
326    /// Get cache statistics
327    pub fn stats(&self) -> ToolCacheStats {
328        let (detailed_entries, detailed_capacity, all_tools_entries, bf_size, bf_hashes) = self
329            .inner
330            .read()
331            .map(|inner| {
332                (
333                    inner.detailed_cache.len(),
334                    inner.detailed_cache.cap().get(),
335                    inner.all_tools_cache.len(),
336                    inner.bloom_filter.size,
337                    inner.bloom_filter.num_hashes,
338                )
339            })
340            .unwrap_or((0, 0, 0, 0, 0));
341
342        ToolCacheStats {
343            detailed_cache_entries: detailed_entries,
344            detailed_cache_capacity: detailed_capacity,
345            all_tools_cache_entries: all_tools_entries,
346            bloom_filter_size: bf_size,
347            bloom_filter_hashes: bf_hashes,
348        }
349    }
350}
351
352/// Cache statistics for monitoring
353#[derive(Debug, Clone)]
354pub struct ToolCacheStats {
355    pub detailed_cache_entries: usize,
356    pub detailed_cache_capacity: usize,
357    pub all_tools_cache_entries: usize,
358    pub bloom_filter_size: usize,
359    pub bloom_filter_hashes: usize,
360}
361
362/// Enhanced tool discovery with caching
363pub struct CachedToolDiscovery {
364    cache: Arc<ToolDiscoveryCache>,
365}
366
367impl CachedToolDiscovery {
368    pub fn new(cache_capacity: usize) -> Self {
369        Self {
370            cache: Arc::new(ToolDiscoveryCache::new(cache_capacity)),
371        }
372    }
373
374    /// Search for tools with multi-level caching
375    pub fn search_tools(
376        &self,
377        provider_name: &str,
378        keyword: &str,
379        detail_level: DetailLevel,
380        all_tools: Vec<McpToolInfo>,
381    ) -> Arc<Vec<ToolDiscoveryResult>> {
382        // Check bloom filter first (fast negative lookup)
383        if !self.cache.might_have_tool(keyword) && !keyword.is_empty() {
384            return Arc::new(Vec::new());
385        }
386
387        // Check detailed cache
388        if let Some(cached) = self
389            .cache
390            .get_cached_discovery(provider_name, keyword, detail_level)
391        {
392            return cached;
393        }
394
395        // Perform the search
396        let results = Arc::new(self.perform_search(&all_tools, keyword, detail_level));
397
398        // Cache the results
399        self.cache.cache_discovery_shared(
400            provider_name,
401            keyword,
402            detail_level,
403            Arc::clone(&results),
404        );
405
406        results
407    }
408
409    /// Get all tools for a provider with caching
410    pub fn get_all_tools_cached(
411        &self,
412        provider_name: &str,
413        all_tools: Vec<McpToolInfo>,
414    ) -> Vec<McpToolInfo> {
415        // Check cache first
416        if let Some(cached) = self.cache.get_all_tools(provider_name, true) {
417            return cached;
418        }
419
420        // Cache the results
421        self.cache.cache_all_tools(provider_name, all_tools.clone());
422
423        all_tools
424    }
425
426    /// Perform the actual search on tool list
427    fn perform_search(
428        &self,
429        tools: &[McpToolInfo],
430        keyword: &str,
431        detail_level: DetailLevel,
432    ) -> Vec<ToolDiscoveryResult> {
433        let keyword_lower = keyword.to_lowercase();
434        let mut results = Vec::new();
435
436        for tool in tools {
437            let relevance_score = self.calculate_relevance(tool, &keyword_lower);
438
439            if relevance_score > 0.0 {
440                let result = ToolDiscoveryResult {
441                    tool: tool.clone(),
442                    relevance_score,
443                    detail_level,
444                };
445                results.push(result);
446            }
447        }
448
449        // Sort by relevance score
450        results.sort_by(|a, b| {
451            b.relevance_score
452                .partial_cmp(&a.relevance_score)
453                .unwrap_or(std::cmp::Ordering::Equal)
454        });
455
456        results
457    }
458
459    /// Calculate relevance score for a tool
460    fn calculate_relevance(&self, tool: &McpToolInfo, keyword: &str) -> f64 {
461        let name_lower = tool.name.to_lowercase();
462        let description_lower = tool.description.to_lowercase();
463
464        let mut score: f64 = 0.0;
465
466        // Name exact match
467        if name_lower == keyword {
468            score += 1.0;
469        }
470        // Name starts with keyword
471        else if name_lower.starts_with(keyword) {
472            score += 0.8;
473        }
474        // Name contains keyword
475        else if name_lower.contains(keyword) {
476            score += 0.6;
477        }
478
479        // Description contains keyword
480        if description_lower.contains(keyword) {
481            score += 0.3;
482        }
483
484        // Input schema contains keyword
485        let schema_str = serde_json::to_string(&tool.input_schema)
486            .unwrap_or_default()
487            .to_lowercase();
488        if schema_str.contains(keyword) {
489            score += 0.2;
490        }
491
492        score.min(1.0)
493    }
494
495    /// Get cache statistics
496    pub fn stats(&self) -> ToolCacheStats {
497        self.cache.stats()
498    }
499}
500
501#[cfg(test)]
502mod tests {
503    use super::*;
504
505    #[test]
506    fn test_bloom_filter() {
507        let mut filter = BloomFilter::new(100, 0.01);
508
509        filter.insert("tool1");
510        filter.insert("tool2");
511        filter.insert("tool3");
512
513        assert!(filter.contains("tool1"));
514        assert!(filter.contains("tool2"));
515        assert!(filter.contains("tool3"));
516        assert!(!filter.contains("tool4"));
517    }
518
519    #[test]
520    fn test_cache_key_equality() {
521        let key1 = ToolDiscoveryCacheKey {
522            provider_name: "test".to_string(),
523            keyword: "search".to_string(),
524            detail_level: DetailLevel::Full,
525        };
526
527        let key2 = ToolDiscoveryCacheKey {
528            provider_name: "test".to_string(),
529            keyword: "search".to_string(),
530            detail_level: DetailLevel::Full,
531        };
532
533        assert_eq!(key1, key2);
534    }
535
536    #[test]
537    fn test_tool_discovery_cache() {
538        let cache = ToolDiscoveryCache::new(10);
539
540        let provider_name = "test_provider";
541        let keyword = "search";
542        let detail_level = DetailLevel::Full;
543
544        // Cache miss
545        assert!(
546            cache
547                .get_cached_discovery(provider_name, keyword, detail_level)
548                .is_none()
549        );
550
551        // Cache some results
552        let results = vec![ToolDiscoveryResult {
553            tool: McpToolInfo {
554                name: "search_files".to_string(),
555                description: "Search for files".to_string(),
556                provider: "test".to_string(),
557                input_schema: serde_json::json!({}),
558            },
559            relevance_score: 0.9,
560            detail_level,
561        }];
562
563        cache.cache_discovery(provider_name, keyword, detail_level, results.clone());
564
565        // Cache hit
566        let cached = cache.get_cached_discovery(provider_name, keyword, detail_level);
567        assert!(cached.is_some());
568        assert_eq!(cached.unwrap().len(), 1);
569    }
570}