1use std::fmt::Write as _;
15
16use zeph_llm::LlmError;
17use zeph_llm::provider::{LlmProvider, Message, Role};
18
19use crate::tool::McpTool;
20
21#[derive(Debug, Clone)]
29enum CachedResult {
30 Ok(Vec<McpTool>),
31 Failed,
33}
34
35#[derive(Debug, Default, Clone)]
63pub struct PruningCache {
64 key: Option<(u64, u64)>,
65 result: Option<CachedResult>,
66}
67
68enum CacheLookup<'a> {
70 Hit(&'a [McpTool]),
72 NegativeHit,
74 Miss,
76}
77
78impl PruningCache {
79 #[must_use]
81 pub fn new() -> Self {
82 Self::default()
83 }
84
85 pub fn reset(&mut self) {
90 self.key = None;
91 self.result = None;
92 }
93
94 fn lookup(&self, msg_hash: u64, tool_hash: u64) -> CacheLookup<'_> {
95 match (&self.key, &self.result) {
96 (Some(k), Some(CachedResult::Ok(tools))) if *k == (msg_hash, tool_hash) => {
97 CacheLookup::Hit(tools)
98 }
99 (Some(k), Some(CachedResult::Failed)) if *k == (msg_hash, tool_hash) => {
100 CacheLookup::NegativeHit
101 }
102 _ => CacheLookup::Miss,
103 }
104 }
105
106 fn insert_ok(&mut self, msg_hash: u64, tool_hash: u64, tools: Vec<McpTool>) {
107 self.key = Some((msg_hash, tool_hash));
108 self.result = Some(CachedResult::Ok(tools));
109 }
110
111 fn insert_failed(&mut self, msg_hash: u64, tool_hash: u64) {
112 self.key = Some((msg_hash, tool_hash));
113 self.result = Some(CachedResult::Failed);
114 }
115}
116
117#[must_use]
123pub fn content_hash(s: &str) -> u64 {
124 let hash = blake3::hash(s.as_bytes());
125 u64::from_le_bytes(hash.as_bytes()[..8].try_into().expect("blake3 >= 8 bytes"))
126}
127
128#[must_use]
145pub fn tool_list_hash(tools: &[McpTool]) -> u64 {
146 let mut hasher = blake3::Hasher::new();
147 let mut sorted: Vec<&McpTool> = tools.iter().collect();
148 sorted.sort_by(|a, b| a.server_id.cmp(&b.server_id).then(a.name.cmp(&b.name)));
149 for tool in sorted {
150 hasher.update(tool.server_id.as_bytes());
151 hasher.update(b"\0");
152 hasher.update(tool.name.as_bytes());
153 hasher.update(b"\0");
154 hasher.update(tool.description.as_bytes());
155 hasher.update(b"\0");
156 match serde_json::to_vec(&tool.input_schema) {
157 Ok(schema_bytes) => {
158 hasher.update(&schema_bytes);
159 }
160 Err(_) => {
161 hasher.update(b"\x00");
162 }
163 }
164 hasher.update(b"\x01");
166 }
167 let hash = hasher.finalize();
168 u64::from_le_bytes(hash.as_bytes()[..8].try_into().expect("blake3 >= 8 bytes"))
169}
170
171#[cfg_attr(
190 feature = "profiling",
191 tracing::instrument(name = "mcp.pruning.prune_tools_cached", skip_all)
192)]
193pub async fn prune_tools_cached<P: LlmProvider>(
194 cache: &mut PruningCache,
195 all_tools: &[McpTool],
196 task_context: &str,
197 params: &PruningParams,
198 provider: &P,
199) -> Result<Vec<McpTool>, PruningError> {
200 let msg_hash = content_hash(task_context);
201 let tl_hash = tool_list_hash(all_tools);
202
203 match cache.lookup(msg_hash, tl_hash) {
204 CacheLookup::Hit(cached) => return Ok(cached.to_vec()),
205 CacheLookup::NegativeHit => {
206 tracing::warn!("pruning cache: negative hit, returning all tools without LLM call");
209 return Ok(all_tools.to_vec());
210 }
211 CacheLookup::Miss => {}
212 }
213
214 match prune_tools(all_tools, task_context, params, provider).await {
215 Ok(result) => {
216 cache.insert_ok(msg_hash, tl_hash, result.clone());
217 Ok(result)
218 }
219 Err(e) => {
220 cache.insert_failed(msg_hash, tl_hash);
221 Err(e)
222 }
223 }
224}
225
226#[non_exhaustive]
228#[derive(Debug, thiserror::Error)]
229pub enum PruningError {
230 #[error("pruning LLM call failed: {0}")]
232 LlmError(#[from] LlmError),
233 #[error("failed to parse pruning response as JSON array of tool names")]
235 ParseError,
236}
237
238#[derive(Debug, Clone)]
244pub struct PruningParams {
245 pub max_tools: usize,
247 pub min_tools_to_prune: usize,
249 pub always_include: Vec<String>,
256}
257
258impl Default for PruningParams {
259 fn default() -> Self {
260 Self {
261 max_tools: 15,
262 min_tools_to_prune: 10,
263 always_include: Vec::new(),
264 }
265 }
266}
267
268#[cfg_attr(
285 feature = "profiling",
286 tracing::instrument(name = "mcp.pruning.prune_tools", skip_all)
287)]
288pub async fn prune_tools<P: LlmProvider>(
289 all_tools: &[McpTool],
290 task_context: &str,
291 params: &PruningParams,
292 provider: &P,
293) -> Result<Vec<McpTool>, PruningError> {
294 if all_tools.len() < params.min_tools_to_prune {
295 return Ok(all_tools.to_vec());
296 }
297
298 let (pinned, candidates): (Vec<_>, Vec<_>) = all_tools
300 .iter()
301 .partition(|t| params.always_include.iter().any(|a| a == &t.name));
302
303 let tool_list = candidates.iter().fold(String::new(), |mut acc, t| {
307 let name = sanitize_tool_name(&t.name);
308 let desc = sanitize_tool_description(&t.description);
309 let _ = writeln!(acc, "- {name}: {desc}");
310 acc
311 });
312
313 let prompt = format!(
314 "Return a JSON array of tool names that are relevant to the task below.\n\
315 Return ONLY the JSON array, no explanation, no markdown.\n\n\
316 Task: {task_context}\n\n\
317 Available tools:\n{tool_list}"
318 );
319
320 let messages = vec![Message::from_legacy(Role::User, prompt)];
321 let response = provider.chat(&messages).await?;
322
323 let relevant_names = parse_name_array(&response)?;
325
326 let mut result: Vec<McpTool> = pinned.into_iter().cloned().collect();
329 let mut candidates_added: usize = 0;
330 for tool in &candidates {
331 if params.max_tools > 0 && candidates_added >= params.max_tools {
333 break;
334 }
335 if relevant_names.iter().any(|n| n == &tool.name) {
336 result.push((*tool).clone());
337 candidates_added += 1;
338 }
339 }
340
341 Ok(result)
342}
343
344fn sanitize_tool_name(name: &str) -> String {
348 name.chars().filter(|c| !c.is_control()).take(64).collect()
349}
350
351fn sanitize_tool_description(desc: &str) -> String {
355 desc.chars().filter(|c| !c.is_control()).take(200).collect()
356}
357
358fn parse_name_array(response: &str) -> Result<Vec<String>, PruningError> {
362 let stripped = response
364 .lines()
365 .filter(|l| !l.trim_start().starts_with("```"))
366 .collect::<Vec<_>>()
367 .join("\n");
368
369 let start = stripped.find('[').ok_or(PruningError::ParseError)?;
371 let end = stripped.rfind(']').ok_or(PruningError::ParseError)?;
372 if end <= start {
373 return Err(PruningError::ParseError);
374 }
375
376 let json_fragment = &stripped[start..=end];
377 let names: Vec<String> =
378 serde_json::from_str(json_fragment).map_err(|_| PruningError::ParseError)?;
379 Ok(names)
380}
381
382#[cfg(test)]
383mod tests {
384 use zeph_llm::mock::MockProvider;
385
386 use super::*;
387
388 fn make_tool(name: &str, description: &str) -> McpTool {
389 McpTool {
390 server_id: "test".into(),
391 name: name.into(),
392 description: description.into(),
393 input_schema: serde_json::Value::Null,
394 output_schema: None,
395 security_meta: crate::tool::ToolSecurityMeta::default(),
396 }
397 }
398
399 fn make_tool_with_server(server_id: &str, name: &str, description: &str) -> McpTool {
400 McpTool {
401 server_id: server_id.into(),
402 name: name.into(),
403 description: description.into(),
404 input_schema: serde_json::Value::Null,
405 output_schema: None,
406 security_meta: crate::tool::ToolSecurityMeta::default(),
407 }
408 }
409
410 fn params_with_max(max_tools: usize) -> PruningParams {
412 PruningParams {
413 max_tools,
414 min_tools_to_prune: 1,
415 always_include: Vec::new(),
416 }
417 }
418
419 #[test]
420 fn parse_plain_array() {
421 let names = parse_name_array(r#"["bash", "read", "write"]"#).unwrap();
422 assert_eq!(names, vec!["bash", "read", "write"]);
423 }
424
425 #[test]
426 fn parse_array_with_markdown_fences() {
427 let input = "```json\n[\"bash\", \"read\"]\n```";
428 let names = parse_name_array(input).unwrap();
429 assert_eq!(names, vec!["bash", "read"]);
430 }
431
432 #[test]
433 fn parse_array_with_preamble() {
434 let input = "Here are the relevant tools:\n[\"bash\", \"read\"]";
435 let names = parse_name_array(input).unwrap();
436 assert_eq!(names, vec!["bash", "read"]);
437 }
438
439 #[test]
440 fn parse_empty_array() {
441 let names = parse_name_array("[]").unwrap();
442 assert!(names.is_empty());
443 }
444
445 #[test]
446 fn parse_invalid_returns_error() {
447 assert!(parse_name_array("not json").is_err());
448 assert!(parse_name_array("").is_err());
449 assert!(parse_name_array("{\"key\": \"val\"}").is_err());
450 }
451
452 #[tokio::test]
455 async fn below_min_detected_early_return() {
456 let tools: Vec<McpTool> = (0..5).map(|i| make_tool(&format!("t{i}"), "d")).collect();
457 let provider = MockProvider::failing();
460 let params = PruningParams {
461 max_tools: 0,
462 min_tools_to_prune: 10, always_include: Vec::new(),
464 };
465
466 let result = prune_tools(&tools, "task", ¶ms, &provider)
467 .await
468 .unwrap();
469 assert_eq!(result.len(), 5, "all tools returned when below threshold");
470 }
471
472 #[tokio::test]
473 async fn always_include_pinned() {
474 let tools = vec![
475 make_tool("pinned", "always here"),
476 make_tool("candidate_a", "desc a"),
477 make_tool("candidate_b", "desc b"),
478 ];
479 let provider = MockProvider::with_responses(vec![r#"["candidate_a"]"#.into()]);
481 let params = PruningParams {
482 max_tools: 0,
483 min_tools_to_prune: 1,
484 always_include: vec!["pinned".into()],
485 };
486
487 let result = prune_tools(&tools, "task", ¶ms, &provider)
488 .await
489 .unwrap();
490 assert!(
491 result.iter().any(|t| t.name == "pinned"),
492 "pinned must survive pruning"
493 );
494 assert!(result.iter().any(|t| t.name == "candidate_a"));
495 }
496
497 #[tokio::test]
499 async fn always_include_matches_bare_name_across_servers() {
500 let tools = vec![
501 make_tool_with_server("server_a", "search", "search on A"),
502 make_tool_with_server("server_b", "search", "search on B"),
503 make_tool_with_server("server_a", "other", "other tool"),
504 ];
505 let provider = MockProvider::with_responses(vec![r#"["other"]"#.into()]);
507 let params = PruningParams {
508 max_tools: 0,
509 min_tools_to_prune: 1,
510 always_include: vec!["search".into()],
511 };
512
513 let result = prune_tools(&tools, "task", ¶ms, &provider)
514 .await
515 .unwrap();
516 assert_eq!(result.len(), 3, "both search tools + other must be present");
517 let search_count = result.iter().filter(|t| t.name == "search").count();
518 assert_eq!(
519 search_count, 2,
520 "both server_a:search and server_b:search must be pinned"
521 );
522 assert!(result.iter().any(|t| t.name == "other"));
523 }
524
525 #[tokio::test]
526 async fn max_tools_cap_respected() {
527 let tools: Vec<McpTool> = (0..5).map(|i| make_tool(&format!("t{i}"), "d")).collect();
528 let names_json = r#"["t0","t1","t2","t3","t4"]"#;
530 let provider = MockProvider::with_responses(vec![names_json.into()]);
531
532 let result = prune_tools(&tools, "task", ¶ms_with_max(2), &provider)
533 .await
534 .unwrap();
535 assert_eq!(
536 result.len(),
537 2,
538 "max_tools=2 must cap LLM-selected candidates"
539 );
540 }
541
542 #[tokio::test]
543 async fn llm_failure_propagates() {
544 let tools: Vec<McpTool> = (0..3).map(|i| make_tool(&format!("t{i}"), "d")).collect();
545 let provider = MockProvider::failing();
546 let result = prune_tools(&tools, "task", ¶ms_with_max(0), &provider).await;
547 assert!(matches!(result, Err(PruningError::LlmError(_))));
548 }
549
550 #[tokio::test]
551 async fn parse_error_propagates() {
552 let tools: Vec<McpTool> = (0..3).map(|i| make_tool(&format!("t{i}"), "d")).collect();
553 let provider = MockProvider::with_responses(vec!["not valid json at all".into()]);
554 let result = prune_tools(&tools, "task", ¶ms_with_max(0), &provider).await;
555 assert!(matches!(result, Err(PruningError::ParseError)));
556 }
557
558 #[tokio::test]
559 async fn max_tools_zero_means_no_cap() {
560 let tools: Vec<McpTool> = (0..5)
561 .map(|i| make_tool(&format!("tool{i}"), "desc"))
562 .collect();
563 let names_json = r#"["tool0","tool1","tool2","tool3","tool4"]"#;
564 let provider = MockProvider::with_responses(vec![names_json.into()]);
565 let params = params_with_max(0);
566
567 let result = prune_tools(&tools, "any task", ¶ms, &provider)
568 .await
569 .unwrap();
570 assert_eq!(result.len(), 5, "max_tools=0 must not cap the result");
571 }
572
573 #[test]
574 fn description_sanitization_strips_control_chars_and_caps() {
575 let desc = "line1\nline2\tinject";
577 let sanitized = sanitize_tool_description(desc);
578 assert!(!sanitized.contains('\n'));
579 assert!(!sanitized.contains('\t'));
580
581 let long_desc = "x".repeat(300);
583 assert_eq!(sanitize_tool_description(&long_desc).len(), 200);
584
585 let long_name = "a".repeat(100);
587 assert_eq!(sanitize_tool_name(&long_name).len(), 64);
588 }
589
590 #[tokio::test]
591 async fn always_include_bypasses_max_tools_cap() {
592 let tools = vec![
594 make_tool("pinned", "always here"),
595 make_tool("candidate_a", "desc a"),
596 make_tool("candidate_b", "desc b"),
597 ];
598 let provider =
599 MockProvider::with_responses(vec![r#"["candidate_a","candidate_b"]"#.into()]);
600 let params = PruningParams {
601 max_tools: 1,
602 min_tools_to_prune: 1,
603 always_include: vec!["pinned".into()],
604 };
605
606 let result = prune_tools(&tools, "task", ¶ms, &provider)
607 .await
608 .unwrap();
609
610 assert!(
612 result.iter().any(|t| t.name == "pinned"),
613 "pinned tool must bypass cap"
614 );
615 assert_eq!(result.len(), 2);
617 }
618
619 #[tokio::test]
622 async fn cache_positive_hit() {
623 let tools: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
626 let provider = MockProvider::with_responses(vec![r#"["t0","t1"]"#.into()]);
627 let params = params_with_max(0);
628 let mut cache = PruningCache::new();
629
630 let r1 = prune_tools_cached(&mut cache, &tools, "query", ¶ms, &provider)
631 .await
632 .unwrap();
633 let r2 = prune_tools_cached(&mut cache, &tools, "query", ¶ms, &provider)
634 .await
635 .unwrap();
636
637 assert_eq!(r1.len(), 2);
638 assert_eq!(r1.len(), r2.len(), "cache hit must return same result");
639 }
640
641 #[tokio::test]
642 async fn cache_miss_on_message_change() {
643 let tools: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
644 let provider =
645 MockProvider::with_responses(vec![r#"["t0","t1"]"#.into(), r#"["t0"]"#.into()]);
646 let params = params_with_max(0);
647 let mut cache = PruningCache::new();
648
649 let r1 = prune_tools_cached(&mut cache, &tools, "query_a", ¶ms, &provider)
650 .await
651 .unwrap();
652 let r2 = prune_tools_cached(&mut cache, &tools, "query_b", ¶ms, &provider)
653 .await
654 .unwrap();
655
656 assert_eq!(r1.len(), 2, "first call returns both tools");
657 assert_eq!(
658 r2.len(),
659 1,
660 "different message triggers cache miss and LLM call"
661 );
662 }
663
664 #[tokio::test]
665 async fn cache_miss_on_tool_list_change() {
666 let tools1: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
667 let mut tools2 = tools1.clone();
668 tools2.push(make_tool("t2", "new tool"));
669
670 let provider = MockProvider::with_responses(vec![
671 r#"["t0","t1"]"#.into(),
672 r#"["t0","t1","t2"]"#.into(),
673 ]);
674 let params = params_with_max(0);
675 let mut cache = PruningCache::new();
676
677 let r1 = prune_tools_cached(&mut cache, &tools1, "query", ¶ms, &provider)
678 .await
679 .unwrap();
680 let r2 = prune_tools_cached(&mut cache, &tools2, "query", ¶ms, &provider)
681 .await
682 .unwrap();
683
684 assert_eq!(r1.len(), 2);
685 assert_eq!(r2.len(), 3, "new tool triggers cache miss");
686 }
687
688 #[tokio::test]
689 async fn cache_negative_hit_skips_llm() {
690 let tools: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
691 let provider = MockProvider::failing();
692 let params = params_with_max(0);
693 let mut cache = PruningCache::new();
694
695 let r1 = prune_tools_cached(&mut cache, &tools, "query", ¶ms, &provider).await;
697 assert!(r1.is_err(), "first call must propagate LLM error");
698
699 let r2 = prune_tools_cached(&mut cache, &tools, "query", ¶ms, &provider)
702 .await
703 .unwrap();
704 assert_eq!(r2.len(), 2, "negative cache hit must return all tools");
705 }
706
707 #[tokio::test]
708 async fn cache_negative_hit_clears_on_reset() {
709 let tools: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
710 let provider = MockProvider::with_responses(vec![r#"["t0","t1"]"#.into()])
712 .with_errors(vec![zeph_llm::LlmError::Other("simulated failure".into())]);
713 let params = params_with_max(0);
714 let mut cache = PruningCache::new();
715
716 let r1 = prune_tools_cached(&mut cache, &tools, "query", ¶ms, &provider).await;
718 assert!(r1.is_err());
719
720 cache.reset();
722
723 let r2 = prune_tools_cached(&mut cache, &tools, "query", ¶ms, &provider)
725 .await
726 .unwrap();
727 assert_eq!(r2.len(), 2, "after reset the LLM must be retried");
728 }
729}