1use std::fmt::Write as _;
15
16use zeph_llm::LlmError;
17use zeph_llm::provider::{LlmProvider, Message, Role};
18
19use crate::tool::McpTool;
20
21#[derive(Debug, Clone)]
29enum CachedResult {
30 Ok(Vec<McpTool>),
31 Failed,
33}
34
35#[derive(Debug, Default, Clone)]
63pub struct PruningCache {
64 key: Option<(u64, u64)>,
65 result: Option<CachedResult>,
66}
67
68enum CacheLookup<'a> {
70 Hit(&'a [McpTool]),
72 NegativeHit,
74 Miss,
76}
77
78impl PruningCache {
79 #[must_use]
81 pub fn new() -> Self {
82 Self::default()
83 }
84
85 pub fn reset(&mut self) {
90 self.key = None;
91 self.result = None;
92 }
93
94 fn lookup(&self, msg_hash: u64, tool_hash: u64) -> CacheLookup<'_> {
95 match (&self.key, &self.result) {
96 (Some(k), Some(CachedResult::Ok(tools))) if *k == (msg_hash, tool_hash) => {
97 CacheLookup::Hit(tools)
98 }
99 (Some(k), Some(CachedResult::Failed)) if *k == (msg_hash, tool_hash) => {
100 CacheLookup::NegativeHit
101 }
102 _ => CacheLookup::Miss,
103 }
104 }
105
106 fn insert_ok(&mut self, msg_hash: u64, tool_hash: u64, tools: Vec<McpTool>) {
107 self.key = Some((msg_hash, tool_hash));
108 self.result = Some(CachedResult::Ok(tools));
109 }
110
111 fn insert_failed(&mut self, msg_hash: u64, tool_hash: u64) {
112 self.key = Some((msg_hash, tool_hash));
113 self.result = Some(CachedResult::Failed);
114 }
115}
116
117#[must_use]
123pub fn content_hash(s: &str) -> u64 {
124 let hash = blake3::hash(s.as_bytes());
125 u64::from_le_bytes(hash.as_bytes()[..8].try_into().expect("blake3 >= 8 bytes"))
126}
127
128#[must_use]
145pub fn tool_list_hash(tools: &[McpTool]) -> u64 {
146 let mut hasher = blake3::Hasher::new();
147 let mut sorted: Vec<&McpTool> = tools.iter().collect();
148 sorted.sort_by(|a, b| a.server_id.cmp(&b.server_id).then(a.name.cmp(&b.name)));
149 for tool in sorted {
150 hasher.update(tool.server_id.as_bytes());
151 hasher.update(b"\0");
152 hasher.update(tool.name.as_bytes());
153 hasher.update(b"\0");
154 hasher.update(tool.description.as_bytes());
155 hasher.update(b"\0");
156 match serde_json::to_vec(&tool.input_schema) {
157 Ok(schema_bytes) => {
158 hasher.update(&schema_bytes);
159 }
160 Err(_) => {
161 hasher.update(b"\x00");
162 }
163 }
164 hasher.update(b"\x01");
166 }
167 let hash = hasher.finalize();
168 u64::from_le_bytes(hash.as_bytes()[..8].try_into().expect("blake3 >= 8 bytes"))
169}
170
171pub async fn prune_tools_cached<P: LlmProvider>(
190 cache: &mut PruningCache,
191 all_tools: &[McpTool],
192 task_context: &str,
193 params: &PruningParams,
194 provider: &P,
195) -> Result<Vec<McpTool>, PruningError> {
196 let msg_hash = content_hash(task_context);
197 let tl_hash = tool_list_hash(all_tools);
198
199 match cache.lookup(msg_hash, tl_hash) {
200 CacheLookup::Hit(cached) => return Ok(cached.to_vec()),
201 CacheLookup::NegativeHit => {
202 tracing::warn!("pruning cache: negative hit, returning all tools without LLM call");
205 return Ok(all_tools.to_vec());
206 }
207 CacheLookup::Miss => {}
208 }
209
210 match prune_tools(all_tools, task_context, params, provider).await {
211 Ok(result) => {
212 cache.insert_ok(msg_hash, tl_hash, result.clone());
213 Ok(result)
214 }
215 Err(e) => {
216 cache.insert_failed(msg_hash, tl_hash);
217 Err(e)
218 }
219 }
220}
221
222#[derive(Debug, thiserror::Error)]
224pub enum PruningError {
225 #[error("pruning LLM call failed: {0}")]
227 LlmError(#[from] LlmError),
228 #[error("failed to parse pruning response as JSON array of tool names")]
230 ParseError,
231}
232
233#[derive(Debug, Clone)]
239pub struct PruningParams {
240 pub max_tools: usize,
242 pub min_tools_to_prune: usize,
244 pub always_include: Vec<String>,
251}
252
253impl Default for PruningParams {
254 fn default() -> Self {
255 Self {
256 max_tools: 15,
257 min_tools_to_prune: 10,
258 always_include: Vec::new(),
259 }
260 }
261}
262
263pub async fn prune_tools<P: LlmProvider>(
280 all_tools: &[McpTool],
281 task_context: &str,
282 params: &PruningParams,
283 provider: &P,
284) -> Result<Vec<McpTool>, PruningError> {
285 if all_tools.len() < params.min_tools_to_prune {
286 return Ok(all_tools.to_vec());
287 }
288
289 let (pinned, candidates): (Vec<_>, Vec<_>) = all_tools
291 .iter()
292 .partition(|t| params.always_include.iter().any(|a| a == &t.name));
293
294 let tool_list = candidates.iter().fold(String::new(), |mut acc, t| {
298 let name = sanitize_tool_name(&t.name);
299 let desc = sanitize_tool_description(&t.description);
300 let _ = writeln!(acc, "- {name}: {desc}");
301 acc
302 });
303
304 let prompt = format!(
305 "Return a JSON array of tool names that are relevant to the task below.\n\
306 Return ONLY the JSON array, no explanation, no markdown.\n\n\
307 Task: {task_context}\n\n\
308 Available tools:\n{tool_list}"
309 );
310
311 let messages = vec![Message::from_legacy(Role::User, prompt)];
312 let response = provider.chat(&messages).await?;
313
314 let relevant_names = parse_name_array(&response)?;
316
317 let mut result: Vec<McpTool> = pinned.into_iter().cloned().collect();
320 let mut candidates_added: usize = 0;
321 for tool in &candidates {
322 if params.max_tools > 0 && candidates_added >= params.max_tools {
324 break;
325 }
326 if relevant_names.iter().any(|n| n == &tool.name) {
327 result.push((*tool).clone());
328 candidates_added += 1;
329 }
330 }
331
332 Ok(result)
333}
334
335fn sanitize_tool_name(name: &str) -> String {
339 name.chars().filter(|c| !c.is_control()).take(64).collect()
340}
341
342fn sanitize_tool_description(desc: &str) -> String {
346 desc.chars().filter(|c| !c.is_control()).take(200).collect()
347}
348
349fn parse_name_array(response: &str) -> Result<Vec<String>, PruningError> {
353 let stripped = response
355 .lines()
356 .filter(|l| !l.trim_start().starts_with("```"))
357 .collect::<Vec<_>>()
358 .join("\n");
359
360 let start = stripped.find('[').ok_or(PruningError::ParseError)?;
362 let end = stripped.rfind(']').ok_or(PruningError::ParseError)?;
363 if end <= start {
364 return Err(PruningError::ParseError);
365 }
366
367 let json_fragment = &stripped[start..=end];
368 let names: Vec<String> =
369 serde_json::from_str(json_fragment).map_err(|_| PruningError::ParseError)?;
370 Ok(names)
371}
372
373#[cfg(test)]
374mod tests {
375 use zeph_llm::mock::MockProvider;
376
377 use super::*;
378
379 fn make_tool(name: &str, description: &str) -> McpTool {
380 McpTool {
381 server_id: "test".into(),
382 name: name.into(),
383 description: description.into(),
384 input_schema: serde_json::Value::Null,
385 security_meta: crate::tool::ToolSecurityMeta::default(),
386 }
387 }
388
389 fn make_tool_with_server(server_id: &str, name: &str, description: &str) -> McpTool {
390 McpTool {
391 server_id: server_id.into(),
392 name: name.into(),
393 description: description.into(),
394 input_schema: serde_json::Value::Null,
395 security_meta: crate::tool::ToolSecurityMeta::default(),
396 }
397 }
398
399 fn params_with_max(max_tools: usize) -> PruningParams {
401 PruningParams {
402 max_tools,
403 min_tools_to_prune: 1,
404 always_include: Vec::new(),
405 }
406 }
407
408 #[test]
409 fn parse_plain_array() {
410 let names = parse_name_array(r#"["bash", "read", "write"]"#).unwrap();
411 assert_eq!(names, vec!["bash", "read", "write"]);
412 }
413
414 #[test]
415 fn parse_array_with_markdown_fences() {
416 let input = "```json\n[\"bash\", \"read\"]\n```";
417 let names = parse_name_array(input).unwrap();
418 assert_eq!(names, vec!["bash", "read"]);
419 }
420
421 #[test]
422 fn parse_array_with_preamble() {
423 let input = "Here are the relevant tools:\n[\"bash\", \"read\"]";
424 let names = parse_name_array(input).unwrap();
425 assert_eq!(names, vec!["bash", "read"]);
426 }
427
428 #[test]
429 fn parse_empty_array() {
430 let names = parse_name_array("[]").unwrap();
431 assert!(names.is_empty());
432 }
433
434 #[test]
435 fn parse_invalid_returns_error() {
436 assert!(parse_name_array("not json").is_err());
437 assert!(parse_name_array("").is_err());
438 assert!(parse_name_array("{\"key\": \"val\"}").is_err());
439 }
440
441 #[tokio::test]
444 async fn below_min_detected_early_return() {
445 let tools: Vec<McpTool> = (0..5).map(|i| make_tool(&format!("t{i}"), "d")).collect();
446 let provider = MockProvider::failing();
449 let params = PruningParams {
450 max_tools: 0,
451 min_tools_to_prune: 10, always_include: Vec::new(),
453 };
454
455 let result = prune_tools(&tools, "task", ¶ms, &provider)
456 .await
457 .unwrap();
458 assert_eq!(result.len(), 5, "all tools returned when below threshold");
459 }
460
461 #[tokio::test]
462 async fn always_include_pinned() {
463 let tools = vec![
464 make_tool("pinned", "always here"),
465 make_tool("candidate_a", "desc a"),
466 make_tool("candidate_b", "desc b"),
467 ];
468 let provider = MockProvider::with_responses(vec![r#"["candidate_a"]"#.into()]);
470 let params = PruningParams {
471 max_tools: 0,
472 min_tools_to_prune: 1,
473 always_include: vec!["pinned".into()],
474 };
475
476 let result = prune_tools(&tools, "task", ¶ms, &provider)
477 .await
478 .unwrap();
479 assert!(
480 result.iter().any(|t| t.name == "pinned"),
481 "pinned must survive pruning"
482 );
483 assert!(result.iter().any(|t| t.name == "candidate_a"));
484 }
485
486 #[tokio::test]
488 async fn always_include_matches_bare_name_across_servers() {
489 let tools = vec![
490 make_tool_with_server("server_a", "search", "search on A"),
491 make_tool_with_server("server_b", "search", "search on B"),
492 make_tool_with_server("server_a", "other", "other tool"),
493 ];
494 let provider = MockProvider::with_responses(vec![r#"["other"]"#.into()]);
496 let params = PruningParams {
497 max_tools: 0,
498 min_tools_to_prune: 1,
499 always_include: vec!["search".into()],
500 };
501
502 let result = prune_tools(&tools, "task", ¶ms, &provider)
503 .await
504 .unwrap();
505 assert_eq!(result.len(), 3, "both search tools + other must be present");
506 let search_count = result.iter().filter(|t| t.name == "search").count();
507 assert_eq!(
508 search_count, 2,
509 "both server_a:search and server_b:search must be pinned"
510 );
511 assert!(result.iter().any(|t| t.name == "other"));
512 }
513
514 #[tokio::test]
515 async fn max_tools_cap_respected() {
516 let tools: Vec<McpTool> = (0..5).map(|i| make_tool(&format!("t{i}"), "d")).collect();
517 let names_json = r#"["t0","t1","t2","t3","t4"]"#;
519 let provider = MockProvider::with_responses(vec![names_json.into()]);
520
521 let result = prune_tools(&tools, "task", ¶ms_with_max(2), &provider)
522 .await
523 .unwrap();
524 assert_eq!(
525 result.len(),
526 2,
527 "max_tools=2 must cap LLM-selected candidates"
528 );
529 }
530
531 #[tokio::test]
532 async fn llm_failure_propagates() {
533 let tools: Vec<McpTool> = (0..3).map(|i| make_tool(&format!("t{i}"), "d")).collect();
534 let provider = MockProvider::failing();
535 let result = prune_tools(&tools, "task", ¶ms_with_max(0), &provider).await;
536 assert!(matches!(result, Err(PruningError::LlmError(_))));
537 }
538
539 #[tokio::test]
540 async fn parse_error_propagates() {
541 let tools: Vec<McpTool> = (0..3).map(|i| make_tool(&format!("t{i}"), "d")).collect();
542 let provider = MockProvider::with_responses(vec!["not valid json at all".into()]);
543 let result = prune_tools(&tools, "task", ¶ms_with_max(0), &provider).await;
544 assert!(matches!(result, Err(PruningError::ParseError)));
545 }
546
547 #[tokio::test]
548 async fn max_tools_zero_means_no_cap() {
549 let tools: Vec<McpTool> = (0..5)
550 .map(|i| make_tool(&format!("tool{i}"), "desc"))
551 .collect();
552 let names_json = r#"["tool0","tool1","tool2","tool3","tool4"]"#;
553 let provider = MockProvider::with_responses(vec![names_json.into()]);
554 let params = params_with_max(0);
555
556 let result = prune_tools(&tools, "any task", ¶ms, &provider)
557 .await
558 .unwrap();
559 assert_eq!(result.len(), 5, "max_tools=0 must not cap the result");
560 }
561
562 #[test]
563 fn description_sanitization_strips_control_chars_and_caps() {
564 let desc = "line1\nline2\tinject";
566 let sanitized = sanitize_tool_description(desc);
567 assert!(!sanitized.contains('\n'));
568 assert!(!sanitized.contains('\t'));
569
570 let long_desc = "x".repeat(300);
572 assert_eq!(sanitize_tool_description(&long_desc).len(), 200);
573
574 let long_name = "a".repeat(100);
576 assert_eq!(sanitize_tool_name(&long_name).len(), 64);
577 }
578
579 #[tokio::test]
580 async fn always_include_bypasses_max_tools_cap() {
581 let tools = vec![
583 make_tool("pinned", "always here"),
584 make_tool("candidate_a", "desc a"),
585 make_tool("candidate_b", "desc b"),
586 ];
587 let provider =
588 MockProvider::with_responses(vec![r#"["candidate_a","candidate_b"]"#.into()]);
589 let params = PruningParams {
590 max_tools: 1,
591 min_tools_to_prune: 1,
592 always_include: vec!["pinned".into()],
593 };
594
595 let result = prune_tools(&tools, "task", ¶ms, &provider)
596 .await
597 .unwrap();
598
599 assert!(
601 result.iter().any(|t| t.name == "pinned"),
602 "pinned tool must bypass cap"
603 );
604 assert_eq!(result.len(), 2);
606 }
607
608 #[tokio::test]
611 async fn cache_positive_hit() {
612 let tools: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
615 let provider = MockProvider::with_responses(vec![r#"["t0","t1"]"#.into()]);
616 let params = params_with_max(0);
617 let mut cache = PruningCache::new();
618
619 let r1 = prune_tools_cached(&mut cache, &tools, "query", ¶ms, &provider)
620 .await
621 .unwrap();
622 let r2 = prune_tools_cached(&mut cache, &tools, "query", ¶ms, &provider)
623 .await
624 .unwrap();
625
626 assert_eq!(r1.len(), 2);
627 assert_eq!(r1.len(), r2.len(), "cache hit must return same result");
628 }
629
630 #[tokio::test]
631 async fn cache_miss_on_message_change() {
632 let tools: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
633 let provider =
634 MockProvider::with_responses(vec![r#"["t0","t1"]"#.into(), r#"["t0"]"#.into()]);
635 let params = params_with_max(0);
636 let mut cache = PruningCache::new();
637
638 let r1 = prune_tools_cached(&mut cache, &tools, "query_a", ¶ms, &provider)
639 .await
640 .unwrap();
641 let r2 = prune_tools_cached(&mut cache, &tools, "query_b", ¶ms, &provider)
642 .await
643 .unwrap();
644
645 assert_eq!(r1.len(), 2, "first call returns both tools");
646 assert_eq!(
647 r2.len(),
648 1,
649 "different message triggers cache miss and LLM call"
650 );
651 }
652
653 #[tokio::test]
654 async fn cache_miss_on_tool_list_change() {
655 let tools1: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
656 let mut tools2 = tools1.clone();
657 tools2.push(make_tool("t2", "new tool"));
658
659 let provider = MockProvider::with_responses(vec![
660 r#"["t0","t1"]"#.into(),
661 r#"["t0","t1","t2"]"#.into(),
662 ]);
663 let params = params_with_max(0);
664 let mut cache = PruningCache::new();
665
666 let r1 = prune_tools_cached(&mut cache, &tools1, "query", ¶ms, &provider)
667 .await
668 .unwrap();
669 let r2 = prune_tools_cached(&mut cache, &tools2, "query", ¶ms, &provider)
670 .await
671 .unwrap();
672
673 assert_eq!(r1.len(), 2);
674 assert_eq!(r2.len(), 3, "new tool triggers cache miss");
675 }
676
677 #[tokio::test]
678 async fn cache_negative_hit_skips_llm() {
679 let tools: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
680 let provider = MockProvider::failing();
681 let params = params_with_max(0);
682 let mut cache = PruningCache::new();
683
684 let r1 = prune_tools_cached(&mut cache, &tools, "query", ¶ms, &provider).await;
686 assert!(r1.is_err(), "first call must propagate LLM error");
687
688 let r2 = prune_tools_cached(&mut cache, &tools, "query", ¶ms, &provider)
691 .await
692 .unwrap();
693 assert_eq!(r2.len(), 2, "negative cache hit must return all tools");
694 }
695
696 #[tokio::test]
697 async fn cache_negative_hit_clears_on_reset() {
698 let tools: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
699 let provider = MockProvider::with_responses(vec![r#"["t0","t1"]"#.into()])
701 .with_errors(vec![zeph_llm::LlmError::Other("simulated failure".into())]);
702 let params = params_with_max(0);
703 let mut cache = PruningCache::new();
704
705 let r1 = prune_tools_cached(&mut cache, &tools, "query", ¶ms, &provider).await;
707 assert!(r1.is_err());
708
709 cache.reset();
711
712 let r2 = prune_tools_cached(&mut cache, &tools, "query", ¶ms, &provider)
714 .await
715 .unwrap();
716 assert_eq!(r2.len(), 2, "after reset the LLM must be retried");
717 }
718}