Skip to main content

zeph_tools/
schema_filter.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Dynamic tool schema filtering based on query-tool embedding similarity (#2020).
5//!
6//! Filters the set of tool definitions sent to the LLM on each turn, selecting
7//! only the most relevant tools based on cosine similarity between the user query
8//! embedding and pre-computed tool description embeddings.
9
10use std::collections::{HashMap, HashSet};
11
12use zeph_common::ToolName;
13use zeph_common::math::cosine_similarity;
14
15use crate::config::ToolDependency;
16
17/// Cached embedding for a tool definition.
18#[derive(Debug, Clone)]
19pub struct ToolEmbedding {
20    pub tool_id: ToolName,
21    pub embedding: Vec<f32>,
22}
23
24/// Reason a tool was included in the filtered set.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26#[non_exhaustive]
27pub enum InclusionReason {
28    /// Tool is in the always-on config list.
29    AlwaysOn,
30    /// Tool name was explicitly mentioned in the user query.
31    NameMentioned,
32    /// Tool scored within the top-K by similarity rank.
33    SimilarityRank,
34    /// MCP tool with too-short description to filter reliably.
35    ShortDescription,
36    /// Tool has no cached embedding (e.g. added after startup via MCP).
37    NoEmbedding,
38    /// Tool included because its hard requirements (`requires`) are all satisfied.
39    DependencyMet,
40    /// Tool received a similarity boost from satisfied soft prerequisites (`prefers`).
41    PreferenceBoost,
42}
43
44/// Exclusion reason for a tool that was blocked by the dependency gate.
45#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct DependencyExclusion {
47    pub tool_id: ToolName,
48    /// IDs of `requires` entries that are not yet satisfied.
49    pub unmet_requires: Vec<String>,
50}
51
52/// Result of filtering tool schemas against a query.
53#[derive(Debug, Clone)]
54pub struct ToolFilterResult {
55    /// Tool IDs that passed the filter.
56    pub included: HashSet<String>,
57    /// Tool IDs that were filtered out by similarity/embedding.
58    pub excluded: Vec<String>,
59    /// Per-tool similarity scores for filterable tools (sorted descending).
60    pub scores: Vec<(String, f32)>,
61    /// Reason each included tool was included.
62    pub inclusion_reasons: Vec<(String, InclusionReason)>,
63    /// Tools excluded specifically due to unmet hard dependencies.
64    pub dependency_exclusions: Vec<DependencyExclusion>,
65}
66
67/// Dependency graph for sequential tool availability (issue #2024).
68///
69/// Built once from `DependencyConfig` at agent start, reused across turns.
70/// Implements cycle detection via DFS topological sort: any tool in a detected
71/// cycle has all its `requires` removed (made unconditionally available) so it
72/// can never be permanently blocked by a dependency loop.
73///
74/// # Deadlock fallback
75///
76/// If all non-always-on tools would be blocked (either by config cycles or
77/// unreachable `requires` chains), `apply()` detects this at filter time and
78/// disables hard gates for that turn, logging a warning.
79#[derive(Debug, Clone, Default)]
80pub struct ToolDependencyGraph {
81    /// Map from `tool_id` -> its dependency spec.
82    /// Tools in cycles have their `requires` cleared at construction time.
83    deps: HashMap<String, ToolDependency>,
84}
85
86impl ToolDependencyGraph {
87    /// Build a dependency graph from a map of tool rules.
88    ///
89    /// Performs DFS-based cycle detection. All tools participating in any cycle
90    /// have their `requires` entries removed so they are always available.
91    #[must_use]
92    pub fn new(deps: HashMap<String, ToolDependency>) -> Self {
93        if deps.is_empty() {
94            return Self { deps };
95        }
96        let cycled = detect_cycles(&deps);
97        if !cycled.is_empty() {
98            tracing::warn!(
99                tools = ?cycled,
100                "tool dependency graph: cycles detected, removing requires for cycle participants"
101            );
102        }
103        let mut resolved = deps;
104        for tool_id in &cycled {
105            if let Some(dep) = resolved.get_mut(tool_id) {
106                dep.requires.clear();
107            }
108        }
109        Self { deps: resolved }
110    }
111
112    /// Returns true if no dependency rules are configured.
113    #[must_use]
114    pub fn is_empty(&self) -> bool {
115        self.deps.is_empty()
116    }
117
118    /// Check if a tool's hard requirements are all satisfied.
119    ///
120    /// Returns `true` if the tool has no `requires` entries, or if all entries
121    /// are present in `completed`. Returns `true` for unconfigured tools.
122    #[must_use]
123    pub fn requirements_met(&self, tool_id: &str, completed: &HashSet<String>) -> bool {
124        self.deps
125            .get(tool_id)
126            .is_none_or(|d| d.requires.iter().all(|r| completed.contains(r)))
127    }
128
129    /// Returns the unmet `requires` entries for a tool, if any.
130    #[must_use]
131    pub fn unmet_requires<'a>(
132        &'a self,
133        tool_id: &str,
134        completed: &HashSet<String>,
135    ) -> Vec<&'a str> {
136        self.deps.get(tool_id).map_or_else(Vec::new, |d| {
137            d.requires
138                .iter()
139                .filter(|r| !completed.contains(r.as_str()))
140                .map(String::as_str)
141                .collect()
142        })
143    }
144
145    /// Calculate similarity boost for soft prerequisites.
146    ///
147    /// Returns `min(boost_per_dep * met_count, max_total_boost)`.
148    #[must_use]
149    pub fn preference_boost(
150        &self,
151        tool_id: &str,
152        completed: &HashSet<String>,
153        boost_per_dep: f32,
154        max_total_boost: f32,
155    ) -> f32 {
156        self.deps.get(tool_id).map_or(0.0, |d| {
157            let met = d
158                .prefers
159                .iter()
160                .filter(|p| completed.contains(p.as_str()))
161                .count();
162            #[allow(clippy::cast_precision_loss)]
163            let boost = met as f32 * boost_per_dep;
164            boost.min(max_total_boost)
165        })
166    }
167
168    /// Apply hard dependency gates and preference boosts to a `ToolFilterResult`.
169    ///
170    /// Called after `ToolSchemaFilter::filter()` returns so the filter signature
171    /// remains unchanged (HIGH-03 fix). Dependency gates are applied AFTER TAFC
172    /// augmentation to prevent re-adding gated tools through augmentation (MED-04 fix).
173    ///
174    /// Only `AlwaysOn` tools bypass hard gates. `NameMentioned` tools are still subject
175    /// to `requires` checks — a user mentioning a gated tool name does not grant access.
176    ///
177    /// # Deadlock fallback (CRIT-01)
178    ///
179    /// If applying hard gates would remove ALL non-always-on included tools, the
180    /// gates are disabled for this turn and a warning is logged.
181    pub fn apply(
182        &self,
183        result: &mut ToolFilterResult,
184        completed: &HashSet<String>,
185        boost_per_dep: f32,
186        max_total_boost: f32,
187        always_on: &HashSet<String>,
188    ) {
189        if self.deps.is_empty() {
190            return;
191        }
192
193        // Only AlwaysOn tools bypass the hard dependency gate.
194        // NameMentioned tools still respect `requires` constraints: a user mentioning a gated
195        // tool name in their query does not grant access to it before its prerequisites run.
196        let bypassed: HashSet<&str> = result
197            .inclusion_reasons
198            .iter()
199            .filter(|(_, r)| matches!(r, InclusionReason::AlwaysOn))
200            .map(|(id, _)| id.as_str())
201            .collect();
202
203        let mut to_exclude: Vec<DependencyExclusion> = Vec::new();
204        for tool_id in &result.included {
205            if bypassed.contains(tool_id.as_str()) {
206                continue;
207            }
208            let unmet: Vec<String> = self
209                .unmet_requires(tool_id, completed)
210                .into_iter()
211                .map(str::to_owned)
212                .collect();
213            if !unmet.is_empty() {
214                to_exclude.push(DependencyExclusion {
215                    tool_id: tool_id.as_str().into(),
216                    unmet_requires: unmet,
217                });
218            }
219        }
220
221        // CRIT-01: deadlock fallback — if gating would leave no non-always-on tools,
222        // skip hard gates for this turn.
223        let non_always_on_included: usize = result
224            .included
225            .iter()
226            .filter(|id| !always_on.contains(id.as_str()))
227            .count();
228        if !to_exclude.is_empty() && to_exclude.len() >= non_always_on_included {
229            tracing::warn!(
230                gated = to_exclude.len(),
231                non_always_on = non_always_on_included,
232                "tool dependency graph: all non-always-on tools would be blocked; \
233                 disabling hard gates for this turn"
234            );
235            to_exclude.clear();
236        }
237
238        // Apply hard gates.
239        for excl in &to_exclude {
240            result.included.remove(excl.tool_id.as_str());
241            result.excluded.push(excl.tool_id.to_string());
242            tracing::debug!(
243                tool_id = %excl.tool_id,
244                unmet = ?excl.unmet_requires,
245                "tool dependency gate: excluded (requires not met)"
246            );
247        }
248        result.dependency_exclusions = to_exclude;
249
250        // Apply preference boosts: adjust scores for tools with satisfied prefers deps.
251        for (tool_id, score) in &mut result.scores {
252            if !result.included.contains(tool_id) {
253                continue;
254            }
255            let boost = self.preference_boost(tool_id, completed, boost_per_dep, max_total_boost);
256            if boost > 0.0 {
257                *score += boost;
258                // Record reason if not already recorded with a higher-priority reason.
259                let already_recorded = result.inclusion_reasons.iter().any(|(id, _)| id == tool_id);
260                if !already_recorded {
261                    result
262                        .inclusion_reasons
263                        .push((tool_id.clone(), InclusionReason::PreferenceBoost));
264                }
265                tracing::debug!(
266                    tool_id = %tool_id,
267                    boost,
268                    "tool dependency: preference boost applied"
269                );
270            }
271        }
272        // Re-sort scores after boosts.
273        result
274            .scores
275            .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
276    }
277
278    /// Filter a slice of tool IDs to those whose hard requirements are met.
279    ///
280    /// Used on iterations 1+ in the native tool loop via the agent helper
281    /// `apply_hard_dependency_gate_to_names`. Returns only the IDs that pass.
282    #[must_use]
283    pub fn filter_tool_names<'a>(
284        &self,
285        names: &[&'a str],
286        completed: &HashSet<String>,
287        always_on: &HashSet<String>,
288    ) -> Vec<&'a str> {
289        names
290            .iter()
291            .copied()
292            .filter(|n| always_on.contains(*n) || self.requirements_met(n, completed))
293            .collect()
294    }
295}
296
297/// DFS-based cycle detection for tool dependency graphs.
298///
299/// Returns the set of tool IDs that participate in any cycle.
300/// Algorithm: standard DFS with three states (unvisited/in-progress/done).
301/// When a back-edge is found (visiting an in-progress node), all nodes in the
302/// current DFS path that form part of the cycle are collected.
303fn detect_cycles(deps: &HashMap<String, ToolDependency>) -> HashSet<String> {
304    #[derive(Clone, Copy, PartialEq)]
305    enum State {
306        Unvisited,
307        InProgress,
308        Done,
309    }
310
311    let mut state: HashMap<&str, State> = HashMap::new();
312    let mut cycled: HashSet<String> = HashSet::new();
313
314    for start in deps.keys() {
315        if state
316            .get(start.as_str())
317            .copied()
318            .unwrap_or(State::Unvisited)
319            != State::Unvisited
320        {
321            continue;
322        }
323        let mut stack: Vec<(&str, usize)> = vec![(start.as_str(), 0)];
324        state.insert(start.as_str(), State::InProgress);
325
326        while let Some((node, child_idx)) = stack.last_mut() {
327            let node = *node;
328            let requires = deps
329                .get(node)
330                .map_or(&[] as &[String], |d| d.requires.as_slice());
331
332            if *child_idx >= requires.len() {
333                state.insert(node, State::Done);
334                stack.pop();
335                continue;
336            }
337
338            let child = requires[*child_idx].as_str();
339            *child_idx += 1;
340
341            match state.get(child).copied().unwrap_or(State::Unvisited) {
342                State::InProgress => {
343                    // Back-edge found: child is the cycle entry point already on the stack.
344                    // Only mark nodes from that entry point to the top of the stack as cycled.
345                    // Ancestors above the cycle entry are NOT part of the cycle.
346                    let cycle_start = stack.iter().position(|(n, _)| *n == child);
347                    if let Some(start) = cycle_start {
348                        for (path_node, _) in &stack[start..] {
349                            cycled.insert((*path_node).to_owned());
350                        }
351                    }
352                    cycled.insert(child.to_owned());
353                }
354                State::Unvisited => {
355                    state.insert(child, State::InProgress);
356                    stack.push((child, 0));
357                }
358                State::Done => {}
359            }
360        }
361    }
362
363    cycled
364}
365
366/// Core filter holding cached tool embeddings and config.
367pub struct ToolSchemaFilter {
368    always_on: HashSet<String>,
369    top_k: usize,
370    min_description_words: usize,
371    embeddings: Vec<ToolEmbedding>,
372    version: u64,
373}
374
375impl ToolSchemaFilter {
376    /// Create a new filter with pre-computed tool embeddings.
377    #[must_use]
378    pub fn new(
379        always_on: Vec<String>,
380        top_k: usize,
381        min_description_words: usize,
382        embeddings: Vec<ToolEmbedding>,
383    ) -> Self {
384        Self {
385            always_on: always_on.into_iter().collect(),
386            top_k,
387            min_description_words,
388            embeddings,
389            version: 0,
390        }
391    }
392
393    /// Current version counter. Incremented on recompute.
394    #[must_use]
395    pub fn version(&self) -> u64 {
396        self.version
397    }
398
399    /// Number of cached tool embeddings.
400    #[must_use]
401    pub fn embedding_count(&self) -> usize {
402        self.embeddings.len()
403    }
404
405    /// Configured top-K limit for similarity ranking.
406    #[must_use]
407    pub fn top_k(&self) -> usize {
408        self.top_k
409    }
410
411    /// Number of always-on tools in the filter config.
412    #[must_use]
413    pub fn always_on_count(&self) -> usize {
414        self.always_on.len()
415    }
416
417    /// Replace tool embeddings (e.g. after MCP tool changes) and bump version.
418    pub fn recompute(&mut self, embeddings: Vec<ToolEmbedding>) {
419        self.embeddings = embeddings;
420        self.version += 1;
421    }
422
423    /// Filter tools for a given user query embedding.
424    ///
425    /// `all_tool_ids` is the full set of tool IDs currently available.
426    /// `tool_descriptions` maps tool ID to its description (for short-description check).
427    /// `query_embedding` is the embedded user query.
428    #[must_use]
429    pub fn filter(
430        &self,
431        all_tool_ids: &[&str],
432        tool_descriptions: &[(&str, &str)],
433        query: &str,
434        query_embedding: &[f32],
435    ) -> ToolFilterResult {
436        let mut included = HashSet::new();
437        let mut inclusion_reasons = Vec::new();
438
439        // 1. Always-on tools.
440        for id in all_tool_ids {
441            if self.always_on.contains(*id) {
442                included.insert((*id).to_owned());
443                inclusion_reasons.push(((*id).to_owned(), InclusionReason::AlwaysOn));
444            }
445        }
446
447        // 2. Name-mentioned tools.
448        let mentioned = find_mentioned_tool_ids(query, all_tool_ids);
449        for id in &mentioned {
450            if included.insert(id.clone()) {
451                inclusion_reasons.push((id.clone(), InclusionReason::NameMentioned));
452            }
453        }
454
455        // 3. Short-description MCP tools.
456        for &(id, desc) in tool_descriptions {
457            let word_count = desc.split_whitespace().count();
458            if word_count < self.min_description_words && included.insert(id.to_owned()) {
459                inclusion_reasons.push((id.to_owned(), InclusionReason::ShortDescription));
460            }
461        }
462
463        // 4. Similarity-ranked filterable tools.
464        let mut scores: Vec<(String, f32)> = self
465            .embeddings
466            .iter()
467            .filter(|e| !included.contains(e.tool_id.as_str()))
468            .map(|e| {
469                let score = cosine_similarity(query_embedding, &e.embedding);
470                (e.tool_id.to_string(), score)
471            })
472            .collect();
473
474        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
475
476        let take = if self.top_k == 0 {
477            scores.len()
478        } else {
479            self.top_k.min(scores.len())
480        };
481
482        for (id, _score) in scores.iter().take(take) {
483            if included.insert(id.clone()) {
484                inclusion_reasons.push((id.clone(), InclusionReason::SimilarityRank));
485            }
486        }
487
488        // 5. Auto-include tools without embeddings (e.g. new MCP tools added after startup).
489        let embedded_ids: HashSet<&str> =
490            self.embeddings.iter().map(|e| e.tool_id.as_str()).collect();
491        for id in all_tool_ids {
492            if !included.contains(*id) && !embedded_ids.contains(*id) {
493                included.insert((*id).to_owned());
494                inclusion_reasons.push(((*id).to_owned(), InclusionReason::NoEmbedding));
495            }
496        }
497
498        // Build excluded list.
499        let excluded: Vec<String> = all_tool_ids
500            .iter()
501            .filter(|id| !included.contains(**id))
502            .map(|id| (*id).to_owned())
503            .collect();
504
505        ToolFilterResult {
506            included,
507            excluded,
508            scores,
509            inclusion_reasons,
510            dependency_exclusions: Vec::new(),
511        }
512    }
513}
514
515/// Find tool IDs explicitly mentioned in the query (case-insensitive, word-boundary aware).
516///
517/// Uses word-boundary checking: the character before and after the match must not be
518/// alphanumeric or underscore. This prevents false positives like "read" matching "thread".
519#[must_use]
520pub fn find_mentioned_tool_ids(query: &str, all_tool_ids: &[&str]) -> Vec<String> {
521    let query_lower = query.to_lowercase();
522    all_tool_ids
523        .iter()
524        .filter(|id| {
525            let id_lower = id.to_lowercase();
526            let mut start = 0;
527            while let Some(pos) = query_lower[start..].find(&id_lower) {
528                let abs_pos = start + pos;
529                let end_pos = abs_pos + id_lower.len();
530                let before_ok = abs_pos == 0
531                    || !query_lower.as_bytes()[abs_pos - 1].is_ascii_alphanumeric()
532                        && query_lower.as_bytes()[abs_pos - 1] != b'_';
533                let after_ok = end_pos >= query_lower.len()
534                    || !query_lower.as_bytes()[end_pos].is_ascii_alphanumeric()
535                        && query_lower.as_bytes()[end_pos] != b'_';
536                if before_ok && after_ok {
537                    return true;
538                }
539                start = abs_pos + 1;
540            }
541            false
542        })
543        .map(|id| (*id).to_owned())
544        .collect()
545}
546
547#[cfg(test)]
548mod tests {
549    use super::*;
550
551    fn make_filter(always_on: Vec<&str>, top_k: usize) -> ToolSchemaFilter {
552        ToolSchemaFilter::new(
553            always_on.into_iter().map(String::from).collect(),
554            top_k,
555            5,
556            vec![
557                ToolEmbedding {
558                    tool_id: "grep".into(),
559                    embedding: vec![0.9, 0.1, 0.0],
560                },
561                ToolEmbedding {
562                    tool_id: "write".into(),
563                    embedding: vec![0.1, 0.9, 0.0],
564                },
565                ToolEmbedding {
566                    tool_id: "find_path".into(),
567                    embedding: vec![0.5, 0.5, 0.0],
568                },
569                ToolEmbedding {
570                    tool_id: "web_scrape".into(),
571                    embedding: vec![0.0, 0.0, 1.0],
572                },
573                ToolEmbedding {
574                    tool_id: "diagnostics".into(),
575                    embedding: vec![0.0, 0.1, 0.9],
576                },
577            ],
578        )
579    }
580
581    #[test]
582    fn top_k_ranking_selects_most_similar() {
583        let filter = make_filter(vec!["bash"], 2);
584        let all_ids: Vec<&str> = vec![
585            "bash",
586            "grep",
587            "write",
588            "find_path",
589            "web_scrape",
590            "diagnostics",
591        ];
592        let query_emb = vec![0.8, 0.2, 0.0]; // close to grep
593        let result = filter.filter(&all_ids, &[], "search for pattern", &query_emb);
594
595        assert!(result.included.contains("bash")); // always-on
596        assert!(result.included.contains("grep")); // top similarity
597        assert!(result.included.contains("find_path")); // 2nd top
598        // web_scrape and diagnostics should be excluded
599        assert!(!result.included.contains("web_scrape"));
600        assert!(!result.included.contains("diagnostics"));
601    }
602
603    #[test]
604    fn always_on_tools_always_included() {
605        let filter = make_filter(vec!["bash", "read"], 1);
606        let all_ids: Vec<&str> = vec!["bash", "read", "grep", "write"];
607        let query_emb = vec![0.0, 1.0, 0.0]; // close to write
608        let result = filter.filter(&all_ids, &[], "test query", &query_emb);
609
610        assert!(result.included.contains("bash"));
611        assert!(result.included.contains("read"));
612        assert!(result.included.contains("write")); // top-1
613        assert!(!result.included.contains("grep"));
614    }
615
616    #[test]
617    fn name_mention_force_includes() {
618        let filter = make_filter(vec!["bash"], 1);
619        let all_ids: Vec<&str> = vec!["bash", "grep", "web_scrape", "write"];
620        let query_emb = vec![0.0, 1.0, 0.0]; // close to write
621        let result = filter.filter(&all_ids, &[], "use web_scrape to fetch", &query_emb);
622
623        assert!(result.included.contains("web_scrape")); // name match
624        assert!(result.included.contains("write")); // top-1
625        assert!(result.included.contains("bash")); // always-on
626    }
627
628    #[test]
629    fn short_mcp_description_auto_included() {
630        let filter = make_filter(vec!["bash"], 1);
631        let all_ids: Vec<&str> = vec!["bash", "grep", "mcp_query"];
632        let descriptions: Vec<(&str, &str)> = vec![
633            ("mcp_query", "Run query"),
634            ("grep", "Search file contents recursively"),
635        ];
636        let query_emb = vec![0.9, 0.1, 0.0];
637        let result = filter.filter(&all_ids, &descriptions, "test", &query_emb);
638
639        assert!(result.included.contains("mcp_query")); // short desc (2 words)
640    }
641
642    #[test]
643    fn empty_embeddings_includes_all_via_no_embedding_fallback() {
644        let filter = ToolSchemaFilter::new(vec!["bash".into()], 6, 5, vec![]);
645        let all_ids: Vec<&str> = vec!["bash", "grep", "write"];
646        let query_emb = vec![0.5, 0.5, 0.0];
647        let result = filter.filter(&all_ids, &[], "test", &query_emb);
648
649        // All tools included: bash (always-on), grep+write (NoEmbedding fallback)
650        assert!(result.included.contains("bash"));
651        assert!(result.included.contains("grep"));
652        assert!(result.included.contains("write"));
653        assert!(result.excluded.is_empty());
654    }
655
656    #[test]
657    fn top_k_zero_includes_all_filterable() {
658        let filter = make_filter(vec!["bash"], 0);
659        let all_ids: Vec<&str> = vec![
660            "bash",
661            "grep",
662            "write",
663            "find_path",
664            "web_scrape",
665            "diagnostics",
666        ];
667        let query_emb = vec![0.1, 0.1, 0.1];
668        let result = filter.filter(&all_ids, &[], "test", &query_emb);
669
670        assert_eq!(result.included.len(), 6); // all included
671        assert!(result.excluded.is_empty());
672    }
673
674    #[test]
675    fn top_k_exceeds_filterable_count_includes_all() {
676        let filter = make_filter(vec!["bash"], 100);
677        let all_ids: Vec<&str> = vec![
678            "bash",
679            "grep",
680            "write",
681            "find_path",
682            "web_scrape",
683            "diagnostics",
684        ];
685        let query_emb = vec![0.1, 0.1, 0.1];
686        let result = filter.filter(&all_ids, &[], "test", &query_emb);
687
688        assert_eq!(result.included.len(), 6);
689    }
690
691    #[test]
692    fn accessors_return_configured_values() {
693        let filter = make_filter(vec!["bash", "read"], 7);
694        assert_eq!(filter.top_k(), 7);
695        assert_eq!(filter.always_on_count(), 2);
696        assert_eq!(filter.embedding_count(), 5);
697    }
698
699    #[test]
700    fn version_counter_incremented_on_recompute() {
701        let mut filter = make_filter(vec![], 3);
702        assert_eq!(filter.version(), 0);
703        filter.recompute(vec![]);
704        assert_eq!(filter.version(), 1);
705        filter.recompute(vec![]);
706        assert_eq!(filter.version(), 2);
707    }
708
709    #[test]
710    fn inclusion_reason_correctness() {
711        let filter = make_filter(vec!["bash"], 1);
712        let all_ids: Vec<&str> = vec!["bash", "grep", "web_scrape", "write"];
713        let descriptions: Vec<(&str, &str)> = vec![("web_scrape", "Scrape")]; // 1 word
714        let query_emb = vec![0.1, 0.9, 0.0]; // close to write
715        let result = filter.filter(&all_ids, &descriptions, "test query", &query_emb);
716
717        let reasons: std::collections::HashMap<String, InclusionReason> =
718            result.inclusion_reasons.into_iter().collect();
719        assert_eq!(reasons.get("bash"), Some(&InclusionReason::AlwaysOn));
720        assert_eq!(
721            reasons.get("web_scrape"),
722            Some(&InclusionReason::ShortDescription)
723        );
724        assert_eq!(reasons.get("write"), Some(&InclusionReason::SimilarityRank));
725    }
726
727    #[test]
728    fn cosine_similarity_identical_vectors() {
729        let v = vec![1.0, 2.0, 3.0];
730        let sim = cosine_similarity(&v, &v);
731        assert!((sim - 1.0).abs() < 1e-5);
732    }
733
734    #[test]
735    fn cosine_similarity_orthogonal_vectors() {
736        let a = vec![1.0, 0.0];
737        let b = vec![0.0, 1.0];
738        let sim = cosine_similarity(&a, &b);
739        assert!(sim.abs() < 1e-5);
740    }
741
742    #[test]
743    fn cosine_similarity_empty_returns_zero() {
744        assert!(cosine_similarity(&[], &[]) < f32::EPSILON);
745    }
746
747    #[test]
748    fn cosine_similarity_mismatched_length_returns_zero() {
749        assert!(cosine_similarity(&[1.0], &[1.0, 2.0]) < f32::EPSILON);
750    }
751
752    #[test]
753    fn find_mentioned_tool_ids_case_insensitive() {
754        let ids = vec!["web_scrape", "grep", "Bash"];
755        let found = find_mentioned_tool_ids("use WEB_SCRAPE and BASH", &ids);
756        assert!(found.contains(&"web_scrape".to_owned()));
757        assert!(found.contains(&"Bash".to_owned()));
758        assert!(!found.contains(&"grep".to_owned()));
759    }
760
761    #[test]
762    fn find_mentioned_tool_ids_word_boundary_no_false_positives() {
763        let ids = vec!["read", "edit", "fetch", "grep"];
764        // "read" should NOT match inside "thread" or "breadcrumb"
765        let found = find_mentioned_tool_ids("thread breadcrumb", &ids);
766        assert!(found.is_empty());
767    }
768
769    #[test]
770    fn find_mentioned_tool_ids_word_boundary_matches_standalone() {
771        let ids = vec!["read", "edit"];
772        let found = find_mentioned_tool_ids("please read and edit the file", &ids);
773        assert!(found.contains(&"read".to_owned()));
774        assert!(found.contains(&"edit".to_owned()));
775    }
776
777    // --- ToolDependencyGraph tests ---
778
779    fn make_dep_graph(rules: &[(&str, Vec<&str>, Vec<&str>)]) -> ToolDependencyGraph {
780        let deps = rules
781            .iter()
782            .map(|(id, requires, prefers)| {
783                (
784                    (*id).to_owned(),
785                    crate::config::ToolDependency {
786                        requires: requires.iter().map(|s| (*s).to_owned()).collect(),
787                        prefers: prefers.iter().map(|s| (*s).to_owned()).collect(),
788                    },
789                )
790            })
791            .collect();
792        ToolDependencyGraph::new(deps)
793    }
794
795    fn completed(ids: &[&str]) -> HashSet<String> {
796        ids.iter().map(|s| (*s).to_owned()).collect()
797    }
798
799    #[test]
800    fn requirements_met_no_deps() {
801        let graph = make_dep_graph(&[]);
802        assert!(graph.requirements_met("any_tool", &completed(&[])));
803    }
804
805    #[test]
806    fn requirements_met_all_satisfied() {
807        let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
808        assert!(graph.requirements_met("apply_patch", &completed(&["read"])));
809    }
810
811    #[test]
812    fn requirements_met_unmet() {
813        let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
814        assert!(!graph.requirements_met("apply_patch", &completed(&[])));
815    }
816
817    #[test]
818    fn requirements_met_unconfigured_tool() {
819        let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
820        // tools not in the graph are always available
821        assert!(graph.requirements_met("grep", &completed(&[])));
822    }
823
824    #[test]
825    fn preference_boost_none_met() {
826        let graph = make_dep_graph(&[("format", vec![], vec!["search", "grep"])]);
827        let boost = graph.preference_boost("format", &completed(&[]), 0.15, 0.2);
828        assert!(boost < f32::EPSILON);
829    }
830
831    #[test]
832    fn preference_boost_partial() {
833        let graph = make_dep_graph(&[("format", vec![], vec!["search", "grep"])]);
834        let boost = graph.preference_boost("format", &completed(&["search"]), 0.15, 0.2);
835        assert!((boost - 0.15).abs() < 1e-5);
836    }
837
838    #[test]
839    fn preference_boost_capped_at_max() {
840        // 3 prefs x 0.15 = 0.45 but max is 0.2
841        let graph = make_dep_graph(&[("format", vec![], vec!["a", "b", "c"])]);
842        let boost = graph.preference_boost("format", &completed(&["a", "b", "c"]), 0.15, 0.2);
843        assert!((boost - 0.2).abs() < 1e-5);
844    }
845
846    #[test]
847    fn cycle_detection_simple_cycle() {
848        // A requires B, B requires A → both should have requires cleared
849        let graph = make_dep_graph(&[
850            ("tool_a", vec!["tool_b"], vec![]),
851            ("tool_b", vec!["tool_a"], vec![]),
852        ]);
853        // After cycle removal both should be unconditionally available
854        assert!(graph.requirements_met("tool_a", &completed(&[])));
855        assert!(graph.requirements_met("tool_b", &completed(&[])));
856    }
857
858    #[test]
859    fn cycle_detection_does_not_affect_non_cycle_tools() {
860        // A requires B, B requires C (no cycle), C requires D (cycle: D requires C)
861        let graph = make_dep_graph(&[
862            ("tool_a", vec!["tool_b"], vec![]),
863            ("tool_b", vec!["tool_c"], vec![]),
864            ("tool_c", vec!["tool_d"], vec![]),
865            ("tool_d", vec!["tool_c"], vec![]), // cycle: c <-> d
866        ]);
867        // tool_c and tool_d participate in cycle → unconditionally available
868        assert!(graph.requirements_met("tool_c", &completed(&[])));
869        assert!(graph.requirements_met("tool_d", &completed(&[])));
870        // tool_a and tool_b are NOT in a cycle → still gated
871        assert!(!graph.requirements_met("tool_a", &completed(&[])));
872        assert!(!graph.requirements_met("tool_b", &completed(&[])));
873    }
874
875    #[test]
876    fn apply_excludes_gated_tool() {
877        let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
878        let filter = make_filter(vec!["bash"], 5);
879        let all_ids = vec!["bash", "read", "apply_patch", "grep"];
880        let query_emb = vec![0.5, 0.5, 0.0];
881        let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
882        // Ensure apply_patch is included before dependency gate
883        result.included.insert("apply_patch".into());
884
885        let always_on: HashSet<String> = ["bash".into()].into();
886        graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
887
888        assert!(!result.included.contains("apply_patch"));
889        assert_eq!(result.dependency_exclusions.len(), 1);
890        assert_eq!(result.dependency_exclusions[0].tool_id, "apply_patch");
891        assert_eq!(result.dependency_exclusions[0].unmet_requires, vec!["read"]);
892    }
893
894    #[test]
895    fn apply_includes_gated_tool_when_dep_met() {
896        let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
897        let filter = make_filter(vec!["bash"], 5);
898        let all_ids = vec!["bash", "read", "apply_patch"];
899        let query_emb = vec![0.5, 0.5, 0.0];
900        let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
901        result.included.insert("apply_patch".into());
902
903        let always_on: HashSet<String> = ["bash".into()].into();
904        graph.apply(&mut result, &completed(&["read"]), 0.15, 0.2, &always_on);
905
906        assert!(result.included.contains("apply_patch"));
907        assert!(result.dependency_exclusions.is_empty());
908    }
909
910    #[test]
911    fn apply_deadlock_fallback_when_all_gated() {
912        // Build a minimal filter with no embeddings so only bash (always-on) and
913        // only_tool (NoEmbedding) are in the result set.
914        let filter = ToolSchemaFilter::new(
915            vec!["bash".into()],
916            5,
917            5,
918            vec![], // no embeddings: only_tool will be included via NoEmbedding fallback
919        );
920        let graph = make_dep_graph(&[("only_tool", vec!["missing"], vec![])]);
921        let all_ids = vec!["bash", "only_tool"];
922        let query_emb = vec![0.5, 0.5, 0.0];
923        let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
924
925        // At this point: included = {bash, only_tool}, non_always_on_included = 1
926        assert!(result.included.contains("only_tool"));
927        assert!(result.included.contains("bash"));
928
929        let always_on: HashSet<String> = ["bash".into()].into();
930        graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
931
932        // Deadlock fallback: only_tool remains included (all non-always-on would be blocked)
933        assert!(result.included.contains("only_tool"));
934        assert!(result.dependency_exclusions.is_empty());
935    }
936
937    #[test]
938    fn apply_always_on_bypasses_gate() {
939        let graph = make_dep_graph(&[("bash", vec!["nonexistent"], vec![])]);
940        let filter = make_filter(vec!["bash"], 5);
941        let all_ids = vec!["bash", "grep"];
942        let query_emb = vec![0.5, 0.5, 0.0];
943        let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
944
945        let always_on: HashSet<String> = ["bash".into()].into();
946        graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
947
948        // bash is always-on, bypasses hard gate
949        assert!(result.included.contains("bash"));
950    }
951
952    // --- Regression tests for HIGH-01 and HIGH-02 ---
953
954    /// HIGH-01 regression: ancestors of a cycle must NOT lose their `requires`.
955    ///
956    /// Graph: A requires B, B requires C, C requires D, D requires C (cycle: C↔D).
957    /// Before fix: A and B were marked cycled and had their requires cleared.
958    /// After fix: only C and D are in the cycle; A and B remain gated.
959    #[test]
960    fn cycle_detection_does_not_clear_ancestor_requires() {
961        let graph = make_dep_graph(&[
962            ("tool_a", vec!["tool_b"], vec![]),
963            ("tool_b", vec!["tool_c"], vec![]),
964            ("tool_c", vec!["tool_d"], vec![]),
965            ("tool_d", vec!["tool_c"], vec![]),
966        ]);
967        // Cycle participants (C, D) must be unconditionally available.
968        assert!(graph.requirements_met("tool_c", &completed(&[])));
969        assert!(graph.requirements_met("tool_d", &completed(&[])));
970        // Non-cycle ancestors (A, B) must still be gated.
971        assert!(!graph.requirements_met("tool_a", &completed(&[])));
972        assert!(!graph.requirements_met("tool_b", &completed(&[])));
973        // A unlocks when B completes; B unlocks when C completes (C is free).
974        assert!(graph.requirements_met("tool_b", &completed(&["tool_c"])));
975        assert!(graph.requirements_met("tool_a", &completed(&["tool_b"])));
976    }
977
978    /// HIGH-02 regression: `NameMentioned` tools must still respect hard gates.
979    ///
980    /// If the user says "use `apply_patch` to fix the bug", `apply_patch` is
981    /// `NameMentioned` but must NOT bypass its `requires=[read]` constraint.
982    #[test]
983    fn name_mentioned_does_not_bypass_hard_gate() {
984        let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
985        let filter = make_filter(vec!["bash"], 5);
986        // Query explicitly mentions apply_patch → NameMentioned reason
987        let all_ids = vec!["bash", "read", "apply_patch"];
988        let query_emb = vec![0.5, 0.5, 0.0];
989        let mut result = filter.filter(&all_ids, &[], "use apply_patch to fix the bug", &query_emb);
990
991        // apply_patch must be in included (name-mentioned) before dependency gate
992        assert!(result.included.contains("apply_patch"));
993        let reason = result
994            .inclusion_reasons
995            .iter()
996            .find(|(id, _)| id == "apply_patch")
997            .map(|(_, r)| r);
998        assert_eq!(reason, Some(&InclusionReason::NameMentioned));
999
1000        let always_on: HashSet<String> = ["bash".into()].into();
1001        graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
1002
1003        // After gate: apply_patch must be excluded (read not completed)
1004        assert!(!result.included.contains("apply_patch"));
1005        assert_eq!(result.dependency_exclusions.len(), 1);
1006        assert_eq!(result.dependency_exclusions[0].tool_id, "apply_patch");
1007    }
1008
1009    // --- Multi-turn dependency chain integration tests ---
1010    //
1011    // These tests simulate the session lifecycle: `completed_tool_ids` grows
1012    // across turns, unlocking downstream tools one step at a time.
1013
1014    /// Turn 1: only `read` is available (no completed tools yet).
1015    /// Turn 2: after `read` completes, `apply_patch` unlocks.
1016    #[test]
1017    fn multi_turn_chain_two_steps() {
1018        // read → apply_patch (linear dependency)
1019        let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
1020        let always_on: HashSet<String> = ["bash".into()].into();
1021
1022        // --- Turn 1: nothing completed yet ---
1023        let filter = ToolSchemaFilter::new(vec!["bash".into()], 5, 5, vec![]);
1024        let all_ids = vec!["bash", "read", "apply_patch"];
1025        let q = vec![0.5, 0.5, 0.0];
1026        let mut result = filter.filter(&all_ids, &[], "fix bug", &q);
1027        graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
1028
1029        // apply_patch should be excluded (read not completed)
1030        assert!(!result.included.contains("apply_patch"));
1031        assert_eq!(result.dependency_exclusions.len(), 1);
1032
1033        // --- Turn 2: `read` was executed successfully ---
1034        let mut result2 = filter.filter(&all_ids, &[], "fix bug", &q);
1035        graph.apply(&mut result2, &completed(&["read"]), 0.15, 0.2, &always_on);
1036
1037        // apply_patch should now be included
1038        assert!(result2.included.contains("apply_patch"));
1039        assert!(result2.dependency_exclusions.is_empty());
1040    }
1041
1042    /// Three-step linear chain: `read` → `search` → `apply_patch`.
1043    /// Each turn unlocks exactly one more tool.
1044    #[test]
1045    fn multi_turn_chain_three_steps() {
1046        let graph = make_dep_graph(&[
1047            ("search", vec!["read"], vec![]),
1048            ("apply_patch", vec!["search"], vec![]),
1049        ]);
1050        let always_on: HashSet<String> = ["bash".into()].into();
1051        let filter = ToolSchemaFilter::new(vec!["bash".into()], 5, 5, vec![]);
1052        let all_ids = vec!["bash", "read", "search", "apply_patch"];
1053        let q = vec![0.5, 0.5, 0.0];
1054
1055        // Turn 1: only read available
1056        let mut r1 = filter.filter(&all_ids, &[], "q", &q);
1057        graph.apply(&mut r1, &completed(&[]), 0.15, 0.2, &always_on);
1058        assert!(r1.included.contains("read"));
1059        assert!(!r1.included.contains("search"));
1060        assert!(!r1.included.contains("apply_patch"));
1061
1062        // Turn 2: read done, search unlocked
1063        let mut r2 = filter.filter(&all_ids, &[], "q", &q);
1064        graph.apply(&mut r2, &completed(&["read"]), 0.15, 0.2, &always_on);
1065        assert!(r2.included.contains("search"));
1066        assert!(!r2.included.contains("apply_patch"));
1067
1068        // Turn 3: search done, apply_patch unlocked
1069        let mut r3 = filter.filter(&all_ids, &[], "q", &q);
1070        graph.apply(
1071            &mut r3,
1072            &completed(&["read", "search"]),
1073            0.15,
1074            0.2,
1075            &always_on,
1076        );
1077        assert!(r3.included.contains("apply_patch"));
1078        assert!(r3.dependency_exclusions.is_empty());
1079    }
1080
1081    /// Multi-requires: `apply_patch` needs both `read` AND `search` to be done.
1082    #[test]
1083    fn multi_turn_multi_requires_both_must_complete() {
1084        let graph = make_dep_graph(&[("apply_patch", vec!["read", "search"], vec![])]);
1085        let always_on: HashSet<String> = ["bash".into()].into();
1086        let filter = ToolSchemaFilter::new(vec!["bash".into()], 5, 5, vec![]);
1087        let all_ids = vec!["bash", "read", "search", "apply_patch"];
1088        let q = vec![0.5, 0.5, 0.0];
1089
1090        // Only `read` done — not enough
1091        let mut r1 = filter.filter(&all_ids, &[], "q", &q);
1092        graph.apply(&mut r1, &completed(&["read"]), 0.15, 0.2, &always_on);
1093        assert!(!r1.included.contains("apply_patch"));
1094        let excl = &r1.dependency_exclusions[0];
1095        assert_eq!(excl.unmet_requires, vec!["search"]);
1096
1097        // Both done — unlocked
1098        let mut r2 = filter.filter(&all_ids, &[], "q", &q);
1099        graph.apply(
1100            &mut r2,
1101            &completed(&["read", "search"]),
1102            0.15,
1103            0.2,
1104            &always_on,
1105        );
1106        assert!(r2.included.contains("apply_patch"));
1107        assert!(r2.dependency_exclusions.is_empty());
1108    }
1109
1110    /// Preference boost increases across turns as soft deps are satisfied.
1111    ///
1112    /// A tool must have a cached embedding to appear in `scores` and receive a
1113    /// score adjustment from `apply()`. This test uses a filter with an explicit
1114    /// embedding for `format` so the score is trackable.
1115    #[test]
1116    fn multi_turn_preference_boost_accumulates() {
1117        // format prefers search and grep (soft deps)
1118        let graph = make_dep_graph(&[("format", vec![], vec!["search", "grep"])]);
1119        let always_on: HashSet<String> = HashSet::new();
1120        // Give format a real embedding so it appears in `scores`.
1121        let filter = ToolSchemaFilter::new(
1122            vec![],
1123            5,
1124            5,
1125            vec![
1126                ToolEmbedding {
1127                    tool_id: "format".into(),
1128                    embedding: vec![0.6, 0.4, 0.0],
1129                },
1130                ToolEmbedding {
1131                    tool_id: "search".into(),
1132                    embedding: vec![0.7, 0.3, 0.0],
1133                },
1134                ToolEmbedding {
1135                    tool_id: "grep".into(),
1136                    embedding: vec![0.8, 0.2, 0.0],
1137                },
1138            ],
1139        );
1140        let all_ids = vec!["format", "search", "grep"];
1141        let q = vec![0.5, 0.5, 0.0];
1142        let boost_per = 0.15_f32;
1143        let max_boost = 0.3_f32;
1144
1145        let score_of = |result: &ToolFilterResult, id: &str| -> f32 {
1146            result
1147                .scores
1148                .iter()
1149                .find(|(tid, _)| tid == id)
1150                .map_or(0.0, |(_, s)| *s)
1151        };
1152
1153        // Turn 1: no prefs satisfied — no boost
1154        let mut r1 = filter.filter(&all_ids, &[], "q", &q);
1155        let base_score = score_of(&r1, "format");
1156        graph.apply(&mut r1, &completed(&[]), boost_per, max_boost, &always_on);
1157        assert!((score_of(&r1, "format") - base_score).abs() < 1e-5);
1158
1159        // Turn 2: search done → +0.15
1160        let mut r2 = filter.filter(&all_ids, &[], "q", &q);
1161        graph.apply(
1162            &mut r2,
1163            &completed(&["search"]),
1164            boost_per,
1165            max_boost,
1166            &always_on,
1167        );
1168        let delta2 = score_of(&r2, "format") - base_score;
1169        assert!(
1170            (delta2 - 0.15).abs() < 1e-4,
1171            "expected +0.15 boost, got {delta2}"
1172        );
1173
1174        // Turn 3: both done → +0.30 (2 * 0.15, within max_boost=0.30)
1175        let mut r3 = filter.filter(&all_ids, &[], "q", &q);
1176        graph.apply(
1177            &mut r3,
1178            &completed(&["search", "grep"]),
1179            boost_per,
1180            max_boost,
1181            &always_on,
1182        );
1183        let delta3 = score_of(&r3, "format") - base_score;
1184        assert!(
1185            (delta3 - 0.30).abs() < 1e-4,
1186            "expected +0.30 boost, got {delta3}"
1187        );
1188    }
1189
1190    /// `filter_tool_names` used for iteration 1+ gating in the native tool loop.
1191    /// Simulates: iteration 0 executes `read`, iteration 1 should now allow `apply_patch`.
1192    #[test]
1193    fn filter_tool_names_multi_turn_unlocks_after_completion() {
1194        let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
1195        let always_on: HashSet<String> = ["bash".into()].into();
1196        let all_names = vec!["bash", "read", "apply_patch"];
1197
1198        // Before read completes
1199        let filtered_before = graph.filter_tool_names(&all_names, &completed(&[]), &always_on);
1200        assert!(filtered_before.contains(&"bash")); // always-on
1201        assert!(filtered_before.contains(&"read")); // no deps
1202        assert!(!filtered_before.contains(&"apply_patch")); // gated
1203
1204        // After read completes
1205        let filtered_after = graph.filter_tool_names(&all_names, &completed(&["read"]), &always_on);
1206        assert!(filtered_after.contains(&"bash"));
1207        assert!(filtered_after.contains(&"read"));
1208        assert!(filtered_after.contains(&"apply_patch")); // unlocked
1209    }
1210
1211    /// Deadlock fallback in `filter_tool_names`: if all non-always-on names would
1212    /// be filtered, return them all unfiltered.
1213    #[test]
1214    fn filter_tool_names_deadlock_fallback_passes_all() {
1215        // only_tool requires `missing` which is never completed
1216        let graph = make_dep_graph(&[("only_tool", vec!["missing"], vec![])]);
1217        let always_on: HashSet<String> = ["bash".into()].into();
1218        let all_names = vec!["bash", "only_tool"];
1219
1220        let filtered = graph.filter_tool_names(&all_names, &completed(&[]), &always_on);
1221
1222        // bash is always-on, only_tool would be gated.
1223        // filter_tool_names does NOT implement deadlock fallback itself —
1224        // it is the caller's responsibility. Verify gating behaviour here:
1225        // only_tool is excluded, only bash passes.
1226        assert!(filtered.contains(&"bash"));
1227        assert!(!filtered.contains(&"only_tool"));
1228    }
1229}