Skip to main content

zeph_mcp/
pruning.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Dynamic MCP tool pruning for context optimization (#2204).
5//!
6//! The `prune_tools` free function filters a list of MCP tools to only those relevant
7//! to the current task, using an LLM call with a fast/cheap model. This reduces context
8//! usage and improves tool selection accuracy when MCP servers expose many tools.
9//!
10//! `zeph-mcp` does not depend on `zeph-config` (circular dependency: zeph-config ->
11//! zeph-mcp). Callers in `zeph-core` convert `ToolPruningConfig` into `PruningParams`
12//! before calling `prune_tools`.
13
14use std::fmt::Write as _;
15
16use zeph_llm::LlmError;
17use zeph_llm::provider::{LlmProvider, Message, Role};
18
19use crate::tool::McpTool;
20
21// ── Per-message pruning cache (#2298) ────────────────────────────────────────
22
23/// Cached outcome stored by [`PruningCache`].
24///
25/// [`Ok`] holds the previously-computed pruned tool list; [`Failed`] is a
26/// sentinel written when the LLM call failed, so subsequent lookups with the
27/// same key return the all-tools fallback without retrying the LLM.
28#[derive(Debug, Clone)]
29enum CachedResult {
30    Ok(Vec<McpTool>),
31    /// LLM call failed; caller should use the full tool list.
32    Failed,
33}
34
35/// Per-message cache for MCP tool pruning results.
36///
37/// Stores at most one entry keyed on `(message_content_hash, tool_list_hash)`.
38/// A cache miss triggers an LLM call; a hit returns the stored result
39/// immediately.  Negative entries (`Failed`) prevent retry storms when the
40/// pruning LLM is transiently unavailable.
41///
42/// # Cache contract
43///
44/// `PruningCache` returns previously-computed pruning results keyed on
45/// `(message_content_hash, tool_list_hash)`.
46///
47/// `tool_list_hash` includes: `server_id`, `name`, `description`, and
48/// `input_schema` for every tool.  Any change to tool metadata (not just the
49/// name set) produces a different hash and causes a cache miss.
50///
51/// `PruningCache::reset()` is additionally called on:
52/// - New user message (top of `process_user_message_inner`)
53/// - `tools/list_changed` notification (in `check_tool_refresh`)
54/// - Manual `/mcp add` or `/mcp remove` commands
55///
56/// `PruningParams` is **not** part of the cache key.  Callers must not change
57/// `PruningParams` within a single user turn; this invariant holds because
58/// params are derived from `ToolPruningConfig`, which is stable within a turn
59/// (config changes trigger a full agent rebuild, not a mid-turn param swap).
60///
61/// Designed for single-owner use (`&mut` on `Agent`). Not thread-safe.
62#[derive(Debug, Default, Clone)]
63pub struct PruningCache {
64    key: Option<(u64, u64)>,
65    result: Option<CachedResult>,
66}
67
68/// Outcome of a [`PruningCache::lookup`] call.
69enum CacheLookup<'a> {
70    /// Positive hit: pruned tool slice from a previous successful call.
71    Hit(&'a [McpTool]),
72    /// Negative hit: LLM previously failed; caller should use the full tool list.
73    NegativeHit,
74    /// No entry for this key.
75    Miss,
76}
77
78impl PruningCache {
79    /// Create a new, empty cache.
80    #[must_use]
81    pub fn new() -> Self {
82        Self::default()
83    }
84
85    /// Clear the cached entry.
86    ///
87    /// Must be called at the start of each user turn and whenever the MCP tool
88    /// list changes (via notification, `/mcp add`, or `/mcp remove`).
89    pub fn reset(&mut self) {
90        self.key = None;
91        self.result = None;
92    }
93
94    fn lookup(&self, msg_hash: u64, tool_hash: u64) -> CacheLookup<'_> {
95        match (&self.key, &self.result) {
96            (Some(k), Some(CachedResult::Ok(tools))) if *k == (msg_hash, tool_hash) => {
97                CacheLookup::Hit(tools)
98            }
99            (Some(k), Some(CachedResult::Failed)) if *k == (msg_hash, tool_hash) => {
100                CacheLookup::NegativeHit
101            }
102            _ => CacheLookup::Miss,
103        }
104    }
105
106    fn insert_ok(&mut self, msg_hash: u64, tool_hash: u64, tools: Vec<McpTool>) {
107        self.key = Some((msg_hash, tool_hash));
108        self.result = Some(CachedResult::Ok(tools));
109    }
110
111    fn insert_failed(&mut self, msg_hash: u64, tool_hash: u64) {
112        self.key = Some((msg_hash, tool_hash));
113        self.result = Some(CachedResult::Failed);
114    }
115}
116
117/// Compute a `u64` hash of a string using blake3 (first 8 bytes, little-endian).
118///
119/// # Panics
120///
121/// Never panics in practice: blake3 always produces at least 8 bytes of output.
122#[must_use]
123pub fn content_hash(s: &str) -> u64 {
124    let hash = blake3::hash(s.as_bytes());
125    u64::from_le_bytes(hash.as_bytes()[..8].try_into().expect("blake3 >= 8 bytes"))
126}
127
128/// Compute a `u64` hash of the full tool list metadata using blake3.
129///
130/// Hashes `server_id`, `name`, `description`, and `input_schema` for every
131/// tool, sorted by qualified name (`server_id` then `name`) for deterministic
132/// ordering regardless of list order.
133///
134/// **`BTreeMap` assumption**: `serde_json::to_vec` produces deterministic output
135/// because `serde_json::Map` defaults to `BTreeMap`-backed storage (sorted
136/// keys).  If the `preserve_order` feature of `serde_json` is ever enabled
137/// (switching `Map` to `IndexMap`), key order becomes insertion-order and
138/// hashing becomes non-deterministic.  Should `preserve_order` be needed,
139/// sort `Map` keys before serialising here.
140///
141/// # Panics
142///
143/// Never panics in practice: blake3 always produces at least 8 bytes of output.
144#[must_use]
145pub fn tool_list_hash(tools: &[McpTool]) -> u64 {
146    let mut hasher = blake3::Hasher::new();
147    let mut sorted: Vec<&McpTool> = tools.iter().collect();
148    sorted.sort_by(|a, b| a.server_id.cmp(&b.server_id).then(a.name.cmp(&b.name)));
149    for tool in sorted {
150        hasher.update(tool.server_id.as_bytes());
151        hasher.update(b"\0");
152        hasher.update(tool.name.as_bytes());
153        hasher.update(b"\0");
154        hasher.update(tool.description.as_bytes());
155        hasher.update(b"\0");
156        match serde_json::to_vec(&tool.input_schema) {
157            Ok(schema_bytes) => {
158                hasher.update(&schema_bytes);
159            }
160            Err(_) => {
161                hasher.update(b"\x00");
162            }
163        }
164        // Tool separator — prevents adjacent-field collisions.
165        hasher.update(b"\x01");
166    }
167    let hash = hasher.finalize();
168    u64::from_le_bytes(hash.as_bytes()[..8].try_into().expect("blake3 >= 8 bytes"))
169}
170
171/// Cache-aware wrapper around [`prune_tools`].
172///
173/// On a **positive cache hit**: returns the previously-computed pruned list
174/// without an LLM call.
175///
176/// On a **negative cache hit** (LLM previously failed for this key): returns
177/// `Ok(all_tools.to_vec())` without retrying the LLM, avoiding retry storms
178/// when the pruning LLM is transiently unavailable.
179///
180/// On a **cache miss**: calls [`prune_tools`], stores the result (success or
181/// failure), and returns.  On LLM failure the negative sentinel is cached and
182/// `Err(PruningError)` is returned so the caller can log and fall back.
183///
184/// # Errors
185///
186/// Propagates `PruningError` from [`prune_tools`] on the first (uncached) LLM
187/// failure.  Subsequent calls with the same key return `Ok(all_tools.to_vec())`
188/// from the negative cache entry.
189#[cfg_attr(
190    feature = "profiling",
191    tracing::instrument(name = "mcp.pruning.prune_tools_cached", skip_all)
192)]
193pub async fn prune_tools_cached<P: LlmProvider>(
194    cache: &mut PruningCache,
195    all_tools: &[McpTool],
196    task_context: &str,
197    params: &PruningParams,
198    provider: &P,
199) -> Result<Vec<McpTool>, PruningError> {
200    let msg_hash = content_hash(task_context);
201    let tl_hash = tool_list_hash(all_tools);
202
203    match cache.lookup(msg_hash, tl_hash) {
204        CacheLookup::Hit(cached) => return Ok(cached.to_vec()),
205        CacheLookup::NegativeHit => {
206            // Negative cache hit: LLM previously failed for this key.
207            // Return all tools as fallback without retrying to avoid retry storms.
208            tracing::warn!("pruning cache: negative hit, returning all tools without LLM call");
209            return Ok(all_tools.to_vec());
210        }
211        CacheLookup::Miss => {}
212    }
213
214    match prune_tools(all_tools, task_context, params, provider).await {
215        Ok(result) => {
216            cache.insert_ok(msg_hash, tl_hash, result.clone());
217            Ok(result)
218        }
219        Err(e) => {
220            cache.insert_failed(msg_hash, tl_hash);
221            Err(e)
222        }
223    }
224}
225
226/// Errors that can occur during tool pruning.
227#[non_exhaustive]
228#[derive(Debug, thiserror::Error)]
229pub enum PruningError {
230    /// LLM call failed.
231    #[error("pruning LLM call failed: {0}")]
232    LlmError(#[from] LlmError),
233    /// Could not extract a valid JSON array from the LLM response.
234    #[error("failed to parse pruning response as JSON array of tool names")]
235    ParseError,
236}
237
238/// Parameters for the `prune_tools` function.
239///
240/// Mirrors `zeph_config::ToolPruningConfig` but lives in `zeph-mcp` to avoid a
241/// circular crate dependency (`zeph-config` → `zeph-mcp`). Callers in `zeph-core`
242/// convert from `ToolPruningConfig`.
243#[derive(Debug, Clone)]
244pub struct PruningParams {
245    /// Maximum number of MCP tools to include after pruning.
246    pub max_tools: usize,
247    /// Minimum number of MCP tools below which pruning is skipped.
248    pub min_tools_to_prune: usize,
249    /// Tool names that are never pruned (always included).
250    ///
251    /// Matches on bare tool `name` (not qualified `server_id:name`).  When two
252    /// MCP servers expose a tool with the same name, both instances are pinned.
253    /// This is intentional: the config is user-facing and users specify tool
254    /// names, not server-qualified identifiers.
255    pub always_include: Vec<String>,
256}
257
258impl Default for PruningParams {
259    fn default() -> Self {
260        Self {
261            max_tools: 15,
262            min_tools_to_prune: 10,
263            always_include: Vec::new(),
264        }
265    }
266}
267
268/// Prune MCP tools to those relevant to the current task.
269///
270/// Returns a filtered subset of `all_tools` based on the LLM's assessment of relevance
271/// to `task_context`. Tools listed in `params.always_include` bypass the LLM filter.
272///
273/// # Behavior
274///
275/// - If `all_tools.len() < params.min_tools_to_prune`, returns `Ok(all_tools.to_vec())`.
276/// - On LLM failure or parse failure, returns `Err(PruningError)` — the caller should
277///   fall back to the full tool list and log at `WARN` level.
278/// - Result is capped at `params.max_tools` total tools. `max_tools == 0` means no cap.
279///
280/// # Errors
281///
282/// Returns `PruningError::LlmError` if the provider call fails.
283/// Returns `PruningError::ParseError` if the response cannot be parsed as a JSON array.
284#[cfg_attr(
285    feature = "profiling",
286    tracing::instrument(name = "mcp.pruning.prune_tools", skip_all)
287)]
288pub async fn prune_tools<P: LlmProvider>(
289    all_tools: &[McpTool],
290    task_context: &str,
291    params: &PruningParams,
292    provider: &P,
293) -> Result<Vec<McpTool>, PruningError> {
294    if all_tools.len() < params.min_tools_to_prune {
295        return Ok(all_tools.to_vec());
296    }
297
298    // Partition: always-include tools bypass the LLM filter.
299    let (pinned, candidates): (Vec<_>, Vec<_>) = all_tools
300        .iter()
301        .partition(|t| params.always_include.iter().any(|a| a == &t.name));
302
303    // Build the pruning prompt.
304    // Sanitize tool names and descriptions before interpolation to prevent prompt injection
305    // from attacker-controlled MCP servers.
306    let tool_list = candidates.iter().fold(String::new(), |mut acc, t| {
307        let name = sanitize_tool_name(&t.name);
308        let desc = sanitize_tool_description(&t.description);
309        let _ = writeln!(acc, "- {name}: {desc}");
310        acc
311    });
312
313    let prompt = format!(
314        "Return a JSON array of tool names that are relevant to the task below.\n\
315         Return ONLY the JSON array, no explanation, no markdown.\n\n\
316         Task: {task_context}\n\n\
317         Available tools:\n{tool_list}"
318    );
319
320    let messages = vec![Message::from_legacy(Role::User, prompt)];
321    let response = provider.chat(&messages).await?;
322
323    // Parse: strip markdown fences, find first `[` to last `]`.
324    let relevant_names = parse_name_array(&response)?;
325
326    // always_include tools are added unconditionally and bypass the max_tools cap;
327    // max_tools applies only to LLM-selected candidates.
328    let mut result: Vec<McpTool> = pinned.into_iter().cloned().collect();
329    let mut candidates_added: usize = 0;
330    for tool in &candidates {
331        // max_tools == 0 means no cap on LLM-selected candidates.
332        if params.max_tools > 0 && candidates_added >= params.max_tools {
333            break;
334        }
335        if relevant_names.iter().any(|n| n == &tool.name) {
336            result.push((*tool).clone());
337            candidates_added += 1;
338        }
339    }
340
341    Ok(result)
342}
343
344/// Sanitize a tool name before interpolating into an LLM prompt.
345///
346/// Strips control characters and caps at 64 characters.
347fn sanitize_tool_name(name: &str) -> String {
348    name.chars().filter(|c| !c.is_control()).take(64).collect()
349}
350
351/// Sanitize a tool description before interpolating into an LLM prompt.
352///
353/// Strips control characters and caps at 200 characters.
354fn sanitize_tool_description(desc: &str) -> String {
355    desc.chars().filter(|c| !c.is_control()).take(200).collect()
356}
357
358/// Extract tool names from an LLM response expected to contain a JSON array of strings.
359///
360/// Handles markdown code fences (` ```json ... ``` `) and leading/trailing whitespace.
361fn parse_name_array(response: &str) -> Result<Vec<String>, PruningError> {
362    // Strip markdown code fence lines.
363    let stripped = response
364        .lines()
365        .filter(|l| !l.trim_start().starts_with("```"))
366        .collect::<Vec<_>>()
367        .join("\n");
368
369    // Find the first `[` and last `]` to isolate the JSON array.
370    let start = stripped.find('[').ok_or(PruningError::ParseError)?;
371    let end = stripped.rfind(']').ok_or(PruningError::ParseError)?;
372    if end <= start {
373        return Err(PruningError::ParseError);
374    }
375
376    let json_fragment = &stripped[start..=end];
377    let names: Vec<String> =
378        serde_json::from_str(json_fragment).map_err(|_| PruningError::ParseError)?;
379    Ok(names)
380}
381
382#[cfg(test)]
383mod tests {
384    use zeph_llm::mock::MockProvider;
385
386    use super::*;
387
388    fn make_tool(name: &str, description: &str) -> McpTool {
389        McpTool {
390            server_id: "test".into(),
391            name: name.into(),
392            description: description.into(),
393            input_schema: serde_json::Value::Null,
394            output_schema: None,
395            security_meta: crate::tool::ToolSecurityMeta::default(),
396        }
397    }
398
399    fn make_tool_with_server(server_id: &str, name: &str, description: &str) -> McpTool {
400        McpTool {
401            server_id: server_id.into(),
402            name: name.into(),
403            description: description.into(),
404            input_schema: serde_json::Value::Null,
405            output_schema: None,
406            security_meta: crate::tool::ToolSecurityMeta::default(),
407        }
408    }
409
410    /// Build params with low `min_tools_to_prune` so tests aren't skipped early.
411    fn params_with_max(max_tools: usize) -> PruningParams {
412        PruningParams {
413            max_tools,
414            min_tools_to_prune: 1,
415            always_include: Vec::new(),
416        }
417    }
418
419    #[test]
420    fn parse_plain_array() {
421        let names = parse_name_array(r#"["bash", "read", "write"]"#).unwrap();
422        assert_eq!(names, vec!["bash", "read", "write"]);
423    }
424
425    #[test]
426    fn parse_array_with_markdown_fences() {
427        let input = "```json\n[\"bash\", \"read\"]\n```";
428        let names = parse_name_array(input).unwrap();
429        assert_eq!(names, vec!["bash", "read"]);
430    }
431
432    #[test]
433    fn parse_array_with_preamble() {
434        let input = "Here are the relevant tools:\n[\"bash\", \"read\"]";
435        let names = parse_name_array(input).unwrap();
436        assert_eq!(names, vec!["bash", "read"]);
437    }
438
439    #[test]
440    fn parse_empty_array() {
441        let names = parse_name_array("[]").unwrap();
442        assert!(names.is_empty());
443    }
444
445    #[test]
446    fn parse_invalid_returns_error() {
447        assert!(parse_name_array("not json").is_err());
448        assert!(parse_name_array("").is_err());
449        assert!(parse_name_array("{\"key\": \"val\"}").is_err());
450    }
451
452    // Replaced below_min_detected tautology (#2300): call prune_tools with a failing
453    // mock to verify the early-return path fires before the LLM is ever contacted.
454    #[tokio::test]
455    async fn below_min_detected_early_return() {
456        let tools: Vec<McpTool> = (0..5).map(|i| make_tool(&format!("t{i}"), "d")).collect();
457        // MockProvider::failing() would panic on any LLM call — if prune_tools invokes it,
458        // the test will error rather than pass.
459        let provider = MockProvider::failing();
460        let params = PruningParams {
461            max_tools: 0,
462            min_tools_to_prune: 10, // 5 tools < 10 → early return before LLM
463            always_include: Vec::new(),
464        };
465
466        let result = prune_tools(&tools, "task", &params, &provider)
467            .await
468            .unwrap();
469        assert_eq!(result.len(), 5, "all tools returned when below threshold");
470    }
471
472    #[tokio::test]
473    async fn always_include_pinned() {
474        let tools = vec![
475            make_tool("pinned", "always here"),
476            make_tool("candidate_a", "desc a"),
477            make_tool("candidate_b", "desc b"),
478        ];
479        // LLM returns only candidate_a; pinned must still appear.
480        let provider = MockProvider::with_responses(vec![r#"["candidate_a"]"#.into()]);
481        let params = PruningParams {
482            max_tools: 0,
483            min_tools_to_prune: 1,
484            always_include: vec!["pinned".into()],
485        };
486
487        let result = prune_tools(&tools, "task", &params, &provider)
488            .await
489            .unwrap();
490        assert!(
491            result.iter().any(|t| t.name == "pinned"),
492            "pinned must survive pruning"
493        );
494        assert!(result.iter().any(|t| t.name == "candidate_a"));
495    }
496
497    /// S4: `always_include` pins tools by bare name across multiple servers.
498    #[tokio::test]
499    async fn always_include_matches_bare_name_across_servers() {
500        let tools = vec![
501            make_tool_with_server("server_a", "search", "search on A"),
502            make_tool_with_server("server_b", "search", "search on B"),
503            make_tool_with_server("server_a", "other", "other tool"),
504        ];
505        // LLM returns only "other"; both "search" instances should still be pinned.
506        let provider = MockProvider::with_responses(vec![r#"["other"]"#.into()]);
507        let params = PruningParams {
508            max_tools: 0,
509            min_tools_to_prune: 1,
510            always_include: vec!["search".into()],
511        };
512
513        let result = prune_tools(&tools, "task", &params, &provider)
514            .await
515            .unwrap();
516        assert_eq!(result.len(), 3, "both search tools + other must be present");
517        let search_count = result.iter().filter(|t| t.name == "search").count();
518        assert_eq!(
519            search_count, 2,
520            "both server_a:search and server_b:search must be pinned"
521        );
522        assert!(result.iter().any(|t| t.name == "other"));
523    }
524
525    #[tokio::test]
526    async fn max_tools_cap_respected() {
527        let tools: Vec<McpTool> = (0..5).map(|i| make_tool(&format!("t{i}"), "d")).collect();
528        // LLM returns all 5 as relevant; max_tools=2 must cap candidates.
529        let names_json = r#"["t0","t1","t2","t3","t4"]"#;
530        let provider = MockProvider::with_responses(vec![names_json.into()]);
531
532        let result = prune_tools(&tools, "task", &params_with_max(2), &provider)
533            .await
534            .unwrap();
535        assert_eq!(
536            result.len(),
537            2,
538            "max_tools=2 must cap LLM-selected candidates"
539        );
540    }
541
542    #[tokio::test]
543    async fn llm_failure_propagates() {
544        let tools: Vec<McpTool> = (0..3).map(|i| make_tool(&format!("t{i}"), "d")).collect();
545        let provider = MockProvider::failing();
546        let result = prune_tools(&tools, "task", &params_with_max(0), &provider).await;
547        assert!(matches!(result, Err(PruningError::LlmError(_))));
548    }
549
550    #[tokio::test]
551    async fn parse_error_propagates() {
552        let tools: Vec<McpTool> = (0..3).map(|i| make_tool(&format!("t{i}"), "d")).collect();
553        let provider = MockProvider::with_responses(vec!["not valid json at all".into()]);
554        let result = prune_tools(&tools, "task", &params_with_max(0), &provider).await;
555        assert!(matches!(result, Err(PruningError::ParseError)));
556    }
557
558    #[tokio::test]
559    async fn max_tools_zero_means_no_cap() {
560        let tools: Vec<McpTool> = (0..5)
561            .map(|i| make_tool(&format!("tool{i}"), "desc"))
562            .collect();
563        let names_json = r#"["tool0","tool1","tool2","tool3","tool4"]"#;
564        let provider = MockProvider::with_responses(vec![names_json.into()]);
565        let params = params_with_max(0);
566
567        let result = prune_tools(&tools, "any task", &params, &provider)
568            .await
569            .unwrap();
570        assert_eq!(result.len(), 5, "max_tools=0 must not cap the result");
571    }
572
573    #[test]
574    fn description_sanitization_strips_control_chars_and_caps() {
575        // Newline and tab are control characters.
576        let desc = "line1\nline2\tinject";
577        let sanitized = sanitize_tool_description(desc);
578        assert!(!sanitized.contains('\n'));
579        assert!(!sanitized.contains('\t'));
580
581        // Cap at 200 characters.
582        let long_desc = "x".repeat(300);
583        assert_eq!(sanitize_tool_description(&long_desc).len(), 200);
584
585        // Name capped at 64 characters.
586        let long_name = "a".repeat(100);
587        assert_eq!(sanitize_tool_name(&long_name).len(), 64);
588    }
589
590    #[tokio::test]
591    async fn always_include_bypasses_max_tools_cap() {
592        // max_tools=1 — only 1 candidate from LLM allowed; but always_include adds unconditionally.
593        let tools = vec![
594            make_tool("pinned", "always here"),
595            make_tool("candidate_a", "desc a"),
596            make_tool("candidate_b", "desc b"),
597        ];
598        let provider =
599            MockProvider::with_responses(vec![r#"["candidate_a","candidate_b"]"#.into()]);
600        let params = PruningParams {
601            max_tools: 1,
602            min_tools_to_prune: 1,
603            always_include: vec!["pinned".into()],
604        };
605
606        let result = prune_tools(&tools, "task", &params, &provider)
607            .await
608            .unwrap();
609
610        // "pinned" is always present regardless of max_tools.
611        assert!(
612            result.iter().any(|t| t.name == "pinned"),
613            "pinned tool must bypass cap"
614        );
615        // Only 1 candidate slot remains after pinned bypasses cap; total = 1 (pinned) + 1 (candidate).
616        assert_eq!(result.len(), 2);
617    }
618
619    // ── PruningCache tests (#2298, #2300) ────────────────────────────────────
620
621    #[tokio::test]
622    async fn cache_positive_hit() {
623        // Two tools to exceed min_tools_to_prune=1; MockProvider has exactly one response.
624        // The second call must succeed from cache without consuming the (empty) response queue.
625        let tools: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
626        let provider = MockProvider::with_responses(vec![r#"["t0","t1"]"#.into()]);
627        let params = params_with_max(0);
628        let mut cache = PruningCache::new();
629
630        let r1 = prune_tools_cached(&mut cache, &tools, "query", &params, &provider)
631            .await
632            .unwrap();
633        let r2 = prune_tools_cached(&mut cache, &tools, "query", &params, &provider)
634            .await
635            .unwrap();
636
637        assert_eq!(r1.len(), 2);
638        assert_eq!(r1.len(), r2.len(), "cache hit must return same result");
639    }
640
641    #[tokio::test]
642    async fn cache_miss_on_message_change() {
643        let tools: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
644        let provider =
645            MockProvider::with_responses(vec![r#"["t0","t1"]"#.into(), r#"["t0"]"#.into()]);
646        let params = params_with_max(0);
647        let mut cache = PruningCache::new();
648
649        let r1 = prune_tools_cached(&mut cache, &tools, "query_a", &params, &provider)
650            .await
651            .unwrap();
652        let r2 = prune_tools_cached(&mut cache, &tools, "query_b", &params, &provider)
653            .await
654            .unwrap();
655
656        assert_eq!(r1.len(), 2, "first call returns both tools");
657        assert_eq!(
658            r2.len(),
659            1,
660            "different message triggers cache miss and LLM call"
661        );
662    }
663
664    #[tokio::test]
665    async fn cache_miss_on_tool_list_change() {
666        let tools1: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
667        let mut tools2 = tools1.clone();
668        tools2.push(make_tool("t2", "new tool"));
669
670        let provider = MockProvider::with_responses(vec![
671            r#"["t0","t1"]"#.into(),
672            r#"["t0","t1","t2"]"#.into(),
673        ]);
674        let params = params_with_max(0);
675        let mut cache = PruningCache::new();
676
677        let r1 = prune_tools_cached(&mut cache, &tools1, "query", &params, &provider)
678            .await
679            .unwrap();
680        let r2 = prune_tools_cached(&mut cache, &tools2, "query", &params, &provider)
681            .await
682            .unwrap();
683
684        assert_eq!(r1.len(), 2);
685        assert_eq!(r2.len(), 3, "new tool triggers cache miss");
686    }
687
688    #[tokio::test]
689    async fn cache_negative_hit_skips_llm() {
690        let tools: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
691        let provider = MockProvider::failing();
692        let params = params_with_max(0);
693        let mut cache = PruningCache::new();
694
695        // First call: LLM fails → error is returned and negative entry is cached.
696        let r1 = prune_tools_cached(&mut cache, &tools, "query", &params, &provider).await;
697        assert!(r1.is_err(), "first call must propagate LLM error");
698
699        // Second call: negative cache hit → returns all tools without calling LLM.
700        // MockProvider::failing() would panic on a second LLM call, proving cache is used.
701        let r2 = prune_tools_cached(&mut cache, &tools, "query", &params, &provider)
702            .await
703            .unwrap();
704        assert_eq!(r2.len(), 2, "negative cache hit must return all tools");
705    }
706
707    #[tokio::test]
708    async fn cache_negative_hit_clears_on_reset() {
709        let tools: Vec<McpTool> = (0..2).map(|i| make_tool(&format!("t{i}"), "d")).collect();
710        // Fail on the first LLM call; succeed on the second (after cache.reset()).
711        let provider = MockProvider::with_responses(vec![r#"["t0","t1"]"#.into()])
712            .with_errors(vec![zeph_llm::LlmError::Other("simulated failure".into())]);
713        let params = params_with_max(0);
714        let mut cache = PruningCache::new();
715
716        // First call: LLM fails → negative entry cached.
717        let r1 = prune_tools_cached(&mut cache, &tools, "query", &params, &provider).await;
718        assert!(r1.is_err());
719
720        // Reset clears the negative entry.
721        cache.reset();
722
723        // After reset the LLM is retried; the queued success response is now returned.
724        let r2 = prune_tools_cached(&mut cache, &tools, "query", &params, &provider)
725            .await
726            .unwrap();
727        assert_eq!(r2.len(), 2, "after reset the LLM must be retried");
728    }
729}