1use std::borrow::Cow;
20use std::collections::{BTreeMap, HashMap, HashSet};
21use std::fmt::Write as _;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29#[non_exhaustive]
30pub enum DisclosureDepth {
31 Minimal,
33 Summary,
35 Parameters,
37 Full,
39}
40
41struct ToolEntry {
47 name: String,
49 namespace: String,
51 description: String,
53 tags: Vec<String>,
55 example_queries: Vec<String>,
57 schema_json: Option<String>,
59 call_count: u64,
62 #[allow(dead_code)] embedding: Vec<f32>,
71}
72
73#[derive(Debug, Clone)]
79#[non_exhaustive]
80pub struct ToolSearchResult {
81 pub name: String,
83 pub namespace: String,
85 pub score: f32,
87 pub rendered: String,
89 pub nearest_namespace: Option<String>,
91 pub alternative_keywords: Vec<String>,
93 pub confidence_level: String,
95}
96
97pub struct ToolSearchIndex {
117 entries: Vec<ToolEntry>,
118 loaded_schemas: HashSet<String>,
120 registry_hash: String,
123}
124
125impl Default for ToolSearchIndex {
126 fn default() -> Self {
127 Self::new()
128 }
129}
130
131impl ToolSearchIndex {
132 pub fn new() -> Self {
134 Self {
135 entries: Vec::new(),
136 loaded_schemas: HashSet::new(),
137 registry_hash: String::new(),
138 }
139 }
140
141 pub fn register(
148 &mut self,
149 name: &str,
150 namespace: &str,
151 description: &str,
152 tags: &[&str],
153 schema_json: Option<&str>,
154 ) {
155 self.entries.retain(|e| e.name != name);
157
158 let embedding = compute_embedding(name, description);
159 self.entries.push(ToolEntry {
160 name: name.to_owned(),
161 namespace: namespace.to_owned(),
162 description: description.to_owned(),
163 tags: tags.iter().map(|&t| t.to_owned()).collect(),
164 example_queries: Vec::new(),
165 schema_json: schema_json.map(str::to_owned),
166 call_count: 0,
167 embedding,
168 });
169 self.recompute_hash();
170 }
171
172 pub fn search(&mut self, query: &str, top_k: usize) -> Vec<ToolSearchResult> {
178 let query_words: HashSet<&str> = query.split_whitespace().collect();
179
180 let mut scored: Vec<(usize, f32)> = self
181 .entries
182 .iter()
183 .enumerate()
184 .map(|(i, entry)| {
185 let mut score = keyword_score(entry, &query_words);
186 if self.loaded_schemas.contains(&entry.name) {
188 score *= 0.8;
189 }
190 (i, score)
191 })
192 .collect();
193
194 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
195 scored.truncate(top_k);
196
197 let top_namespace = scored
199 .first()
200 .map(|(i, _)| self.entries[*i].namespace.clone());
201
202 let alternative_keywords: Vec<String> = scored
204 .iter()
205 .flat_map(|(i, _)| self.entries[*i].tags.iter().cloned())
206 .filter(|t| {
207 !query_words
208 .iter()
209 .any(|&w| w.eq_ignore_ascii_case(t.as_str()))
210 })
211 .collect::<HashSet<_>>()
212 .into_iter()
213 .take(5)
214 .collect();
215
216 let results: Vec<ToolSearchResult> = scored
217 .into_iter()
218 .map(|(i, score)| {
219 let entry = &self.entries[i];
220
221 let confidence_level = if score > 5.0 {
222 "high"
223 } else if score > 2.0 {
224 "medium"
225 } else {
226 "low"
227 };
228
229 let nearest_namespace = if confidence_level == "low" {
230 top_namespace.clone()
231 } else {
232 None
233 };
234
235 let rendered = render(entry, DisclosureDepth::Summary);
236
237 ToolSearchResult {
238 name: entry.name.clone(),
239 namespace: entry.namespace.clone(),
240 score,
241 rendered,
242 nearest_namespace,
243 alternative_keywords: alternative_keywords.clone(),
244 confidence_level: confidence_level.to_owned(),
245 }
246 })
247 .collect();
248
249 for r in &results {
251 let _ = self.loaded_schemas.insert(r.name.clone());
252 }
253
254 results
255 }
256
257 pub fn browse_namespace(&self, namespace: &str) -> Vec<ToolSearchResult> {
259 self.entries
260 .iter()
261 .filter(|e| e.namespace == namespace)
262 .map(|e| ToolSearchResult {
263 name: e.name.clone(),
264 namespace: e.namespace.clone(),
265 score: 1.0,
266 rendered: render(e, DisclosureDepth::Summary),
267 nearest_namespace: None,
268 alternative_keywords: Vec::new(),
269 confidence_level: "high".to_owned(),
270 })
271 .collect()
272 }
273
274 pub fn list_compact(&self) -> Vec<(String, String)> {
276 self.entries
277 .iter()
278 .map(|e| (e.name.clone(), e.description.clone()))
279 .collect()
280 }
281
282 pub fn record_success(&mut self, query: &str, tool_name: &str) {
287 if let Some(entry) = self.entries.iter_mut().find(|e| e.name == tool_name) {
288 entry.call_count = entry.call_count.saturating_add(1);
289 if !entry.example_queries.iter().any(|q| q == query) && entry.example_queries.len() < 10
290 {
291 entry.example_queries.push(query.to_owned());
292 }
293 }
294 }
295
296 pub fn search_progressive(
300 &mut self,
301 query: &str,
302 steps: usize,
303 per_step_k: usize,
304 ) -> Vec<ToolSearchResult> {
305 let mut seen: HashSet<String> = HashSet::new();
306 let mut all_results: Vec<ToolSearchResult> = Vec::new();
307 let mut remaining_query = query.to_owned();
308
309 for _ in 0..steps {
310 let step_results = self.search(&remaining_query, per_step_k);
311 for r in step_results {
312 if seen.insert(r.name.clone()) {
313 all_results.push(r);
314 }
315 }
316 let found_names: Vec<&str> = all_results.iter().map(|r| r.name.as_str()).collect();
318 remaining_query = format!("{query} -{}", found_names.join(" -"));
319 }
320 all_results
321 }
322
323 pub fn registry_hash(&self) -> &str {
325 &self.registry_hash
326 }
327
328 fn recompute_hash(&mut self) {
333 use sha2::{Digest, Sha256};
334
335 let sorted: BTreeMap<&str, &str> = self
336 .entries
337 .iter()
338 .map(|e| (e.name.as_str(), e.description.as_str()))
339 .collect();
340
341 let data = serde_json::to_string(&sorted).unwrap_or_else(|_| format!("{sorted:?}"));
342
343 let mut hasher = Sha256::new();
344 hasher.update(data.as_bytes());
345 self.registry_hash = format!("{:x}", hasher.finalize());
346 }
347}
348
349const fn compute_embedding(_name: &str, _description: &str) -> Vec<f32> {
355 Vec::new()
356}
357
358fn keyword_score(entry: &ToolEntry, query_words: &HashSet<&str>) -> f32 {
364 let name_words: HashSet<&str> = entry.name.split(['-', '_', ' ']).collect();
365 let desc_words: HashSet<&str> = entry.description.split_whitespace().collect();
366 let ns_words: HashSet<&str> = entry.namespace.split(['-', '_']).collect();
367
368 #[allow(clippy::cast_precision_loss)]
369 let ns_score: f32 = ns_words.intersection(query_words).count() as f32 * 5.0;
370 #[allow(clippy::cast_precision_loss)]
371 let name_score: f32 = name_words.intersection(query_words).count() as f32 * 3.0;
372 #[allow(clippy::cast_precision_loss)]
373 let desc_score: f32 = desc_words.intersection(query_words).count() as f32 * 2.0;
374 #[allow(clippy::cast_precision_loss)]
375 let tag_score: f32 = entry
376 .tags
377 .iter()
378 .flat_map(|t| t.split(['-', '_', ' ']))
379 .filter(|w| query_words.contains(w))
380 .count() as f32
381 * 1.5;
382
383 #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
387 let freq_boost = ((1.0 + entry.call_count as f64).log2() * 0.5).min(3.0) as f32;
388
389 ns_score + name_score + desc_score + tag_score + freq_boost
390}
391
392fn render(entry: &ToolEntry, depth: DisclosureDepth) -> String {
398 match depth {
399 DisclosureDepth::Minimal => entry.name.clone(),
400 DisclosureDepth::Summary => {
401 let desc = if entry.description.len() > 100 {
402 format!("{}…", &entry.description[..100])
403 } else {
404 entry.description.clone()
405 };
406 format!("{}: {desc}", entry.name)
407 }
408 DisclosureDepth::Parameters => {
409 let summary = if entry.description.len() > 100 {
410 format!("{}…", &entry.description[..100])
411 } else {
412 entry.description.clone()
413 };
414 let params = extract_parameter_names(entry.schema_json.as_deref());
415 if params.is_empty() {
416 format!("{}: {summary}", entry.name)
417 } else {
418 format!("{}: {summary} (params: {})", entry.name, params.join(", "))
419 }
420 }
421 DisclosureDepth::Full => {
422 let mut out = format!("name: {}\ndescription: {}\n", entry.name, entry.description);
423 if let Some(ref schema) = entry.schema_json {
424 out.push_str("schema: ");
425 out.push_str(schema);
426 }
427 out
428 }
429 }
430}
431
432fn extract_parameter_names(schema_json: Option<&str>) -> Vec<String> {
434 let Some(json) = schema_json else {
435 return Vec::new();
436 };
437 let Ok(value) = serde_json::from_str::<serde_json::Value>(json) else {
438 return Vec::new();
439 };
440 value
441 .get("properties")
442 .and_then(|p| p.as_object())
443 .map(|props| props.keys().cloned().collect())
444 .unwrap_or_default()
445}
446
447#[derive(Debug, Clone)]
453#[non_exhaustive]
454pub struct ToolSearchArgs {
455 pub query: Option<String>,
457 pub namespace: Option<String>,
459 pub top_k: Option<usize>,
461}
462
463impl ToolSearchArgs {
464 pub const fn new() -> Self {
466 Self {
467 query: None,
468 namespace: None,
469 top_k: None,
470 }
471 }
472
473 #[must_use]
475 pub fn with_query(mut self, query: impl Into<String>) -> Self {
476 self.query = Some(query.into());
477 self
478 }
479
480 #[must_use]
482 pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
483 self.namespace = Some(namespace.into());
484 self
485 }
486
487 #[must_use]
489 pub const fn with_top_k(mut self, top_k: usize) -> Self {
490 self.top_k = Some(top_k);
491 self
492 }
493}
494
495impl Default for ToolSearchArgs {
496 fn default() -> Self {
497 Self::new()
498 }
499}
500
501pub fn run_tool_search(index: &mut ToolSearchIndex, args: &ToolSearchArgs) -> String {
505 let top_k = args.top_k.unwrap_or(5);
506
507 if let Some(ref ns) = args.namespace {
508 let results = index.browse_namespace(ns);
509 if results.is_empty() {
510 return format!("No tools found in namespace '{ns}'.");
511 }
512 let mut out = format!("Tools in namespace '{ns}':\n");
513 for r in &results {
514 let _ = writeln!(out, " - {}", r.rendered);
515 }
516 return out;
517 }
518
519 let query = match &args.query {
520 Some(q) => q.clone(),
521 None => return "Provide either 'query' or 'namespace'.".to_owned(),
522 };
523
524 let results = index.search(&query, top_k);
525 if results.is_empty() {
526 return format!("No tools matched '{query}'.");
527 }
528
529 let mut out = format!("Tool search results for '{query}':\n");
530 for r in &results {
531 let _ = writeln!(out, " [{}] {}", r.confidence_level, r.rendered);
532 }
533 out
534}
535
536pub fn run_tool_list(index: &ToolSearchIndex) -> String {
538 let mut grouped: BTreeMap<&str, Vec<(&str, &str)>> = BTreeMap::new();
540 for entry in &index.entries {
541 grouped
542 .entry(entry.namespace.as_str())
543 .or_default()
544 .push((entry.name.as_str(), entry.description.as_str()));
545 }
546
547 let mut out = String::new();
548 for (ns, tools) in &grouped {
549 out.push_str(ns);
550 out.push_str(":\n");
551 for (name, desc) in tools {
552 let short_desc = if desc.len() > 80 {
553 format!("{}…", &desc[..80])
554 } else {
555 (*desc).to_owned()
556 };
557 let _ = writeln!(out, " - {name}: {short_desc}");
558 }
559 }
560 out
561}
562
563pub fn allocate_budget(results: &[ToolSearchResult]) -> String {
573 const TOKEN_CAP: usize = 5_000;
574 const CHARS_PER_TOKEN: usize = 4;
575
576 let mut out = String::new();
577 let mut tokens_used: usize = 0;
578
579 for (i, r) in results.iter().enumerate() {
580 let depth_label = if i < 5 {
581 "full"
582 } else if i < 15 {
583 "summary"
584 } else {
585 "minimal"
586 };
587 let line = format!("[{depth_label}] {} (score={:.2})\n", r.rendered, r.score);
588 tokens_used += line.len() / CHARS_PER_TOKEN;
589 if tokens_used > TOKEN_CAP {
590 break;
591 }
592 out.push_str(&line);
593 }
594 out
595}
596
597pub fn verify_parameter_types(results: &mut [ToolSearchResult], query: &str) {
607 let query_lower = query.to_lowercase();
608 for r in results.iter_mut() {
609 let looks_like_path = query_lower.contains('/')
610 || query_lower.contains(".rs")
611 || query_lower.contains(".py")
612 || query_lower.contains("file");
613 let is_file_tool = r.name.contains("read")
614 || r.name.contains("write")
615 || r.name.contains("file")
616 || r.name.contains("glob")
617 || r.namespace == "vfs";
618
619 if looks_like_path && !is_file_tool {
620 r.score *= 0.7;
621 }
622
623 let looks_like_symbol = query_lower.contains("function")
624 || query_lower.contains("method")
625 || query_lower.contains("struct")
626 || query_lower.contains("class");
627 let is_lsp_tool =
628 r.namespace == "lsp" || r.name.contains("symbol") || r.name.contains("goto");
629
630 if looks_like_symbol && !is_lsp_tool {
631 r.score *= 0.85;
632 }
633 }
634 results.sort_by(|a, b| {
635 b.score
636 .partial_cmp(&a.score)
637 .unwrap_or(std::cmp::Ordering::Equal)
638 });
639}
640
641pub struct ToolTransitionGraph {
649 transitions: HashMap<String, HashMap<String, f64>>,
651 half_life: usize,
653 total_invocations: usize,
655}
656
657impl ToolTransitionGraph {
658 pub fn new(half_life: usize) -> Self {
660 Self {
661 transitions: HashMap::new(),
662 half_life,
663 total_invocations: 0,
664 }
665 }
666
667 pub fn record_transition(&mut self, from: &str, to: &str) {
669 let _ = self
670 .transitions
671 .entry(from.to_owned())
672 .or_default()
673 .entry(to.to_owned())
674 .and_modify(|c| *c += 1.0)
675 .or_insert(1.0);
676 self.total_invocations += 1;
677 }
678
679 pub fn successors(&self, current: &str) -> Vec<(String, f32)> {
683 let Some(counts) = self.transitions.get(current) else {
684 return Vec::new();
685 };
686 let total: f64 = counts.values().sum();
687 let exponent = self.total_invocations / self.half_life.max(1);
688 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
689 let decay = 0.5_f64.powi(exponent as i32);
690 let mut results: Vec<(String, f32)> = counts
691 .iter()
692 .map(|(name, count)| {
693 #[allow(clippy::cast_possible_truncation)]
694 let score = ((count / total) * decay) as f32;
695 (name.clone(), score)
696 })
697 .collect();
698 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
699 results
700 }
701}
702
703pub trait QueryPreprocessor: Send + Sync {
709 fn preprocess<'a>(&self, query: &'a str) -> Cow<'a, str>;
711}
712
713pub struct IntentExtractor;
715
716impl QueryPreprocessor for IntentExtractor {
717 fn preprocess<'a>(&self, query: &'a str) -> Cow<'a, str> {
718 const STOP_WORDS: &[&str] = &[
719 "the", "a", "an", "in", "for", "to", "of", "that", "which", "with", "from",
720 ];
721 let words: Vec<&str> = query.split_whitespace().collect();
722 if words.len() <= 5 {
723 return Cow::Borrowed(query);
724 }
725 let content_words: Vec<&str> = words
726 .iter()
727 .copied()
728 .filter(|w| !STOP_WORDS.contains(&w.to_lowercase().as_str()))
729 .take(5)
730 .collect();
731 Cow::Owned(content_words.join(" "))
732 }
733}
734
735#[cfg(test)]
740#[allow(clippy::unwrap_used, clippy::cast_precision_loss)]
741mod tests {
742 use super::*;
743
744 fn sample_index() -> ToolSearchIndex {
745 let mut idx = ToolSearchIndex::new();
746 idx.register(
747 "read_file",
748 "vfs",
749 "Read the contents of a file",
750 &["file", "read"],
751 None,
752 );
753 idx.register(
754 "search_code",
755 "index",
756 "Semantic code search using embeddings",
757 &["search", "semantic"],
758 None,
759 );
760 idx.register(
761 "list_dir",
762 "vfs",
763 "List directory contents",
764 &["ls", "directory"],
765 None,
766 );
767 idx
768 }
769
770 #[test]
772 fn tool_search_finds_by_query() {
773 let mut idx = sample_index();
774 let results = idx.search("read file contents", 3);
775 assert!(!results.is_empty());
776 assert_eq!(results[0].name, "read_file");
777 }
778
779 #[test]
781 fn namespace_browse_returns_all() {
782 let mut idx = ToolSearchIndex::new();
783 idx.register("read_file", "vfs", "Read file", &[], None);
784 idx.register("write_file", "vfs", "Write file", &[], None);
785 idx.register("search_code", "index", "Search code", &[], None);
786
787 let vfs_tools = idx.browse_namespace("vfs");
788 assert_eq!(vfs_tools.len(), 2);
789 }
790
791 #[test]
793 fn exact_name_match_ranks_first() {
794 let mut idx = ToolSearchIndex::new();
795 idx.register(
796 "search_code",
797 "index",
798 "Semantic code search",
799 &["search"],
800 None,
801 );
802 idx.register("read_file", "vfs", "Read file", &["file"], None);
803 idx.register("list_dir", "vfs", "List directory contents", &[], None);
804
805 let results = idx.search("search code", 3);
806 assert!(!results.is_empty());
807 assert_eq!(results[0].name, "search_code");
808 }
809
810 #[test]
812 fn adaptive_scoring_penalises_loaded_schemas() {
813 let mut idx = ToolSearchIndex::new();
814 idx.register(
815 "find_func",
816 "lsp",
817 "Find function definition",
818 &["function", "find"],
819 None,
820 );
821
822 let first_results = idx.search("find function", 1);
824 assert!(!first_results.is_empty());
825 let score_before = first_results[0].score;
826
827 let second_results = idx.search("find function", 1);
829 assert!(!second_results.is_empty());
830 let score_after = second_results[0].score;
831
832 assert!(
833 score_after < score_before,
834 "expected penalised score {score_after} < {score_before}"
835 );
836 }
837
838 #[test]
840 fn registry_hash_changes_on_registration() {
841 let mut idx = ToolSearchIndex::new();
842 let h0 = idx.registry_hash().to_owned();
843 idx.register("read_file", "vfs", "Read file", &[], None);
844 let h1 = idx.registry_hash().to_owned();
845 idx.register("write_file", "vfs", "Write file", &[], None);
846 let h2 = idx.registry_hash().to_owned();
847 assert_ne!(h0, h1);
848 assert_ne!(h1, h2);
849 }
850
851 #[test]
853 fn record_success_capped() {
854 let mut idx = ToolSearchIndex::new();
855 idx.register("read_file", "vfs", "Read file", &[], None);
856 for i in 0..15 {
857 idx.record_success(&format!("query {i}"), "read_file");
858 }
859 let entry = idx.entries.iter().find(|e| e.name == "read_file").unwrap();
860 assert_eq!(entry.example_queries.len(), 10);
861 }
862
863 #[test]
865 fn record_success_no_duplicates() {
866 let mut idx = ToolSearchIndex::new();
867 idx.register("read_file", "vfs", "Read file", &[], None);
868 for _ in 0..5 {
869 idx.record_success("read a file", "read_file");
870 }
871 let entry = idx.entries.iter().find(|e| e.name == "read_file").unwrap();
872 assert_eq!(entry.example_queries.len(), 1);
873 }
874
875 #[test]
877 fn progressive_retrieval_deduplicates() {
878 let mut idx = ToolSearchIndex::new();
879 idx.register("read_file", "vfs", "Read file contents", &["file"], None);
880 idx.register("write_file", "vfs", "Write file contents", &["file"], None);
881 idx.register(
882 "search_code",
883 "index",
884 "Search code semantically",
885 &["search"],
886 None,
887 );
888
889 let results = idx.search_progressive("file operations", 2, 2);
890 let names: Vec<&String> = results.iter().map(|r| &r.name).collect();
891 let unique_names: HashSet<&String> = names.iter().copied().collect();
892 assert_eq!(names.len(), unique_names.len());
893 }
894
895 #[test]
897 fn transition_graph_boosts_successors() {
898 let mut g = ToolTransitionGraph::new(100);
899 g.record_transition("read_file", "write_file");
900 g.record_transition("read_file", "write_file");
901 g.record_transition("read_file", "search_code");
902
903 let successors = g.successors("read_file");
904 assert!(!successors.is_empty());
905 assert_eq!(successors[0].0, "write_file");
906 }
907
908 #[test]
910 fn intent_extractor_shortens_long_query() {
911 let extractor = IntentExtractor;
912 let long = "I need to find the function that handles authentication in the codebase";
913 let short = extractor.preprocess(long);
914 assert!(short.split_whitespace().count() <= 5);
915 }
916
917 #[test]
919 fn intent_extractor_passthrough_short_query() {
920 let extractor = IntentExtractor;
921 let q = "read file";
922 let result = extractor.preprocess(q);
923 assert_eq!(result, q);
924 }
925
926 #[test]
928 fn parameter_verification_demotes_non_file_tools() {
929 let mut results = vec![
930 ToolSearchResult {
931 name: "go_to_definition".to_owned(),
932 namespace: "lsp".to_owned(),
933 score: 6.0,
935 rendered: String::new(),
936 nearest_namespace: None,
937 alternative_keywords: Vec::new(),
938 confidence_level: "high".to_owned(),
939 },
940 ToolSearchResult {
941 name: "read_file".to_owned(),
942 namespace: "vfs".to_owned(),
943 score: 5.0,
944 rendered: String::new(),
945 nearest_namespace: None,
946 alternative_keywords: Vec::new(),
947 confidence_level: "high".to_owned(),
948 },
949 ];
950 verify_parameter_types(&mut results, "read /src/main.rs file");
952 assert_eq!(results[0].name, "read_file");
953 }
954
955 #[test]
957 fn run_tool_list_grouped_output() {
958 let mut idx = ToolSearchIndex::new();
959 idx.register("read_file", "vfs", "Read file", &[], None);
960 idx.register("write_file", "vfs", "Write file", &[], None);
961 idx.register("search_code", "index", "Search code", &[], None);
962
963 let output = run_tool_list(&idx);
964 assert!(output.contains("vfs:"));
965 assert!(output.contains("index:"));
966 }
967
968 #[test]
970 fn allocate_budget_produces_output() {
971 let results: Vec<ToolSearchResult> = (0..20)
972 .map(|i| ToolSearchResult {
973 name: format!("tool_{i}"),
974 namespace: "vfs".to_owned(),
975 score: 10.0 - i as f32,
976 rendered: format!("tool_{i}: does something useful"),
977 nearest_namespace: None,
978 alternative_keywords: Vec::new(),
979 confidence_level: "high".to_owned(),
980 })
981 .collect();
982 let output = allocate_budget(&results);
983 assert!(!output.is_empty());
984 assert!(output.contains("[full]"));
985 assert!(output.contains("[summary]"));
986 }
987}