1use std::collections::{HashMap, HashSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::error::{Result, SqzError};
6use crate::preset::Preset;
7use crate::types::ToolId;
8
9#[derive(Debug, Clone, Serialize, Deserialize, Default)]
11pub struct ToolDefinition {
12 pub id: ToolId,
13 pub name: String,
14 pub description: String,
15 #[serde(default)]
17 pub input_schema: serde_json::Value,
18 #[serde(default)]
20 pub output_schema: serde_json::Value,
21 #[serde(default)]
23 pub compression_transforms: Vec<String>,
24}
25
26type BagOfWords = HashSet<String>;
28
29fn tokenize(text: &str) -> BagOfWords {
31 text.split(|c: char| !c.is_alphanumeric())
32 .filter(|s| !s.is_empty())
33 .map(|s| s.to_lowercase())
34 .collect()
35}
36
37fn jaccard(a: &BagOfWords, b: &BagOfWords) -> f64 {
40 if a.is_empty() && b.is_empty() {
41 return 0.0;
42 }
43 let intersection = a.intersection(b).count() as f64;
44 let union = a.union(b).count() as f64;
45 if union == 0.0 {
46 0.0
47 } else {
48 intersection / union
49 }
50}
51
52pub struct ToolSelector {
60 bags: HashMap<ToolId, BagOfWords>,
62 tool_ids: Vec<ToolId>,
64 threshold: f64,
66 default_tools: Vec<ToolId>,
68}
69
70impl ToolSelector {
71 pub fn new(_model_path: &std::path::Path, preset: &Preset) -> Result<Self> {
76 let threshold = preset.tool_selection.similarity_threshold;
77 let default_tools = preset.tool_selection.default_tools.clone();
78 Ok(Self {
79 bags: HashMap::new(),
80 tool_ids: Vec::new(),
81 threshold,
82 default_tools,
83 })
84 }
85
86 pub fn register_tools(&mut self, tools: &[ToolDefinition]) -> Result<()> {
88 for tool in tools {
89 let bag = tokenize(&tool.description);
90 if !self.bags.contains_key(&tool.id) {
91 self.tool_ids.push(tool.id.clone());
92 }
93 self.bags.insert(tool.id.clone(), bag);
94 }
95 Ok(())
96 }
97
98 pub fn select(&self, intent: &str, max_tools: usize) -> Result<Vec<ToolId>> {
102 let tool_count = self.tool_ids.len();
103 if tool_count == 0 {
104 return Ok(self.default_tools.clone());
105 }
106
107 let intent_bag = tokenize(intent);
108
109 let mut scored: Vec<(f64, &ToolId)> = self
111 .tool_ids
112 .iter()
113 .map(|id| {
114 let bag = self.bags.get(id).expect("bag must exist for registered tool");
115 let score = jaccard(&intent_bag, bag);
116 (score, id)
117 })
118 .collect();
119
120 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
122 .then_with(|| a.1.cmp(b.1)));
123
124 let best_score = scored.first().map(|(s, _)| *s).unwrap_or(0.0);
128 if best_score < self.threshold {
129 return Ok(self.default_tools.clone());
130 }
131
132 let upper = max_tools.min(5).min(tool_count);
134 let lower = 3_usize.min(tool_count);
135 let count = upper.max(lower);
136
137 let result = scored
138 .into_iter()
139 .take(count)
140 .map(|(_, id)| id.clone())
141 .collect();
142
143 Ok(result)
144 }
145
146 pub fn update_tool(&mut self, tool: &ToolDefinition) -> Result<()> {
148 if !self.bags.contains_key(&tool.id) {
149 return Err(SqzError::Other(format!(
150 "tool '{}' is not registered; use register_tools first",
151 tool.id
152 )));
153 }
154 let bag = tokenize(&tool.description);
155 self.bags.insert(tool.id.clone(), bag);
156 Ok(())
157 }
158}
159
160#[cfg(test)]
165mod tests {
166 use super::*;
167 use proptest::prelude::*;
168 use std::path::Path;
169
170 fn make_preset_with_threshold(threshold: f64, default_tools: Vec<String>) -> Preset {
171 let mut p = Preset::default();
172 p.tool_selection.similarity_threshold = threshold;
173 p.tool_selection.default_tools = default_tools;
174 p
175 }
176
177 fn make_tools(n: usize) -> Vec<ToolDefinition> {
178 (0..n)
179 .map(|i| ToolDefinition {
180 id: format!("tool_{i}"),
181 name: format!("Tool {i}"),
182 description: format!(
183 "This tool performs operation number {i} for task category alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu nu xi omicron pi rho sigma tau upsilon phi chi psi omega {i}"
184 ),
185 ..Default::default()
186 })
187 .collect()
188 }
189
190 #[test]
195 fn test_tokenize_basic() {
196 let bag = tokenize("hello world foo");
197 assert!(bag.contains("hello"));
198 assert!(bag.contains("world"));
199 assert!(bag.contains("foo"));
200 }
201
202 #[test]
203 fn test_tokenize_punctuation() {
204 let bag = tokenize("read_file: reads a file.");
205 assert!(bag.contains("read"));
206 assert!(bag.contains("file"));
207 assert!(bag.contains("reads"));
208 assert!(bag.contains("a"));
209 }
210
211 #[test]
212 fn test_jaccard_identical() {
213 let a = tokenize("read file");
214 let b = tokenize("read file");
215 assert!((jaccard(&a, &b) - 1.0).abs() < 1e-9);
216 }
217
218 #[test]
219 fn test_jaccard_disjoint() {
220 let a = tokenize("alpha beta");
221 let b = tokenize("gamma delta");
222 assert!((jaccard(&a, &b)).abs() < 1e-9);
223 }
224
225 #[test]
226 fn test_select_returns_between_3_and_5_for_large_set() {
227 let preset = make_preset_with_threshold(0.0, vec![]);
228 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
229 let tools = make_tools(10);
230 selector.register_tools(&tools).unwrap();
231
232 let result = selector.select("operation task alpha beta", 5).unwrap();
233 assert!(result.len() >= 3, "expected >= 3, got {}", result.len());
234 assert!(result.len() <= 5, "expected <= 5, got {}", result.len());
235 }
236
237 #[test]
238 fn test_select_returns_at_most_tool_count_for_small_set() {
239 let preset = make_preset_with_threshold(0.0, vec![]);
240 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
241 let tools = make_tools(2);
242 selector.register_tools(&tools).unwrap();
243
244 let result = selector.select("operation task", 5).unwrap();
245 assert!(result.len() <= 2, "expected <= 2, got {}", result.len());
246 }
247
248 #[test]
249 fn test_fallback_to_defaults_on_low_confidence() {
250 let defaults = vec!["default_a".to_string(), "default_b".to_string()];
251 let preset = make_preset_with_threshold(1.0, defaults.clone());
253 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
254 let tools = make_tools(5);
255 selector.register_tools(&tools).unwrap();
256
257 let result = selector.select("completely unrelated xyz", 5).unwrap();
258 assert_eq!(result, defaults);
259 }
260
261 #[test]
262 fn test_update_tool_changes_embedding() {
263 let preset = make_preset_with_threshold(0.0, vec![]);
264 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
265 let tools = vec![ToolDefinition {
266 id: "t1".to_string(),
267 name: "T1".to_string(),
268 description: "alpha beta gamma".to_string(),
269 ..Default::default()
270 }];
271 selector.register_tools(&tools).unwrap();
272
273 let updated = ToolDefinition {
274 id: "t1".to_string(),
275 name: "T1".to_string(),
276 description: "delta epsilon zeta".to_string(),
277 ..Default::default()
278 };
279 selector.update_tool(&updated).unwrap();
280
281 let bag = selector.bags.get("t1").unwrap();
282 assert!(bag.contains("delta"));
283 assert!(!bag.contains("alpha"));
284 }
285
286 #[test]
287 fn test_update_tool_unregistered_returns_error() {
288 let preset = make_preset_with_threshold(0.0, vec![]);
289 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
290 let result = selector.update_tool(&ToolDefinition {
291 id: "nonexistent".to_string(),
292 name: "X".to_string(),
293 description: "desc".to_string(),
294 ..Default::default()
295 });
296 assert!(result.is_err());
297 }
298
299 #[test]
300 fn test_empty_tool_set_returns_defaults() {
301 let defaults = vec!["fallback".to_string()];
302 let preset = make_preset_with_threshold(0.0, defaults.clone());
303 let selector = ToolSelector::new(Path::new(""), &preset).unwrap();
304 let result = selector.select("anything", 5).unwrap();
305 assert_eq!(result, defaults);
306 }
307
308 fn arb_tool_count_and_intent() -> impl Strategy<Value = (usize, String)> {
315 (5usize..=20usize, "[a-z ]{5,40}".prop_map(|s| s.trim().to_string()))
316 }
317
318 fn arb_small_tool_count_and_intent() -> impl Strategy<Value = (usize, String)> {
320 (1usize..=4usize, "[a-z ]{5,40}".prop_map(|s| s.trim().to_string()))
321 }
322
323 proptest! {
324 #[test]
334 fn prop_tool_selection_cardinality_large(
335 (tool_count, intent) in arb_tool_count_and_intent()
336 ) {
337 let preset = make_preset_with_threshold(0.0, vec![]);
339 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
340 let tools = make_tools(tool_count);
341 selector.register_tools(&tools).unwrap();
342
343 let result = selector.select(&intent, 5).unwrap();
344
345 prop_assert!(
346 result.len() >= 3,
347 "expected >= 3 tools, got {} (tool_count={}, intent='{}')",
348 result.len(), tool_count, intent
349 );
350 prop_assert!(
351 result.len() <= 5,
352 "expected <= 5 tools, got {} (tool_count={}, intent='{}')",
353 result.len(), tool_count, intent
354 );
355 }
356
357 #[test]
362 fn prop_tool_selection_cardinality_small(
363 (tool_count, intent) in arb_small_tool_count_and_intent()
364 ) {
365 let preset = make_preset_with_threshold(0.0, vec![]);
366 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
367 let tools = make_tools(tool_count);
368 selector.register_tools(&tools).unwrap();
369
370 let result = selector.select(&intent, 5).unwrap();
371
372 prop_assert!(
373 result.len() <= tool_count,
374 "expected <= {} tools, got {} (intent='{}')",
375 tool_count, result.len(), intent
376 );
377 }
378 }
379}