Skip to main content

zeph_tools/
schema_filter.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Dynamic tool schema filtering based on query-tool embedding similarity (#2020).
5//!
6//! Filters the set of tool definitions sent to the LLM on each turn, selecting
7//! only the most relevant tools based on cosine similarity between the user query
8//! embedding and pre-computed tool description embeddings.
9
10use std::collections::{HashMap, HashSet};
11
12use zeph_common::ToolName;
13use zeph_common::math::cosine_similarity;
14
15use crate::config::ToolDependency;
16
17/// Cached embedding for a tool definition.
18#[derive(Debug, Clone)]
19pub struct ToolEmbedding {
20    pub tool_id: ToolName,
21    pub embedding: Vec<f32>,
22}
23
24/// Reason a tool was included in the filtered set.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum InclusionReason {
27    /// Tool is in the always-on config list.
28    AlwaysOn,
29    /// Tool name was explicitly mentioned in the user query.
30    NameMentioned,
31    /// Tool scored within the top-K by similarity rank.
32    SimilarityRank,
33    /// MCP tool with too-short description to filter reliably.
34    ShortDescription,
35    /// Tool has no cached embedding (e.g. added after startup via MCP).
36    NoEmbedding,
37    /// Tool included because its hard requirements (`requires`) are all satisfied.
38    DependencyMet,
39    /// Tool received a similarity boost from satisfied soft prerequisites (`prefers`).
40    PreferenceBoost,
41}
42
43/// Exclusion reason for a tool that was blocked by the dependency gate.
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub struct DependencyExclusion {
46    pub tool_id: ToolName,
47    /// IDs of `requires` entries that are not yet satisfied.
48    pub unmet_requires: Vec<String>,
49}
50
51/// Result of filtering tool schemas against a query.
52#[derive(Debug, Clone)]
53pub struct ToolFilterResult {
54    /// Tool IDs that passed the filter.
55    pub included: HashSet<String>,
56    /// Tool IDs that were filtered out by similarity/embedding.
57    pub excluded: Vec<String>,
58    /// Per-tool similarity scores for filterable tools (sorted descending).
59    pub scores: Vec<(String, f32)>,
60    /// Reason each included tool was included.
61    pub inclusion_reasons: Vec<(String, InclusionReason)>,
62    /// Tools excluded specifically due to unmet hard dependencies.
63    pub dependency_exclusions: Vec<DependencyExclusion>,
64}
65
66/// Dependency graph for sequential tool availability (issue #2024).
67///
68/// Built once from `DependencyConfig` at agent start, reused across turns.
69/// Implements cycle detection via DFS topological sort: any tool in a detected
70/// cycle has all its `requires` removed (made unconditionally available) so it
71/// can never be permanently blocked by a dependency loop.
72///
73/// # Deadlock fallback
74///
75/// If all non-always-on tools would be blocked (either by config cycles or
76/// unreachable `requires` chains), `apply()` detects this at filter time and
77/// disables hard gates for that turn, logging a warning.
78#[derive(Debug, Clone, Default)]
79pub struct ToolDependencyGraph {
80    /// Map from `tool_id` -> its dependency spec.
81    /// Tools in cycles have their `requires` cleared at construction time.
82    deps: HashMap<String, ToolDependency>,
83}
84
85impl ToolDependencyGraph {
86    /// Build a dependency graph from a map of tool rules.
87    ///
88    /// Performs DFS-based cycle detection. All tools participating in any cycle
89    /// have their `requires` entries removed so they are always available.
90    #[must_use]
91    pub fn new(deps: HashMap<String, ToolDependency>) -> Self {
92        if deps.is_empty() {
93            return Self { deps };
94        }
95        let cycled = detect_cycles(&deps);
96        if !cycled.is_empty() {
97            tracing::warn!(
98                tools = ?cycled,
99                "tool dependency graph: cycles detected, removing requires for cycle participants"
100            );
101        }
102        let mut resolved = deps;
103        for tool_id in &cycled {
104            if let Some(dep) = resolved.get_mut(tool_id) {
105                dep.requires.clear();
106            }
107        }
108        Self { deps: resolved }
109    }
110
111    /// Returns true if no dependency rules are configured.
112    #[must_use]
113    pub fn is_empty(&self) -> bool {
114        self.deps.is_empty()
115    }
116
117    /// Check if a tool's hard requirements are all satisfied.
118    ///
119    /// Returns `true` if the tool has no `requires` entries, or if all entries
120    /// are present in `completed`. Returns `true` for unconfigured tools.
121    #[must_use]
122    pub fn requirements_met(&self, tool_id: &str, completed: &HashSet<String>) -> bool {
123        self.deps
124            .get(tool_id)
125            .is_none_or(|d| d.requires.iter().all(|r| completed.contains(r)))
126    }
127
128    /// Returns the unmet `requires` entries for a tool, if any.
129    #[must_use]
130    pub fn unmet_requires<'a>(
131        &'a self,
132        tool_id: &str,
133        completed: &HashSet<String>,
134    ) -> Vec<&'a str> {
135        self.deps.get(tool_id).map_or_else(Vec::new, |d| {
136            d.requires
137                .iter()
138                .filter(|r| !completed.contains(r.as_str()))
139                .map(String::as_str)
140                .collect()
141        })
142    }
143
144    /// Calculate similarity boost for soft prerequisites.
145    ///
146    /// Returns `min(boost_per_dep * met_count, max_total_boost)`.
147    #[must_use]
148    pub fn preference_boost(
149        &self,
150        tool_id: &str,
151        completed: &HashSet<String>,
152        boost_per_dep: f32,
153        max_total_boost: f32,
154    ) -> f32 {
155        self.deps.get(tool_id).map_or(0.0, |d| {
156            let met = d
157                .prefers
158                .iter()
159                .filter(|p| completed.contains(p.as_str()))
160                .count();
161            #[allow(clippy::cast_precision_loss)]
162            let boost = met as f32 * boost_per_dep;
163            boost.min(max_total_boost)
164        })
165    }
166
167    /// Apply hard dependency gates and preference boosts to a `ToolFilterResult`.
168    ///
169    /// Called after `ToolSchemaFilter::filter()` returns so the filter signature
170    /// remains unchanged (HIGH-03 fix). Dependency gates are applied AFTER TAFC
171    /// augmentation to prevent re-adding gated tools through augmentation (MED-04 fix).
172    ///
173    /// Only `AlwaysOn` tools bypass hard gates. `NameMentioned` tools are still subject
174    /// to `requires` checks — a user mentioning a gated tool name does not grant access.
175    ///
176    /// # Deadlock fallback (CRIT-01)
177    ///
178    /// If applying hard gates would remove ALL non-always-on included tools, the
179    /// gates are disabled for this turn and a warning is logged.
180    pub fn apply(
181        &self,
182        result: &mut ToolFilterResult,
183        completed: &HashSet<String>,
184        boost_per_dep: f32,
185        max_total_boost: f32,
186        always_on: &HashSet<String>,
187    ) {
188        if self.deps.is_empty() {
189            return;
190        }
191
192        // Only AlwaysOn tools bypass the hard dependency gate.
193        // NameMentioned tools still respect `requires` constraints: a user mentioning a gated
194        // tool name in their query does not grant access to it before its prerequisites run.
195        let bypassed: HashSet<&str> = result
196            .inclusion_reasons
197            .iter()
198            .filter(|(_, r)| matches!(r, InclusionReason::AlwaysOn))
199            .map(|(id, _)| id.as_str())
200            .collect();
201
202        let mut to_exclude: Vec<DependencyExclusion> = Vec::new();
203        for tool_id in &result.included {
204            if bypassed.contains(tool_id.as_str()) {
205                continue;
206            }
207            let unmet: Vec<String> = self
208                .unmet_requires(tool_id, completed)
209                .into_iter()
210                .map(str::to_owned)
211                .collect();
212            if !unmet.is_empty() {
213                to_exclude.push(DependencyExclusion {
214                    tool_id: tool_id.as_str().into(),
215                    unmet_requires: unmet,
216                });
217            }
218        }
219
220        // CRIT-01: deadlock fallback — if gating would leave no non-always-on tools,
221        // skip hard gates for this turn.
222        let non_always_on_included: usize = result
223            .included
224            .iter()
225            .filter(|id| !always_on.contains(id.as_str()))
226            .count();
227        if !to_exclude.is_empty() && to_exclude.len() >= non_always_on_included {
228            tracing::warn!(
229                gated = to_exclude.len(),
230                non_always_on = non_always_on_included,
231                "tool dependency graph: all non-always-on tools would be blocked; \
232                 disabling hard gates for this turn"
233            );
234            to_exclude.clear();
235        }
236
237        // Apply hard gates.
238        for excl in &to_exclude {
239            result.included.remove(excl.tool_id.as_str());
240            result.excluded.push(excl.tool_id.to_string());
241            tracing::debug!(
242                tool_id = %excl.tool_id,
243                unmet = ?excl.unmet_requires,
244                "tool dependency gate: excluded (requires not met)"
245            );
246        }
247        result.dependency_exclusions = to_exclude;
248
249        // Apply preference boosts: adjust scores for tools with satisfied prefers deps.
250        for (tool_id, score) in &mut result.scores {
251            if !result.included.contains(tool_id) {
252                continue;
253            }
254            let boost = self.preference_boost(tool_id, completed, boost_per_dep, max_total_boost);
255            if boost > 0.0 {
256                *score += boost;
257                // Record reason if not already recorded with a higher-priority reason.
258                let already_recorded = result.inclusion_reasons.iter().any(|(id, _)| id == tool_id);
259                if !already_recorded {
260                    result
261                        .inclusion_reasons
262                        .push((tool_id.clone(), InclusionReason::PreferenceBoost));
263                }
264                tracing::debug!(
265                    tool_id = %tool_id,
266                    boost,
267                    "tool dependency: preference boost applied"
268                );
269            }
270        }
271        // Re-sort scores after boosts.
272        result
273            .scores
274            .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
275    }
276
277    /// Filter a slice of tool IDs to those whose hard requirements are met.
278    ///
279    /// Used on iterations 1+ in the native tool loop via the agent helper
280    /// `apply_hard_dependency_gate_to_names`. Returns only the IDs that pass.
281    #[must_use]
282    pub fn filter_tool_names<'a>(
283        &self,
284        names: &[&'a str],
285        completed: &HashSet<String>,
286        always_on: &HashSet<String>,
287    ) -> Vec<&'a str> {
288        names
289            .iter()
290            .copied()
291            .filter(|n| always_on.contains(*n) || self.requirements_met(n, completed))
292            .collect()
293    }
294}
295
296/// DFS-based cycle detection for tool dependency graphs.
297///
298/// Returns the set of tool IDs that participate in any cycle.
299/// Algorithm: standard DFS with three states (unvisited/in-progress/done).
300/// When a back-edge is found (visiting an in-progress node), all nodes in the
301/// current DFS path that form part of the cycle are collected.
302fn detect_cycles(deps: &HashMap<String, ToolDependency>) -> HashSet<String> {
303    #[derive(Clone, Copy, PartialEq)]
304    enum State {
305        Unvisited,
306        InProgress,
307        Done,
308    }
309
310    let mut state: HashMap<&str, State> = HashMap::new();
311    let mut cycled: HashSet<String> = HashSet::new();
312
313    for start in deps.keys() {
314        if state
315            .get(start.as_str())
316            .copied()
317            .unwrap_or(State::Unvisited)
318            != State::Unvisited
319        {
320            continue;
321        }
322        let mut stack: Vec<(&str, usize)> = vec![(start.as_str(), 0)];
323        state.insert(start.as_str(), State::InProgress);
324
325        while let Some((node, child_idx)) = stack.last_mut() {
326            let node = *node;
327            let requires = deps
328                .get(node)
329                .map_or(&[] as &[String], |d| d.requires.as_slice());
330
331            if *child_idx >= requires.len() {
332                state.insert(node, State::Done);
333                stack.pop();
334                continue;
335            }
336
337            let child = requires[*child_idx].as_str();
338            *child_idx += 1;
339
340            match state.get(child).copied().unwrap_or(State::Unvisited) {
341                State::InProgress => {
342                    // Back-edge found: child is the cycle entry point already on the stack.
343                    // Only mark nodes from that entry point to the top of the stack as cycled.
344                    // Ancestors above the cycle entry are NOT part of the cycle.
345                    let cycle_start = stack.iter().position(|(n, _)| *n == child);
346                    if let Some(start) = cycle_start {
347                        for (path_node, _) in &stack[start..] {
348                            cycled.insert((*path_node).to_owned());
349                        }
350                    }
351                    cycled.insert(child.to_owned());
352                }
353                State::Unvisited => {
354                    state.insert(child, State::InProgress);
355                    stack.push((child, 0));
356                }
357                State::Done => {}
358            }
359        }
360    }
361
362    cycled
363}
364
365/// Core filter holding cached tool embeddings and config.
366pub struct ToolSchemaFilter {
367    always_on: HashSet<String>,
368    top_k: usize,
369    min_description_words: usize,
370    embeddings: Vec<ToolEmbedding>,
371    version: u64,
372}
373
374impl ToolSchemaFilter {
375    /// Create a new filter with pre-computed tool embeddings.
376    #[must_use]
377    pub fn new(
378        always_on: Vec<String>,
379        top_k: usize,
380        min_description_words: usize,
381        embeddings: Vec<ToolEmbedding>,
382    ) -> Self {
383        Self {
384            always_on: always_on.into_iter().collect(),
385            top_k,
386            min_description_words,
387            embeddings,
388            version: 0,
389        }
390    }
391
392    /// Current version counter. Incremented on recompute.
393    #[must_use]
394    pub fn version(&self) -> u64 {
395        self.version
396    }
397
398    /// Number of cached tool embeddings.
399    #[must_use]
400    pub fn embedding_count(&self) -> usize {
401        self.embeddings.len()
402    }
403
404    /// Configured top-K limit for similarity ranking.
405    #[must_use]
406    pub fn top_k(&self) -> usize {
407        self.top_k
408    }
409
410    /// Number of always-on tools in the filter config.
411    #[must_use]
412    pub fn always_on_count(&self) -> usize {
413        self.always_on.len()
414    }
415
416    /// Replace tool embeddings (e.g. after MCP tool changes) and bump version.
417    pub fn recompute(&mut self, embeddings: Vec<ToolEmbedding>) {
418        self.embeddings = embeddings;
419        self.version += 1;
420    }
421
422    /// Filter tools for a given user query embedding.
423    ///
424    /// `all_tool_ids` is the full set of tool IDs currently available.
425    /// `tool_descriptions` maps tool ID to its description (for short-description check).
426    /// `query_embedding` is the embedded user query.
427    #[must_use]
428    pub fn filter(
429        &self,
430        all_tool_ids: &[&str],
431        tool_descriptions: &[(&str, &str)],
432        query: &str,
433        query_embedding: &[f32],
434    ) -> ToolFilterResult {
435        let mut included = HashSet::new();
436        let mut inclusion_reasons = Vec::new();
437
438        // 1. Always-on tools.
439        for id in all_tool_ids {
440            if self.always_on.contains(*id) {
441                included.insert((*id).to_owned());
442                inclusion_reasons.push(((*id).to_owned(), InclusionReason::AlwaysOn));
443            }
444        }
445
446        // 2. Name-mentioned tools.
447        let mentioned = find_mentioned_tool_ids(query, all_tool_ids);
448        for id in &mentioned {
449            if included.insert(id.clone()) {
450                inclusion_reasons.push((id.clone(), InclusionReason::NameMentioned));
451            }
452        }
453
454        // 3. Short-description MCP tools.
455        for &(id, desc) in tool_descriptions {
456            let word_count = desc.split_whitespace().count();
457            if word_count < self.min_description_words && included.insert(id.to_owned()) {
458                inclusion_reasons.push((id.to_owned(), InclusionReason::ShortDescription));
459            }
460        }
461
462        // 4. Similarity-ranked filterable tools.
463        let mut scores: Vec<(String, f32)> = self
464            .embeddings
465            .iter()
466            .filter(|e| !included.contains(e.tool_id.as_str()))
467            .map(|e| {
468                let score = cosine_similarity(query_embedding, &e.embedding);
469                (e.tool_id.to_string(), score)
470            })
471            .collect();
472
473        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
474
475        let take = if self.top_k == 0 {
476            scores.len()
477        } else {
478            self.top_k.min(scores.len())
479        };
480
481        for (id, _score) in scores.iter().take(take) {
482            if included.insert(id.clone()) {
483                inclusion_reasons.push((id.clone(), InclusionReason::SimilarityRank));
484            }
485        }
486
487        // 5. Auto-include tools without embeddings (e.g. new MCP tools added after startup).
488        let embedded_ids: HashSet<&str> =
489            self.embeddings.iter().map(|e| e.tool_id.as_str()).collect();
490        for id in all_tool_ids {
491            if !included.contains(*id) && !embedded_ids.contains(*id) {
492                included.insert((*id).to_owned());
493                inclusion_reasons.push(((*id).to_owned(), InclusionReason::NoEmbedding));
494            }
495        }
496
497        // Build excluded list.
498        let excluded: Vec<String> = all_tool_ids
499            .iter()
500            .filter(|id| !included.contains(**id))
501            .map(|id| (*id).to_owned())
502            .collect();
503
504        ToolFilterResult {
505            included,
506            excluded,
507            scores,
508            inclusion_reasons,
509            dependency_exclusions: Vec::new(),
510        }
511    }
512}
513
514/// Find tool IDs explicitly mentioned in the query (case-insensitive, word-boundary aware).
515///
516/// Uses word-boundary checking: the character before and after the match must not be
517/// alphanumeric or underscore. This prevents false positives like "read" matching "thread".
518#[must_use]
519pub fn find_mentioned_tool_ids(query: &str, all_tool_ids: &[&str]) -> Vec<String> {
520    let query_lower = query.to_lowercase();
521    all_tool_ids
522        .iter()
523        .filter(|id| {
524            let id_lower = id.to_lowercase();
525            let mut start = 0;
526            while let Some(pos) = query_lower[start..].find(&id_lower) {
527                let abs_pos = start + pos;
528                let end_pos = abs_pos + id_lower.len();
529                let before_ok = abs_pos == 0
530                    || !query_lower.as_bytes()[abs_pos - 1].is_ascii_alphanumeric()
531                        && query_lower.as_bytes()[abs_pos - 1] != b'_';
532                let after_ok = end_pos >= query_lower.len()
533                    || !query_lower.as_bytes()[end_pos].is_ascii_alphanumeric()
534                        && query_lower.as_bytes()[end_pos] != b'_';
535                if before_ok && after_ok {
536                    return true;
537                }
538                start = abs_pos + 1;
539            }
540            false
541        })
542        .map(|id| (*id).to_owned())
543        .collect()
544}
545
546#[cfg(test)]
547mod tests {
548    use super::*;
549
550    fn make_filter(always_on: Vec<&str>, top_k: usize) -> ToolSchemaFilter {
551        ToolSchemaFilter::new(
552            always_on.into_iter().map(String::from).collect(),
553            top_k,
554            5,
555            vec![
556                ToolEmbedding {
557                    tool_id: "grep".into(),
558                    embedding: vec![0.9, 0.1, 0.0],
559                },
560                ToolEmbedding {
561                    tool_id: "write".into(),
562                    embedding: vec![0.1, 0.9, 0.0],
563                },
564                ToolEmbedding {
565                    tool_id: "find_path".into(),
566                    embedding: vec![0.5, 0.5, 0.0],
567                },
568                ToolEmbedding {
569                    tool_id: "web_scrape".into(),
570                    embedding: vec![0.0, 0.0, 1.0],
571                },
572                ToolEmbedding {
573                    tool_id: "diagnostics".into(),
574                    embedding: vec![0.0, 0.1, 0.9],
575                },
576            ],
577        )
578    }
579
580    #[test]
581    fn top_k_ranking_selects_most_similar() {
582        let filter = make_filter(vec!["bash"], 2);
583        let all_ids: Vec<&str> = vec![
584            "bash",
585            "grep",
586            "write",
587            "find_path",
588            "web_scrape",
589            "diagnostics",
590        ];
591        let query_emb = vec![0.8, 0.2, 0.0]; // close to grep
592        let result = filter.filter(&all_ids, &[], "search for pattern", &query_emb);
593
594        assert!(result.included.contains("bash")); // always-on
595        assert!(result.included.contains("grep")); // top similarity
596        assert!(result.included.contains("find_path")); // 2nd top
597        // web_scrape and diagnostics should be excluded
598        assert!(!result.included.contains("web_scrape"));
599        assert!(!result.included.contains("diagnostics"));
600    }
601
602    #[test]
603    fn always_on_tools_always_included() {
604        let filter = make_filter(vec!["bash", "read"], 1);
605        let all_ids: Vec<&str> = vec!["bash", "read", "grep", "write"];
606        let query_emb = vec![0.0, 1.0, 0.0]; // close to write
607        let result = filter.filter(&all_ids, &[], "test query", &query_emb);
608
609        assert!(result.included.contains("bash"));
610        assert!(result.included.contains("read"));
611        assert!(result.included.contains("write")); // top-1
612        assert!(!result.included.contains("grep"));
613    }
614
615    #[test]
616    fn name_mention_force_includes() {
617        let filter = make_filter(vec!["bash"], 1);
618        let all_ids: Vec<&str> = vec!["bash", "grep", "web_scrape", "write"];
619        let query_emb = vec![0.0, 1.0, 0.0]; // close to write
620        let result = filter.filter(&all_ids, &[], "use web_scrape to fetch", &query_emb);
621
622        assert!(result.included.contains("web_scrape")); // name match
623        assert!(result.included.contains("write")); // top-1
624        assert!(result.included.contains("bash")); // always-on
625    }
626
627    #[test]
628    fn short_mcp_description_auto_included() {
629        let filter = make_filter(vec!["bash"], 1);
630        let all_ids: Vec<&str> = vec!["bash", "grep", "mcp_query"];
631        let descriptions: Vec<(&str, &str)> = vec![
632            ("mcp_query", "Run query"),
633            ("grep", "Search file contents recursively"),
634        ];
635        let query_emb = vec![0.9, 0.1, 0.0];
636        let result = filter.filter(&all_ids, &descriptions, "test", &query_emb);
637
638        assert!(result.included.contains("mcp_query")); // short desc (2 words)
639    }
640
641    #[test]
642    fn empty_embeddings_includes_all_via_no_embedding_fallback() {
643        let filter = ToolSchemaFilter::new(vec!["bash".into()], 6, 5, vec![]);
644        let all_ids: Vec<&str> = vec!["bash", "grep", "write"];
645        let query_emb = vec![0.5, 0.5, 0.0];
646        let result = filter.filter(&all_ids, &[], "test", &query_emb);
647
648        // All tools included: bash (always-on), grep+write (NoEmbedding fallback)
649        assert!(result.included.contains("bash"));
650        assert!(result.included.contains("grep"));
651        assert!(result.included.contains("write"));
652        assert!(result.excluded.is_empty());
653    }
654
655    #[test]
656    fn top_k_zero_includes_all_filterable() {
657        let filter = make_filter(vec!["bash"], 0);
658        let all_ids: Vec<&str> = vec![
659            "bash",
660            "grep",
661            "write",
662            "find_path",
663            "web_scrape",
664            "diagnostics",
665        ];
666        let query_emb = vec![0.1, 0.1, 0.1];
667        let result = filter.filter(&all_ids, &[], "test", &query_emb);
668
669        assert_eq!(result.included.len(), 6); // all included
670        assert!(result.excluded.is_empty());
671    }
672
673    #[test]
674    fn top_k_exceeds_filterable_count_includes_all() {
675        let filter = make_filter(vec!["bash"], 100);
676        let all_ids: Vec<&str> = vec![
677            "bash",
678            "grep",
679            "write",
680            "find_path",
681            "web_scrape",
682            "diagnostics",
683        ];
684        let query_emb = vec![0.1, 0.1, 0.1];
685        let result = filter.filter(&all_ids, &[], "test", &query_emb);
686
687        assert_eq!(result.included.len(), 6);
688    }
689
690    #[test]
691    fn accessors_return_configured_values() {
692        let filter = make_filter(vec!["bash", "read"], 7);
693        assert_eq!(filter.top_k(), 7);
694        assert_eq!(filter.always_on_count(), 2);
695        assert_eq!(filter.embedding_count(), 5);
696    }
697
698    #[test]
699    fn version_counter_incremented_on_recompute() {
700        let mut filter = make_filter(vec![], 3);
701        assert_eq!(filter.version(), 0);
702        filter.recompute(vec![]);
703        assert_eq!(filter.version(), 1);
704        filter.recompute(vec![]);
705        assert_eq!(filter.version(), 2);
706    }
707
708    #[test]
709    fn inclusion_reason_correctness() {
710        let filter = make_filter(vec!["bash"], 1);
711        let all_ids: Vec<&str> = vec!["bash", "grep", "web_scrape", "write"];
712        let descriptions: Vec<(&str, &str)> = vec![("web_scrape", "Scrape")]; // 1 word
713        let query_emb = vec![0.1, 0.9, 0.0]; // close to write
714        let result = filter.filter(&all_ids, &descriptions, "test query", &query_emb);
715
716        let reasons: std::collections::HashMap<String, InclusionReason> =
717            result.inclusion_reasons.into_iter().collect();
718        assert_eq!(reasons.get("bash"), Some(&InclusionReason::AlwaysOn));
719        assert_eq!(
720            reasons.get("web_scrape"),
721            Some(&InclusionReason::ShortDescription)
722        );
723        assert_eq!(reasons.get("write"), Some(&InclusionReason::SimilarityRank));
724    }
725
726    #[test]
727    fn cosine_similarity_identical_vectors() {
728        let v = vec![1.0, 2.0, 3.0];
729        let sim = cosine_similarity(&v, &v);
730        assert!((sim - 1.0).abs() < 1e-5);
731    }
732
733    #[test]
734    fn cosine_similarity_orthogonal_vectors() {
735        let a = vec![1.0, 0.0];
736        let b = vec![0.0, 1.0];
737        let sim = cosine_similarity(&a, &b);
738        assert!(sim.abs() < 1e-5);
739    }
740
741    #[test]
742    fn cosine_similarity_empty_returns_zero() {
743        assert!(cosine_similarity(&[], &[]) < f32::EPSILON);
744    }
745
746    #[test]
747    fn cosine_similarity_mismatched_length_returns_zero() {
748        assert!(cosine_similarity(&[1.0], &[1.0, 2.0]) < f32::EPSILON);
749    }
750
751    #[test]
752    fn find_mentioned_tool_ids_case_insensitive() {
753        let ids = vec!["web_scrape", "grep", "Bash"];
754        let found = find_mentioned_tool_ids("use WEB_SCRAPE and BASH", &ids);
755        assert!(found.contains(&"web_scrape".to_owned()));
756        assert!(found.contains(&"Bash".to_owned()));
757        assert!(!found.contains(&"grep".to_owned()));
758    }
759
760    #[test]
761    fn find_mentioned_tool_ids_word_boundary_no_false_positives() {
762        let ids = vec!["read", "edit", "fetch", "grep"];
763        // "read" should NOT match inside "thread" or "breadcrumb"
764        let found = find_mentioned_tool_ids("thread breadcrumb", &ids);
765        assert!(found.is_empty());
766    }
767
768    #[test]
769    fn find_mentioned_tool_ids_word_boundary_matches_standalone() {
770        let ids = vec!["read", "edit"];
771        let found = find_mentioned_tool_ids("please read and edit the file", &ids);
772        assert!(found.contains(&"read".to_owned()));
773        assert!(found.contains(&"edit".to_owned()));
774    }
775
776    // --- ToolDependencyGraph tests ---
777
778    fn make_dep_graph(rules: &[(&str, Vec<&str>, Vec<&str>)]) -> ToolDependencyGraph {
779        let deps = rules
780            .iter()
781            .map(|(id, requires, prefers)| {
782                (
783                    (*id).to_owned(),
784                    crate::config::ToolDependency {
785                        requires: requires.iter().map(|s| (*s).to_owned()).collect(),
786                        prefers: prefers.iter().map(|s| (*s).to_owned()).collect(),
787                    },
788                )
789            })
790            .collect();
791        ToolDependencyGraph::new(deps)
792    }
793
794    fn completed(ids: &[&str]) -> HashSet<String> {
795        ids.iter().map(|s| (*s).to_owned()).collect()
796    }
797
798    #[test]
799    fn requirements_met_no_deps() {
800        let graph = make_dep_graph(&[]);
801        assert!(graph.requirements_met("any_tool", &completed(&[])));
802    }
803
804    #[test]
805    fn requirements_met_all_satisfied() {
806        let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
807        assert!(graph.requirements_met("apply_patch", &completed(&["read"])));
808    }
809
810    #[test]
811    fn requirements_met_unmet() {
812        let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
813        assert!(!graph.requirements_met("apply_patch", &completed(&[])));
814    }
815
816    #[test]
817    fn requirements_met_unconfigured_tool() {
818        let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
819        // tools not in the graph are always available
820        assert!(graph.requirements_met("grep", &completed(&[])));
821    }
822
823    #[test]
824    fn preference_boost_none_met() {
825        let graph = make_dep_graph(&[("format", vec![], vec!["search", "grep"])]);
826        let boost = graph.preference_boost("format", &completed(&[]), 0.15, 0.2);
827        assert!(boost < f32::EPSILON);
828    }
829
830    #[test]
831    fn preference_boost_partial() {
832        let graph = make_dep_graph(&[("format", vec![], vec!["search", "grep"])]);
833        let boost = graph.preference_boost("format", &completed(&["search"]), 0.15, 0.2);
834        assert!((boost - 0.15).abs() < 1e-5);
835    }
836
837    #[test]
838    fn preference_boost_capped_at_max() {
839        // 3 prefs x 0.15 = 0.45 but max is 0.2
840        let graph = make_dep_graph(&[("format", vec![], vec!["a", "b", "c"])]);
841        let boost = graph.preference_boost("format", &completed(&["a", "b", "c"]), 0.15, 0.2);
842        assert!((boost - 0.2).abs() < 1e-5);
843    }
844
845    #[test]
846    fn cycle_detection_simple_cycle() {
847        // A requires B, B requires A → both should have requires cleared
848        let graph = make_dep_graph(&[
849            ("tool_a", vec!["tool_b"], vec![]),
850            ("tool_b", vec!["tool_a"], vec![]),
851        ]);
852        // After cycle removal both should be unconditionally available
853        assert!(graph.requirements_met("tool_a", &completed(&[])));
854        assert!(graph.requirements_met("tool_b", &completed(&[])));
855    }
856
857    #[test]
858    fn cycle_detection_does_not_affect_non_cycle_tools() {
859        // A requires B, B requires C (no cycle), C requires D (cycle: D requires C)
860        let graph = make_dep_graph(&[
861            ("tool_a", vec!["tool_b"], vec![]),
862            ("tool_b", vec!["tool_c"], vec![]),
863            ("tool_c", vec!["tool_d"], vec![]),
864            ("tool_d", vec!["tool_c"], vec![]), // cycle: c <-> d
865        ]);
866        // tool_c and tool_d participate in cycle → unconditionally available
867        assert!(graph.requirements_met("tool_c", &completed(&[])));
868        assert!(graph.requirements_met("tool_d", &completed(&[])));
869        // tool_a and tool_b are NOT in a cycle → still gated
870        assert!(!graph.requirements_met("tool_a", &completed(&[])));
871        assert!(!graph.requirements_met("tool_b", &completed(&[])));
872    }
873
874    #[test]
875    fn apply_excludes_gated_tool() {
876        let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
877        let filter = make_filter(vec!["bash"], 5);
878        let all_ids = vec!["bash", "read", "apply_patch", "grep"];
879        let query_emb = vec![0.5, 0.5, 0.0];
880        let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
881        // Ensure apply_patch is included before dependency gate
882        result.included.insert("apply_patch".into());
883
884        let always_on: HashSet<String> = ["bash".into()].into();
885        graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
886
887        assert!(!result.included.contains("apply_patch"));
888        assert_eq!(result.dependency_exclusions.len(), 1);
889        assert_eq!(result.dependency_exclusions[0].tool_id, "apply_patch");
890        assert_eq!(result.dependency_exclusions[0].unmet_requires, vec!["read"]);
891    }
892
893    #[test]
894    fn apply_includes_gated_tool_when_dep_met() {
895        let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
896        let filter = make_filter(vec!["bash"], 5);
897        let all_ids = vec!["bash", "read", "apply_patch"];
898        let query_emb = vec![0.5, 0.5, 0.0];
899        let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
900        result.included.insert("apply_patch".into());
901
902        let always_on: HashSet<String> = ["bash".into()].into();
903        graph.apply(&mut result, &completed(&["read"]), 0.15, 0.2, &always_on);
904
905        assert!(result.included.contains("apply_patch"));
906        assert!(result.dependency_exclusions.is_empty());
907    }
908
909    #[test]
910    fn apply_deadlock_fallback_when_all_gated() {
911        // Build a minimal filter with no embeddings so only bash (always-on) and
912        // only_tool (NoEmbedding) are in the result set.
913        let filter = ToolSchemaFilter::new(
914            vec!["bash".into()],
915            5,
916            5,
917            vec![], // no embeddings: only_tool will be included via NoEmbedding fallback
918        );
919        let graph = make_dep_graph(&[("only_tool", vec!["missing"], vec![])]);
920        let all_ids = vec!["bash", "only_tool"];
921        let query_emb = vec![0.5, 0.5, 0.0];
922        let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
923
924        // At this point: included = {bash, only_tool}, non_always_on_included = 1
925        assert!(result.included.contains("only_tool"));
926        assert!(result.included.contains("bash"));
927
928        let always_on: HashSet<String> = ["bash".into()].into();
929        graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
930
931        // Deadlock fallback: only_tool remains included (all non-always-on would be blocked)
932        assert!(result.included.contains("only_tool"));
933        assert!(result.dependency_exclusions.is_empty());
934    }
935
936    #[test]
937    fn apply_always_on_bypasses_gate() {
938        let graph = make_dep_graph(&[("bash", vec!["nonexistent"], vec![])]);
939        let filter = make_filter(vec!["bash"], 5);
940        let all_ids = vec!["bash", "grep"];
941        let query_emb = vec![0.5, 0.5, 0.0];
942        let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
943
944        let always_on: HashSet<String> = ["bash".into()].into();
945        graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
946
947        // bash is always-on, bypasses hard gate
948        assert!(result.included.contains("bash"));
949    }
950
951    // --- Regression tests for HIGH-01 and HIGH-02 ---
952
953    /// HIGH-01 regression: ancestors of a cycle must NOT lose their `requires`.
954    ///
955    /// Graph: A requires B, B requires C, C requires D, D requires C (cycle: C↔D).
956    /// Before fix: A and B were marked cycled and had their requires cleared.
957    /// After fix: only C and D are in the cycle; A and B remain gated.
958    #[test]
959    fn cycle_detection_does_not_clear_ancestor_requires() {
960        let graph = make_dep_graph(&[
961            ("tool_a", vec!["tool_b"], vec![]),
962            ("tool_b", vec!["tool_c"], vec![]),
963            ("tool_c", vec!["tool_d"], vec![]),
964            ("tool_d", vec!["tool_c"], vec![]),
965        ]);
966        // Cycle participants (C, D) must be unconditionally available.
967        assert!(graph.requirements_met("tool_c", &completed(&[])));
968        assert!(graph.requirements_met("tool_d", &completed(&[])));
969        // Non-cycle ancestors (A, B) must still be gated.
970        assert!(!graph.requirements_met("tool_a", &completed(&[])));
971        assert!(!graph.requirements_met("tool_b", &completed(&[])));
972        // A unlocks when B completes; B unlocks when C completes (C is free).
973        assert!(graph.requirements_met("tool_b", &completed(&["tool_c"])));
974        assert!(graph.requirements_met("tool_a", &completed(&["tool_b"])));
975    }
976
977    /// HIGH-02 regression: `NameMentioned` tools must still respect hard gates.
978    ///
979    /// If the user says "use `apply_patch` to fix the bug", `apply_patch` is
980    /// `NameMentioned` but must NOT bypass its `requires=[read]` constraint.
981    #[test]
982    fn name_mentioned_does_not_bypass_hard_gate() {
983        let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
984        let filter = make_filter(vec!["bash"], 5);
985        // Query explicitly mentions apply_patch → NameMentioned reason
986        let all_ids = vec!["bash", "read", "apply_patch"];
987        let query_emb = vec![0.5, 0.5, 0.0];
988        let mut result = filter.filter(&all_ids, &[], "use apply_patch to fix the bug", &query_emb);
989
990        // apply_patch must be in included (name-mentioned) before dependency gate
991        assert!(result.included.contains("apply_patch"));
992        let reason = result
993            .inclusion_reasons
994            .iter()
995            .find(|(id, _)| id == "apply_patch")
996            .map(|(_, r)| r);
997        assert_eq!(reason, Some(&InclusionReason::NameMentioned));
998
999        let always_on: HashSet<String> = ["bash".into()].into();
1000        graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
1001
1002        // After gate: apply_patch must be excluded (read not completed)
1003        assert!(!result.included.contains("apply_patch"));
1004        assert_eq!(result.dependency_exclusions.len(), 1);
1005        assert_eq!(result.dependency_exclusions[0].tool_id, "apply_patch");
1006    }
1007
1008    // --- Multi-turn dependency chain integration tests ---
1009    //
1010    // These tests simulate the session lifecycle: `completed_tool_ids` grows
1011    // across turns, unlocking downstream tools one step at a time.
1012
1013    /// Turn 1: only `read` is available (no completed tools yet).
1014    /// Turn 2: after `read` completes, `apply_patch` unlocks.
1015    #[test]
1016    fn multi_turn_chain_two_steps() {
1017        // read → apply_patch (linear dependency)
1018        let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
1019        let always_on: HashSet<String> = ["bash".into()].into();
1020
1021        // --- Turn 1: nothing completed yet ---
1022        let filter = ToolSchemaFilter::new(vec!["bash".into()], 5, 5, vec![]);
1023        let all_ids = vec!["bash", "read", "apply_patch"];
1024        let q = vec![0.5, 0.5, 0.0];
1025        let mut result = filter.filter(&all_ids, &[], "fix bug", &q);
1026        graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
1027
1028        // apply_patch should be excluded (read not completed)
1029        assert!(!result.included.contains("apply_patch"));
1030        assert_eq!(result.dependency_exclusions.len(), 1);
1031
1032        // --- Turn 2: `read` was executed successfully ---
1033        let mut result2 = filter.filter(&all_ids, &[], "fix bug", &q);
1034        graph.apply(&mut result2, &completed(&["read"]), 0.15, 0.2, &always_on);
1035
1036        // apply_patch should now be included
1037        assert!(result2.included.contains("apply_patch"));
1038        assert!(result2.dependency_exclusions.is_empty());
1039    }
1040
1041    /// Three-step linear chain: `read` → `search` → `apply_patch`.
1042    /// Each turn unlocks exactly one more tool.
1043    #[test]
1044    fn multi_turn_chain_three_steps() {
1045        let graph = make_dep_graph(&[
1046            ("search", vec!["read"], vec![]),
1047            ("apply_patch", vec!["search"], vec![]),
1048        ]);
1049        let always_on: HashSet<String> = ["bash".into()].into();
1050        let filter = ToolSchemaFilter::new(vec!["bash".into()], 5, 5, vec![]);
1051        let all_ids = vec!["bash", "read", "search", "apply_patch"];
1052        let q = vec![0.5, 0.5, 0.0];
1053
1054        // Turn 1: only read available
1055        let mut r1 = filter.filter(&all_ids, &[], "q", &q);
1056        graph.apply(&mut r1, &completed(&[]), 0.15, 0.2, &always_on);
1057        assert!(r1.included.contains("read"));
1058        assert!(!r1.included.contains("search"));
1059        assert!(!r1.included.contains("apply_patch"));
1060
1061        // Turn 2: read done, search unlocked
1062        let mut r2 = filter.filter(&all_ids, &[], "q", &q);
1063        graph.apply(&mut r2, &completed(&["read"]), 0.15, 0.2, &always_on);
1064        assert!(r2.included.contains("search"));
1065        assert!(!r2.included.contains("apply_patch"));
1066
1067        // Turn 3: search done, apply_patch unlocked
1068        let mut r3 = filter.filter(&all_ids, &[], "q", &q);
1069        graph.apply(
1070            &mut r3,
1071            &completed(&["read", "search"]),
1072            0.15,
1073            0.2,
1074            &always_on,
1075        );
1076        assert!(r3.included.contains("apply_patch"));
1077        assert!(r3.dependency_exclusions.is_empty());
1078    }
1079
1080    /// Multi-requires: `apply_patch` needs both `read` AND `search` to be done.
1081    #[test]
1082    fn multi_turn_multi_requires_both_must_complete() {
1083        let graph = make_dep_graph(&[("apply_patch", vec!["read", "search"], vec![])]);
1084        let always_on: HashSet<String> = ["bash".into()].into();
1085        let filter = ToolSchemaFilter::new(vec!["bash".into()], 5, 5, vec![]);
1086        let all_ids = vec!["bash", "read", "search", "apply_patch"];
1087        let q = vec![0.5, 0.5, 0.0];
1088
1089        // Only `read` done — not enough
1090        let mut r1 = filter.filter(&all_ids, &[], "q", &q);
1091        graph.apply(&mut r1, &completed(&["read"]), 0.15, 0.2, &always_on);
1092        assert!(!r1.included.contains("apply_patch"));
1093        let excl = &r1.dependency_exclusions[0];
1094        assert_eq!(excl.unmet_requires, vec!["search"]);
1095
1096        // Both done — unlocked
1097        let mut r2 = filter.filter(&all_ids, &[], "q", &q);
1098        graph.apply(
1099            &mut r2,
1100            &completed(&["read", "search"]),
1101            0.15,
1102            0.2,
1103            &always_on,
1104        );
1105        assert!(r2.included.contains("apply_patch"));
1106        assert!(r2.dependency_exclusions.is_empty());
1107    }
1108
1109    /// Preference boost increases across turns as soft deps are satisfied.
1110    ///
1111    /// A tool must have a cached embedding to appear in `scores` and receive a
1112    /// score adjustment from `apply()`. This test uses a filter with an explicit
1113    /// embedding for `format` so the score is trackable.
1114    #[test]
1115    fn multi_turn_preference_boost_accumulates() {
1116        // format prefers search and grep (soft deps)
1117        let graph = make_dep_graph(&[("format", vec![], vec!["search", "grep"])]);
1118        let always_on: HashSet<String> = HashSet::new();
1119        // Give format a real embedding so it appears in `scores`.
1120        let filter = ToolSchemaFilter::new(
1121            vec![],
1122            5,
1123            5,
1124            vec![
1125                ToolEmbedding {
1126                    tool_id: "format".into(),
1127                    embedding: vec![0.6, 0.4, 0.0],
1128                },
1129                ToolEmbedding {
1130                    tool_id: "search".into(),
1131                    embedding: vec![0.7, 0.3, 0.0],
1132                },
1133                ToolEmbedding {
1134                    tool_id: "grep".into(),
1135                    embedding: vec![0.8, 0.2, 0.0],
1136                },
1137            ],
1138        );
1139        let all_ids = vec!["format", "search", "grep"];
1140        let q = vec![0.5, 0.5, 0.0];
1141        let boost_per = 0.15_f32;
1142        let max_boost = 0.3_f32;
1143
1144        let score_of = |result: &ToolFilterResult, id: &str| -> f32 {
1145            result
1146                .scores
1147                .iter()
1148                .find(|(tid, _)| tid == id)
1149                .map_or(0.0, |(_, s)| *s)
1150        };
1151
1152        // Turn 1: no prefs satisfied — no boost
1153        let mut r1 = filter.filter(&all_ids, &[], "q", &q);
1154        let base_score = score_of(&r1, "format");
1155        graph.apply(&mut r1, &completed(&[]), boost_per, max_boost, &always_on);
1156        assert!((score_of(&r1, "format") - base_score).abs() < 1e-5);
1157
1158        // Turn 2: search done → +0.15
1159        let mut r2 = filter.filter(&all_ids, &[], "q", &q);
1160        graph.apply(
1161            &mut r2,
1162            &completed(&["search"]),
1163            boost_per,
1164            max_boost,
1165            &always_on,
1166        );
1167        let delta2 = score_of(&r2, "format") - base_score;
1168        assert!(
1169            (delta2 - 0.15).abs() < 1e-4,
1170            "expected +0.15 boost, got {delta2}"
1171        );
1172
1173        // Turn 3: both done → +0.30 (2 * 0.15, within max_boost=0.30)
1174        let mut r3 = filter.filter(&all_ids, &[], "q", &q);
1175        graph.apply(
1176            &mut r3,
1177            &completed(&["search", "grep"]),
1178            boost_per,
1179            max_boost,
1180            &always_on,
1181        );
1182        let delta3 = score_of(&r3, "format") - base_score;
1183        assert!(
1184            (delta3 - 0.30).abs() < 1e-4,
1185            "expected +0.30 boost, got {delta3}"
1186        );
1187    }
1188
1189    /// `filter_tool_names` used for iteration 1+ gating in the native tool loop.
1190    /// Simulates: iteration 0 executes `read`, iteration 1 should now allow `apply_patch`.
1191    #[test]
1192    fn filter_tool_names_multi_turn_unlocks_after_completion() {
1193        let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
1194        let always_on: HashSet<String> = ["bash".into()].into();
1195        let all_names = vec!["bash", "read", "apply_patch"];
1196
1197        // Before read completes
1198        let filtered_before = graph.filter_tool_names(&all_names, &completed(&[]), &always_on);
1199        assert!(filtered_before.contains(&"bash")); // always-on
1200        assert!(filtered_before.contains(&"read")); // no deps
1201        assert!(!filtered_before.contains(&"apply_patch")); // gated
1202
1203        // After read completes
1204        let filtered_after = graph.filter_tool_names(&all_names, &completed(&["read"]), &always_on);
1205        assert!(filtered_after.contains(&"bash"));
1206        assert!(filtered_after.contains(&"read"));
1207        assert!(filtered_after.contains(&"apply_patch")); // unlocked
1208    }
1209
1210    /// Deadlock fallback in `filter_tool_names`: if all non-always-on names would
1211    /// be filtered, return them all unfiltered.
1212    #[test]
1213    fn filter_tool_names_deadlock_fallback_passes_all() {
1214        // only_tool requires `missing` which is never completed
1215        let graph = make_dep_graph(&[("only_tool", vec!["missing"], vec![])]);
1216        let always_on: HashSet<String> = ["bash".into()].into();
1217        let all_names = vec!["bash", "only_tool"];
1218
1219        let filtered = graph.filter_tool_names(&all_names, &completed(&[]), &always_on);
1220
1221        // bash is always-on, only_tool would be gated.
1222        // filter_tool_names does NOT implement deadlock fallback itself —
1223        // it is the caller's responsibility. Verify gating behaviour here:
1224        // only_tool is excluded, only bash passes.
1225        assert!(filtered.contains(&"bash"));
1226        assert!(!filtered.contains(&"only_tool"));
1227    }
1228}