Skip to main content

zeph_tools/
schema_filter.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Dynamic tool schema filtering based on query-tool embedding similarity (#2020).
5//!
6//! Filters the set of tool definitions sent to the LLM on each turn, selecting
7//! only the most relevant tools based on cosine similarity between the user query
8//! embedding and pre-computed tool description embeddings.
9
10use std::collections::{HashMap, HashSet};
11
12use zeph_common::math::cosine_similarity;
13
14use crate::config::ToolDependency;
15
16/// Cached embedding for a tool definition.
17#[derive(Debug, Clone)]
18pub struct ToolEmbedding {
19    pub tool_id: String,
20    pub embedding: Vec<f32>,
21}
22
23/// Reason a tool was included in the filtered set.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum InclusionReason {
26    /// Tool is in the always-on config list.
27    AlwaysOn,
28    /// Tool name was explicitly mentioned in the user query.
29    NameMentioned,
30    /// Tool scored within the top-K by similarity rank.
31    SimilarityRank,
32    /// MCP tool with too-short description to filter reliably.
33    ShortDescription,
34    /// Tool has no cached embedding (e.g. added after startup via MCP).
35    NoEmbedding,
36    /// Tool included because its hard requirements (`requires`) are all satisfied.
37    DependencyMet,
38    /// Tool received a similarity boost from satisfied soft prerequisites (`prefers`).
39    PreferenceBoost,
40}
41
42/// Exclusion reason for a tool that was blocked by the dependency gate.
43#[derive(Debug, Clone, PartialEq, Eq)]
44pub struct DependencyExclusion {
45    pub tool_id: String,
46    /// IDs of `requires` entries that are not yet satisfied.
47    pub unmet_requires: Vec<String>,
48}
49
50/// Result of filtering tool schemas against a query.
51#[derive(Debug, Clone)]
52pub struct ToolFilterResult {
53    /// Tool IDs that passed the filter.
54    pub included: HashSet<String>,
55    /// Tool IDs that were filtered out by similarity/embedding.
56    pub excluded: Vec<String>,
57    /// Per-tool similarity scores for filterable tools (sorted descending).
58    pub scores: Vec<(String, f32)>,
59    /// Reason each included tool was included.
60    pub inclusion_reasons: Vec<(String, InclusionReason)>,
61    /// Tools excluded specifically due to unmet hard dependencies.
62    pub dependency_exclusions: Vec<DependencyExclusion>,
63}
64
65/// Dependency graph for sequential tool availability (issue #2024).
66///
67/// Built once from `DependencyConfig` at agent start, reused across turns.
68/// Implements cycle detection via DFS topological sort: any tool in a detected
69/// cycle has all its `requires` removed (made unconditionally available) so it
70/// can never be permanently blocked by a dependency loop.
71///
72/// # Deadlock fallback
73///
74/// If all non-always-on tools would be blocked (either by config cycles or
75/// unreachable `requires` chains), `apply()` detects this at filter time and
76/// disables hard gates for that turn, logging a warning.
77#[derive(Debug, Clone, Default)]
78pub struct ToolDependencyGraph {
79    /// Map from `tool_id` -> its dependency spec.
80    /// Tools in cycles have their `requires` cleared at construction time.
81    deps: HashMap<String, ToolDependency>,
82}
83
84impl ToolDependencyGraph {
85    /// Build a dependency graph from a map of tool rules.
86    ///
87    /// Performs DFS-based cycle detection. All tools participating in any cycle
88    /// have their `requires` entries removed so they are always available.
89    #[must_use]
90    pub fn new(deps: HashMap<String, ToolDependency>) -> Self {
91        if deps.is_empty() {
92            return Self { deps };
93        }
94        let cycled = detect_cycles(&deps);
95        if !cycled.is_empty() {
96            tracing::warn!(
97                tools = ?cycled,
98                "tool dependency graph: cycles detected, removing requires for cycle participants"
99            );
100        }
101        let mut resolved = deps;
102        for tool_id in &cycled {
103            if let Some(dep) = resolved.get_mut(tool_id) {
104                dep.requires.clear();
105            }
106        }
107        Self { deps: resolved }
108    }
109
110    /// Returns true if no dependency rules are configured.
111    #[must_use]
112    pub fn is_empty(&self) -> bool {
113        self.deps.is_empty()
114    }
115
116    /// Check if a tool's hard requirements are all satisfied.
117    ///
118    /// Returns `true` if the tool has no `requires` entries, or if all entries
119    /// are present in `completed`. Returns `true` for unconfigured tools.
120    #[must_use]
121    pub fn requirements_met(&self, tool_id: &str, completed: &HashSet<String>) -> bool {
122        self.deps
123            .get(tool_id)
124            .is_none_or(|d| d.requires.iter().all(|r| completed.contains(r)))
125    }
126
127    /// Returns the unmet `requires` entries for a tool, if any.
128    #[must_use]
129    pub fn unmet_requires<'a>(
130        &'a self,
131        tool_id: &str,
132        completed: &HashSet<String>,
133    ) -> Vec<&'a str> {
134        self.deps.get(tool_id).map_or_else(Vec::new, |d| {
135            d.requires
136                .iter()
137                .filter(|r| !completed.contains(r.as_str()))
138                .map(String::as_str)
139                .collect()
140        })
141    }
142
143    /// Calculate similarity boost for soft prerequisites.
144    ///
145    /// Returns `min(boost_per_dep * met_count, max_total_boost)`.
146    #[must_use]
147    pub fn preference_boost(
148        &self,
149        tool_id: &str,
150        completed: &HashSet<String>,
151        boost_per_dep: f32,
152        max_total_boost: f32,
153    ) -> f32 {
154        self.deps.get(tool_id).map_or(0.0, |d| {
155            let met = d
156                .prefers
157                .iter()
158                .filter(|p| completed.contains(p.as_str()))
159                .count();
160            #[allow(clippy::cast_precision_loss)]
161            let boost = met as f32 * boost_per_dep;
162            boost.min(max_total_boost)
163        })
164    }
165
166    /// Apply hard dependency gates and preference boosts to a `ToolFilterResult`.
167    ///
168    /// Called after `ToolSchemaFilter::filter()` returns so the filter signature
169    /// remains unchanged (HIGH-03 fix). Dependency gates are applied AFTER TAFC
170    /// augmentation to prevent re-adding gated tools through augmentation (MED-04 fix).
171    ///
172    /// Only `AlwaysOn` tools bypass hard gates. `NameMentioned` tools are still subject
173    /// to `requires` checks — a user mentioning a gated tool name does not grant access.
174    ///
175    /// # Deadlock fallback (CRIT-01)
176    ///
177    /// If applying hard gates would remove ALL non-always-on included tools, the
178    /// gates are disabled for this turn and a warning is logged.
179    pub fn apply(
180        &self,
181        result: &mut ToolFilterResult,
182        completed: &HashSet<String>,
183        boost_per_dep: f32,
184        max_total_boost: f32,
185        always_on: &HashSet<String>,
186    ) {
187        if self.deps.is_empty() {
188            return;
189        }
190
191        // Only AlwaysOn tools bypass the hard dependency gate.
192        // NameMentioned tools still respect `requires` constraints: a user mentioning a gated
193        // tool name in their query does not grant access to it before its prerequisites run.
194        let bypassed: HashSet<&str> = result
195            .inclusion_reasons
196            .iter()
197            .filter(|(_, r)| matches!(r, InclusionReason::AlwaysOn))
198            .map(|(id, _)| id.as_str())
199            .collect();
200
201        let mut to_exclude: Vec<DependencyExclusion> = Vec::new();
202        for tool_id in &result.included {
203            if bypassed.contains(tool_id.as_str()) {
204                continue;
205            }
206            let unmet: Vec<String> = self
207                .unmet_requires(tool_id, completed)
208                .into_iter()
209                .map(str::to_owned)
210                .collect();
211            if !unmet.is_empty() {
212                to_exclude.push(DependencyExclusion {
213                    tool_id: tool_id.clone(),
214                    unmet_requires: unmet,
215                });
216            }
217        }
218
219        // CRIT-01: deadlock fallback — if gating would leave no non-always-on tools,
220        // skip hard gates for this turn.
221        let non_always_on_included: usize = result
222            .included
223            .iter()
224            .filter(|id| !always_on.contains(id.as_str()))
225            .count();
226        if !to_exclude.is_empty() && to_exclude.len() >= non_always_on_included {
227            tracing::warn!(
228                gated = to_exclude.len(),
229                non_always_on = non_always_on_included,
230                "tool dependency graph: all non-always-on tools would be blocked; \
231                 disabling hard gates for this turn"
232            );
233            to_exclude.clear();
234        }
235
236        // Apply hard gates.
237        for excl in &to_exclude {
238            result.included.remove(&excl.tool_id);
239            result.excluded.push(excl.tool_id.clone());
240            tracing::debug!(
241                tool_id = %excl.tool_id,
242                unmet = ?excl.unmet_requires,
243                "tool dependency gate: excluded (requires not met)"
244            );
245        }
246        result.dependency_exclusions = to_exclude;
247
248        // Apply preference boosts: adjust scores for tools with satisfied prefers deps.
249        for (tool_id, score) in &mut result.scores {
250            if !result.included.contains(tool_id) {
251                continue;
252            }
253            let boost = self.preference_boost(tool_id, completed, boost_per_dep, max_total_boost);
254            if boost > 0.0 {
255                *score += boost;
256                // Record reason if not already recorded with a higher-priority reason.
257                let already_recorded = result.inclusion_reasons.iter().any(|(id, _)| id == tool_id);
258                if !already_recorded {
259                    result
260                        .inclusion_reasons
261                        .push((tool_id.clone(), InclusionReason::PreferenceBoost));
262                }
263                tracing::debug!(
264                    tool_id = %tool_id,
265                    boost,
266                    "tool dependency: preference boost applied"
267                );
268            }
269        }
270        // Re-sort scores after boosts.
271        result
272            .scores
273            .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
274    }
275
276    /// Filter a slice of tool IDs to those whose hard requirements are met.
277    ///
278    /// Used on iterations 1+ in the native tool loop via the agent helper
279    /// `apply_hard_dependency_gate_to_names`. Returns only the IDs that pass.
280    #[must_use]
281    pub fn filter_tool_names<'a>(
282        &self,
283        names: &[&'a str],
284        completed: &HashSet<String>,
285        always_on: &HashSet<String>,
286    ) -> Vec<&'a str> {
287        names
288            .iter()
289            .copied()
290            .filter(|n| always_on.contains(*n) || self.requirements_met(n, completed))
291            .collect()
292    }
293}
294
295/// DFS-based cycle detection for tool dependency graphs.
296///
297/// Returns the set of tool IDs that participate in any cycle.
298/// Algorithm: standard DFS with three states (unvisited/in-progress/done).
299/// When a back-edge is found (visiting an in-progress node), all nodes in the
300/// current DFS path that form part of the cycle are collected.
301fn detect_cycles(deps: &HashMap<String, ToolDependency>) -> HashSet<String> {
302    #[derive(Clone, Copy, PartialEq)]
303    enum State {
304        Unvisited,
305        InProgress,
306        Done,
307    }
308
309    let mut state: HashMap<&str, State> = HashMap::new();
310    let mut cycled: HashSet<String> = HashSet::new();
311
312    for start in deps.keys() {
313        if state
314            .get(start.as_str())
315            .copied()
316            .unwrap_or(State::Unvisited)
317            != State::Unvisited
318        {
319            continue;
320        }
321        let mut stack: Vec<(&str, usize)> = vec![(start.as_str(), 0)];
322        state.insert(start.as_str(), State::InProgress);
323
324        while let Some((node, child_idx)) = stack.last_mut() {
325            let node = *node;
326            let requires = deps
327                .get(node)
328                .map_or(&[] as &[String], |d| d.requires.as_slice());
329
330            if *child_idx >= requires.len() {
331                state.insert(node, State::Done);
332                stack.pop();
333                continue;
334            }
335
336            let child = requires[*child_idx].as_str();
337            *child_idx += 1;
338
339            match state.get(child).copied().unwrap_or(State::Unvisited) {
340                State::InProgress => {
341                    // Back-edge found: child is the cycle entry point already on the stack.
342                    // Only mark nodes from that entry point to the top of the stack as cycled.
343                    // Ancestors above the cycle entry are NOT part of the cycle.
344                    let cycle_start = stack.iter().position(|(n, _)| *n == child);
345                    if let Some(start) = cycle_start {
346                        for (path_node, _) in &stack[start..] {
347                            cycled.insert((*path_node).to_owned());
348                        }
349                    }
350                    cycled.insert(child.to_owned());
351                }
352                State::Unvisited => {
353                    state.insert(child, State::InProgress);
354                    stack.push((child, 0));
355                }
356                State::Done => {}
357            }
358        }
359    }
360
361    cycled
362}
363
364/// Core filter holding cached tool embeddings and config.
365pub struct ToolSchemaFilter {
366    always_on: HashSet<String>,
367    top_k: usize,
368    min_description_words: usize,
369    embeddings: Vec<ToolEmbedding>,
370    version: u64,
371}
372
373impl ToolSchemaFilter {
374    /// Create a new filter with pre-computed tool embeddings.
375    #[must_use]
376    pub fn new(
377        always_on: Vec<String>,
378        top_k: usize,
379        min_description_words: usize,
380        embeddings: Vec<ToolEmbedding>,
381    ) -> Self {
382        Self {
383            always_on: always_on.into_iter().collect(),
384            top_k,
385            min_description_words,
386            embeddings,
387            version: 0,
388        }
389    }
390
391    /// Current version counter. Incremented on recompute.
392    #[must_use]
393    pub fn version(&self) -> u64 {
394        self.version
395    }
396
397    /// Number of cached tool embeddings.
398    #[must_use]
399    pub fn embedding_count(&self) -> usize {
400        self.embeddings.len()
401    }
402
403    /// Configured top-K limit for similarity ranking.
404    #[must_use]
405    pub fn top_k(&self) -> usize {
406        self.top_k
407    }
408
409    /// Number of always-on tools in the filter config.
410    #[must_use]
411    pub fn always_on_count(&self) -> usize {
412        self.always_on.len()
413    }
414
415    /// Replace tool embeddings (e.g. after MCP tool changes) and bump version.
416    pub fn recompute(&mut self, embeddings: Vec<ToolEmbedding>) {
417        self.embeddings = embeddings;
418        self.version += 1;
419    }
420
421    /// Filter tools for a given user query embedding.
422    ///
423    /// `all_tool_ids` is the full set of tool IDs currently available.
424    /// `tool_descriptions` maps tool ID to its description (for short-description check).
425    /// `query_embedding` is the embedded user query.
426    #[must_use]
427    pub fn filter(
428        &self,
429        all_tool_ids: &[&str],
430        tool_descriptions: &[(&str, &str)],
431        query: &str,
432        query_embedding: &[f32],
433    ) -> ToolFilterResult {
434        let mut included = HashSet::new();
435        let mut inclusion_reasons = Vec::new();
436
437        // 1. Always-on tools.
438        for id in all_tool_ids {
439            if self.always_on.contains(*id) {
440                included.insert((*id).to_owned());
441                inclusion_reasons.push(((*id).to_owned(), InclusionReason::AlwaysOn));
442            }
443        }
444
445        // 2. Name-mentioned tools.
446        let mentioned = find_mentioned_tool_ids(query, all_tool_ids);
447        for id in &mentioned {
448            if included.insert(id.clone()) {
449                inclusion_reasons.push((id.clone(), InclusionReason::NameMentioned));
450            }
451        }
452
453        // 3. Short-description MCP tools.
454        for &(id, desc) in tool_descriptions {
455            let word_count = desc.split_whitespace().count();
456            if word_count < self.min_description_words && included.insert(id.to_owned()) {
457                inclusion_reasons.push((id.to_owned(), InclusionReason::ShortDescription));
458            }
459        }
460
461        // 4. Similarity-ranked filterable tools.
462        let mut scores: Vec<(String, f32)> = self
463            .embeddings
464            .iter()
465            .filter(|e| !included.contains(&e.tool_id))
466            .map(|e| {
467                let score = cosine_similarity(query_embedding, &e.embedding);
468                (e.tool_id.clone(), score)
469            })
470            .collect();
471
472        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
473
474        let take = if self.top_k == 0 {
475            scores.len()
476        } else {
477            self.top_k.min(scores.len())
478        };
479
480        for (id, _score) in scores.iter().take(take) {
481            if included.insert(id.clone()) {
482                inclusion_reasons.push((id.clone(), InclusionReason::SimilarityRank));
483            }
484        }
485
486        // 5. Auto-include tools without embeddings (e.g. new MCP tools added after startup).
487        let embedded_ids: HashSet<&str> =
488            self.embeddings.iter().map(|e| e.tool_id.as_str()).collect();
489        for id in all_tool_ids {
490            if !included.contains(*id) && !embedded_ids.contains(*id) {
491                included.insert((*id).to_owned());
492                inclusion_reasons.push(((*id).to_owned(), InclusionReason::NoEmbedding));
493            }
494        }
495
496        // Build excluded list.
497        let excluded: Vec<String> = all_tool_ids
498            .iter()
499            .filter(|id| !included.contains(**id))
500            .map(|id| (*id).to_owned())
501            .collect();
502
503        ToolFilterResult {
504            included,
505            excluded,
506            scores,
507            inclusion_reasons,
508            dependency_exclusions: Vec::new(),
509        }
510    }
511}
512
513/// Find tool IDs explicitly mentioned in the query (case-insensitive, word-boundary aware).
514///
515/// Uses word-boundary checking: the character before and after the match must not be
516/// alphanumeric or underscore. This prevents false positives like "read" matching "thread".
517#[must_use]
518pub fn find_mentioned_tool_ids(query: &str, all_tool_ids: &[&str]) -> Vec<String> {
519    let query_lower = query.to_lowercase();
520    all_tool_ids
521        .iter()
522        .filter(|id| {
523            let id_lower = id.to_lowercase();
524            let mut start = 0;
525            while let Some(pos) = query_lower[start..].find(&id_lower) {
526                let abs_pos = start + pos;
527                let end_pos = abs_pos + id_lower.len();
528                let before_ok = abs_pos == 0
529                    || !query_lower.as_bytes()[abs_pos - 1].is_ascii_alphanumeric()
530                        && query_lower.as_bytes()[abs_pos - 1] != b'_';
531                let after_ok = end_pos >= query_lower.len()
532                    || !query_lower.as_bytes()[end_pos].is_ascii_alphanumeric()
533                        && query_lower.as_bytes()[end_pos] != b'_';
534                if before_ok && after_ok {
535                    return true;
536                }
537                start = abs_pos + 1;
538            }
539            false
540        })
541        .map(|id| (*id).to_owned())
542        .collect()
543}
544
545#[cfg(test)]
546mod tests {
547    use super::*;
548
549    fn make_filter(always_on: Vec<&str>, top_k: usize) -> ToolSchemaFilter {
550        ToolSchemaFilter::new(
551            always_on.into_iter().map(String::from).collect(),
552            top_k,
553            5,
554            vec![
555                ToolEmbedding {
556                    tool_id: "grep".into(),
557                    embedding: vec![0.9, 0.1, 0.0],
558                },
559                ToolEmbedding {
560                    tool_id: "write".into(),
561                    embedding: vec![0.1, 0.9, 0.0],
562                },
563                ToolEmbedding {
564                    tool_id: "find_path".into(),
565                    embedding: vec![0.5, 0.5, 0.0],
566                },
567                ToolEmbedding {
568                    tool_id: "web_scrape".into(),
569                    embedding: vec![0.0, 0.0, 1.0],
570                },
571                ToolEmbedding {
572                    tool_id: "diagnostics".into(),
573                    embedding: vec![0.0, 0.1, 0.9],
574                },
575            ],
576        )
577    }
578
579    #[test]
580    fn top_k_ranking_selects_most_similar() {
581        let filter = make_filter(vec!["bash"], 2);
582        let all_ids: Vec<&str> = vec![
583            "bash",
584            "grep",
585            "write",
586            "find_path",
587            "web_scrape",
588            "diagnostics",
589        ];
590        let query_emb = vec![0.8, 0.2, 0.0]; // close to grep
591        let result = filter.filter(&all_ids, &[], "search for pattern", &query_emb);
592
593        assert!(result.included.contains("bash")); // always-on
594        assert!(result.included.contains("grep")); // top similarity
595        assert!(result.included.contains("find_path")); // 2nd top
596        // web_scrape and diagnostics should be excluded
597        assert!(!result.included.contains("web_scrape"));
598        assert!(!result.included.contains("diagnostics"));
599    }
600
601    #[test]
602    fn always_on_tools_always_included() {
603        let filter = make_filter(vec!["bash", "read"], 1);
604        let all_ids: Vec<&str> = vec!["bash", "read", "grep", "write"];
605        let query_emb = vec![0.0, 1.0, 0.0]; // close to write
606        let result = filter.filter(&all_ids, &[], "test query", &query_emb);
607
608        assert!(result.included.contains("bash"));
609        assert!(result.included.contains("read"));
610        assert!(result.included.contains("write")); // top-1
611        assert!(!result.included.contains("grep"));
612    }
613
614    #[test]
615    fn name_mention_force_includes() {
616        let filter = make_filter(vec!["bash"], 1);
617        let all_ids: Vec<&str> = vec!["bash", "grep", "web_scrape", "write"];
618        let query_emb = vec![0.0, 1.0, 0.0]; // close to write
619        let result = filter.filter(&all_ids, &[], "use web_scrape to fetch", &query_emb);
620
621        assert!(result.included.contains("web_scrape")); // name match
622        assert!(result.included.contains("write")); // top-1
623        assert!(result.included.contains("bash")); // always-on
624    }
625
626    #[test]
627    fn short_mcp_description_auto_included() {
628        let filter = make_filter(vec!["bash"], 1);
629        let all_ids: Vec<&str> = vec!["bash", "grep", "mcp_query"];
630        let descriptions: Vec<(&str, &str)> = vec![
631            ("mcp_query", "Run query"),
632            ("grep", "Search file contents recursively"),
633        ];
634        let query_emb = vec![0.9, 0.1, 0.0];
635        let result = filter.filter(&all_ids, &descriptions, "test", &query_emb);
636
637        assert!(result.included.contains("mcp_query")); // short desc (2 words)
638    }
639
640    #[test]
641    fn empty_embeddings_includes_all_via_no_embedding_fallback() {
642        let filter = ToolSchemaFilter::new(vec!["bash".into()], 6, 5, vec![]);
643        let all_ids: Vec<&str> = vec!["bash", "grep", "write"];
644        let query_emb = vec![0.5, 0.5, 0.0];
645        let result = filter.filter(&all_ids, &[], "test", &query_emb);
646
647        // All tools included: bash (always-on), grep+write (NoEmbedding fallback)
648        assert!(result.included.contains("bash"));
649        assert!(result.included.contains("grep"));
650        assert!(result.included.contains("write"));
651        assert!(result.excluded.is_empty());
652    }
653
654    #[test]
655    fn top_k_zero_includes_all_filterable() {
656        let filter = make_filter(vec!["bash"], 0);
657        let all_ids: Vec<&str> = vec![
658            "bash",
659            "grep",
660            "write",
661            "find_path",
662            "web_scrape",
663            "diagnostics",
664        ];
665        let query_emb = vec![0.1, 0.1, 0.1];
666        let result = filter.filter(&all_ids, &[], "test", &query_emb);
667
668        assert_eq!(result.included.len(), 6); // all included
669        assert!(result.excluded.is_empty());
670    }
671
672    #[test]
673    fn top_k_exceeds_filterable_count_includes_all() {
674        let filter = make_filter(vec!["bash"], 100);
675        let all_ids: Vec<&str> = vec![
676            "bash",
677            "grep",
678            "write",
679            "find_path",
680            "web_scrape",
681            "diagnostics",
682        ];
683        let query_emb = vec![0.1, 0.1, 0.1];
684        let result = filter.filter(&all_ids, &[], "test", &query_emb);
685
686        assert_eq!(result.included.len(), 6);
687    }
688
689    #[test]
690    fn accessors_return_configured_values() {
691        let filter = make_filter(vec!["bash", "read"], 7);
692        assert_eq!(filter.top_k(), 7);
693        assert_eq!(filter.always_on_count(), 2);
694        assert_eq!(filter.embedding_count(), 5);
695    }
696
697    #[test]
698    fn version_counter_incremented_on_recompute() {
699        let mut filter = make_filter(vec![], 3);
700        assert_eq!(filter.version(), 0);
701        filter.recompute(vec![]);
702        assert_eq!(filter.version(), 1);
703        filter.recompute(vec![]);
704        assert_eq!(filter.version(), 2);
705    }
706
707    #[test]
708    fn inclusion_reason_correctness() {
709        let filter = make_filter(vec!["bash"], 1);
710        let all_ids: Vec<&str> = vec!["bash", "grep", "web_scrape", "write"];
711        let descriptions: Vec<(&str, &str)> = vec![("web_scrape", "Scrape")]; // 1 word
712        let query_emb = vec![0.1, 0.9, 0.0]; // close to write
713        let result = filter.filter(&all_ids, &descriptions, "test query", &query_emb);
714
715        let reasons: std::collections::HashMap<String, InclusionReason> =
716            result.inclusion_reasons.into_iter().collect();
717        assert_eq!(reasons.get("bash"), Some(&InclusionReason::AlwaysOn));
718        assert_eq!(
719            reasons.get("web_scrape"),
720            Some(&InclusionReason::ShortDescription)
721        );
722        assert_eq!(reasons.get("write"), Some(&InclusionReason::SimilarityRank));
723    }
724
725    #[test]
726    fn cosine_similarity_identical_vectors() {
727        let v = vec![1.0, 2.0, 3.0];
728        let sim = cosine_similarity(&v, &v);
729        assert!((sim - 1.0).abs() < 1e-5);
730    }
731
732    #[test]
733    fn cosine_similarity_orthogonal_vectors() {
734        let a = vec![1.0, 0.0];
735        let b = vec![0.0, 1.0];
736        let sim = cosine_similarity(&a, &b);
737        assert!(sim.abs() < 1e-5);
738    }
739
740    #[test]
741    fn cosine_similarity_empty_returns_zero() {
742        assert_eq!(cosine_similarity(&[], &[]), 0.0);
743    }
744
745    #[test]
746    fn cosine_similarity_mismatched_length_returns_zero() {
747        assert_eq!(cosine_similarity(&[1.0], &[1.0, 2.0]), 0.0);
748    }
749
750    #[test]
751    fn find_mentioned_tool_ids_case_insensitive() {
752        let ids = vec!["web_scrape", "grep", "Bash"];
753        let found = find_mentioned_tool_ids("use WEB_SCRAPE and BASH", &ids);
754        assert!(found.contains(&"web_scrape".to_owned()));
755        assert!(found.contains(&"Bash".to_owned()));
756        assert!(!found.contains(&"grep".to_owned()));
757    }
758
759    #[test]
760    fn find_mentioned_tool_ids_word_boundary_no_false_positives() {
761        let ids = vec!["read", "edit", "fetch", "grep"];
762        // "read" should NOT match inside "thread" or "breadcrumb"
763        let found = find_mentioned_tool_ids("thread breadcrumb", &ids);
764        assert!(found.is_empty());
765    }
766
767    #[test]
768    fn find_mentioned_tool_ids_word_boundary_matches_standalone() {
769        let ids = vec!["read", "edit"];
770        let found = find_mentioned_tool_ids("please read and edit the file", &ids);
771        assert!(found.contains(&"read".to_owned()));
772        assert!(found.contains(&"edit".to_owned()));
773    }
774
775    // --- ToolDependencyGraph tests ---
776
777    fn make_dep_graph(rules: &[(&str, Vec<&str>, Vec<&str>)]) -> ToolDependencyGraph {
778        let deps = rules
779            .iter()
780            .map(|(id, requires, prefers)| {
781                (
782                    (*id).to_owned(),
783                    crate::config::ToolDependency {
784                        requires: requires.iter().map(|s| (*s).to_owned()).collect(),
785                        prefers: prefers.iter().map(|s| (*s).to_owned()).collect(),
786                    },
787                )
788            })
789            .collect();
790        ToolDependencyGraph::new(deps)
791    }
792
793    fn completed(ids: &[&str]) -> HashSet<String> {
794        ids.iter().map(|s| (*s).to_owned()).collect()
795    }
796
797    #[test]
798    fn requirements_met_no_deps() {
799        let graph = make_dep_graph(&[]);
800        assert!(graph.requirements_met("any_tool", &completed(&[])));
801    }
802
803    #[test]
804    fn requirements_met_all_satisfied() {
805        let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
806        assert!(graph.requirements_met("apply_patch", &completed(&["read"])));
807    }
808
809    #[test]
810    fn requirements_met_unmet() {
811        let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
812        assert!(!graph.requirements_met("apply_patch", &completed(&[])));
813    }
814
815    #[test]
816    fn requirements_met_unconfigured_tool() {
817        let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
818        // tools not in the graph are always available
819        assert!(graph.requirements_met("grep", &completed(&[])));
820    }
821
822    #[test]
823    fn preference_boost_none_met() {
824        let graph = make_dep_graph(&[("format", vec![], vec!["search", "grep"])]);
825        let boost = graph.preference_boost("format", &completed(&[]), 0.15, 0.2);
826        assert_eq!(boost, 0.0);
827    }
828
829    #[test]
830    fn preference_boost_partial() {
831        let graph = make_dep_graph(&[("format", vec![], vec!["search", "grep"])]);
832        let boost = graph.preference_boost("format", &completed(&["search"]), 0.15, 0.2);
833        assert!((boost - 0.15).abs() < 1e-5);
834    }
835
836    #[test]
837    fn preference_boost_capped_at_max() {
838        // 3 prefs x 0.15 = 0.45 but max is 0.2
839        let graph = make_dep_graph(&[("format", vec![], vec!["a", "b", "c"])]);
840        let boost = graph.preference_boost("format", &completed(&["a", "b", "c"]), 0.15, 0.2);
841        assert!((boost - 0.2).abs() < 1e-5);
842    }
843
844    #[test]
845    fn cycle_detection_simple_cycle() {
846        // A requires B, B requires A → both should have requires cleared
847        let graph = make_dep_graph(&[
848            ("tool_a", vec!["tool_b"], vec![]),
849            ("tool_b", vec!["tool_a"], vec![]),
850        ]);
851        // After cycle removal both should be unconditionally available
852        assert!(graph.requirements_met("tool_a", &completed(&[])));
853        assert!(graph.requirements_met("tool_b", &completed(&[])));
854    }
855
856    #[test]
857    fn cycle_detection_does_not_affect_non_cycle_tools() {
858        // A requires B, B requires C (no cycle), C requires D (cycle: D requires C)
859        let graph = make_dep_graph(&[
860            ("tool_a", vec!["tool_b"], vec![]),
861            ("tool_b", vec!["tool_c"], vec![]),
862            ("tool_c", vec!["tool_d"], vec![]),
863            ("tool_d", vec!["tool_c"], vec![]), // cycle: c <-> d
864        ]);
865        // tool_c and tool_d participate in cycle → unconditionally available
866        assert!(graph.requirements_met("tool_c", &completed(&[])));
867        assert!(graph.requirements_met("tool_d", &completed(&[])));
868        // tool_a and tool_b are NOT in a cycle → still gated
869        assert!(!graph.requirements_met("tool_a", &completed(&[])));
870        assert!(!graph.requirements_met("tool_b", &completed(&[])));
871    }
872
873    #[test]
874    fn apply_excludes_gated_tool() {
875        let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
876        let filter = make_filter(vec!["bash"], 5);
877        let all_ids = vec!["bash", "read", "apply_patch", "grep"];
878        let query_emb = vec![0.5, 0.5, 0.0];
879        let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
880        // Ensure apply_patch is included before dependency gate
881        result.included.insert("apply_patch".into());
882
883        let always_on: HashSet<String> = ["bash".into()].into();
884        graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
885
886        assert!(!result.included.contains("apply_patch"));
887        assert_eq!(result.dependency_exclusions.len(), 1);
888        assert_eq!(result.dependency_exclusions[0].tool_id, "apply_patch");
889        assert_eq!(result.dependency_exclusions[0].unmet_requires, vec!["read"]);
890    }
891
892    #[test]
893    fn apply_includes_gated_tool_when_dep_met() {
894        let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
895        let filter = make_filter(vec!["bash"], 5);
896        let all_ids = vec!["bash", "read", "apply_patch"];
897        let query_emb = vec![0.5, 0.5, 0.0];
898        let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
899        result.included.insert("apply_patch".into());
900
901        let always_on: HashSet<String> = ["bash".into()].into();
902        graph.apply(&mut result, &completed(&["read"]), 0.15, 0.2, &always_on);
903
904        assert!(result.included.contains("apply_patch"));
905        assert!(result.dependency_exclusions.is_empty());
906    }
907
908    #[test]
909    fn apply_deadlock_fallback_when_all_gated() {
910        // Build a minimal filter with no embeddings so only bash (always-on) and
911        // only_tool (NoEmbedding) are in the result set.
912        let filter = ToolSchemaFilter::new(
913            vec!["bash".into()],
914            5,
915            5,
916            vec![], // no embeddings: only_tool will be included via NoEmbedding fallback
917        );
918        let graph = make_dep_graph(&[("only_tool", vec!["missing"], vec![])]);
919        let all_ids = vec!["bash", "only_tool"];
920        let query_emb = vec![0.5, 0.5, 0.0];
921        let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
922
923        // At this point: included = {bash, only_tool}, non_always_on_included = 1
924        assert!(result.included.contains("only_tool"));
925        assert!(result.included.contains("bash"));
926
927        let always_on: HashSet<String> = ["bash".into()].into();
928        graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
929
930        // Deadlock fallback: only_tool remains included (all non-always-on would be blocked)
931        assert!(result.included.contains("only_tool"));
932        assert!(result.dependency_exclusions.is_empty());
933    }
934
935    #[test]
936    fn apply_always_on_bypasses_gate() {
937        let graph = make_dep_graph(&[("bash", vec!["nonexistent"], vec![])]);
938        let filter = make_filter(vec!["bash"], 5);
939        let all_ids = vec!["bash", "grep"];
940        let query_emb = vec![0.5, 0.5, 0.0];
941        let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
942
943        let always_on: HashSet<String> = ["bash".into()].into();
944        graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
945
946        // bash is always-on, bypasses hard gate
947        assert!(result.included.contains("bash"));
948    }
949
950    // --- Regression tests for HIGH-01 and HIGH-02 ---
951
952    /// HIGH-01 regression: ancestors of a cycle must NOT lose their `requires`.
953    ///
954    /// Graph: A requires B, B requires C, C requires D, D requires C (cycle: C↔D).
955    /// Before fix: A and B were marked cycled and had their requires cleared.
956    /// After fix: only C and D are in the cycle; A and B remain gated.
957    #[test]
958    fn cycle_detection_does_not_clear_ancestor_requires() {
959        let graph = make_dep_graph(&[
960            ("tool_a", vec!["tool_b"], vec![]),
961            ("tool_b", vec!["tool_c"], vec![]),
962            ("tool_c", vec!["tool_d"], vec![]),
963            ("tool_d", vec!["tool_c"], vec![]),
964        ]);
965        // Cycle participants (C, D) must be unconditionally available.
966        assert!(graph.requirements_met("tool_c", &completed(&[])));
967        assert!(graph.requirements_met("tool_d", &completed(&[])));
968        // Non-cycle ancestors (A, B) must still be gated.
969        assert!(!graph.requirements_met("tool_a", &completed(&[])));
970        assert!(!graph.requirements_met("tool_b", &completed(&[])));
971        // A unlocks when B completes; B unlocks when C completes (C is free).
972        assert!(graph.requirements_met("tool_b", &completed(&["tool_c"])));
973        assert!(graph.requirements_met("tool_a", &completed(&["tool_b"])));
974    }
975
976    /// HIGH-02 regression: NameMentioned tools must still respect hard gates.
977    ///
978    /// If the user says "use apply_patch to fix the bug", apply_patch is
979    /// NameMentioned but must NOT bypass its requires=[read] constraint.
980    #[test]
981    fn name_mentioned_does_not_bypass_hard_gate() {
982        let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
983        let filter = make_filter(vec!["bash"], 5);
984        // Query explicitly mentions apply_patch → NameMentioned reason
985        let all_ids = vec!["bash", "read", "apply_patch"];
986        let query_emb = vec![0.5, 0.5, 0.0];
987        let mut result = filter.filter(&all_ids, &[], "use apply_patch to fix the bug", &query_emb);
988
989        // apply_patch must be in included (name-mentioned) before dependency gate
990        assert!(result.included.contains("apply_patch"));
991        let reason = result
992            .inclusion_reasons
993            .iter()
994            .find(|(id, _)| id == "apply_patch")
995            .map(|(_, r)| r);
996        assert_eq!(reason, Some(&InclusionReason::NameMentioned));
997
998        let always_on: HashSet<String> = ["bash".into()].into();
999        graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
1000
1001        // After gate: apply_patch must be excluded (read not completed)
1002        assert!(!result.included.contains("apply_patch"));
1003        assert_eq!(result.dependency_exclusions.len(), 1);
1004        assert_eq!(result.dependency_exclusions[0].tool_id, "apply_patch");
1005    }
1006
1007    // --- Multi-turn dependency chain integration tests ---
1008    //
1009    // These tests simulate the session lifecycle: `completed_tool_ids` grows
1010    // across turns, unlocking downstream tools one step at a time.
1011
1012    /// Turn 1: only `read` is available (no completed tools yet).
1013    /// Turn 2: after `read` completes, `apply_patch` unlocks.
1014    #[test]
1015    fn multi_turn_chain_two_steps() {
1016        // read → apply_patch (linear dependency)
1017        let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
1018        let always_on: HashSet<String> = ["bash".into()].into();
1019
1020        // --- Turn 1: nothing completed yet ---
1021        let filter = ToolSchemaFilter::new(vec!["bash".into()], 5, 5, vec![]);
1022        let all_ids = vec!["bash", "read", "apply_patch"];
1023        let q = vec![0.5, 0.5, 0.0];
1024        let mut result = filter.filter(&all_ids, &[], "fix bug", &q);
1025        graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
1026
1027        // apply_patch should be excluded (read not completed)
1028        assert!(!result.included.contains("apply_patch"));
1029        assert_eq!(result.dependency_exclusions.len(), 1);
1030
1031        // --- Turn 2: `read` was executed successfully ---
1032        let mut result2 = filter.filter(&all_ids, &[], "fix bug", &q);
1033        graph.apply(&mut result2, &completed(&["read"]), 0.15, 0.2, &always_on);
1034
1035        // apply_patch should now be included
1036        assert!(result2.included.contains("apply_patch"));
1037        assert!(result2.dependency_exclusions.is_empty());
1038    }
1039
1040    /// Three-step linear chain: read → search → apply_patch.
1041    /// Each turn unlocks exactly one more tool.
1042    #[test]
1043    fn multi_turn_chain_three_steps() {
1044        let graph = make_dep_graph(&[
1045            ("search", vec!["read"], vec![]),
1046            ("apply_patch", vec!["search"], vec![]),
1047        ]);
1048        let always_on: HashSet<String> = ["bash".into()].into();
1049        let filter = ToolSchemaFilter::new(vec!["bash".into()], 5, 5, vec![]);
1050        let all_ids = vec!["bash", "read", "search", "apply_patch"];
1051        let q = vec![0.5, 0.5, 0.0];
1052
1053        // Turn 1: only read available
1054        let mut r1 = filter.filter(&all_ids, &[], "q", &q);
1055        graph.apply(&mut r1, &completed(&[]), 0.15, 0.2, &always_on);
1056        assert!(r1.included.contains("read"));
1057        assert!(!r1.included.contains("search"));
1058        assert!(!r1.included.contains("apply_patch"));
1059
1060        // Turn 2: read done, search unlocked
1061        let mut r2 = filter.filter(&all_ids, &[], "q", &q);
1062        graph.apply(&mut r2, &completed(&["read"]), 0.15, 0.2, &always_on);
1063        assert!(r2.included.contains("search"));
1064        assert!(!r2.included.contains("apply_patch"));
1065
1066        // Turn 3: search done, apply_patch unlocked
1067        let mut r3 = filter.filter(&all_ids, &[], "q", &q);
1068        graph.apply(
1069            &mut r3,
1070            &completed(&["read", "search"]),
1071            0.15,
1072            0.2,
1073            &always_on,
1074        );
1075        assert!(r3.included.contains("apply_patch"));
1076        assert!(r3.dependency_exclusions.is_empty());
1077    }
1078
1079    /// Multi-requires: apply_patch needs both `read` AND `search` to be done.
1080    #[test]
1081    fn multi_turn_multi_requires_both_must_complete() {
1082        let graph = make_dep_graph(&[("apply_patch", vec!["read", "search"], vec![])]);
1083        let always_on: HashSet<String> = ["bash".into()].into();
1084        let filter = ToolSchemaFilter::new(vec!["bash".into()], 5, 5, vec![]);
1085        let all_ids = vec!["bash", "read", "search", "apply_patch"];
1086        let q = vec![0.5, 0.5, 0.0];
1087
1088        // Only `read` done — not enough
1089        let mut r1 = filter.filter(&all_ids, &[], "q", &q);
1090        graph.apply(&mut r1, &completed(&["read"]), 0.15, 0.2, &always_on);
1091        assert!(!r1.included.contains("apply_patch"));
1092        let excl = &r1.dependency_exclusions[0];
1093        assert_eq!(excl.unmet_requires, vec!["search"]);
1094
1095        // Both done — unlocked
1096        let mut r2 = filter.filter(&all_ids, &[], "q", &q);
1097        graph.apply(
1098            &mut r2,
1099            &completed(&["read", "search"]),
1100            0.15,
1101            0.2,
1102            &always_on,
1103        );
1104        assert!(r2.included.contains("apply_patch"));
1105        assert!(r2.dependency_exclusions.is_empty());
1106    }
1107
1108    /// Preference boost increases across turns as soft deps are satisfied.
1109    ///
1110    /// A tool must have a cached embedding to appear in `scores` and receive a
1111    /// score adjustment from `apply()`. This test uses a filter with an explicit
1112    /// embedding for `format` so the score is trackable.
1113    #[test]
1114    fn multi_turn_preference_boost_accumulates() {
1115        // format prefers search and grep (soft deps)
1116        let graph = make_dep_graph(&[("format", vec![], vec!["search", "grep"])]);
1117        let always_on: HashSet<String> = HashSet::new();
1118        // Give format a real embedding so it appears in `scores`.
1119        let filter = ToolSchemaFilter::new(
1120            vec![],
1121            5,
1122            5,
1123            vec![
1124                ToolEmbedding {
1125                    tool_id: "format".into(),
1126                    embedding: vec![0.6, 0.4, 0.0],
1127                },
1128                ToolEmbedding {
1129                    tool_id: "search".into(),
1130                    embedding: vec![0.7, 0.3, 0.0],
1131                },
1132                ToolEmbedding {
1133                    tool_id: "grep".into(),
1134                    embedding: vec![0.8, 0.2, 0.0],
1135                },
1136            ],
1137        );
1138        let all_ids = vec!["format", "search", "grep"];
1139        let q = vec![0.5, 0.5, 0.0];
1140        let boost_per = 0.15_f32;
1141        let max_boost = 0.3_f32;
1142
1143        let score_of = |result: &ToolFilterResult, id: &str| -> f32 {
1144            result
1145                .scores
1146                .iter()
1147                .find(|(tid, _)| tid == id)
1148                .map(|(_, s)| *s)
1149                .unwrap_or(0.0)
1150        };
1151
1152        // Turn 1: no prefs satisfied — no boost
1153        let mut r1 = filter.filter(&all_ids, &[], "q", &q);
1154        let base_score = score_of(&r1, "format");
1155        graph.apply(&mut r1, &completed(&[]), boost_per, max_boost, &always_on);
1156        assert!((score_of(&r1, "format") - base_score).abs() < 1e-5);
1157
1158        // Turn 2: search done → +0.15
1159        let mut r2 = filter.filter(&all_ids, &[], "q", &q);
1160        graph.apply(
1161            &mut r2,
1162            &completed(&["search"]),
1163            boost_per,
1164            max_boost,
1165            &always_on,
1166        );
1167        let delta2 = score_of(&r2, "format") - base_score;
1168        assert!(
1169            (delta2 - 0.15).abs() < 1e-4,
1170            "expected +0.15 boost, got {delta2}"
1171        );
1172
1173        // Turn 3: both done → +0.30 (2 * 0.15, within max_boost=0.30)
1174        let mut r3 = filter.filter(&all_ids, &[], "q", &q);
1175        graph.apply(
1176            &mut r3,
1177            &completed(&["search", "grep"]),
1178            boost_per,
1179            max_boost,
1180            &always_on,
1181        );
1182        let delta3 = score_of(&r3, "format") - base_score;
1183        assert!(
1184            (delta3 - 0.30).abs() < 1e-4,
1185            "expected +0.30 boost, got {delta3}"
1186        );
1187    }
1188
1189    /// `filter_tool_names` used for iteration 1+ gating in the native tool loop.
1190    /// Simulates: iteration 0 executes `read`, iteration 1 should now allow `apply_patch`.
1191    #[test]
1192    fn filter_tool_names_multi_turn_unlocks_after_completion() {
1193        let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
1194        let always_on: HashSet<String> = ["bash".into()].into();
1195        let all_names = vec!["bash", "read", "apply_patch"];
1196
1197        // Before read completes
1198        let filtered_before = graph.filter_tool_names(&all_names, &completed(&[]), &always_on);
1199        assert!(filtered_before.contains(&"bash")); // always-on
1200        assert!(filtered_before.contains(&"read")); // no deps
1201        assert!(!filtered_before.contains(&"apply_patch")); // gated
1202
1203        // After read completes
1204        let filtered_after = graph.filter_tool_names(&all_names, &completed(&["read"]), &always_on);
1205        assert!(filtered_after.contains(&"bash"));
1206        assert!(filtered_after.contains(&"read"));
1207        assert!(filtered_after.contains(&"apply_patch")); // unlocked
1208    }
1209
1210    /// Deadlock fallback in `filter_tool_names`: if all non-always-on names would
1211    /// be filtered, return them all unfiltered.
1212    #[test]
1213    fn filter_tool_names_deadlock_fallback_passes_all() {
1214        // only_tool requires `missing` which is never completed
1215        let graph = make_dep_graph(&[("only_tool", vec!["missing"], vec![])]);
1216        let always_on: HashSet<String> = ["bash".into()].into();
1217        let all_names = vec!["bash", "only_tool"];
1218
1219        let filtered = graph.filter_tool_names(&all_names, &completed(&[]), &always_on);
1220
1221        // bash is always-on, only_tool would be gated.
1222        // filter_tool_names does NOT implement deadlock fallback itself —
1223        // it is the caller's responsibility. Verify gating behaviour here:
1224        // only_tool is excluded, only bash passes.
1225        assert!(filtered.contains(&"bash"));
1226        assert!(!filtered.contains(&"only_tool"));
1227    }
1228}