1use std::collections::{HashMap, HashSet};
2
3use crate::error::{Result, SqzError};
4use crate::preset::Preset;
5use crate::types::ToolId;
6
7#[derive(Debug, Clone)]
9pub struct ToolDefinition {
10 pub id: ToolId,
11 pub name: String,
12 pub description: String,
13}
14
15type BagOfWords = HashSet<String>;
17
18fn tokenize(text: &str) -> BagOfWords {
20 text.split(|c: char| !c.is_alphanumeric())
21 .filter(|s| !s.is_empty())
22 .map(|s| s.to_lowercase())
23 .collect()
24}
25
26fn jaccard(a: &BagOfWords, b: &BagOfWords) -> f64 {
29 if a.is_empty() && b.is_empty() {
30 return 0.0;
31 }
32 let intersection = a.intersection(b).count() as f64;
33 let union = a.union(b).count() as f64;
34 if union == 0.0 {
35 0.0
36 } else {
37 intersection / union
38 }
39}
40
41pub struct ToolSelector {
49 bags: HashMap<ToolId, BagOfWords>,
51 tool_ids: Vec<ToolId>,
53 threshold: f64,
55 default_tools: Vec<ToolId>,
57}
58
59impl ToolSelector {
60 pub fn new(_model_path: &std::path::Path, preset: &Preset) -> Result<Self> {
65 let threshold = preset.tool_selection.similarity_threshold;
66 let default_tools = preset.tool_selection.default_tools.clone();
67 Ok(Self {
68 bags: HashMap::new(),
69 tool_ids: Vec::new(),
70 threshold,
71 default_tools,
72 })
73 }
74
75 pub fn register_tools(&mut self, tools: &[ToolDefinition]) -> Result<()> {
77 for tool in tools {
78 let bag = tokenize(&tool.description);
79 if !self.bags.contains_key(&tool.id) {
80 self.tool_ids.push(tool.id.clone());
81 }
82 self.bags.insert(tool.id.clone(), bag);
83 }
84 Ok(())
85 }
86
87 pub fn select(&self, intent: &str, max_tools: usize) -> Result<Vec<ToolId>> {
91 let tool_count = self.tool_ids.len();
92 if tool_count == 0 {
93 return Ok(self.default_tools.clone());
94 }
95
96 let intent_bag = tokenize(intent);
97
98 let mut scored: Vec<(f64, &ToolId)> = self
100 .tool_ids
101 .iter()
102 .map(|id| {
103 let bag = self.bags.get(id).expect("bag must exist for registered tool");
104 let score = jaccard(&intent_bag, bag);
105 (score, id)
106 })
107 .collect();
108
109 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
111 .then_with(|| a.1.cmp(b.1)));
112
113 let best_score = scored.first().map(|(s, _)| *s).unwrap_or(0.0);
117 if best_score < self.threshold {
118 return Ok(self.default_tools.clone());
119 }
120
121 let upper = max_tools.min(5).min(tool_count);
123 let lower = 3_usize.min(tool_count);
124 let count = upper.max(lower);
125
126 let result = scored
127 .into_iter()
128 .take(count)
129 .map(|(_, id)| id.clone())
130 .collect();
131
132 Ok(result)
133 }
134
135 pub fn update_tool(&mut self, tool: &ToolDefinition) -> Result<()> {
137 if !self.bags.contains_key(&tool.id) {
138 return Err(SqzError::Other(format!(
139 "tool '{}' is not registered; use register_tools first",
140 tool.id
141 )));
142 }
143 let bag = tokenize(&tool.description);
144 self.bags.insert(tool.id.clone(), bag);
145 Ok(())
146 }
147}
148
149#[cfg(test)]
154mod tests {
155 use super::*;
156 use proptest::prelude::*;
157 use std::path::Path;
158
159 fn make_preset_with_threshold(threshold: f64, default_tools: Vec<String>) -> Preset {
160 let mut p = Preset::default();
161 p.tool_selection.similarity_threshold = threshold;
162 p.tool_selection.default_tools = default_tools;
163 p
164 }
165
166 fn make_tools(n: usize) -> Vec<ToolDefinition> {
167 (0..n)
168 .map(|i| ToolDefinition {
169 id: format!("tool_{i}"),
170 name: format!("Tool {i}"),
171 description: format!(
172 "This tool performs operation number {i} for task category alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu nu xi omicron pi rho sigma tau upsilon phi chi psi omega {i}"
173 ),
174 })
175 .collect()
176 }
177
178 #[test]
183 fn test_tokenize_basic() {
184 let bag = tokenize("hello world foo");
185 assert!(bag.contains("hello"));
186 assert!(bag.contains("world"));
187 assert!(bag.contains("foo"));
188 }
189
190 #[test]
191 fn test_tokenize_punctuation() {
192 let bag = tokenize("read_file: reads a file.");
193 assert!(bag.contains("read"));
194 assert!(bag.contains("file"));
195 assert!(bag.contains("reads"));
196 assert!(bag.contains("a"));
197 }
198
199 #[test]
200 fn test_jaccard_identical() {
201 let a = tokenize("read file");
202 let b = tokenize("read file");
203 assert!((jaccard(&a, &b) - 1.0).abs() < 1e-9);
204 }
205
206 #[test]
207 fn test_jaccard_disjoint() {
208 let a = tokenize("alpha beta");
209 let b = tokenize("gamma delta");
210 assert!((jaccard(&a, &b)).abs() < 1e-9);
211 }
212
213 #[test]
214 fn test_select_returns_between_3_and_5_for_large_set() {
215 let preset = make_preset_with_threshold(0.0, vec![]);
216 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
217 let tools = make_tools(10);
218 selector.register_tools(&tools).unwrap();
219
220 let result = selector.select("operation task alpha beta", 5).unwrap();
221 assert!(result.len() >= 3, "expected >= 3, got {}", result.len());
222 assert!(result.len() <= 5, "expected <= 5, got {}", result.len());
223 }
224
225 #[test]
226 fn test_select_returns_at_most_tool_count_for_small_set() {
227 let preset = make_preset_with_threshold(0.0, vec![]);
228 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
229 let tools = make_tools(2);
230 selector.register_tools(&tools).unwrap();
231
232 let result = selector.select("operation task", 5).unwrap();
233 assert!(result.len() <= 2, "expected <= 2, got {}", result.len());
234 }
235
236 #[test]
237 fn test_fallback_to_defaults_on_low_confidence() {
238 let defaults = vec!["default_a".to_string(), "default_b".to_string()];
239 let preset = make_preset_with_threshold(1.0, defaults.clone());
241 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
242 let tools = make_tools(5);
243 selector.register_tools(&tools).unwrap();
244
245 let result = selector.select("completely unrelated xyz", 5).unwrap();
246 assert_eq!(result, defaults);
247 }
248
249 #[test]
250 fn test_update_tool_changes_embedding() {
251 let preset = make_preset_with_threshold(0.0, vec![]);
252 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
253 let tools = vec![ToolDefinition {
254 id: "t1".to_string(),
255 name: "T1".to_string(),
256 description: "alpha beta gamma".to_string(),
257 }];
258 selector.register_tools(&tools).unwrap();
259
260 let updated = ToolDefinition {
261 id: "t1".to_string(),
262 name: "T1".to_string(),
263 description: "delta epsilon zeta".to_string(),
264 };
265 selector.update_tool(&updated).unwrap();
266
267 let bag = selector.bags.get("t1").unwrap();
268 assert!(bag.contains("delta"));
269 assert!(!bag.contains("alpha"));
270 }
271
272 #[test]
273 fn test_update_tool_unregistered_returns_error() {
274 let preset = make_preset_with_threshold(0.0, vec![]);
275 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
276 let result = selector.update_tool(&ToolDefinition {
277 id: "nonexistent".to_string(),
278 name: "X".to_string(),
279 description: "desc".to_string(),
280 });
281 assert!(result.is_err());
282 }
283
284 #[test]
285 fn test_empty_tool_set_returns_defaults() {
286 let defaults = vec!["fallback".to_string()];
287 let preset = make_preset_with_threshold(0.0, defaults.clone());
288 let selector = ToolSelector::new(Path::new(""), &preset).unwrap();
289 let result = selector.select("anything", 5).unwrap();
290 assert_eq!(result, defaults);
291 }
292
293 fn arb_tool_count_and_intent() -> impl Strategy<Value = (usize, String)> {
300 (5usize..=20usize, "[a-z ]{5,40}".prop_map(|s| s.trim().to_string()))
301 }
302
303 fn arb_small_tool_count_and_intent() -> impl Strategy<Value = (usize, String)> {
305 (1usize..=4usize, "[a-z ]{5,40}".prop_map(|s| s.trim().to_string()))
306 }
307
308 proptest! {
309 #[test]
319 fn prop_tool_selection_cardinality_large(
320 (tool_count, intent) in arb_tool_count_and_intent()
321 ) {
322 let preset = make_preset_with_threshold(0.0, vec![]);
324 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
325 let tools = make_tools(tool_count);
326 selector.register_tools(&tools).unwrap();
327
328 let result = selector.select(&intent, 5).unwrap();
329
330 prop_assert!(
331 result.len() >= 3,
332 "expected >= 3 tools, got {} (tool_count={}, intent='{}')",
333 result.len(), tool_count, intent
334 );
335 prop_assert!(
336 result.len() <= 5,
337 "expected <= 5 tools, got {} (tool_count={}, intent='{}')",
338 result.len(), tool_count, intent
339 );
340 }
341
342 #[test]
347 fn prop_tool_selection_cardinality_small(
348 (tool_count, intent) in arb_small_tool_count_and_intent()
349 ) {
350 let preset = make_preset_with_threshold(0.0, vec![]);
351 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
352 let tools = make_tools(tool_count);
353 selector.register_tools(&tools).unwrap();
354
355 let result = selector.select(&intent, 5).unwrap();
356
357 prop_assert!(
358 result.len() <= tool_count,
359 "expected <= {} tools, got {} (intent='{}')",
360 tool_count, result.len(), intent
361 );
362 }
363 }
364}