Skip to main content

sgr_agent/
discovery.rs

1//! Progressive discovery — filter tools by relevance to current query.
2//!
3//! Use `ToolFilter` when your registry has many tools (20+) and you want to
4//! limit what the LLM sees per step. System tools are always included.
5//! Non-system tools are ranked by keyword overlap + fuzzy matching (strsim).
6//!
7//! Usage: call `filter.select(user_query, &registry)` to get a subset of tools,
8//! then pass those to your agent's `decide()` via a filtered registry.
9
10use crate::agent_tool::Tool;
11use crate::registry::ToolRegistry;
12
13/// Tool filter for progressive discovery.
14pub struct ToolFilter {
15    /// Maximum number of non-system tools to expose.
16    pub max_visible: usize,
17}
18
19impl ToolFilter {
20    pub fn new(max_visible: usize) -> Self {
21        Self { max_visible }
22    }
23
24    /// Select relevant tools for a query. System tools always included.
25    pub fn select<'a>(&self, query: &str, registry: &'a ToolRegistry) -> Vec<&'a dyn Tool> {
26        let query_lower = query.to_lowercase();
27        let query_words: Vec<&str> = query_lower.split_whitespace().collect();
28
29        let mut system_tools = Vec::new();
30        let mut scored: Vec<(&dyn Tool, f64)> = Vec::new();
31
32        for tool in registry.list() {
33            if tool.is_system() {
34                system_tools.push(tool);
35                continue;
36            }
37
38            let score = score_tool(tool, &query_lower, &query_words);
39            scored.push((tool, score));
40        }
41
42        // Sort by score descending
43        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
44
45        // Take top N
46        let mut result = system_tools;
47        for (tool, _score) in scored.into_iter().take(self.max_visible) {
48            result.push(tool);
49        }
50
51        result
52    }
53}
54
55impl Default for ToolFilter {
56    fn default() -> Self {
57        Self { max_visible: 10 }
58    }
59}
60
61/// Score a tool's relevance to a query.
62fn score_tool(tool: &dyn Tool, query_lower: &str, query_words: &[&str]) -> f64 {
63    let name = tool.name().to_lowercase();
64    let desc = tool.description().to_lowercase();
65    let combined = format!("{} {}", name, desc);
66    let tool_words: Vec<&str> = combined.split_whitespace().collect();
67
68    let mut score = 0.0;
69
70    // Exact name match
71    if query_lower.contains(&name) {
72        score += 5.0;
73    }
74
75    // Word intersection
76    for qw in query_words {
77        for tw in &tool_words {
78            if qw == tw {
79                score += 2.0;
80            } else {
81                let sim = strsim::normalized_levenshtein(qw, tw);
82                if sim > 0.7 {
83                    score += sim;
84                }
85            }
86        }
87    }
88
89    // Substring match in name
90    for qw in query_words {
91        if name.contains(qw) {
92            score += 1.5;
93        }
94    }
95
96    score
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102    use crate::agent_tool::{ToolError, ToolOutput};
103    use crate::context::AgentContext;
104    use serde_json::Value;
105
106    struct TestTool {
107        tool_name: &'static str,
108        desc: &'static str,
109        system: bool,
110    }
111
112    #[async_trait::async_trait]
113    impl Tool for TestTool {
114        fn name(&self) -> &str {
115            self.tool_name
116        }
117        fn description(&self) -> &str {
118            self.desc
119        }
120        fn is_system(&self) -> bool {
121            self.system
122        }
123        fn parameters_schema(&self) -> Value {
124            serde_json::json!({"type": "object"})
125        }
126        async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
127            Ok(ToolOutput::text("ok"))
128        }
129    }
130
131    #[test]
132    fn system_tools_always_included() {
133        let reg = ToolRegistry::new()
134            .register(TestTool {
135                tool_name: "finish_task",
136                desc: "finish",
137                system: true,
138            })
139            .register(TestTool {
140                tool_name: "read_file",
141                desc: "read a file from disk",
142                system: false,
143            })
144            .register(TestTool {
145                tool_name: "bash",
146                desc: "run shell command",
147                system: false,
148            });
149
150        let filter = ToolFilter::new(1);
151        let selected = filter.select("read the file", &reg);
152
153        // System tool always present
154        assert!(selected.iter().any(|t| t.name() == "finish_task"));
155        // Only 1 non-system tool (max_visible=1)
156        let non_sys: Vec<_> = selected.iter().filter(|t| !t.is_system()).collect();
157        assert_eq!(non_sys.len(), 1);
158    }
159
160    #[test]
161    fn relevant_tool_ranked_higher() {
162        let reg = ToolRegistry::new()
163            .register(TestTool {
164                tool_name: "read_file",
165                desc: "read a file from disk",
166                system: false,
167            })
168            .register(TestTool {
169                tool_name: "bash",
170                desc: "run shell command",
171                system: false,
172            })
173            .register(TestTool {
174                tool_name: "write_file",
175                desc: "write content to a file",
176                system: false,
177            });
178
179        let filter = ToolFilter::new(2);
180        let selected = filter.select("read the file main.rs", &reg);
181        // read_file should be first non-system tool
182        assert_eq!(selected[0].name(), "read_file");
183    }
184
185    #[test]
186    fn empty_query_returns_all_up_to_max() {
187        let reg = ToolRegistry::new()
188            .register(TestTool {
189                tool_name: "a",
190                desc: "tool a",
191                system: false,
192            })
193            .register(TestTool {
194                tool_name: "b",
195                desc: "tool b",
196                system: false,
197            })
198            .register(TestTool {
199                tool_name: "c",
200                desc: "tool c",
201                system: false,
202            });
203
204        let filter = ToolFilter::new(2);
205        let selected = filter.select("", &reg);
206        assert_eq!(selected.len(), 2);
207    }
208}