1use std::collections::{HashMap, HashSet};
11
12use zeph_common::math::cosine_similarity;
13
14use crate::config::ToolDependency;
15
16#[derive(Debug, Clone)]
18pub struct ToolEmbedding {
19 pub tool_id: String,
20 pub embedding: Vec<f32>,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum InclusionReason {
26 AlwaysOn,
28 NameMentioned,
30 SimilarityRank,
32 ShortDescription,
34 NoEmbedding,
36 DependencyMet,
38 PreferenceBoost,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq)]
44pub struct DependencyExclusion {
45 pub tool_id: String,
46 pub unmet_requires: Vec<String>,
48}
49
50#[derive(Debug, Clone)]
52pub struct ToolFilterResult {
53 pub included: HashSet<String>,
55 pub excluded: Vec<String>,
57 pub scores: Vec<(String, f32)>,
59 pub inclusion_reasons: Vec<(String, InclusionReason)>,
61 pub dependency_exclusions: Vec<DependencyExclusion>,
63}
64
65#[derive(Debug, Clone, Default)]
78pub struct ToolDependencyGraph {
79 deps: HashMap<String, ToolDependency>,
82}
83
84impl ToolDependencyGraph {
85 #[must_use]
90 pub fn new(deps: HashMap<String, ToolDependency>) -> Self {
91 if deps.is_empty() {
92 return Self { deps };
93 }
94 let cycled = detect_cycles(&deps);
95 if !cycled.is_empty() {
96 tracing::warn!(
97 tools = ?cycled,
98 "tool dependency graph: cycles detected, removing requires for cycle participants"
99 );
100 }
101 let mut resolved = deps;
102 for tool_id in &cycled {
103 if let Some(dep) = resolved.get_mut(tool_id) {
104 dep.requires.clear();
105 }
106 }
107 Self { deps: resolved }
108 }
109
110 #[must_use]
112 pub fn is_empty(&self) -> bool {
113 self.deps.is_empty()
114 }
115
116 #[must_use]
121 pub fn requirements_met(&self, tool_id: &str, completed: &HashSet<String>) -> bool {
122 self.deps
123 .get(tool_id)
124 .is_none_or(|d| d.requires.iter().all(|r| completed.contains(r)))
125 }
126
127 #[must_use]
129 pub fn unmet_requires<'a>(
130 &'a self,
131 tool_id: &str,
132 completed: &HashSet<String>,
133 ) -> Vec<&'a str> {
134 self.deps.get(tool_id).map_or_else(Vec::new, |d| {
135 d.requires
136 .iter()
137 .filter(|r| !completed.contains(r.as_str()))
138 .map(String::as_str)
139 .collect()
140 })
141 }
142
143 #[must_use]
147 pub fn preference_boost(
148 &self,
149 tool_id: &str,
150 completed: &HashSet<String>,
151 boost_per_dep: f32,
152 max_total_boost: f32,
153 ) -> f32 {
154 self.deps.get(tool_id).map_or(0.0, |d| {
155 let met = d
156 .prefers
157 .iter()
158 .filter(|p| completed.contains(p.as_str()))
159 .count();
160 #[allow(clippy::cast_precision_loss)]
161 let boost = met as f32 * boost_per_dep;
162 boost.min(max_total_boost)
163 })
164 }
165
166 pub fn apply(
180 &self,
181 result: &mut ToolFilterResult,
182 completed: &HashSet<String>,
183 boost_per_dep: f32,
184 max_total_boost: f32,
185 always_on: &HashSet<String>,
186 ) {
187 if self.deps.is_empty() {
188 return;
189 }
190
191 let bypassed: HashSet<&str> = result
195 .inclusion_reasons
196 .iter()
197 .filter(|(_, r)| matches!(r, InclusionReason::AlwaysOn))
198 .map(|(id, _)| id.as_str())
199 .collect();
200
201 let mut to_exclude: Vec<DependencyExclusion> = Vec::new();
202 for tool_id in &result.included {
203 if bypassed.contains(tool_id.as_str()) {
204 continue;
205 }
206 let unmet: Vec<String> = self
207 .unmet_requires(tool_id, completed)
208 .into_iter()
209 .map(str::to_owned)
210 .collect();
211 if !unmet.is_empty() {
212 to_exclude.push(DependencyExclusion {
213 tool_id: tool_id.clone(),
214 unmet_requires: unmet,
215 });
216 }
217 }
218
219 let non_always_on_included: usize = result
222 .included
223 .iter()
224 .filter(|id| !always_on.contains(id.as_str()))
225 .count();
226 if !to_exclude.is_empty() && to_exclude.len() >= non_always_on_included {
227 tracing::warn!(
228 gated = to_exclude.len(),
229 non_always_on = non_always_on_included,
230 "tool dependency graph: all non-always-on tools would be blocked; \
231 disabling hard gates for this turn"
232 );
233 to_exclude.clear();
234 }
235
236 for excl in &to_exclude {
238 result.included.remove(&excl.tool_id);
239 result.excluded.push(excl.tool_id.clone());
240 tracing::debug!(
241 tool_id = %excl.tool_id,
242 unmet = ?excl.unmet_requires,
243 "tool dependency gate: excluded (requires not met)"
244 );
245 }
246 result.dependency_exclusions = to_exclude;
247
248 for (tool_id, score) in &mut result.scores {
250 if !result.included.contains(tool_id) {
251 continue;
252 }
253 let boost = self.preference_boost(tool_id, completed, boost_per_dep, max_total_boost);
254 if boost > 0.0 {
255 *score += boost;
256 let already_recorded = result.inclusion_reasons.iter().any(|(id, _)| id == tool_id);
258 if !already_recorded {
259 result
260 .inclusion_reasons
261 .push((tool_id.clone(), InclusionReason::PreferenceBoost));
262 }
263 tracing::debug!(
264 tool_id = %tool_id,
265 boost,
266 "tool dependency: preference boost applied"
267 );
268 }
269 }
270 result
272 .scores
273 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
274 }
275
276 #[must_use]
281 pub fn filter_tool_names<'a>(
282 &self,
283 names: &[&'a str],
284 completed: &HashSet<String>,
285 always_on: &HashSet<String>,
286 ) -> Vec<&'a str> {
287 names
288 .iter()
289 .copied()
290 .filter(|n| always_on.contains(*n) || self.requirements_met(n, completed))
291 .collect()
292 }
293}
294
295fn detect_cycles(deps: &HashMap<String, ToolDependency>) -> HashSet<String> {
302 #[derive(Clone, Copy, PartialEq)]
303 enum State {
304 Unvisited,
305 InProgress,
306 Done,
307 }
308
309 let mut state: HashMap<&str, State> = HashMap::new();
310 let mut cycled: HashSet<String> = HashSet::new();
311
312 for start in deps.keys() {
313 if state
314 .get(start.as_str())
315 .copied()
316 .unwrap_or(State::Unvisited)
317 != State::Unvisited
318 {
319 continue;
320 }
321 let mut stack: Vec<(&str, usize)> = vec![(start.as_str(), 0)];
322 state.insert(start.as_str(), State::InProgress);
323
324 while let Some((node, child_idx)) = stack.last_mut() {
325 let node = *node;
326 let requires = deps
327 .get(node)
328 .map_or(&[] as &[String], |d| d.requires.as_slice());
329
330 if *child_idx >= requires.len() {
331 state.insert(node, State::Done);
332 stack.pop();
333 continue;
334 }
335
336 let child = requires[*child_idx].as_str();
337 *child_idx += 1;
338
339 match state.get(child).copied().unwrap_or(State::Unvisited) {
340 State::InProgress => {
341 let cycle_start = stack.iter().position(|(n, _)| *n == child);
345 if let Some(start) = cycle_start {
346 for (path_node, _) in &stack[start..] {
347 cycled.insert((*path_node).to_owned());
348 }
349 }
350 cycled.insert(child.to_owned());
351 }
352 State::Unvisited => {
353 state.insert(child, State::InProgress);
354 stack.push((child, 0));
355 }
356 State::Done => {}
357 }
358 }
359 }
360
361 cycled
362}
363
364pub struct ToolSchemaFilter {
366 always_on: HashSet<String>,
367 top_k: usize,
368 min_description_words: usize,
369 embeddings: Vec<ToolEmbedding>,
370 version: u64,
371}
372
373impl ToolSchemaFilter {
374 #[must_use]
376 pub fn new(
377 always_on: Vec<String>,
378 top_k: usize,
379 min_description_words: usize,
380 embeddings: Vec<ToolEmbedding>,
381 ) -> Self {
382 Self {
383 always_on: always_on.into_iter().collect(),
384 top_k,
385 min_description_words,
386 embeddings,
387 version: 0,
388 }
389 }
390
391 #[must_use]
393 pub fn version(&self) -> u64 {
394 self.version
395 }
396
397 #[must_use]
399 pub fn embedding_count(&self) -> usize {
400 self.embeddings.len()
401 }
402
403 #[must_use]
405 pub fn top_k(&self) -> usize {
406 self.top_k
407 }
408
409 #[must_use]
411 pub fn always_on_count(&self) -> usize {
412 self.always_on.len()
413 }
414
415 pub fn recompute(&mut self, embeddings: Vec<ToolEmbedding>) {
417 self.embeddings = embeddings;
418 self.version += 1;
419 }
420
421 #[must_use]
427 pub fn filter(
428 &self,
429 all_tool_ids: &[&str],
430 tool_descriptions: &[(&str, &str)],
431 query: &str,
432 query_embedding: &[f32],
433 ) -> ToolFilterResult {
434 let mut included = HashSet::new();
435 let mut inclusion_reasons = Vec::new();
436
437 for id in all_tool_ids {
439 if self.always_on.contains(*id) {
440 included.insert((*id).to_owned());
441 inclusion_reasons.push(((*id).to_owned(), InclusionReason::AlwaysOn));
442 }
443 }
444
445 let mentioned = find_mentioned_tool_ids(query, all_tool_ids);
447 for id in &mentioned {
448 if included.insert(id.clone()) {
449 inclusion_reasons.push((id.clone(), InclusionReason::NameMentioned));
450 }
451 }
452
453 for &(id, desc) in tool_descriptions {
455 let word_count = desc.split_whitespace().count();
456 if word_count < self.min_description_words && included.insert(id.to_owned()) {
457 inclusion_reasons.push((id.to_owned(), InclusionReason::ShortDescription));
458 }
459 }
460
461 let mut scores: Vec<(String, f32)> = self
463 .embeddings
464 .iter()
465 .filter(|e| !included.contains(&e.tool_id))
466 .map(|e| {
467 let score = cosine_similarity(query_embedding, &e.embedding);
468 (e.tool_id.clone(), score)
469 })
470 .collect();
471
472 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
473
474 let take = if self.top_k == 0 {
475 scores.len()
476 } else {
477 self.top_k.min(scores.len())
478 };
479
480 for (id, _score) in scores.iter().take(take) {
481 if included.insert(id.clone()) {
482 inclusion_reasons.push((id.clone(), InclusionReason::SimilarityRank));
483 }
484 }
485
486 let embedded_ids: HashSet<&str> =
488 self.embeddings.iter().map(|e| e.tool_id.as_str()).collect();
489 for id in all_tool_ids {
490 if !included.contains(*id) && !embedded_ids.contains(*id) {
491 included.insert((*id).to_owned());
492 inclusion_reasons.push(((*id).to_owned(), InclusionReason::NoEmbedding));
493 }
494 }
495
496 let excluded: Vec<String> = all_tool_ids
498 .iter()
499 .filter(|id| !included.contains(**id))
500 .map(|id| (*id).to_owned())
501 .collect();
502
503 ToolFilterResult {
504 included,
505 excluded,
506 scores,
507 inclusion_reasons,
508 dependency_exclusions: Vec::new(),
509 }
510 }
511}
512
513#[must_use]
518pub fn find_mentioned_tool_ids(query: &str, all_tool_ids: &[&str]) -> Vec<String> {
519 let query_lower = query.to_lowercase();
520 all_tool_ids
521 .iter()
522 .filter(|id| {
523 let id_lower = id.to_lowercase();
524 let mut start = 0;
525 while let Some(pos) = query_lower[start..].find(&id_lower) {
526 let abs_pos = start + pos;
527 let end_pos = abs_pos + id_lower.len();
528 let before_ok = abs_pos == 0
529 || !query_lower.as_bytes()[abs_pos - 1].is_ascii_alphanumeric()
530 && query_lower.as_bytes()[abs_pos - 1] != b'_';
531 let after_ok = end_pos >= query_lower.len()
532 || !query_lower.as_bytes()[end_pos].is_ascii_alphanumeric()
533 && query_lower.as_bytes()[end_pos] != b'_';
534 if before_ok && after_ok {
535 return true;
536 }
537 start = abs_pos + 1;
538 }
539 false
540 })
541 .map(|id| (*id).to_owned())
542 .collect()
543}
544
545#[cfg(test)]
546mod tests {
547 use super::*;
548
549 fn make_filter(always_on: Vec<&str>, top_k: usize) -> ToolSchemaFilter {
550 ToolSchemaFilter::new(
551 always_on.into_iter().map(String::from).collect(),
552 top_k,
553 5,
554 vec![
555 ToolEmbedding {
556 tool_id: "grep".into(),
557 embedding: vec![0.9, 0.1, 0.0],
558 },
559 ToolEmbedding {
560 tool_id: "write".into(),
561 embedding: vec![0.1, 0.9, 0.0],
562 },
563 ToolEmbedding {
564 tool_id: "find_path".into(),
565 embedding: vec![0.5, 0.5, 0.0],
566 },
567 ToolEmbedding {
568 tool_id: "web_scrape".into(),
569 embedding: vec![0.0, 0.0, 1.0],
570 },
571 ToolEmbedding {
572 tool_id: "diagnostics".into(),
573 embedding: vec![0.0, 0.1, 0.9],
574 },
575 ],
576 )
577 }
578
579 #[test]
580 fn top_k_ranking_selects_most_similar() {
581 let filter = make_filter(vec!["bash"], 2);
582 let all_ids: Vec<&str> = vec![
583 "bash",
584 "grep",
585 "write",
586 "find_path",
587 "web_scrape",
588 "diagnostics",
589 ];
590 let query_emb = vec![0.8, 0.2, 0.0]; let result = filter.filter(&all_ids, &[], "search for pattern", &query_emb);
592
593 assert!(result.included.contains("bash")); assert!(result.included.contains("grep")); assert!(result.included.contains("find_path")); assert!(!result.included.contains("web_scrape"));
598 assert!(!result.included.contains("diagnostics"));
599 }
600
601 #[test]
602 fn always_on_tools_always_included() {
603 let filter = make_filter(vec!["bash", "read"], 1);
604 let all_ids: Vec<&str> = vec!["bash", "read", "grep", "write"];
605 let query_emb = vec![0.0, 1.0, 0.0]; let result = filter.filter(&all_ids, &[], "test query", &query_emb);
607
608 assert!(result.included.contains("bash"));
609 assert!(result.included.contains("read"));
610 assert!(result.included.contains("write")); assert!(!result.included.contains("grep"));
612 }
613
614 #[test]
615 fn name_mention_force_includes() {
616 let filter = make_filter(vec!["bash"], 1);
617 let all_ids: Vec<&str> = vec!["bash", "grep", "web_scrape", "write"];
618 let query_emb = vec![0.0, 1.0, 0.0]; let result = filter.filter(&all_ids, &[], "use web_scrape to fetch", &query_emb);
620
621 assert!(result.included.contains("web_scrape")); assert!(result.included.contains("write")); assert!(result.included.contains("bash")); }
625
626 #[test]
627 fn short_mcp_description_auto_included() {
628 let filter = make_filter(vec!["bash"], 1);
629 let all_ids: Vec<&str> = vec!["bash", "grep", "mcp_query"];
630 let descriptions: Vec<(&str, &str)> = vec![
631 ("mcp_query", "Run query"),
632 ("grep", "Search file contents recursively"),
633 ];
634 let query_emb = vec![0.9, 0.1, 0.0];
635 let result = filter.filter(&all_ids, &descriptions, "test", &query_emb);
636
637 assert!(result.included.contains("mcp_query")); }
639
640 #[test]
641 fn empty_embeddings_includes_all_via_no_embedding_fallback() {
642 let filter = ToolSchemaFilter::new(vec!["bash".into()], 6, 5, vec![]);
643 let all_ids: Vec<&str> = vec!["bash", "grep", "write"];
644 let query_emb = vec![0.5, 0.5, 0.0];
645 let result = filter.filter(&all_ids, &[], "test", &query_emb);
646
647 assert!(result.included.contains("bash"));
649 assert!(result.included.contains("grep"));
650 assert!(result.included.contains("write"));
651 assert!(result.excluded.is_empty());
652 }
653
654 #[test]
655 fn top_k_zero_includes_all_filterable() {
656 let filter = make_filter(vec!["bash"], 0);
657 let all_ids: Vec<&str> = vec![
658 "bash",
659 "grep",
660 "write",
661 "find_path",
662 "web_scrape",
663 "diagnostics",
664 ];
665 let query_emb = vec![0.1, 0.1, 0.1];
666 let result = filter.filter(&all_ids, &[], "test", &query_emb);
667
668 assert_eq!(result.included.len(), 6); assert!(result.excluded.is_empty());
670 }
671
672 #[test]
673 fn top_k_exceeds_filterable_count_includes_all() {
674 let filter = make_filter(vec!["bash"], 100);
675 let all_ids: Vec<&str> = vec![
676 "bash",
677 "grep",
678 "write",
679 "find_path",
680 "web_scrape",
681 "diagnostics",
682 ];
683 let query_emb = vec![0.1, 0.1, 0.1];
684 let result = filter.filter(&all_ids, &[], "test", &query_emb);
685
686 assert_eq!(result.included.len(), 6);
687 }
688
689 #[test]
690 fn accessors_return_configured_values() {
691 let filter = make_filter(vec!["bash", "read"], 7);
692 assert_eq!(filter.top_k(), 7);
693 assert_eq!(filter.always_on_count(), 2);
694 assert_eq!(filter.embedding_count(), 5);
695 }
696
697 #[test]
698 fn version_counter_incremented_on_recompute() {
699 let mut filter = make_filter(vec![], 3);
700 assert_eq!(filter.version(), 0);
701 filter.recompute(vec![]);
702 assert_eq!(filter.version(), 1);
703 filter.recompute(vec![]);
704 assert_eq!(filter.version(), 2);
705 }
706
707 #[test]
708 fn inclusion_reason_correctness() {
709 let filter = make_filter(vec!["bash"], 1);
710 let all_ids: Vec<&str> = vec!["bash", "grep", "web_scrape", "write"];
711 let descriptions: Vec<(&str, &str)> = vec![("web_scrape", "Scrape")]; let query_emb = vec![0.1, 0.9, 0.0]; let result = filter.filter(&all_ids, &descriptions, "test query", &query_emb);
714
715 let reasons: std::collections::HashMap<String, InclusionReason> =
716 result.inclusion_reasons.into_iter().collect();
717 assert_eq!(reasons.get("bash"), Some(&InclusionReason::AlwaysOn));
718 assert_eq!(
719 reasons.get("web_scrape"),
720 Some(&InclusionReason::ShortDescription)
721 );
722 assert_eq!(reasons.get("write"), Some(&InclusionReason::SimilarityRank));
723 }
724
725 #[test]
726 fn cosine_similarity_identical_vectors() {
727 let v = vec![1.0, 2.0, 3.0];
728 let sim = cosine_similarity(&v, &v);
729 assert!((sim - 1.0).abs() < 1e-5);
730 }
731
732 #[test]
733 fn cosine_similarity_orthogonal_vectors() {
734 let a = vec![1.0, 0.0];
735 let b = vec![0.0, 1.0];
736 let sim = cosine_similarity(&a, &b);
737 assert!(sim.abs() < 1e-5);
738 }
739
740 #[test]
741 fn cosine_similarity_empty_returns_zero() {
742 assert_eq!(cosine_similarity(&[], &[]), 0.0);
743 }
744
745 #[test]
746 fn cosine_similarity_mismatched_length_returns_zero() {
747 assert_eq!(cosine_similarity(&[1.0], &[1.0, 2.0]), 0.0);
748 }
749
750 #[test]
751 fn find_mentioned_tool_ids_case_insensitive() {
752 let ids = vec!["web_scrape", "grep", "Bash"];
753 let found = find_mentioned_tool_ids("use WEB_SCRAPE and BASH", &ids);
754 assert!(found.contains(&"web_scrape".to_owned()));
755 assert!(found.contains(&"Bash".to_owned()));
756 assert!(!found.contains(&"grep".to_owned()));
757 }
758
759 #[test]
760 fn find_mentioned_tool_ids_word_boundary_no_false_positives() {
761 let ids = vec!["read", "edit", "fetch", "grep"];
762 let found = find_mentioned_tool_ids("thread breadcrumb", &ids);
764 assert!(found.is_empty());
765 }
766
767 #[test]
768 fn find_mentioned_tool_ids_word_boundary_matches_standalone() {
769 let ids = vec!["read", "edit"];
770 let found = find_mentioned_tool_ids("please read and edit the file", &ids);
771 assert!(found.contains(&"read".to_owned()));
772 assert!(found.contains(&"edit".to_owned()));
773 }
774
775 fn make_dep_graph(rules: &[(&str, Vec<&str>, Vec<&str>)]) -> ToolDependencyGraph {
778 let deps = rules
779 .iter()
780 .map(|(id, requires, prefers)| {
781 (
782 (*id).to_owned(),
783 crate::config::ToolDependency {
784 requires: requires.iter().map(|s| (*s).to_owned()).collect(),
785 prefers: prefers.iter().map(|s| (*s).to_owned()).collect(),
786 },
787 )
788 })
789 .collect();
790 ToolDependencyGraph::new(deps)
791 }
792
793 fn completed(ids: &[&str]) -> HashSet<String> {
794 ids.iter().map(|s| (*s).to_owned()).collect()
795 }
796
797 #[test]
798 fn requirements_met_no_deps() {
799 let graph = make_dep_graph(&[]);
800 assert!(graph.requirements_met("any_tool", &completed(&[])));
801 }
802
803 #[test]
804 fn requirements_met_all_satisfied() {
805 let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
806 assert!(graph.requirements_met("apply_patch", &completed(&["read"])));
807 }
808
809 #[test]
810 fn requirements_met_unmet() {
811 let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
812 assert!(!graph.requirements_met("apply_patch", &completed(&[])));
813 }
814
815 #[test]
816 fn requirements_met_unconfigured_tool() {
817 let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
818 assert!(graph.requirements_met("grep", &completed(&[])));
820 }
821
822 #[test]
823 fn preference_boost_none_met() {
824 let graph = make_dep_graph(&[("format", vec![], vec!["search", "grep"])]);
825 let boost = graph.preference_boost("format", &completed(&[]), 0.15, 0.2);
826 assert_eq!(boost, 0.0);
827 }
828
829 #[test]
830 fn preference_boost_partial() {
831 let graph = make_dep_graph(&[("format", vec![], vec!["search", "grep"])]);
832 let boost = graph.preference_boost("format", &completed(&["search"]), 0.15, 0.2);
833 assert!((boost - 0.15).abs() < 1e-5);
834 }
835
836 #[test]
837 fn preference_boost_capped_at_max() {
838 let graph = make_dep_graph(&[("format", vec![], vec!["a", "b", "c"])]);
840 let boost = graph.preference_boost("format", &completed(&["a", "b", "c"]), 0.15, 0.2);
841 assert!((boost - 0.2).abs() < 1e-5);
842 }
843
844 #[test]
845 fn cycle_detection_simple_cycle() {
846 let graph = make_dep_graph(&[
848 ("tool_a", vec!["tool_b"], vec![]),
849 ("tool_b", vec!["tool_a"], vec![]),
850 ]);
851 assert!(graph.requirements_met("tool_a", &completed(&[])));
853 assert!(graph.requirements_met("tool_b", &completed(&[])));
854 }
855
856 #[test]
857 fn cycle_detection_does_not_affect_non_cycle_tools() {
858 let graph = make_dep_graph(&[
860 ("tool_a", vec!["tool_b"], vec![]),
861 ("tool_b", vec!["tool_c"], vec![]),
862 ("tool_c", vec!["tool_d"], vec![]),
863 ("tool_d", vec!["tool_c"], vec![]), ]);
865 assert!(graph.requirements_met("tool_c", &completed(&[])));
867 assert!(graph.requirements_met("tool_d", &completed(&[])));
868 assert!(!graph.requirements_met("tool_a", &completed(&[])));
870 assert!(!graph.requirements_met("tool_b", &completed(&[])));
871 }
872
873 #[test]
874 fn apply_excludes_gated_tool() {
875 let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
876 let filter = make_filter(vec!["bash"], 5);
877 let all_ids = vec!["bash", "read", "apply_patch", "grep"];
878 let query_emb = vec![0.5, 0.5, 0.0];
879 let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
880 result.included.insert("apply_patch".into());
882
883 let always_on: HashSet<String> = ["bash".into()].into();
884 graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
885
886 assert!(!result.included.contains("apply_patch"));
887 assert_eq!(result.dependency_exclusions.len(), 1);
888 assert_eq!(result.dependency_exclusions[0].tool_id, "apply_patch");
889 assert_eq!(result.dependency_exclusions[0].unmet_requires, vec!["read"]);
890 }
891
892 #[test]
893 fn apply_includes_gated_tool_when_dep_met() {
894 let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
895 let filter = make_filter(vec!["bash"], 5);
896 let all_ids = vec!["bash", "read", "apply_patch"];
897 let query_emb = vec![0.5, 0.5, 0.0];
898 let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
899 result.included.insert("apply_patch".into());
900
901 let always_on: HashSet<String> = ["bash".into()].into();
902 graph.apply(&mut result, &completed(&["read"]), 0.15, 0.2, &always_on);
903
904 assert!(result.included.contains("apply_patch"));
905 assert!(result.dependency_exclusions.is_empty());
906 }
907
908 #[test]
909 fn apply_deadlock_fallback_when_all_gated() {
910 let filter = ToolSchemaFilter::new(
913 vec!["bash".into()],
914 5,
915 5,
916 vec![], );
918 let graph = make_dep_graph(&[("only_tool", vec!["missing"], vec![])]);
919 let all_ids = vec!["bash", "only_tool"];
920 let query_emb = vec![0.5, 0.5, 0.0];
921 let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
922
923 assert!(result.included.contains("only_tool"));
925 assert!(result.included.contains("bash"));
926
927 let always_on: HashSet<String> = ["bash".into()].into();
928 graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
929
930 assert!(result.included.contains("only_tool"));
932 assert!(result.dependency_exclusions.is_empty());
933 }
934
935 #[test]
936 fn apply_always_on_bypasses_gate() {
937 let graph = make_dep_graph(&[("bash", vec!["nonexistent"], vec![])]);
938 let filter = make_filter(vec!["bash"], 5);
939 let all_ids = vec!["bash", "grep"];
940 let query_emb = vec![0.5, 0.5, 0.0];
941 let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
942
943 let always_on: HashSet<String> = ["bash".into()].into();
944 graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
945
946 assert!(result.included.contains("bash"));
948 }
949
950 #[test]
958 fn cycle_detection_does_not_clear_ancestor_requires() {
959 let graph = make_dep_graph(&[
960 ("tool_a", vec!["tool_b"], vec![]),
961 ("tool_b", vec!["tool_c"], vec![]),
962 ("tool_c", vec!["tool_d"], vec![]),
963 ("tool_d", vec!["tool_c"], vec![]),
964 ]);
965 assert!(graph.requirements_met("tool_c", &completed(&[])));
967 assert!(graph.requirements_met("tool_d", &completed(&[])));
968 assert!(!graph.requirements_met("tool_a", &completed(&[])));
970 assert!(!graph.requirements_met("tool_b", &completed(&[])));
971 assert!(graph.requirements_met("tool_b", &completed(&["tool_c"])));
973 assert!(graph.requirements_met("tool_a", &completed(&["tool_b"])));
974 }
975
976 #[test]
981 fn name_mentioned_does_not_bypass_hard_gate() {
982 let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
983 let filter = make_filter(vec!["bash"], 5);
984 let all_ids = vec!["bash", "read", "apply_patch"];
986 let query_emb = vec![0.5, 0.5, 0.0];
987 let mut result = filter.filter(&all_ids, &[], "use apply_patch to fix the bug", &query_emb);
988
989 assert!(result.included.contains("apply_patch"));
991 let reason = result
992 .inclusion_reasons
993 .iter()
994 .find(|(id, _)| id == "apply_patch")
995 .map(|(_, r)| r);
996 assert_eq!(reason, Some(&InclusionReason::NameMentioned));
997
998 let always_on: HashSet<String> = ["bash".into()].into();
999 graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
1000
1001 assert!(!result.included.contains("apply_patch"));
1003 assert_eq!(result.dependency_exclusions.len(), 1);
1004 assert_eq!(result.dependency_exclusions[0].tool_id, "apply_patch");
1005 }
1006
1007 #[test]
1015 fn multi_turn_chain_two_steps() {
1016 let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
1018 let always_on: HashSet<String> = ["bash".into()].into();
1019
1020 let filter = ToolSchemaFilter::new(vec!["bash".into()], 5, 5, vec![]);
1022 let all_ids = vec!["bash", "read", "apply_patch"];
1023 let q = vec![0.5, 0.5, 0.0];
1024 let mut result = filter.filter(&all_ids, &[], "fix bug", &q);
1025 graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
1026
1027 assert!(!result.included.contains("apply_patch"));
1029 assert_eq!(result.dependency_exclusions.len(), 1);
1030
1031 let mut result2 = filter.filter(&all_ids, &[], "fix bug", &q);
1033 graph.apply(&mut result2, &completed(&["read"]), 0.15, 0.2, &always_on);
1034
1035 assert!(result2.included.contains("apply_patch"));
1037 assert!(result2.dependency_exclusions.is_empty());
1038 }
1039
1040 #[test]
1043 fn multi_turn_chain_three_steps() {
1044 let graph = make_dep_graph(&[
1045 ("search", vec!["read"], vec![]),
1046 ("apply_patch", vec!["search"], vec![]),
1047 ]);
1048 let always_on: HashSet<String> = ["bash".into()].into();
1049 let filter = ToolSchemaFilter::new(vec!["bash".into()], 5, 5, vec![]);
1050 let all_ids = vec!["bash", "read", "search", "apply_patch"];
1051 let q = vec![0.5, 0.5, 0.0];
1052
1053 let mut r1 = filter.filter(&all_ids, &[], "q", &q);
1055 graph.apply(&mut r1, &completed(&[]), 0.15, 0.2, &always_on);
1056 assert!(r1.included.contains("read"));
1057 assert!(!r1.included.contains("search"));
1058 assert!(!r1.included.contains("apply_patch"));
1059
1060 let mut r2 = filter.filter(&all_ids, &[], "q", &q);
1062 graph.apply(&mut r2, &completed(&["read"]), 0.15, 0.2, &always_on);
1063 assert!(r2.included.contains("search"));
1064 assert!(!r2.included.contains("apply_patch"));
1065
1066 let mut r3 = filter.filter(&all_ids, &[], "q", &q);
1068 graph.apply(
1069 &mut r3,
1070 &completed(&["read", "search"]),
1071 0.15,
1072 0.2,
1073 &always_on,
1074 );
1075 assert!(r3.included.contains("apply_patch"));
1076 assert!(r3.dependency_exclusions.is_empty());
1077 }
1078
1079 #[test]
1081 fn multi_turn_multi_requires_both_must_complete() {
1082 let graph = make_dep_graph(&[("apply_patch", vec!["read", "search"], vec![])]);
1083 let always_on: HashSet<String> = ["bash".into()].into();
1084 let filter = ToolSchemaFilter::new(vec!["bash".into()], 5, 5, vec![]);
1085 let all_ids = vec!["bash", "read", "search", "apply_patch"];
1086 let q = vec![0.5, 0.5, 0.0];
1087
1088 let mut r1 = filter.filter(&all_ids, &[], "q", &q);
1090 graph.apply(&mut r1, &completed(&["read"]), 0.15, 0.2, &always_on);
1091 assert!(!r1.included.contains("apply_patch"));
1092 let excl = &r1.dependency_exclusions[0];
1093 assert_eq!(excl.unmet_requires, vec!["search"]);
1094
1095 let mut r2 = filter.filter(&all_ids, &[], "q", &q);
1097 graph.apply(
1098 &mut r2,
1099 &completed(&["read", "search"]),
1100 0.15,
1101 0.2,
1102 &always_on,
1103 );
1104 assert!(r2.included.contains("apply_patch"));
1105 assert!(r2.dependency_exclusions.is_empty());
1106 }
1107
1108 #[test]
1114 fn multi_turn_preference_boost_accumulates() {
1115 let graph = make_dep_graph(&[("format", vec![], vec!["search", "grep"])]);
1117 let always_on: HashSet<String> = HashSet::new();
1118 let filter = ToolSchemaFilter::new(
1120 vec![],
1121 5,
1122 5,
1123 vec![
1124 ToolEmbedding {
1125 tool_id: "format".into(),
1126 embedding: vec![0.6, 0.4, 0.0],
1127 },
1128 ToolEmbedding {
1129 tool_id: "search".into(),
1130 embedding: vec![0.7, 0.3, 0.0],
1131 },
1132 ToolEmbedding {
1133 tool_id: "grep".into(),
1134 embedding: vec![0.8, 0.2, 0.0],
1135 },
1136 ],
1137 );
1138 let all_ids = vec!["format", "search", "grep"];
1139 let q = vec![0.5, 0.5, 0.0];
1140 let boost_per = 0.15_f32;
1141 let max_boost = 0.3_f32;
1142
1143 let score_of = |result: &ToolFilterResult, id: &str| -> f32 {
1144 result
1145 .scores
1146 .iter()
1147 .find(|(tid, _)| tid == id)
1148 .map(|(_, s)| *s)
1149 .unwrap_or(0.0)
1150 };
1151
1152 let mut r1 = filter.filter(&all_ids, &[], "q", &q);
1154 let base_score = score_of(&r1, "format");
1155 graph.apply(&mut r1, &completed(&[]), boost_per, max_boost, &always_on);
1156 assert!((score_of(&r1, "format") - base_score).abs() < 1e-5);
1157
1158 let mut r2 = filter.filter(&all_ids, &[], "q", &q);
1160 graph.apply(
1161 &mut r2,
1162 &completed(&["search"]),
1163 boost_per,
1164 max_boost,
1165 &always_on,
1166 );
1167 let delta2 = score_of(&r2, "format") - base_score;
1168 assert!(
1169 (delta2 - 0.15).abs() < 1e-4,
1170 "expected +0.15 boost, got {delta2}"
1171 );
1172
1173 let mut r3 = filter.filter(&all_ids, &[], "q", &q);
1175 graph.apply(
1176 &mut r3,
1177 &completed(&["search", "grep"]),
1178 boost_per,
1179 max_boost,
1180 &always_on,
1181 );
1182 let delta3 = score_of(&r3, "format") - base_score;
1183 assert!(
1184 (delta3 - 0.30).abs() < 1e-4,
1185 "expected +0.30 boost, got {delta3}"
1186 );
1187 }
1188
1189 #[test]
1192 fn filter_tool_names_multi_turn_unlocks_after_completion() {
1193 let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
1194 let always_on: HashSet<String> = ["bash".into()].into();
1195 let all_names = vec!["bash", "read", "apply_patch"];
1196
1197 let filtered_before = graph.filter_tool_names(&all_names, &completed(&[]), &always_on);
1199 assert!(filtered_before.contains(&"bash")); assert!(filtered_before.contains(&"read")); assert!(!filtered_before.contains(&"apply_patch")); let filtered_after = graph.filter_tool_names(&all_names, &completed(&["read"]), &always_on);
1205 assert!(filtered_after.contains(&"bash"));
1206 assert!(filtered_after.contains(&"read"));
1207 assert!(filtered_after.contains(&"apply_patch")); }
1209
1210 #[test]
1213 fn filter_tool_names_deadlock_fallback_passes_all() {
1214 let graph = make_dep_graph(&[("only_tool", vec!["missing"], vec![])]);
1216 let always_on: HashSet<String> = ["bash".into()].into();
1217 let all_names = vec!["bash", "only_tool"];
1218
1219 let filtered = graph.filter_tool_names(&all_names, &completed(&[]), &always_on);
1220
1221 assert!(filtered.contains(&"bash"));
1226 assert!(!filtered.contains(&"only_tool"));
1227 }
1228}