1use std::collections::{HashMap, HashSet};
11
12use zeph_common::ToolName;
13use zeph_common::math::cosine_similarity;
14
15use crate::config::ToolDependency;
16
17#[derive(Debug, Clone)]
19pub struct ToolEmbedding {
20 pub tool_id: ToolName,
21 pub embedding: Vec<f32>,
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum InclusionReason {
27 AlwaysOn,
29 NameMentioned,
31 SimilarityRank,
33 ShortDescription,
35 NoEmbedding,
37 DependencyMet,
39 PreferenceBoost,
41}
42
43#[derive(Debug, Clone, PartialEq, Eq)]
45pub struct DependencyExclusion {
46 pub tool_id: ToolName,
47 pub unmet_requires: Vec<String>,
49}
50
51#[derive(Debug, Clone)]
53pub struct ToolFilterResult {
54 pub included: HashSet<String>,
56 pub excluded: Vec<String>,
58 pub scores: Vec<(String, f32)>,
60 pub inclusion_reasons: Vec<(String, InclusionReason)>,
62 pub dependency_exclusions: Vec<DependencyExclusion>,
64}
65
66#[derive(Debug, Clone, Default)]
79pub struct ToolDependencyGraph {
80 deps: HashMap<String, ToolDependency>,
83}
84
85impl ToolDependencyGraph {
86 #[must_use]
91 pub fn new(deps: HashMap<String, ToolDependency>) -> Self {
92 if deps.is_empty() {
93 return Self { deps };
94 }
95 let cycled = detect_cycles(&deps);
96 if !cycled.is_empty() {
97 tracing::warn!(
98 tools = ?cycled,
99 "tool dependency graph: cycles detected, removing requires for cycle participants"
100 );
101 }
102 let mut resolved = deps;
103 for tool_id in &cycled {
104 if let Some(dep) = resolved.get_mut(tool_id) {
105 dep.requires.clear();
106 }
107 }
108 Self { deps: resolved }
109 }
110
111 #[must_use]
113 pub fn is_empty(&self) -> bool {
114 self.deps.is_empty()
115 }
116
117 #[must_use]
122 pub fn requirements_met(&self, tool_id: &str, completed: &HashSet<String>) -> bool {
123 self.deps
124 .get(tool_id)
125 .is_none_or(|d| d.requires.iter().all(|r| completed.contains(r)))
126 }
127
128 #[must_use]
130 pub fn unmet_requires<'a>(
131 &'a self,
132 tool_id: &str,
133 completed: &HashSet<String>,
134 ) -> Vec<&'a str> {
135 self.deps.get(tool_id).map_or_else(Vec::new, |d| {
136 d.requires
137 .iter()
138 .filter(|r| !completed.contains(r.as_str()))
139 .map(String::as_str)
140 .collect()
141 })
142 }
143
144 #[must_use]
148 pub fn preference_boost(
149 &self,
150 tool_id: &str,
151 completed: &HashSet<String>,
152 boost_per_dep: f32,
153 max_total_boost: f32,
154 ) -> f32 {
155 self.deps.get(tool_id).map_or(0.0, |d| {
156 let met = d
157 .prefers
158 .iter()
159 .filter(|p| completed.contains(p.as_str()))
160 .count();
161 #[allow(clippy::cast_precision_loss)]
162 let boost = met as f32 * boost_per_dep;
163 boost.min(max_total_boost)
164 })
165 }
166
167 pub fn apply(
181 &self,
182 result: &mut ToolFilterResult,
183 completed: &HashSet<String>,
184 boost_per_dep: f32,
185 max_total_boost: f32,
186 always_on: &HashSet<String>,
187 ) {
188 if self.deps.is_empty() {
189 return;
190 }
191
192 let bypassed: HashSet<&str> = result
196 .inclusion_reasons
197 .iter()
198 .filter(|(_, r)| matches!(r, InclusionReason::AlwaysOn))
199 .map(|(id, _)| id.as_str())
200 .collect();
201
202 let mut to_exclude: Vec<DependencyExclusion> = Vec::new();
203 for tool_id in &result.included {
204 if bypassed.contains(tool_id.as_str()) {
205 continue;
206 }
207 let unmet: Vec<String> = self
208 .unmet_requires(tool_id, completed)
209 .into_iter()
210 .map(str::to_owned)
211 .collect();
212 if !unmet.is_empty() {
213 to_exclude.push(DependencyExclusion {
214 tool_id: tool_id.as_str().into(),
215 unmet_requires: unmet,
216 });
217 }
218 }
219
220 let non_always_on_included: usize = result
223 .included
224 .iter()
225 .filter(|id| !always_on.contains(id.as_str()))
226 .count();
227 if !to_exclude.is_empty() && to_exclude.len() >= non_always_on_included {
228 tracing::warn!(
229 gated = to_exclude.len(),
230 non_always_on = non_always_on_included,
231 "tool dependency graph: all non-always-on tools would be blocked; \
232 disabling hard gates for this turn"
233 );
234 to_exclude.clear();
235 }
236
237 for excl in &to_exclude {
239 result.included.remove(excl.tool_id.as_str());
240 result.excluded.push(excl.tool_id.to_string());
241 tracing::debug!(
242 tool_id = %excl.tool_id,
243 unmet = ?excl.unmet_requires,
244 "tool dependency gate: excluded (requires not met)"
245 );
246 }
247 result.dependency_exclusions = to_exclude;
248
249 for (tool_id, score) in &mut result.scores {
251 if !result.included.contains(tool_id) {
252 continue;
253 }
254 let boost = self.preference_boost(tool_id, completed, boost_per_dep, max_total_boost);
255 if boost > 0.0 {
256 *score += boost;
257 let already_recorded = result.inclusion_reasons.iter().any(|(id, _)| id == tool_id);
259 if !already_recorded {
260 result
261 .inclusion_reasons
262 .push((tool_id.clone(), InclusionReason::PreferenceBoost));
263 }
264 tracing::debug!(
265 tool_id = %tool_id,
266 boost,
267 "tool dependency: preference boost applied"
268 );
269 }
270 }
271 result
273 .scores
274 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
275 }
276
277 #[must_use]
282 pub fn filter_tool_names<'a>(
283 &self,
284 names: &[&'a str],
285 completed: &HashSet<String>,
286 always_on: &HashSet<String>,
287 ) -> Vec<&'a str> {
288 names
289 .iter()
290 .copied()
291 .filter(|n| always_on.contains(*n) || self.requirements_met(n, completed))
292 .collect()
293 }
294}
295
296fn detect_cycles(deps: &HashMap<String, ToolDependency>) -> HashSet<String> {
303 #[derive(Clone, Copy, PartialEq)]
304 enum State {
305 Unvisited,
306 InProgress,
307 Done,
308 }
309
310 let mut state: HashMap<&str, State> = HashMap::new();
311 let mut cycled: HashSet<String> = HashSet::new();
312
313 for start in deps.keys() {
314 if state
315 .get(start.as_str())
316 .copied()
317 .unwrap_or(State::Unvisited)
318 != State::Unvisited
319 {
320 continue;
321 }
322 let mut stack: Vec<(&str, usize)> = vec![(start.as_str(), 0)];
323 state.insert(start.as_str(), State::InProgress);
324
325 while let Some((node, child_idx)) = stack.last_mut() {
326 let node = *node;
327 let requires = deps
328 .get(node)
329 .map_or(&[] as &[String], |d| d.requires.as_slice());
330
331 if *child_idx >= requires.len() {
332 state.insert(node, State::Done);
333 stack.pop();
334 continue;
335 }
336
337 let child = requires[*child_idx].as_str();
338 *child_idx += 1;
339
340 match state.get(child).copied().unwrap_or(State::Unvisited) {
341 State::InProgress => {
342 let cycle_start = stack.iter().position(|(n, _)| *n == child);
346 if let Some(start) = cycle_start {
347 for (path_node, _) in &stack[start..] {
348 cycled.insert((*path_node).to_owned());
349 }
350 }
351 cycled.insert(child.to_owned());
352 }
353 State::Unvisited => {
354 state.insert(child, State::InProgress);
355 stack.push((child, 0));
356 }
357 State::Done => {}
358 }
359 }
360 }
361
362 cycled
363}
364
365pub struct ToolSchemaFilter {
367 always_on: HashSet<String>,
368 top_k: usize,
369 min_description_words: usize,
370 embeddings: Vec<ToolEmbedding>,
371 version: u64,
372}
373
374impl ToolSchemaFilter {
375 #[must_use]
377 pub fn new(
378 always_on: Vec<String>,
379 top_k: usize,
380 min_description_words: usize,
381 embeddings: Vec<ToolEmbedding>,
382 ) -> Self {
383 Self {
384 always_on: always_on.into_iter().collect(),
385 top_k,
386 min_description_words,
387 embeddings,
388 version: 0,
389 }
390 }
391
392 #[must_use]
394 pub fn version(&self) -> u64 {
395 self.version
396 }
397
398 #[must_use]
400 pub fn embedding_count(&self) -> usize {
401 self.embeddings.len()
402 }
403
404 #[must_use]
406 pub fn top_k(&self) -> usize {
407 self.top_k
408 }
409
410 #[must_use]
412 pub fn always_on_count(&self) -> usize {
413 self.always_on.len()
414 }
415
416 pub fn recompute(&mut self, embeddings: Vec<ToolEmbedding>) {
418 self.embeddings = embeddings;
419 self.version += 1;
420 }
421
422 #[must_use]
428 pub fn filter(
429 &self,
430 all_tool_ids: &[&str],
431 tool_descriptions: &[(&str, &str)],
432 query: &str,
433 query_embedding: &[f32],
434 ) -> ToolFilterResult {
435 let mut included = HashSet::new();
436 let mut inclusion_reasons = Vec::new();
437
438 for id in all_tool_ids {
440 if self.always_on.contains(*id) {
441 included.insert((*id).to_owned());
442 inclusion_reasons.push(((*id).to_owned(), InclusionReason::AlwaysOn));
443 }
444 }
445
446 let mentioned = find_mentioned_tool_ids(query, all_tool_ids);
448 for id in &mentioned {
449 if included.insert(id.clone()) {
450 inclusion_reasons.push((id.clone(), InclusionReason::NameMentioned));
451 }
452 }
453
454 for &(id, desc) in tool_descriptions {
456 let word_count = desc.split_whitespace().count();
457 if word_count < self.min_description_words && included.insert(id.to_owned()) {
458 inclusion_reasons.push((id.to_owned(), InclusionReason::ShortDescription));
459 }
460 }
461
462 let mut scores: Vec<(String, f32)> = self
464 .embeddings
465 .iter()
466 .filter(|e| !included.contains(e.tool_id.as_str()))
467 .map(|e| {
468 let score = cosine_similarity(query_embedding, &e.embedding);
469 (e.tool_id.to_string(), score)
470 })
471 .collect();
472
473 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
474
475 let take = if self.top_k == 0 {
476 scores.len()
477 } else {
478 self.top_k.min(scores.len())
479 };
480
481 for (id, _score) in scores.iter().take(take) {
482 if included.insert(id.clone()) {
483 inclusion_reasons.push((id.clone(), InclusionReason::SimilarityRank));
484 }
485 }
486
487 let embedded_ids: HashSet<&str> =
489 self.embeddings.iter().map(|e| e.tool_id.as_str()).collect();
490 for id in all_tool_ids {
491 if !included.contains(*id) && !embedded_ids.contains(*id) {
492 included.insert((*id).to_owned());
493 inclusion_reasons.push(((*id).to_owned(), InclusionReason::NoEmbedding));
494 }
495 }
496
497 let excluded: Vec<String> = all_tool_ids
499 .iter()
500 .filter(|id| !included.contains(**id))
501 .map(|id| (*id).to_owned())
502 .collect();
503
504 ToolFilterResult {
505 included,
506 excluded,
507 scores,
508 inclusion_reasons,
509 dependency_exclusions: Vec::new(),
510 }
511 }
512}
513
514#[must_use]
519pub fn find_mentioned_tool_ids(query: &str, all_tool_ids: &[&str]) -> Vec<String> {
520 let query_lower = query.to_lowercase();
521 all_tool_ids
522 .iter()
523 .filter(|id| {
524 let id_lower = id.to_lowercase();
525 let mut start = 0;
526 while let Some(pos) = query_lower[start..].find(&id_lower) {
527 let abs_pos = start + pos;
528 let end_pos = abs_pos + id_lower.len();
529 let before_ok = abs_pos == 0
530 || !query_lower.as_bytes()[abs_pos - 1].is_ascii_alphanumeric()
531 && query_lower.as_bytes()[abs_pos - 1] != b'_';
532 let after_ok = end_pos >= query_lower.len()
533 || !query_lower.as_bytes()[end_pos].is_ascii_alphanumeric()
534 && query_lower.as_bytes()[end_pos] != b'_';
535 if before_ok && after_ok {
536 return true;
537 }
538 start = abs_pos + 1;
539 }
540 false
541 })
542 .map(|id| (*id).to_owned())
543 .collect()
544}
545
546#[cfg(test)]
547mod tests {
548 use super::*;
549
550 fn make_filter(always_on: Vec<&str>, top_k: usize) -> ToolSchemaFilter {
551 ToolSchemaFilter::new(
552 always_on.into_iter().map(String::from).collect(),
553 top_k,
554 5,
555 vec![
556 ToolEmbedding {
557 tool_id: "grep".into(),
558 embedding: vec![0.9, 0.1, 0.0],
559 },
560 ToolEmbedding {
561 tool_id: "write".into(),
562 embedding: vec![0.1, 0.9, 0.0],
563 },
564 ToolEmbedding {
565 tool_id: "find_path".into(),
566 embedding: vec![0.5, 0.5, 0.0],
567 },
568 ToolEmbedding {
569 tool_id: "web_scrape".into(),
570 embedding: vec![0.0, 0.0, 1.0],
571 },
572 ToolEmbedding {
573 tool_id: "diagnostics".into(),
574 embedding: vec![0.0, 0.1, 0.9],
575 },
576 ],
577 )
578 }
579
580 #[test]
581 fn top_k_ranking_selects_most_similar() {
582 let filter = make_filter(vec!["bash"], 2);
583 let all_ids: Vec<&str> = vec![
584 "bash",
585 "grep",
586 "write",
587 "find_path",
588 "web_scrape",
589 "diagnostics",
590 ];
591 let query_emb = vec![0.8, 0.2, 0.0]; let result = filter.filter(&all_ids, &[], "search for pattern", &query_emb);
593
594 assert!(result.included.contains("bash")); assert!(result.included.contains("grep")); assert!(result.included.contains("find_path")); assert!(!result.included.contains("web_scrape"));
599 assert!(!result.included.contains("diagnostics"));
600 }
601
602 #[test]
603 fn always_on_tools_always_included() {
604 let filter = make_filter(vec!["bash", "read"], 1);
605 let all_ids: Vec<&str> = vec!["bash", "read", "grep", "write"];
606 let query_emb = vec![0.0, 1.0, 0.0]; let result = filter.filter(&all_ids, &[], "test query", &query_emb);
608
609 assert!(result.included.contains("bash"));
610 assert!(result.included.contains("read"));
611 assert!(result.included.contains("write")); assert!(!result.included.contains("grep"));
613 }
614
615 #[test]
616 fn name_mention_force_includes() {
617 let filter = make_filter(vec!["bash"], 1);
618 let all_ids: Vec<&str> = vec!["bash", "grep", "web_scrape", "write"];
619 let query_emb = vec![0.0, 1.0, 0.0]; let result = filter.filter(&all_ids, &[], "use web_scrape to fetch", &query_emb);
621
622 assert!(result.included.contains("web_scrape")); assert!(result.included.contains("write")); assert!(result.included.contains("bash")); }
626
627 #[test]
628 fn short_mcp_description_auto_included() {
629 let filter = make_filter(vec!["bash"], 1);
630 let all_ids: Vec<&str> = vec!["bash", "grep", "mcp_query"];
631 let descriptions: Vec<(&str, &str)> = vec![
632 ("mcp_query", "Run query"),
633 ("grep", "Search file contents recursively"),
634 ];
635 let query_emb = vec![0.9, 0.1, 0.0];
636 let result = filter.filter(&all_ids, &descriptions, "test", &query_emb);
637
638 assert!(result.included.contains("mcp_query")); }
640
641 #[test]
642 fn empty_embeddings_includes_all_via_no_embedding_fallback() {
643 let filter = ToolSchemaFilter::new(vec!["bash".into()], 6, 5, vec![]);
644 let all_ids: Vec<&str> = vec!["bash", "grep", "write"];
645 let query_emb = vec![0.5, 0.5, 0.0];
646 let result = filter.filter(&all_ids, &[], "test", &query_emb);
647
648 assert!(result.included.contains("bash"));
650 assert!(result.included.contains("grep"));
651 assert!(result.included.contains("write"));
652 assert!(result.excluded.is_empty());
653 }
654
655 #[test]
656 fn top_k_zero_includes_all_filterable() {
657 let filter = make_filter(vec!["bash"], 0);
658 let all_ids: Vec<&str> = vec![
659 "bash",
660 "grep",
661 "write",
662 "find_path",
663 "web_scrape",
664 "diagnostics",
665 ];
666 let query_emb = vec![0.1, 0.1, 0.1];
667 let result = filter.filter(&all_ids, &[], "test", &query_emb);
668
669 assert_eq!(result.included.len(), 6); assert!(result.excluded.is_empty());
671 }
672
673 #[test]
674 fn top_k_exceeds_filterable_count_includes_all() {
675 let filter = make_filter(vec!["bash"], 100);
676 let all_ids: Vec<&str> = vec![
677 "bash",
678 "grep",
679 "write",
680 "find_path",
681 "web_scrape",
682 "diagnostics",
683 ];
684 let query_emb = vec![0.1, 0.1, 0.1];
685 let result = filter.filter(&all_ids, &[], "test", &query_emb);
686
687 assert_eq!(result.included.len(), 6);
688 }
689
690 #[test]
691 fn accessors_return_configured_values() {
692 let filter = make_filter(vec!["bash", "read"], 7);
693 assert_eq!(filter.top_k(), 7);
694 assert_eq!(filter.always_on_count(), 2);
695 assert_eq!(filter.embedding_count(), 5);
696 }
697
698 #[test]
699 fn version_counter_incremented_on_recompute() {
700 let mut filter = make_filter(vec![], 3);
701 assert_eq!(filter.version(), 0);
702 filter.recompute(vec![]);
703 assert_eq!(filter.version(), 1);
704 filter.recompute(vec![]);
705 assert_eq!(filter.version(), 2);
706 }
707
708 #[test]
709 fn inclusion_reason_correctness() {
710 let filter = make_filter(vec!["bash"], 1);
711 let all_ids: Vec<&str> = vec!["bash", "grep", "web_scrape", "write"];
712 let descriptions: Vec<(&str, &str)> = vec![("web_scrape", "Scrape")]; let query_emb = vec![0.1, 0.9, 0.0]; let result = filter.filter(&all_ids, &descriptions, "test query", &query_emb);
715
716 let reasons: std::collections::HashMap<String, InclusionReason> =
717 result.inclusion_reasons.into_iter().collect();
718 assert_eq!(reasons.get("bash"), Some(&InclusionReason::AlwaysOn));
719 assert_eq!(
720 reasons.get("web_scrape"),
721 Some(&InclusionReason::ShortDescription)
722 );
723 assert_eq!(reasons.get("write"), Some(&InclusionReason::SimilarityRank));
724 }
725
726 #[test]
727 fn cosine_similarity_identical_vectors() {
728 let v = vec![1.0, 2.0, 3.0];
729 let sim = cosine_similarity(&v, &v);
730 assert!((sim - 1.0).abs() < 1e-5);
731 }
732
733 #[test]
734 fn cosine_similarity_orthogonal_vectors() {
735 let a = vec![1.0, 0.0];
736 let b = vec![0.0, 1.0];
737 let sim = cosine_similarity(&a, &b);
738 assert!(sim.abs() < 1e-5);
739 }
740
741 #[test]
742 fn cosine_similarity_empty_returns_zero() {
743 assert!(cosine_similarity(&[], &[]) < f32::EPSILON);
744 }
745
746 #[test]
747 fn cosine_similarity_mismatched_length_returns_zero() {
748 assert!(cosine_similarity(&[1.0], &[1.0, 2.0]) < f32::EPSILON);
749 }
750
751 #[test]
752 fn find_mentioned_tool_ids_case_insensitive() {
753 let ids = vec!["web_scrape", "grep", "Bash"];
754 let found = find_mentioned_tool_ids("use WEB_SCRAPE and BASH", &ids);
755 assert!(found.contains(&"web_scrape".to_owned()));
756 assert!(found.contains(&"Bash".to_owned()));
757 assert!(!found.contains(&"grep".to_owned()));
758 }
759
760 #[test]
761 fn find_mentioned_tool_ids_word_boundary_no_false_positives() {
762 let ids = vec!["read", "edit", "fetch", "grep"];
763 let found = find_mentioned_tool_ids("thread breadcrumb", &ids);
765 assert!(found.is_empty());
766 }
767
768 #[test]
769 fn find_mentioned_tool_ids_word_boundary_matches_standalone() {
770 let ids = vec!["read", "edit"];
771 let found = find_mentioned_tool_ids("please read and edit the file", &ids);
772 assert!(found.contains(&"read".to_owned()));
773 assert!(found.contains(&"edit".to_owned()));
774 }
775
776 fn make_dep_graph(rules: &[(&str, Vec<&str>, Vec<&str>)]) -> ToolDependencyGraph {
779 let deps = rules
780 .iter()
781 .map(|(id, requires, prefers)| {
782 (
783 (*id).to_owned(),
784 crate::config::ToolDependency {
785 requires: requires.iter().map(|s| (*s).to_owned()).collect(),
786 prefers: prefers.iter().map(|s| (*s).to_owned()).collect(),
787 },
788 )
789 })
790 .collect();
791 ToolDependencyGraph::new(deps)
792 }
793
794 fn completed(ids: &[&str]) -> HashSet<String> {
795 ids.iter().map(|s| (*s).to_owned()).collect()
796 }
797
798 #[test]
799 fn requirements_met_no_deps() {
800 let graph = make_dep_graph(&[]);
801 assert!(graph.requirements_met("any_tool", &completed(&[])));
802 }
803
804 #[test]
805 fn requirements_met_all_satisfied() {
806 let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
807 assert!(graph.requirements_met("apply_patch", &completed(&["read"])));
808 }
809
810 #[test]
811 fn requirements_met_unmet() {
812 let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
813 assert!(!graph.requirements_met("apply_patch", &completed(&[])));
814 }
815
816 #[test]
817 fn requirements_met_unconfigured_tool() {
818 let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
819 assert!(graph.requirements_met("grep", &completed(&[])));
821 }
822
823 #[test]
824 fn preference_boost_none_met() {
825 let graph = make_dep_graph(&[("format", vec![], vec!["search", "grep"])]);
826 let boost = graph.preference_boost("format", &completed(&[]), 0.15, 0.2);
827 assert!(boost < f32::EPSILON);
828 }
829
830 #[test]
831 fn preference_boost_partial() {
832 let graph = make_dep_graph(&[("format", vec![], vec!["search", "grep"])]);
833 let boost = graph.preference_boost("format", &completed(&["search"]), 0.15, 0.2);
834 assert!((boost - 0.15).abs() < 1e-5);
835 }
836
837 #[test]
838 fn preference_boost_capped_at_max() {
839 let graph = make_dep_graph(&[("format", vec![], vec!["a", "b", "c"])]);
841 let boost = graph.preference_boost("format", &completed(&["a", "b", "c"]), 0.15, 0.2);
842 assert!((boost - 0.2).abs() < 1e-5);
843 }
844
845 #[test]
846 fn cycle_detection_simple_cycle() {
847 let graph = make_dep_graph(&[
849 ("tool_a", vec!["tool_b"], vec![]),
850 ("tool_b", vec!["tool_a"], vec![]),
851 ]);
852 assert!(graph.requirements_met("tool_a", &completed(&[])));
854 assert!(graph.requirements_met("tool_b", &completed(&[])));
855 }
856
857 #[test]
858 fn cycle_detection_does_not_affect_non_cycle_tools() {
859 let graph = make_dep_graph(&[
861 ("tool_a", vec!["tool_b"], vec![]),
862 ("tool_b", vec!["tool_c"], vec![]),
863 ("tool_c", vec!["tool_d"], vec![]),
864 ("tool_d", vec!["tool_c"], vec![]), ]);
866 assert!(graph.requirements_met("tool_c", &completed(&[])));
868 assert!(graph.requirements_met("tool_d", &completed(&[])));
869 assert!(!graph.requirements_met("tool_a", &completed(&[])));
871 assert!(!graph.requirements_met("tool_b", &completed(&[])));
872 }
873
874 #[test]
875 fn apply_excludes_gated_tool() {
876 let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
877 let filter = make_filter(vec!["bash"], 5);
878 let all_ids = vec!["bash", "read", "apply_patch", "grep"];
879 let query_emb = vec![0.5, 0.5, 0.0];
880 let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
881 result.included.insert("apply_patch".into());
883
884 let always_on: HashSet<String> = ["bash".into()].into();
885 graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
886
887 assert!(!result.included.contains("apply_patch"));
888 assert_eq!(result.dependency_exclusions.len(), 1);
889 assert_eq!(result.dependency_exclusions[0].tool_id, "apply_patch");
890 assert_eq!(result.dependency_exclusions[0].unmet_requires, vec!["read"]);
891 }
892
893 #[test]
894 fn apply_includes_gated_tool_when_dep_met() {
895 let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
896 let filter = make_filter(vec!["bash"], 5);
897 let all_ids = vec!["bash", "read", "apply_patch"];
898 let query_emb = vec![0.5, 0.5, 0.0];
899 let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
900 result.included.insert("apply_patch".into());
901
902 let always_on: HashSet<String> = ["bash".into()].into();
903 graph.apply(&mut result, &completed(&["read"]), 0.15, 0.2, &always_on);
904
905 assert!(result.included.contains("apply_patch"));
906 assert!(result.dependency_exclusions.is_empty());
907 }
908
909 #[test]
910 fn apply_deadlock_fallback_when_all_gated() {
911 let filter = ToolSchemaFilter::new(
914 vec!["bash".into()],
915 5,
916 5,
917 vec![], );
919 let graph = make_dep_graph(&[("only_tool", vec!["missing"], vec![])]);
920 let all_ids = vec!["bash", "only_tool"];
921 let query_emb = vec![0.5, 0.5, 0.0];
922 let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
923
924 assert!(result.included.contains("only_tool"));
926 assert!(result.included.contains("bash"));
927
928 let always_on: HashSet<String> = ["bash".into()].into();
929 graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
930
931 assert!(result.included.contains("only_tool"));
933 assert!(result.dependency_exclusions.is_empty());
934 }
935
936 #[test]
937 fn apply_always_on_bypasses_gate() {
938 let graph = make_dep_graph(&[("bash", vec!["nonexistent"], vec![])]);
939 let filter = make_filter(vec!["bash"], 5);
940 let all_ids = vec!["bash", "grep"];
941 let query_emb = vec![0.5, 0.5, 0.0];
942 let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
943
944 let always_on: HashSet<String> = ["bash".into()].into();
945 graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
946
947 assert!(result.included.contains("bash"));
949 }
950
951 #[test]
959 fn cycle_detection_does_not_clear_ancestor_requires() {
960 let graph = make_dep_graph(&[
961 ("tool_a", vec!["tool_b"], vec![]),
962 ("tool_b", vec!["tool_c"], vec![]),
963 ("tool_c", vec!["tool_d"], vec![]),
964 ("tool_d", vec!["tool_c"], vec![]),
965 ]);
966 assert!(graph.requirements_met("tool_c", &completed(&[])));
968 assert!(graph.requirements_met("tool_d", &completed(&[])));
969 assert!(!graph.requirements_met("tool_a", &completed(&[])));
971 assert!(!graph.requirements_met("tool_b", &completed(&[])));
972 assert!(graph.requirements_met("tool_b", &completed(&["tool_c"])));
974 assert!(graph.requirements_met("tool_a", &completed(&["tool_b"])));
975 }
976
977 #[test]
982 fn name_mentioned_does_not_bypass_hard_gate() {
983 let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
984 let filter = make_filter(vec!["bash"], 5);
985 let all_ids = vec!["bash", "read", "apply_patch"];
987 let query_emb = vec![0.5, 0.5, 0.0];
988 let mut result = filter.filter(&all_ids, &[], "use apply_patch to fix the bug", &query_emb);
989
990 assert!(result.included.contains("apply_patch"));
992 let reason = result
993 .inclusion_reasons
994 .iter()
995 .find(|(id, _)| id == "apply_patch")
996 .map(|(_, r)| r);
997 assert_eq!(reason, Some(&InclusionReason::NameMentioned));
998
999 let always_on: HashSet<String> = ["bash".into()].into();
1000 graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
1001
1002 assert!(!result.included.contains("apply_patch"));
1004 assert_eq!(result.dependency_exclusions.len(), 1);
1005 assert_eq!(result.dependency_exclusions[0].tool_id, "apply_patch");
1006 }
1007
1008 #[test]
1016 fn multi_turn_chain_two_steps() {
1017 let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
1019 let always_on: HashSet<String> = ["bash".into()].into();
1020
1021 let filter = ToolSchemaFilter::new(vec!["bash".into()], 5, 5, vec![]);
1023 let all_ids = vec!["bash", "read", "apply_patch"];
1024 let q = vec![0.5, 0.5, 0.0];
1025 let mut result = filter.filter(&all_ids, &[], "fix bug", &q);
1026 graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
1027
1028 assert!(!result.included.contains("apply_patch"));
1030 assert_eq!(result.dependency_exclusions.len(), 1);
1031
1032 let mut result2 = filter.filter(&all_ids, &[], "fix bug", &q);
1034 graph.apply(&mut result2, &completed(&["read"]), 0.15, 0.2, &always_on);
1035
1036 assert!(result2.included.contains("apply_patch"));
1038 assert!(result2.dependency_exclusions.is_empty());
1039 }
1040
1041 #[test]
1044 fn multi_turn_chain_three_steps() {
1045 let graph = make_dep_graph(&[
1046 ("search", vec!["read"], vec![]),
1047 ("apply_patch", vec!["search"], vec![]),
1048 ]);
1049 let always_on: HashSet<String> = ["bash".into()].into();
1050 let filter = ToolSchemaFilter::new(vec!["bash".into()], 5, 5, vec![]);
1051 let all_ids = vec!["bash", "read", "search", "apply_patch"];
1052 let q = vec![0.5, 0.5, 0.0];
1053
1054 let mut r1 = filter.filter(&all_ids, &[], "q", &q);
1056 graph.apply(&mut r1, &completed(&[]), 0.15, 0.2, &always_on);
1057 assert!(r1.included.contains("read"));
1058 assert!(!r1.included.contains("search"));
1059 assert!(!r1.included.contains("apply_patch"));
1060
1061 let mut r2 = filter.filter(&all_ids, &[], "q", &q);
1063 graph.apply(&mut r2, &completed(&["read"]), 0.15, 0.2, &always_on);
1064 assert!(r2.included.contains("search"));
1065 assert!(!r2.included.contains("apply_patch"));
1066
1067 let mut r3 = filter.filter(&all_ids, &[], "q", &q);
1069 graph.apply(
1070 &mut r3,
1071 &completed(&["read", "search"]),
1072 0.15,
1073 0.2,
1074 &always_on,
1075 );
1076 assert!(r3.included.contains("apply_patch"));
1077 assert!(r3.dependency_exclusions.is_empty());
1078 }
1079
1080 #[test]
1082 fn multi_turn_multi_requires_both_must_complete() {
1083 let graph = make_dep_graph(&[("apply_patch", vec!["read", "search"], vec![])]);
1084 let always_on: HashSet<String> = ["bash".into()].into();
1085 let filter = ToolSchemaFilter::new(vec!["bash".into()], 5, 5, vec![]);
1086 let all_ids = vec!["bash", "read", "search", "apply_patch"];
1087 let q = vec![0.5, 0.5, 0.0];
1088
1089 let mut r1 = filter.filter(&all_ids, &[], "q", &q);
1091 graph.apply(&mut r1, &completed(&["read"]), 0.15, 0.2, &always_on);
1092 assert!(!r1.included.contains("apply_patch"));
1093 let excl = &r1.dependency_exclusions[0];
1094 assert_eq!(excl.unmet_requires, vec!["search"]);
1095
1096 let mut r2 = filter.filter(&all_ids, &[], "q", &q);
1098 graph.apply(
1099 &mut r2,
1100 &completed(&["read", "search"]),
1101 0.15,
1102 0.2,
1103 &always_on,
1104 );
1105 assert!(r2.included.contains("apply_patch"));
1106 assert!(r2.dependency_exclusions.is_empty());
1107 }
1108
1109 #[test]
1115 fn multi_turn_preference_boost_accumulates() {
1116 let graph = make_dep_graph(&[("format", vec![], vec!["search", "grep"])]);
1118 let always_on: HashSet<String> = HashSet::new();
1119 let filter = ToolSchemaFilter::new(
1121 vec![],
1122 5,
1123 5,
1124 vec![
1125 ToolEmbedding {
1126 tool_id: "format".into(),
1127 embedding: vec![0.6, 0.4, 0.0],
1128 },
1129 ToolEmbedding {
1130 tool_id: "search".into(),
1131 embedding: vec![0.7, 0.3, 0.0],
1132 },
1133 ToolEmbedding {
1134 tool_id: "grep".into(),
1135 embedding: vec![0.8, 0.2, 0.0],
1136 },
1137 ],
1138 );
1139 let all_ids = vec!["format", "search", "grep"];
1140 let q = vec![0.5, 0.5, 0.0];
1141 let boost_per = 0.15_f32;
1142 let max_boost = 0.3_f32;
1143
1144 let score_of = |result: &ToolFilterResult, id: &str| -> f32 {
1145 result
1146 .scores
1147 .iter()
1148 .find(|(tid, _)| tid == id)
1149 .map_or(0.0, |(_, s)| *s)
1150 };
1151
1152 let mut r1 = filter.filter(&all_ids, &[], "q", &q);
1154 let base_score = score_of(&r1, "format");
1155 graph.apply(&mut r1, &completed(&[]), boost_per, max_boost, &always_on);
1156 assert!((score_of(&r1, "format") - base_score).abs() < 1e-5);
1157
1158 let mut r2 = filter.filter(&all_ids, &[], "q", &q);
1160 graph.apply(
1161 &mut r2,
1162 &completed(&["search"]),
1163 boost_per,
1164 max_boost,
1165 &always_on,
1166 );
1167 let delta2 = score_of(&r2, "format") - base_score;
1168 assert!(
1169 (delta2 - 0.15).abs() < 1e-4,
1170 "expected +0.15 boost, got {delta2}"
1171 );
1172
1173 let mut r3 = filter.filter(&all_ids, &[], "q", &q);
1175 graph.apply(
1176 &mut r3,
1177 &completed(&["search", "grep"]),
1178 boost_per,
1179 max_boost,
1180 &always_on,
1181 );
1182 let delta3 = score_of(&r3, "format") - base_score;
1183 assert!(
1184 (delta3 - 0.30).abs() < 1e-4,
1185 "expected +0.30 boost, got {delta3}"
1186 );
1187 }
1188
1189 #[test]
1192 fn filter_tool_names_multi_turn_unlocks_after_completion() {
1193 let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
1194 let always_on: HashSet<String> = ["bash".into()].into();
1195 let all_names = vec!["bash", "read", "apply_patch"];
1196
1197 let filtered_before = graph.filter_tool_names(&all_names, &completed(&[]), &always_on);
1199 assert!(filtered_before.contains(&"bash")); assert!(filtered_before.contains(&"read")); assert!(!filtered_before.contains(&"apply_patch")); let filtered_after = graph.filter_tool_names(&all_names, &completed(&["read"]), &always_on);
1205 assert!(filtered_after.contains(&"bash"));
1206 assert!(filtered_after.contains(&"read"));
1207 assert!(filtered_after.contains(&"apply_patch")); }
1209
1210 #[test]
1213 fn filter_tool_names_deadlock_fallback_passes_all() {
1214 let graph = make_dep_graph(&[("only_tool", vec!["missing"], vec![])]);
1216 let always_on: HashSet<String> = ["bash".into()].into();
1217 let all_names = vec!["bash", "only_tool"];
1218
1219 let filtered = graph.filter_tool_names(&all_names, &completed(&[]), &always_on);
1220
1221 assert!(filtered.contains(&"bash"));
1226 assert!(!filtered.contains(&"only_tool"));
1227 }
1228}