1use std::collections::{HashMap, HashSet};
11
12use zeph_common::ToolName;
13use zeph_common::math::cosine_similarity;
14
15use crate::config::ToolDependency;
16
17#[derive(Debug, Clone)]
19pub struct ToolEmbedding {
20 pub tool_id: ToolName,
21 pub embedding: Vec<f32>,
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26#[non_exhaustive]
27pub enum InclusionReason {
28 AlwaysOn,
30 NameMentioned,
32 SimilarityRank,
34 ShortDescription,
36 NoEmbedding,
38 DependencyMet,
40 PreferenceBoost,
42}
43
44#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct DependencyExclusion {
47 pub tool_id: ToolName,
48 pub unmet_requires: Vec<String>,
50}
51
52#[derive(Debug, Clone)]
54pub struct ToolFilterResult {
55 pub included: HashSet<String>,
57 pub excluded: Vec<String>,
59 pub scores: Vec<(String, f32)>,
61 pub inclusion_reasons: Vec<(String, InclusionReason)>,
63 pub dependency_exclusions: Vec<DependencyExclusion>,
65}
66
67#[derive(Debug, Clone, Default)]
80pub struct ToolDependencyGraph {
81 deps: HashMap<String, ToolDependency>,
84}
85
86impl ToolDependencyGraph {
87 #[must_use]
92 pub fn new(deps: HashMap<String, ToolDependency>) -> Self {
93 if deps.is_empty() {
94 return Self { deps };
95 }
96 let cycled = detect_cycles(&deps);
97 if !cycled.is_empty() {
98 tracing::warn!(
99 tools = ?cycled,
100 "tool dependency graph: cycles detected, removing requires for cycle participants"
101 );
102 }
103 let mut resolved = deps;
104 for tool_id in &cycled {
105 if let Some(dep) = resolved.get_mut(tool_id) {
106 dep.requires.clear();
107 }
108 }
109 Self { deps: resolved }
110 }
111
112 #[must_use]
114 pub fn is_empty(&self) -> bool {
115 self.deps.is_empty()
116 }
117
118 #[must_use]
123 pub fn requirements_met(&self, tool_id: &str, completed: &HashSet<String>) -> bool {
124 self.deps
125 .get(tool_id)
126 .is_none_or(|d| d.requires.iter().all(|r| completed.contains(r)))
127 }
128
129 #[must_use]
131 pub fn unmet_requires<'a>(
132 &'a self,
133 tool_id: &str,
134 completed: &HashSet<String>,
135 ) -> Vec<&'a str> {
136 self.deps.get(tool_id).map_or_else(Vec::new, |d| {
137 d.requires
138 .iter()
139 .filter(|r| !completed.contains(r.as_str()))
140 .map(String::as_str)
141 .collect()
142 })
143 }
144
145 #[must_use]
149 pub fn preference_boost(
150 &self,
151 tool_id: &str,
152 completed: &HashSet<String>,
153 boost_per_dep: f32,
154 max_total_boost: f32,
155 ) -> f32 {
156 self.deps.get(tool_id).map_or(0.0, |d| {
157 let met = d
158 .prefers
159 .iter()
160 .filter(|p| completed.contains(p.as_str()))
161 .count();
162 #[allow(clippy::cast_precision_loss)]
163 let boost = met as f32 * boost_per_dep;
164 boost.min(max_total_boost)
165 })
166 }
167
168 pub fn apply(
182 &self,
183 result: &mut ToolFilterResult,
184 completed: &HashSet<String>,
185 boost_per_dep: f32,
186 max_total_boost: f32,
187 always_on: &HashSet<String>,
188 ) {
189 if self.deps.is_empty() {
190 return;
191 }
192
193 let bypassed: HashSet<&str> = result
197 .inclusion_reasons
198 .iter()
199 .filter(|(_, r)| matches!(r, InclusionReason::AlwaysOn))
200 .map(|(id, _)| id.as_str())
201 .collect();
202
203 let mut to_exclude: Vec<DependencyExclusion> = Vec::new();
204 for tool_id in &result.included {
205 if bypassed.contains(tool_id.as_str()) {
206 continue;
207 }
208 let unmet: Vec<String> = self
209 .unmet_requires(tool_id, completed)
210 .into_iter()
211 .map(str::to_owned)
212 .collect();
213 if !unmet.is_empty() {
214 to_exclude.push(DependencyExclusion {
215 tool_id: tool_id.as_str().into(),
216 unmet_requires: unmet,
217 });
218 }
219 }
220
221 let non_always_on_included: usize = result
224 .included
225 .iter()
226 .filter(|id| !always_on.contains(id.as_str()))
227 .count();
228 if !to_exclude.is_empty() && to_exclude.len() >= non_always_on_included {
229 tracing::warn!(
230 gated = to_exclude.len(),
231 non_always_on = non_always_on_included,
232 "tool dependency graph: all non-always-on tools would be blocked; \
233 disabling hard gates for this turn"
234 );
235 to_exclude.clear();
236 }
237
238 for excl in &to_exclude {
240 result.included.remove(excl.tool_id.as_str());
241 result.excluded.push(excl.tool_id.to_string());
242 tracing::debug!(
243 tool_id = %excl.tool_id,
244 unmet = ?excl.unmet_requires,
245 "tool dependency gate: excluded (requires not met)"
246 );
247 }
248 result.dependency_exclusions = to_exclude;
249
250 for (tool_id, score) in &mut result.scores {
252 if !result.included.contains(tool_id) {
253 continue;
254 }
255 let boost = self.preference_boost(tool_id, completed, boost_per_dep, max_total_boost);
256 if boost > 0.0 {
257 *score += boost;
258 let already_recorded = result.inclusion_reasons.iter().any(|(id, _)| id == tool_id);
260 if !already_recorded {
261 result
262 .inclusion_reasons
263 .push((tool_id.clone(), InclusionReason::PreferenceBoost));
264 }
265 tracing::debug!(
266 tool_id = %tool_id,
267 boost,
268 "tool dependency: preference boost applied"
269 );
270 }
271 }
272 result
274 .scores
275 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
276 }
277
278 #[must_use]
283 pub fn filter_tool_names<'a>(
284 &self,
285 names: &[&'a str],
286 completed: &HashSet<String>,
287 always_on: &HashSet<String>,
288 ) -> Vec<&'a str> {
289 names
290 .iter()
291 .copied()
292 .filter(|n| always_on.contains(*n) || self.requirements_met(n, completed))
293 .collect()
294 }
295}
296
297fn detect_cycles(deps: &HashMap<String, ToolDependency>) -> HashSet<String> {
304 #[derive(Clone, Copy, PartialEq)]
305 enum State {
306 Unvisited,
307 InProgress,
308 Done,
309 }
310
311 let mut state: HashMap<&str, State> = HashMap::new();
312 let mut cycled: HashSet<String> = HashSet::new();
313
314 for start in deps.keys() {
315 if state
316 .get(start.as_str())
317 .copied()
318 .unwrap_or(State::Unvisited)
319 != State::Unvisited
320 {
321 continue;
322 }
323 let mut stack: Vec<(&str, usize)> = vec![(start.as_str(), 0)];
324 state.insert(start.as_str(), State::InProgress);
325
326 while let Some((node, child_idx)) = stack.last_mut() {
327 let node = *node;
328 let requires = deps
329 .get(node)
330 .map_or(&[] as &[String], |d| d.requires.as_slice());
331
332 if *child_idx >= requires.len() {
333 state.insert(node, State::Done);
334 stack.pop();
335 continue;
336 }
337
338 let child = requires[*child_idx].as_str();
339 *child_idx += 1;
340
341 match state.get(child).copied().unwrap_or(State::Unvisited) {
342 State::InProgress => {
343 let cycle_start = stack.iter().position(|(n, _)| *n == child);
347 if let Some(start) = cycle_start {
348 for (path_node, _) in &stack[start..] {
349 cycled.insert((*path_node).to_owned());
350 }
351 }
352 cycled.insert(child.to_owned());
353 }
354 State::Unvisited => {
355 state.insert(child, State::InProgress);
356 stack.push((child, 0));
357 }
358 State::Done => {}
359 }
360 }
361 }
362
363 cycled
364}
365
366pub struct ToolSchemaFilter {
368 always_on: HashSet<String>,
369 top_k: usize,
370 min_description_words: usize,
371 embeddings: Vec<ToolEmbedding>,
372 version: u64,
373}
374
375impl ToolSchemaFilter {
376 #[must_use]
378 pub fn new(
379 always_on: Vec<String>,
380 top_k: usize,
381 min_description_words: usize,
382 embeddings: Vec<ToolEmbedding>,
383 ) -> Self {
384 Self {
385 always_on: always_on.into_iter().collect(),
386 top_k,
387 min_description_words,
388 embeddings,
389 version: 0,
390 }
391 }
392
393 #[must_use]
395 pub fn version(&self) -> u64 {
396 self.version
397 }
398
399 #[must_use]
401 pub fn embedding_count(&self) -> usize {
402 self.embeddings.len()
403 }
404
405 #[must_use]
407 pub fn top_k(&self) -> usize {
408 self.top_k
409 }
410
411 #[must_use]
413 pub fn always_on_count(&self) -> usize {
414 self.always_on.len()
415 }
416
417 pub fn recompute(&mut self, embeddings: Vec<ToolEmbedding>) {
419 self.embeddings = embeddings;
420 self.version += 1;
421 }
422
423 #[must_use]
429 pub fn filter(
430 &self,
431 all_tool_ids: &[&str],
432 tool_descriptions: &[(&str, &str)],
433 query: &str,
434 query_embedding: &[f32],
435 ) -> ToolFilterResult {
436 let mut included = HashSet::new();
437 let mut inclusion_reasons = Vec::new();
438
439 for id in all_tool_ids {
441 if self.always_on.contains(*id) {
442 included.insert((*id).to_owned());
443 inclusion_reasons.push(((*id).to_owned(), InclusionReason::AlwaysOn));
444 }
445 }
446
447 let mentioned = find_mentioned_tool_ids(query, all_tool_ids);
449 for id in &mentioned {
450 if included.insert(id.clone()) {
451 inclusion_reasons.push((id.clone(), InclusionReason::NameMentioned));
452 }
453 }
454
455 for &(id, desc) in tool_descriptions {
457 let word_count = desc.split_whitespace().count();
458 if word_count < self.min_description_words && included.insert(id.to_owned()) {
459 inclusion_reasons.push((id.to_owned(), InclusionReason::ShortDescription));
460 }
461 }
462
463 let mut scores: Vec<(String, f32)> = self
465 .embeddings
466 .iter()
467 .filter(|e| !included.contains(e.tool_id.as_str()))
468 .map(|e| {
469 let score = cosine_similarity(query_embedding, &e.embedding);
470 (e.tool_id.to_string(), score)
471 })
472 .collect();
473
474 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
475
476 let take = if self.top_k == 0 {
477 scores.len()
478 } else {
479 self.top_k.min(scores.len())
480 };
481
482 for (id, _score) in scores.iter().take(take) {
483 if included.insert(id.clone()) {
484 inclusion_reasons.push((id.clone(), InclusionReason::SimilarityRank));
485 }
486 }
487
488 let embedded_ids: HashSet<&str> =
490 self.embeddings.iter().map(|e| e.tool_id.as_str()).collect();
491 for id in all_tool_ids {
492 if !included.contains(*id) && !embedded_ids.contains(*id) {
493 included.insert((*id).to_owned());
494 inclusion_reasons.push(((*id).to_owned(), InclusionReason::NoEmbedding));
495 }
496 }
497
498 let excluded: Vec<String> = all_tool_ids
500 .iter()
501 .filter(|id| !included.contains(**id))
502 .map(|id| (*id).to_owned())
503 .collect();
504
505 ToolFilterResult {
506 included,
507 excluded,
508 scores,
509 inclusion_reasons,
510 dependency_exclusions: Vec::new(),
511 }
512 }
513}
514
515#[must_use]
520pub fn find_mentioned_tool_ids(query: &str, all_tool_ids: &[&str]) -> Vec<String> {
521 let query_lower = query.to_lowercase();
522 all_tool_ids
523 .iter()
524 .filter(|id| {
525 let id_lower = id.to_lowercase();
526 let mut start = 0;
527 while let Some(pos) = query_lower[start..].find(&id_lower) {
528 let abs_pos = start + pos;
529 let end_pos = abs_pos + id_lower.len();
530 let before_ok = abs_pos == 0
531 || !query_lower.as_bytes()[abs_pos - 1].is_ascii_alphanumeric()
532 && query_lower.as_bytes()[abs_pos - 1] != b'_';
533 let after_ok = end_pos >= query_lower.len()
534 || !query_lower.as_bytes()[end_pos].is_ascii_alphanumeric()
535 && query_lower.as_bytes()[end_pos] != b'_';
536 if before_ok && after_ok {
537 return true;
538 }
539 start = abs_pos + 1;
540 }
541 false
542 })
543 .map(|id| (*id).to_owned())
544 .collect()
545}
546
547#[cfg(test)]
548mod tests {
549 use super::*;
550
551 fn make_filter(always_on: Vec<&str>, top_k: usize) -> ToolSchemaFilter {
552 ToolSchemaFilter::new(
553 always_on.into_iter().map(String::from).collect(),
554 top_k,
555 5,
556 vec![
557 ToolEmbedding {
558 tool_id: "grep".into(),
559 embedding: vec![0.9, 0.1, 0.0],
560 },
561 ToolEmbedding {
562 tool_id: "write".into(),
563 embedding: vec![0.1, 0.9, 0.0],
564 },
565 ToolEmbedding {
566 tool_id: "find_path".into(),
567 embedding: vec![0.5, 0.5, 0.0],
568 },
569 ToolEmbedding {
570 tool_id: "web_scrape".into(),
571 embedding: vec![0.0, 0.0, 1.0],
572 },
573 ToolEmbedding {
574 tool_id: "diagnostics".into(),
575 embedding: vec![0.0, 0.1, 0.9],
576 },
577 ],
578 )
579 }
580
581 #[test]
582 fn top_k_ranking_selects_most_similar() {
583 let filter = make_filter(vec!["bash"], 2);
584 let all_ids: Vec<&str> = vec![
585 "bash",
586 "grep",
587 "write",
588 "find_path",
589 "web_scrape",
590 "diagnostics",
591 ];
592 let query_emb = vec![0.8, 0.2, 0.0]; let result = filter.filter(&all_ids, &[], "search for pattern", &query_emb);
594
595 assert!(result.included.contains("bash")); assert!(result.included.contains("grep")); assert!(result.included.contains("find_path")); assert!(!result.included.contains("web_scrape"));
600 assert!(!result.included.contains("diagnostics"));
601 }
602
603 #[test]
604 fn always_on_tools_always_included() {
605 let filter = make_filter(vec!["bash", "read"], 1);
606 let all_ids: Vec<&str> = vec!["bash", "read", "grep", "write"];
607 let query_emb = vec![0.0, 1.0, 0.0]; let result = filter.filter(&all_ids, &[], "test query", &query_emb);
609
610 assert!(result.included.contains("bash"));
611 assert!(result.included.contains("read"));
612 assert!(result.included.contains("write")); assert!(!result.included.contains("grep"));
614 }
615
616 #[test]
617 fn name_mention_force_includes() {
618 let filter = make_filter(vec!["bash"], 1);
619 let all_ids: Vec<&str> = vec!["bash", "grep", "web_scrape", "write"];
620 let query_emb = vec![0.0, 1.0, 0.0]; let result = filter.filter(&all_ids, &[], "use web_scrape to fetch", &query_emb);
622
623 assert!(result.included.contains("web_scrape")); assert!(result.included.contains("write")); assert!(result.included.contains("bash")); }
627
628 #[test]
629 fn short_mcp_description_auto_included() {
630 let filter = make_filter(vec!["bash"], 1);
631 let all_ids: Vec<&str> = vec!["bash", "grep", "mcp_query"];
632 let descriptions: Vec<(&str, &str)> = vec![
633 ("mcp_query", "Run query"),
634 ("grep", "Search file contents recursively"),
635 ];
636 let query_emb = vec![0.9, 0.1, 0.0];
637 let result = filter.filter(&all_ids, &descriptions, "test", &query_emb);
638
639 assert!(result.included.contains("mcp_query")); }
641
642 #[test]
643 fn empty_embeddings_includes_all_via_no_embedding_fallback() {
644 let filter = ToolSchemaFilter::new(vec!["bash".into()], 6, 5, vec![]);
645 let all_ids: Vec<&str> = vec!["bash", "grep", "write"];
646 let query_emb = vec![0.5, 0.5, 0.0];
647 let result = filter.filter(&all_ids, &[], "test", &query_emb);
648
649 assert!(result.included.contains("bash"));
651 assert!(result.included.contains("grep"));
652 assert!(result.included.contains("write"));
653 assert!(result.excluded.is_empty());
654 }
655
656 #[test]
657 fn top_k_zero_includes_all_filterable() {
658 let filter = make_filter(vec!["bash"], 0);
659 let all_ids: Vec<&str> = vec![
660 "bash",
661 "grep",
662 "write",
663 "find_path",
664 "web_scrape",
665 "diagnostics",
666 ];
667 let query_emb = vec![0.1, 0.1, 0.1];
668 let result = filter.filter(&all_ids, &[], "test", &query_emb);
669
670 assert_eq!(result.included.len(), 6); assert!(result.excluded.is_empty());
672 }
673
674 #[test]
675 fn top_k_exceeds_filterable_count_includes_all() {
676 let filter = make_filter(vec!["bash"], 100);
677 let all_ids: Vec<&str> = vec![
678 "bash",
679 "grep",
680 "write",
681 "find_path",
682 "web_scrape",
683 "diagnostics",
684 ];
685 let query_emb = vec![0.1, 0.1, 0.1];
686 let result = filter.filter(&all_ids, &[], "test", &query_emb);
687
688 assert_eq!(result.included.len(), 6);
689 }
690
691 #[test]
692 fn accessors_return_configured_values() {
693 let filter = make_filter(vec!["bash", "read"], 7);
694 assert_eq!(filter.top_k(), 7);
695 assert_eq!(filter.always_on_count(), 2);
696 assert_eq!(filter.embedding_count(), 5);
697 }
698
699 #[test]
700 fn version_counter_incremented_on_recompute() {
701 let mut filter = make_filter(vec![], 3);
702 assert_eq!(filter.version(), 0);
703 filter.recompute(vec![]);
704 assert_eq!(filter.version(), 1);
705 filter.recompute(vec![]);
706 assert_eq!(filter.version(), 2);
707 }
708
709 #[test]
710 fn inclusion_reason_correctness() {
711 let filter = make_filter(vec!["bash"], 1);
712 let all_ids: Vec<&str> = vec!["bash", "grep", "web_scrape", "write"];
713 let descriptions: Vec<(&str, &str)> = vec![("web_scrape", "Scrape")]; let query_emb = vec![0.1, 0.9, 0.0]; let result = filter.filter(&all_ids, &descriptions, "test query", &query_emb);
716
717 let reasons: std::collections::HashMap<String, InclusionReason> =
718 result.inclusion_reasons.into_iter().collect();
719 assert_eq!(reasons.get("bash"), Some(&InclusionReason::AlwaysOn));
720 assert_eq!(
721 reasons.get("web_scrape"),
722 Some(&InclusionReason::ShortDescription)
723 );
724 assert_eq!(reasons.get("write"), Some(&InclusionReason::SimilarityRank));
725 }
726
727 #[test]
728 fn cosine_similarity_identical_vectors() {
729 let v = vec![1.0, 2.0, 3.0];
730 let sim = cosine_similarity(&v, &v);
731 assert!((sim - 1.0).abs() < 1e-5);
732 }
733
734 #[test]
735 fn cosine_similarity_orthogonal_vectors() {
736 let a = vec![1.0, 0.0];
737 let b = vec![0.0, 1.0];
738 let sim = cosine_similarity(&a, &b);
739 assert!(sim.abs() < 1e-5);
740 }
741
742 #[test]
743 fn cosine_similarity_empty_returns_zero() {
744 assert!(cosine_similarity(&[], &[]) < f32::EPSILON);
745 }
746
747 #[test]
748 fn cosine_similarity_mismatched_length_returns_zero() {
749 assert!(cosine_similarity(&[1.0], &[1.0, 2.0]) < f32::EPSILON);
750 }
751
752 #[test]
753 fn find_mentioned_tool_ids_case_insensitive() {
754 let ids = vec!["web_scrape", "grep", "Bash"];
755 let found = find_mentioned_tool_ids("use WEB_SCRAPE and BASH", &ids);
756 assert!(found.contains(&"web_scrape".to_owned()));
757 assert!(found.contains(&"Bash".to_owned()));
758 assert!(!found.contains(&"grep".to_owned()));
759 }
760
761 #[test]
762 fn find_mentioned_tool_ids_word_boundary_no_false_positives() {
763 let ids = vec!["read", "edit", "fetch", "grep"];
764 let found = find_mentioned_tool_ids("thread breadcrumb", &ids);
766 assert!(found.is_empty());
767 }
768
769 #[test]
770 fn find_mentioned_tool_ids_word_boundary_matches_standalone() {
771 let ids = vec!["read", "edit"];
772 let found = find_mentioned_tool_ids("please read and edit the file", &ids);
773 assert!(found.contains(&"read".to_owned()));
774 assert!(found.contains(&"edit".to_owned()));
775 }
776
777 fn make_dep_graph(rules: &[(&str, Vec<&str>, Vec<&str>)]) -> ToolDependencyGraph {
780 let deps = rules
781 .iter()
782 .map(|(id, requires, prefers)| {
783 (
784 (*id).to_owned(),
785 crate::config::ToolDependency {
786 requires: requires.iter().map(|s| (*s).to_owned()).collect(),
787 prefers: prefers.iter().map(|s| (*s).to_owned()).collect(),
788 },
789 )
790 })
791 .collect();
792 ToolDependencyGraph::new(deps)
793 }
794
795 fn completed(ids: &[&str]) -> HashSet<String> {
796 ids.iter().map(|s| (*s).to_owned()).collect()
797 }
798
799 #[test]
800 fn requirements_met_no_deps() {
801 let graph = make_dep_graph(&[]);
802 assert!(graph.requirements_met("any_tool", &completed(&[])));
803 }
804
805 #[test]
806 fn requirements_met_all_satisfied() {
807 let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
808 assert!(graph.requirements_met("apply_patch", &completed(&["read"])));
809 }
810
811 #[test]
812 fn requirements_met_unmet() {
813 let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
814 assert!(!graph.requirements_met("apply_patch", &completed(&[])));
815 }
816
817 #[test]
818 fn requirements_met_unconfigured_tool() {
819 let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
820 assert!(graph.requirements_met("grep", &completed(&[])));
822 }
823
824 #[test]
825 fn preference_boost_none_met() {
826 let graph = make_dep_graph(&[("format", vec![], vec!["search", "grep"])]);
827 let boost = graph.preference_boost("format", &completed(&[]), 0.15, 0.2);
828 assert!(boost < f32::EPSILON);
829 }
830
831 #[test]
832 fn preference_boost_partial() {
833 let graph = make_dep_graph(&[("format", vec![], vec!["search", "grep"])]);
834 let boost = graph.preference_boost("format", &completed(&["search"]), 0.15, 0.2);
835 assert!((boost - 0.15).abs() < 1e-5);
836 }
837
838 #[test]
839 fn preference_boost_capped_at_max() {
840 let graph = make_dep_graph(&[("format", vec![], vec!["a", "b", "c"])]);
842 let boost = graph.preference_boost("format", &completed(&["a", "b", "c"]), 0.15, 0.2);
843 assert!((boost - 0.2).abs() < 1e-5);
844 }
845
846 #[test]
847 fn cycle_detection_simple_cycle() {
848 let graph = make_dep_graph(&[
850 ("tool_a", vec!["tool_b"], vec![]),
851 ("tool_b", vec!["tool_a"], vec![]),
852 ]);
853 assert!(graph.requirements_met("tool_a", &completed(&[])));
855 assert!(graph.requirements_met("tool_b", &completed(&[])));
856 }
857
858 #[test]
859 fn cycle_detection_does_not_affect_non_cycle_tools() {
860 let graph = make_dep_graph(&[
862 ("tool_a", vec!["tool_b"], vec![]),
863 ("tool_b", vec!["tool_c"], vec![]),
864 ("tool_c", vec!["tool_d"], vec![]),
865 ("tool_d", vec!["tool_c"], vec![]), ]);
867 assert!(graph.requirements_met("tool_c", &completed(&[])));
869 assert!(graph.requirements_met("tool_d", &completed(&[])));
870 assert!(!graph.requirements_met("tool_a", &completed(&[])));
872 assert!(!graph.requirements_met("tool_b", &completed(&[])));
873 }
874
875 #[test]
876 fn apply_excludes_gated_tool() {
877 let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
878 let filter = make_filter(vec!["bash"], 5);
879 let all_ids = vec!["bash", "read", "apply_patch", "grep"];
880 let query_emb = vec![0.5, 0.5, 0.0];
881 let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
882 result.included.insert("apply_patch".into());
884
885 let always_on: HashSet<String> = ["bash".into()].into();
886 graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
887
888 assert!(!result.included.contains("apply_patch"));
889 assert_eq!(result.dependency_exclusions.len(), 1);
890 assert_eq!(result.dependency_exclusions[0].tool_id, "apply_patch");
891 assert_eq!(result.dependency_exclusions[0].unmet_requires, vec!["read"]);
892 }
893
894 #[test]
895 fn apply_includes_gated_tool_when_dep_met() {
896 let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
897 let filter = make_filter(vec!["bash"], 5);
898 let all_ids = vec!["bash", "read", "apply_patch"];
899 let query_emb = vec![0.5, 0.5, 0.0];
900 let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
901 result.included.insert("apply_patch".into());
902
903 let always_on: HashSet<String> = ["bash".into()].into();
904 graph.apply(&mut result, &completed(&["read"]), 0.15, 0.2, &always_on);
905
906 assert!(result.included.contains("apply_patch"));
907 assert!(result.dependency_exclusions.is_empty());
908 }
909
910 #[test]
911 fn apply_deadlock_fallback_when_all_gated() {
912 let filter = ToolSchemaFilter::new(
915 vec!["bash".into()],
916 5,
917 5,
918 vec![], );
920 let graph = make_dep_graph(&[("only_tool", vec!["missing"], vec![])]);
921 let all_ids = vec!["bash", "only_tool"];
922 let query_emb = vec![0.5, 0.5, 0.0];
923 let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
924
925 assert!(result.included.contains("only_tool"));
927 assert!(result.included.contains("bash"));
928
929 let always_on: HashSet<String> = ["bash".into()].into();
930 graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
931
932 assert!(result.included.contains("only_tool"));
934 assert!(result.dependency_exclusions.is_empty());
935 }
936
937 #[test]
938 fn apply_always_on_bypasses_gate() {
939 let graph = make_dep_graph(&[("bash", vec!["nonexistent"], vec![])]);
940 let filter = make_filter(vec!["bash"], 5);
941 let all_ids = vec!["bash", "grep"];
942 let query_emb = vec![0.5, 0.5, 0.0];
943 let mut result = filter.filter(&all_ids, &[], "test", &query_emb);
944
945 let always_on: HashSet<String> = ["bash".into()].into();
946 graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
947
948 assert!(result.included.contains("bash"));
950 }
951
952 #[test]
960 fn cycle_detection_does_not_clear_ancestor_requires() {
961 let graph = make_dep_graph(&[
962 ("tool_a", vec!["tool_b"], vec![]),
963 ("tool_b", vec!["tool_c"], vec![]),
964 ("tool_c", vec!["tool_d"], vec![]),
965 ("tool_d", vec!["tool_c"], vec![]),
966 ]);
967 assert!(graph.requirements_met("tool_c", &completed(&[])));
969 assert!(graph.requirements_met("tool_d", &completed(&[])));
970 assert!(!graph.requirements_met("tool_a", &completed(&[])));
972 assert!(!graph.requirements_met("tool_b", &completed(&[])));
973 assert!(graph.requirements_met("tool_b", &completed(&["tool_c"])));
975 assert!(graph.requirements_met("tool_a", &completed(&["tool_b"])));
976 }
977
978 #[test]
983 fn name_mentioned_does_not_bypass_hard_gate() {
984 let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
985 let filter = make_filter(vec!["bash"], 5);
986 let all_ids = vec!["bash", "read", "apply_patch"];
988 let query_emb = vec![0.5, 0.5, 0.0];
989 let mut result = filter.filter(&all_ids, &[], "use apply_patch to fix the bug", &query_emb);
990
991 assert!(result.included.contains("apply_patch"));
993 let reason = result
994 .inclusion_reasons
995 .iter()
996 .find(|(id, _)| id == "apply_patch")
997 .map(|(_, r)| r);
998 assert_eq!(reason, Some(&InclusionReason::NameMentioned));
999
1000 let always_on: HashSet<String> = ["bash".into()].into();
1001 graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
1002
1003 assert!(!result.included.contains("apply_patch"));
1005 assert_eq!(result.dependency_exclusions.len(), 1);
1006 assert_eq!(result.dependency_exclusions[0].tool_id, "apply_patch");
1007 }
1008
1009 #[test]
1017 fn multi_turn_chain_two_steps() {
1018 let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
1020 let always_on: HashSet<String> = ["bash".into()].into();
1021
1022 let filter = ToolSchemaFilter::new(vec!["bash".into()], 5, 5, vec![]);
1024 let all_ids = vec!["bash", "read", "apply_patch"];
1025 let q = vec![0.5, 0.5, 0.0];
1026 let mut result = filter.filter(&all_ids, &[], "fix bug", &q);
1027 graph.apply(&mut result, &completed(&[]), 0.15, 0.2, &always_on);
1028
1029 assert!(!result.included.contains("apply_patch"));
1031 assert_eq!(result.dependency_exclusions.len(), 1);
1032
1033 let mut result2 = filter.filter(&all_ids, &[], "fix bug", &q);
1035 graph.apply(&mut result2, &completed(&["read"]), 0.15, 0.2, &always_on);
1036
1037 assert!(result2.included.contains("apply_patch"));
1039 assert!(result2.dependency_exclusions.is_empty());
1040 }
1041
1042 #[test]
1045 fn multi_turn_chain_three_steps() {
1046 let graph = make_dep_graph(&[
1047 ("search", vec!["read"], vec![]),
1048 ("apply_patch", vec!["search"], vec![]),
1049 ]);
1050 let always_on: HashSet<String> = ["bash".into()].into();
1051 let filter = ToolSchemaFilter::new(vec!["bash".into()], 5, 5, vec![]);
1052 let all_ids = vec!["bash", "read", "search", "apply_patch"];
1053 let q = vec![0.5, 0.5, 0.0];
1054
1055 let mut r1 = filter.filter(&all_ids, &[], "q", &q);
1057 graph.apply(&mut r1, &completed(&[]), 0.15, 0.2, &always_on);
1058 assert!(r1.included.contains("read"));
1059 assert!(!r1.included.contains("search"));
1060 assert!(!r1.included.contains("apply_patch"));
1061
1062 let mut r2 = filter.filter(&all_ids, &[], "q", &q);
1064 graph.apply(&mut r2, &completed(&["read"]), 0.15, 0.2, &always_on);
1065 assert!(r2.included.contains("search"));
1066 assert!(!r2.included.contains("apply_patch"));
1067
1068 let mut r3 = filter.filter(&all_ids, &[], "q", &q);
1070 graph.apply(
1071 &mut r3,
1072 &completed(&["read", "search"]),
1073 0.15,
1074 0.2,
1075 &always_on,
1076 );
1077 assert!(r3.included.contains("apply_patch"));
1078 assert!(r3.dependency_exclusions.is_empty());
1079 }
1080
1081 #[test]
1083 fn multi_turn_multi_requires_both_must_complete() {
1084 let graph = make_dep_graph(&[("apply_patch", vec!["read", "search"], vec![])]);
1085 let always_on: HashSet<String> = ["bash".into()].into();
1086 let filter = ToolSchemaFilter::new(vec!["bash".into()], 5, 5, vec![]);
1087 let all_ids = vec!["bash", "read", "search", "apply_patch"];
1088 let q = vec![0.5, 0.5, 0.0];
1089
1090 let mut r1 = filter.filter(&all_ids, &[], "q", &q);
1092 graph.apply(&mut r1, &completed(&["read"]), 0.15, 0.2, &always_on);
1093 assert!(!r1.included.contains("apply_patch"));
1094 let excl = &r1.dependency_exclusions[0];
1095 assert_eq!(excl.unmet_requires, vec!["search"]);
1096
1097 let mut r2 = filter.filter(&all_ids, &[], "q", &q);
1099 graph.apply(
1100 &mut r2,
1101 &completed(&["read", "search"]),
1102 0.15,
1103 0.2,
1104 &always_on,
1105 );
1106 assert!(r2.included.contains("apply_patch"));
1107 assert!(r2.dependency_exclusions.is_empty());
1108 }
1109
1110 #[test]
1116 fn multi_turn_preference_boost_accumulates() {
1117 let graph = make_dep_graph(&[("format", vec![], vec!["search", "grep"])]);
1119 let always_on: HashSet<String> = HashSet::new();
1120 let filter = ToolSchemaFilter::new(
1122 vec![],
1123 5,
1124 5,
1125 vec![
1126 ToolEmbedding {
1127 tool_id: "format".into(),
1128 embedding: vec![0.6, 0.4, 0.0],
1129 },
1130 ToolEmbedding {
1131 tool_id: "search".into(),
1132 embedding: vec![0.7, 0.3, 0.0],
1133 },
1134 ToolEmbedding {
1135 tool_id: "grep".into(),
1136 embedding: vec![0.8, 0.2, 0.0],
1137 },
1138 ],
1139 );
1140 let all_ids = vec!["format", "search", "grep"];
1141 let q = vec![0.5, 0.5, 0.0];
1142 let boost_per = 0.15_f32;
1143 let max_boost = 0.3_f32;
1144
1145 let score_of = |result: &ToolFilterResult, id: &str| -> f32 {
1146 result
1147 .scores
1148 .iter()
1149 .find(|(tid, _)| tid == id)
1150 .map_or(0.0, |(_, s)| *s)
1151 };
1152
1153 let mut r1 = filter.filter(&all_ids, &[], "q", &q);
1155 let base_score = score_of(&r1, "format");
1156 graph.apply(&mut r1, &completed(&[]), boost_per, max_boost, &always_on);
1157 assert!((score_of(&r1, "format") - base_score).abs() < 1e-5);
1158
1159 let mut r2 = filter.filter(&all_ids, &[], "q", &q);
1161 graph.apply(
1162 &mut r2,
1163 &completed(&["search"]),
1164 boost_per,
1165 max_boost,
1166 &always_on,
1167 );
1168 let delta2 = score_of(&r2, "format") - base_score;
1169 assert!(
1170 (delta2 - 0.15).abs() < 1e-4,
1171 "expected +0.15 boost, got {delta2}"
1172 );
1173
1174 let mut r3 = filter.filter(&all_ids, &[], "q", &q);
1176 graph.apply(
1177 &mut r3,
1178 &completed(&["search", "grep"]),
1179 boost_per,
1180 max_boost,
1181 &always_on,
1182 );
1183 let delta3 = score_of(&r3, "format") - base_score;
1184 assert!(
1185 (delta3 - 0.30).abs() < 1e-4,
1186 "expected +0.30 boost, got {delta3}"
1187 );
1188 }
1189
1190 #[test]
1193 fn filter_tool_names_multi_turn_unlocks_after_completion() {
1194 let graph = make_dep_graph(&[("apply_patch", vec!["read"], vec![])]);
1195 let always_on: HashSet<String> = ["bash".into()].into();
1196 let all_names = vec!["bash", "read", "apply_patch"];
1197
1198 let filtered_before = graph.filter_tool_names(&all_names, &completed(&[]), &always_on);
1200 assert!(filtered_before.contains(&"bash")); assert!(filtered_before.contains(&"read")); assert!(!filtered_before.contains(&"apply_patch")); let filtered_after = graph.filter_tool_names(&all_names, &completed(&["read"]), &always_on);
1206 assert!(filtered_after.contains(&"bash"));
1207 assert!(filtered_after.contains(&"read"));
1208 assert!(filtered_after.contains(&"apply_patch")); }
1210
1211 #[test]
1214 fn filter_tool_names_deadlock_fallback_passes_all() {
1215 let graph = make_dep_graph(&[("only_tool", vec!["missing"], vec![])]);
1217 let always_on: HashSet<String> = ["bash".into()].into();
1218 let all_names = vec!["bash", "only_tool"];
1219
1220 let filtered = graph.filter_tool_names(&all_names, &completed(&[]), &always_on);
1221
1222 assert!(filtered.contains(&"bash"));
1227 assert!(!filtered.contains(&"only_tool"));
1228 }
1229}