1use lru::LruCache;
7use rustc_hash::FxHashMap;
8use std::num::NonZeroUsize;
9use std::sync::{Arc, RwLock};
10use std::time::{Duration, Instant};
11use tracing::error;
12
13use super::McpToolInfo;
14use super::tool_discovery::DetailLevel;
15
16#[derive(Clone)]
18pub struct BloomFilter {
19 bits: Vec<bool>,
21 num_hashes: usize,
23 size: usize,
25}
26
27impl BloomFilter {
28 pub fn new(expected_items: usize, false_positive_rate: f64) -> Self {
29 let size = Self::optimal_size(expected_items, false_positive_rate);
30 let num_hashes = Self::optimal_num_hashes(size, expected_items);
31
32 Self {
33 bits: vec![false; size],
34 num_hashes,
35 size,
36 }
37 }
38
39 pub fn insert(&mut self, item: &str) {
41 for i in 0..self.num_hashes {
42 let hash = self.hash(item, i);
43 let index = hash % self.size;
44 self.bits[index] = true;
45 }
46 }
47
48 pub fn contains(&self, item: &str) -> bool {
50 for i in 0..self.num_hashes {
51 let hash = self.hash(item, i);
52 let index = hash % self.size;
53 if !self.bits[index] {
54 return false;
55 }
56 }
57 true
58 }
59
60 pub fn clear(&mut self) {
62 self.bits.fill(false);
63 }
64
65 fn optimal_size(expected_items: usize, false_positive_rate: f64) -> usize {
67 let size = -(expected_items as f64 * false_positive_rate.ln() / (2.0_f64.ln().powi(2)));
68 size.ceil() as usize
69 }
70
71 fn optimal_num_hashes(size: usize, expected_items: usize) -> usize {
73 let num_hashes = (size as f64 / expected_items as f64) * 2.0_f64.ln();
74 num_hashes.ceil() as usize
75 }
76
77 fn hash(&self, item: &str, seed: usize) -> usize {
79 use std::collections::hash_map::DefaultHasher;
80 use std::hash::{Hash, Hasher};
81
82 let mut hasher = DefaultHasher::new();
83 item.hash(&mut hasher);
84 seed.hash(&mut hasher);
85 hasher.finish() as usize
86 }
87}
88
89#[derive(Debug, Clone, Hash, PartialEq, Eq)]
91struct ToolDiscoveryCacheKey {
92 provider_name: String,
93 keyword: String,
94 detail_level: DetailLevel,
95}
96
97#[derive(Clone)]
99struct CachedToolDiscoveryEntry {
100 results: Arc<Vec<ToolDiscoveryResult>>,
102 timestamp: Instant,
103}
104
105struct DiscoveryCacheInner {
106 bloom_filter: BloomFilter,
107 detailed_cache: LruCache<ToolDiscoveryCacheKey, CachedToolDiscoveryEntry>,
108 all_tools_cache: FxHashMap<String, Vec<McpToolInfo>>,
109 last_refresh: FxHashMap<String, Instant>,
110}
111
112#[derive(Debug, Clone)]
114pub struct ToolDiscoveryResult {
115 pub tool: McpToolInfo,
116 pub relevance_score: f64,
117 pub detail_level: DetailLevel,
118}
119
120pub struct ToolDiscoveryCache {
122 inner: Arc<RwLock<DiscoveryCacheInner>>,
123 config: CacheConfig,
125}
126
127#[derive(Clone)]
128struct CacheConfig {
129 max_age: Duration,
131 provider_refresh_interval: Duration,
133 expected_tool_count: usize,
135 false_positive_rate: f64,
137}
138
139impl ToolDiscoveryCache {
140 pub fn new(capacity: usize) -> Self {
141 let config = CacheConfig {
142 max_age: Duration::from_secs(300), provider_refresh_interval: Duration::from_secs(60), expected_tool_count: 1000,
145 false_positive_rate: 0.01, };
147
148 let bloom_filter = BloomFilter::new(config.expected_tool_count, config.false_positive_rate);
149 let cache_size = NonZeroUsize::new(capacity).or(NonZeroUsize::new(100));
150
151 Self {
152 inner: Arc::new(RwLock::new(DiscoveryCacheInner {
153 bloom_filter,
154 detailed_cache: LruCache::new(cache_size.unwrap_or(NonZeroUsize::MIN)),
155 all_tools_cache: FxHashMap::default(),
156 last_refresh: FxHashMap::default(),
157 })),
158 config,
159 }
160 }
161
162 pub fn might_have_tool(&self, tool_name: &str) -> bool {
164 match self.inner.read() {
165 Ok(inner) => inner.bloom_filter.contains(tool_name),
166 Err(_) => {
167 tracing::warn!("Bloom filter lock poisoned, assuming tool might exist");
168 true
169 }
170 }
171 }
172
173 pub fn get_cached_discovery(
175 &self,
176 provider_name: &str,
177 keyword: &str,
178 detail_level: DetailLevel,
179 ) -> Option<Arc<Vec<ToolDiscoveryResult>>> {
180 let key = ToolDiscoveryCacheKey {
182 provider_name: provider_name.to_owned(),
183 keyword: keyword.to_owned(),
184 detail_level,
185 };
186
187 let mut inner = match self.inner.write() {
188 Ok(inner) => inner,
189 Err(e) => {
190 tracing::error!("Detailed cache lock poisoned: {}", e);
191 return None;
192 }
193 };
194
195 if let Some(cached) = inner.detailed_cache.get(&key) {
196 if cached.timestamp.elapsed() < self.config.max_age {
198 return Some(Arc::clone(&cached.results));
199 } else {
200 inner.detailed_cache.pop(&key);
202 }
203 }
204
205 None
206 }
207
208 pub fn cache_discovery(
210 &self,
211 provider_name: &str,
212 keyword: &str,
213 detail_level: DetailLevel,
214 results: Vec<ToolDiscoveryResult>,
215 ) {
216 self.cache_discovery_shared(provider_name, keyword, detail_level, Arc::new(results));
217 }
218
219 fn cache_discovery_shared(
220 &self,
221 provider_name: &str,
222 keyword: &str,
223 detail_level: DetailLevel,
224 results: Arc<Vec<ToolDiscoveryResult>>,
225 ) {
226 let key = ToolDiscoveryCacheKey {
228 provider_name: provider_name.to_owned(),
229 keyword: keyword.to_owned(),
230 detail_level,
231 };
232
233 let cached = CachedToolDiscoveryEntry {
234 results: Arc::clone(&results),
236 timestamp: Instant::now(),
237 };
238
239 let Ok(mut inner) = self.inner.write() else {
240 tracing::error!("Failed to acquire discovery cache lock for writing");
241 return;
242 };
243
244 inner.detailed_cache.put(key, cached);
245
246 for result in results.iter() {
247 inner.bloom_filter.insert(&result.tool.name);
248 }
249 }
250
251 pub fn get_all_tools(
253 &self,
254 provider_name: &str,
255 refresh_if_stale: bool,
256 ) -> Option<Vec<McpToolInfo>> {
257 let inner = match self.inner.read() {
258 Ok(inner) => inner,
259 Err(e) => {
260 error!("Discovery cache lock poisoned: {}", e);
261 return None;
262 }
263 };
264
265 let should_refresh = if let Some(last) = inner.last_refresh.get(provider_name) {
266 last.elapsed() > self.config.provider_refresh_interval
267 } else {
268 true
269 };
270
271 if should_refresh && refresh_if_stale {
272 return None; }
274
275 inner.all_tools_cache.get(provider_name).cloned()
276 }
277
278 pub fn cache_all_tools(&self, provider_name: &str, tools: Vec<McpToolInfo>) {
280 let mut inner = match self.inner.write() {
281 Ok(inner) => inner,
282 Err(e) => {
283 tracing::error!("Discovery cache lock poisoned: {}", e);
284 return;
285 }
286 };
287
288 inner
289 .all_tools_cache
290 .insert(provider_name.to_owned(), tools.clone());
291 inner
292 .last_refresh
293 .insert(provider_name.to_owned(), Instant::now());
294
295 inner.bloom_filter.clear(); let all_tool_names: Vec<String> = inner
299 .all_tools_cache
300 .values()
301 .flat_map(|provider_tools| provider_tools.iter().map(|tool| tool.name.clone()))
302 .collect();
303
304 for tool_name in all_tool_names {
305 inner.bloom_filter.insert(&tool_name);
306 }
307 }
308
309 pub fn cache_tool_result(&self, _cache_key: String, _result: serde_json::Value) {
311 }
315
316 pub fn clear(&self) {
318 if let Ok(mut inner) = self.inner.write() {
319 inner.bloom_filter.clear();
320 inner.detailed_cache.clear();
321 inner.all_tools_cache.clear();
322 inner.last_refresh.clear();
323 }
324 }
325
326 pub fn stats(&self) -> ToolCacheStats {
328 let (detailed_entries, detailed_capacity, all_tools_entries, bf_size, bf_hashes) = self
329 .inner
330 .read()
331 .map(|inner| {
332 (
333 inner.detailed_cache.len(),
334 inner.detailed_cache.cap().get(),
335 inner.all_tools_cache.len(),
336 inner.bloom_filter.size,
337 inner.bloom_filter.num_hashes,
338 )
339 })
340 .unwrap_or((0, 0, 0, 0, 0));
341
342 ToolCacheStats {
343 detailed_cache_entries: detailed_entries,
344 detailed_cache_capacity: detailed_capacity,
345 all_tools_cache_entries: all_tools_entries,
346 bloom_filter_size: bf_size,
347 bloom_filter_hashes: bf_hashes,
348 }
349 }
350}
351
352#[derive(Debug, Clone)]
354pub struct ToolCacheStats {
355 pub detailed_cache_entries: usize,
356 pub detailed_cache_capacity: usize,
357 pub all_tools_cache_entries: usize,
358 pub bloom_filter_size: usize,
359 pub bloom_filter_hashes: usize,
360}
361
362pub struct CachedToolDiscovery {
364 cache: Arc<ToolDiscoveryCache>,
365}
366
367impl CachedToolDiscovery {
368 pub fn new(cache_capacity: usize) -> Self {
369 Self {
370 cache: Arc::new(ToolDiscoveryCache::new(cache_capacity)),
371 }
372 }
373
374 pub fn search_tools(
376 &self,
377 provider_name: &str,
378 keyword: &str,
379 detail_level: DetailLevel,
380 all_tools: Vec<McpToolInfo>,
381 ) -> Arc<Vec<ToolDiscoveryResult>> {
382 if !self.cache.might_have_tool(keyword) && !keyword.is_empty() {
384 return Arc::new(Vec::new());
385 }
386
387 if let Some(cached) = self
389 .cache
390 .get_cached_discovery(provider_name, keyword, detail_level)
391 {
392 return cached;
393 }
394
395 let results = Arc::new(self.perform_search(&all_tools, keyword, detail_level));
397
398 self.cache.cache_discovery_shared(
400 provider_name,
401 keyword,
402 detail_level,
403 Arc::clone(&results),
404 );
405
406 results
407 }
408
409 pub fn get_all_tools_cached(
411 &self,
412 provider_name: &str,
413 all_tools: Vec<McpToolInfo>,
414 ) -> Vec<McpToolInfo> {
415 if let Some(cached) = self.cache.get_all_tools(provider_name, true) {
417 return cached;
418 }
419
420 self.cache.cache_all_tools(provider_name, all_tools.clone());
422
423 all_tools
424 }
425
426 fn perform_search(
428 &self,
429 tools: &[McpToolInfo],
430 keyword: &str,
431 detail_level: DetailLevel,
432 ) -> Vec<ToolDiscoveryResult> {
433 let keyword_lower = keyword.to_lowercase();
434 let mut results = Vec::new();
435
436 for tool in tools {
437 let relevance_score = self.calculate_relevance(tool, &keyword_lower);
438
439 if relevance_score > 0.0 {
440 let result = ToolDiscoveryResult {
441 tool: tool.clone(),
442 relevance_score,
443 detail_level,
444 };
445 results.push(result);
446 }
447 }
448
449 results.sort_by(|a, b| {
451 b.relevance_score
452 .partial_cmp(&a.relevance_score)
453 .unwrap_or(std::cmp::Ordering::Equal)
454 });
455
456 results
457 }
458
459 fn calculate_relevance(&self, tool: &McpToolInfo, keyword: &str) -> f64 {
461 let name_lower = tool.name.to_lowercase();
462 let description_lower = tool.description.to_lowercase();
463
464 let mut score: f64 = 0.0;
465
466 if name_lower == keyword {
468 score += 1.0;
469 }
470 else if name_lower.starts_with(keyword) {
472 score += 0.8;
473 }
474 else if name_lower.contains(keyword) {
476 score += 0.6;
477 }
478
479 if description_lower.contains(keyword) {
481 score += 0.3;
482 }
483
484 let schema_str = serde_json::to_string(&tool.input_schema)
486 .unwrap_or_default()
487 .to_lowercase();
488 if schema_str.contains(keyword) {
489 score += 0.2;
490 }
491
492 score.min(1.0)
493 }
494
495 pub fn stats(&self) -> ToolCacheStats {
497 self.cache.stats()
498 }
499}
500
501#[cfg(test)]
502mod tests {
503 use super::*;
504
505 #[test]
506 fn test_bloom_filter() {
507 let mut filter = BloomFilter::new(100, 0.01);
508
509 filter.insert("tool1");
510 filter.insert("tool2");
511 filter.insert("tool3");
512
513 assert!(filter.contains("tool1"));
514 assert!(filter.contains("tool2"));
515 assert!(filter.contains("tool3"));
516 assert!(!filter.contains("tool4"));
517 }
518
519 #[test]
520 fn test_cache_key_equality() {
521 let key1 = ToolDiscoveryCacheKey {
522 provider_name: "test".to_string(),
523 keyword: "search".to_string(),
524 detail_level: DetailLevel::Full,
525 };
526
527 let key2 = ToolDiscoveryCacheKey {
528 provider_name: "test".to_string(),
529 keyword: "search".to_string(),
530 detail_level: DetailLevel::Full,
531 };
532
533 assert_eq!(key1, key2);
534 }
535
536 #[test]
537 fn test_tool_discovery_cache() {
538 let cache = ToolDiscoveryCache::new(10);
539
540 let provider_name = "test_provider";
541 let keyword = "search";
542 let detail_level = DetailLevel::Full;
543
544 assert!(
546 cache
547 .get_cached_discovery(provider_name, keyword, detail_level)
548 .is_none()
549 );
550
551 let results = vec![ToolDiscoveryResult {
553 tool: McpToolInfo {
554 name: "search_files".to_string(),
555 description: "Search for files".to_string(),
556 provider: "test".to_string(),
557 input_schema: serde_json::json!({}),
558 },
559 relevance_score: 0.9,
560 detail_level,
561 }];
562
563 cache.cache_discovery(provider_name, keyword, detail_level, results.clone());
564
565 let cached = cache.get_cached_discovery(provider_name, keyword, detail_level);
567 assert!(cached.is_some());
568 assert_eq!(cached.unwrap().len(), 1);
569 }
570}