1use std::collections::{HashMap, HashSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::error::{Result, SqzError};
6use crate::preset::Preset;
7use crate::types::ToolId;
8
9#[derive(Debug, Clone, Serialize, Deserialize, Default)]
11pub struct ToolDefinition {
12 pub id: ToolId,
13 pub name: String,
14 pub description: String,
15 #[serde(default)]
17 pub input_schema: serde_json::Value,
18 #[serde(default)]
28 pub output_schema: serde_json::Value,
29 #[serde(default)]
31 pub compression_transforms: Vec<String>,
32}
33
34type BagOfWords = HashSet<String>;
36
37fn tokenize(text: &str) -> BagOfWords {
39 text.split(|c: char| !c.is_alphanumeric())
40 .filter(|s| !s.is_empty())
41 .map(|s| s.to_lowercase())
42 .collect()
43}
44
45fn tokenize_tf(text: &str) -> HashMap<String, u32> {
47 let mut freq = HashMap::new();
48 for word in text.split(|c: char| !c.is_alphanumeric()).filter(|s| !s.is_empty()) {
49 *freq.entry(word.to_lowercase()).or_insert(0) += 1;
50 }
51 freq
52}
53
54fn jaccard(a: &BagOfWords, b: &BagOfWords) -> f64 {
57 if a.is_empty() && b.is_empty() {
58 return 0.0;
59 }
60 let intersection = a.intersection(b).count() as f64;
61 let union = a.union(b).count() as f64;
62 if union == 0.0 {
63 0.0
64 } else {
65 intersection / union
66 }
67}
68
69type TfIdfVector = HashMap<String, f64>;
73
74fn compute_tfidf(
80 term_freq: &HashMap<String, u32>,
81 doc_freq: &HashMap<String, u32>,
82 total_docs: u32,
83) -> TfIdfVector {
84 let doc_len: u32 = term_freq.values().sum();
85 if doc_len == 0 {
86 return HashMap::new();
87 }
88
89 let mut vector = HashMap::new();
90 for (term, &count) in term_freq {
91 let tf = count as f64 / doc_len as f64;
92 let df = doc_freq.get(term).copied().unwrap_or(1).max(1);
93 let idf = (total_docs as f64 / df as f64).ln();
94 let weight = tf * idf;
95 if weight > 0.0 {
96 vector.insert(term.clone(), weight);
97 }
98 }
99 vector
100}
101
102fn cosine_similarity(a: &TfIdfVector, b: &TfIdfVector) -> f64 {
106 if a.is_empty() || b.is_empty() {
107 return 0.0;
108 }
109
110 let dot: f64 = a
111 .iter()
112 .filter_map(|(term, &wa)| b.get(term).map(|&wb| wa * wb))
113 .sum();
114
115 let norm_a: f64 = a.values().map(|w| w * w).sum::<f64>().sqrt();
116 let norm_b: f64 = b.values().map(|w| w * w).sum::<f64>().sqrt();
117
118 if norm_a == 0.0 || norm_b == 0.0 {
119 0.0
120 } else {
121 dot / (norm_a * norm_b)
122 }
123}
124
125pub struct ToolSelector {
137 bags: HashMap<ToolId, BagOfWords>,
139 tfidf_vectors: HashMap<ToolId, TfIdfVector>,
141 doc_freq: HashMap<String, u32>,
143 total_docs: u32,
145 tool_ids: Vec<ToolId>,
147 threshold: f64,
149 default_tools: Vec<ToolId>,
151 term_freqs: HashMap<ToolId, HashMap<String, u32>>,
153}
154
155impl ToolSelector {
156 pub fn new(_model_path: &std::path::Path, preset: &Preset) -> Result<Self> {
160 let threshold = preset.tool_selection.similarity_threshold;
161 let default_tools = preset.tool_selection.default_tools.clone();
162 Ok(Self {
163 bags: HashMap::new(),
164 tfidf_vectors: HashMap::new(),
165 doc_freq: HashMap::new(),
166 total_docs: 0,
167 tool_ids: Vec::new(),
168 threshold,
169 default_tools,
170 term_freqs: HashMap::new(),
171 })
172 }
173
174 pub fn register_tools(&mut self, tools: &[ToolDefinition]) -> Result<()> {
176 for tool in tools {
178 let bag = tokenize(&tool.description);
179 let tf = tokenize_tf(&tool.description);
180
181 if !self.bags.contains_key(&tool.id) {
183 self.tool_ids.push(tool.id.clone());
184 self.total_docs += 1;
185 for term in tf.keys() {
186 *self.doc_freq.entry(term.clone()).or_insert(0) += 1;
187 }
188 } else {
189 if let Some(old_tf) = self.term_freqs.get(&tool.id) {
191 for term in old_tf.keys() {
192 if let Some(count) = self.doc_freq.get_mut(term) {
193 *count = count.saturating_sub(1);
194 }
195 }
196 }
197 for term in tf.keys() {
198 *self.doc_freq.entry(term.clone()).or_insert(0) += 1;
199 }
200 }
201
202 self.bags.insert(tool.id.clone(), bag);
203 self.term_freqs.insert(tool.id.clone(), tf);
204 }
205
206 self.recompute_tfidf();
208 Ok(())
209 }
210
211 fn recompute_tfidf(&mut self) {
213 for id in &self.tool_ids {
214 if let Some(tf) = self.term_freqs.get(id) {
215 let vector = compute_tfidf(tf, &self.doc_freq, self.total_docs);
216 self.tfidf_vectors.insert(id.clone(), vector);
217 }
218 }
219 }
220
221 pub fn select(&self, intent: &str, max_tools: usize) -> Result<Vec<ToolId>> {
228 let tool_count = self.tool_ids.len();
229 if tool_count == 0 {
230 return Ok(self.default_tools.clone());
231 }
232
233 let intent_words: Vec<&str> = intent
234 .split(|c: char| !c.is_alphanumeric())
235 .filter(|s| !s.is_empty())
236 .collect();
237
238 let use_tfidf = intent_words.len() >= 3 && self.total_docs >= 2;
240
241 let mut scored: Vec<(f64, &ToolId)> = if use_tfidf {
242 let intent_tf = tokenize_tf(intent);
243 let intent_vector = compute_tfidf(&intent_tf, &self.doc_freq, self.total_docs);
244
245 self.tool_ids
246 .iter()
247 .map(|id| {
248 let score = self
249 .tfidf_vectors
250 .get(id)
251 .map(|v| cosine_similarity(&intent_vector, v))
252 .unwrap_or(0.0);
253 (score, id)
254 })
255 .collect()
256 } else {
257 let intent_bag = tokenize(intent);
259 self.tool_ids
260 .iter()
261 .map(|id| {
262 let bag = self.bags.get(id).expect("bag must exist for registered tool");
263 let score = jaccard(&intent_bag, bag);
264 (score, id)
265 })
266 .collect()
267 };
268
269 scored.sort_by(|a, b| {
271 b.0.partial_cmp(&a.0)
272 .unwrap_or(std::cmp::Ordering::Equal)
273 .then_with(|| a.1.cmp(b.1))
274 });
275
276 let best_score = scored.first().map(|(s, _)| *s).unwrap_or(0.0);
278 if best_score < self.threshold {
279 return Ok(self.default_tools.clone());
280 }
281
282 let upper = max_tools.min(5).min(tool_count);
284 let lower = 3_usize.min(tool_count);
285 let count = upper.max(lower);
286
287 let result = scored
288 .into_iter()
289 .take(count)
290 .map(|(_, id)| id.clone())
291 .collect();
292
293 Ok(result)
294 }
295
296 pub fn update_tool(&mut self, tool: &ToolDefinition) -> Result<()> {
298 if !self.bags.contains_key(&tool.id) {
299 return Err(SqzError::Other(format!(
300 "tool '{}' is not registered; use register_tools first",
301 tool.id
302 )));
303 }
304
305 if let Some(old_tf) = self.term_freqs.get(&tool.id) {
307 for term in old_tf.keys() {
308 if let Some(count) = self.doc_freq.get_mut(term) {
309 *count = count.saturating_sub(1);
310 }
311 }
312 }
313
314 let bag = tokenize(&tool.description);
315 let tf = tokenize_tf(&tool.description);
316 for term in tf.keys() {
317 *self.doc_freq.entry(term.clone()).or_insert(0) += 1;
318 }
319
320 self.bags.insert(tool.id.clone(), bag);
321 self.term_freqs.insert(tool.id.clone(), tf);
322
323 self.recompute_tfidf();
325 Ok(())
326 }
327}
328
329#[cfg(test)]
334mod tests {
335 use super::*;
336 use proptest::prelude::*;
337 use std::path::Path;
338
339 fn make_preset_with_threshold(threshold: f64, default_tools: Vec<String>) -> Preset {
340 let mut p = Preset::default();
341 p.tool_selection.similarity_threshold = threshold;
342 p.tool_selection.default_tools = default_tools;
343 p
344 }
345
346 fn make_tools(n: usize) -> Vec<ToolDefinition> {
347 (0..n)
348 .map(|i| ToolDefinition {
349 id: format!("tool_{i}"),
350 name: format!("Tool {i}"),
351 description: format!(
352 "This tool performs operation number {i} for task category alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu nu xi omicron pi rho sigma tau upsilon phi chi psi omega {i}"
353 ),
354 ..Default::default()
355 })
356 .collect()
357 }
358
359 #[test]
364 fn test_tokenize_basic() {
365 let bag = tokenize("hello world foo");
366 assert!(bag.contains("hello"));
367 assert!(bag.contains("world"));
368 assert!(bag.contains("foo"));
369 }
370
371 #[test]
372 fn test_tokenize_punctuation() {
373 let bag = tokenize("read_file: reads a file.");
374 assert!(bag.contains("read"));
375 assert!(bag.contains("file"));
376 assert!(bag.contains("reads"));
377 assert!(bag.contains("a"));
378 }
379
380 #[test]
381 fn test_jaccard_identical() {
382 let a = tokenize("read file");
383 let b = tokenize("read file");
384 assert!((jaccard(&a, &b) - 1.0).abs() < 1e-9);
385 }
386
387 #[test]
388 fn test_jaccard_disjoint() {
389 let a = tokenize("alpha beta");
390 let b = tokenize("gamma delta");
391 assert!((jaccard(&a, &b)).abs() < 1e-9);
392 }
393
394 #[test]
395 fn test_select_returns_between_3_and_5_for_large_set() {
396 let preset = make_preset_with_threshold(0.0, vec![]);
397 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
398 let tools = make_tools(10);
399 selector.register_tools(&tools).unwrap();
400
401 let result = selector.select("operation task alpha beta", 5).unwrap();
402 assert!(result.len() >= 3, "expected >= 3, got {}", result.len());
403 assert!(result.len() <= 5, "expected <= 5, got {}", result.len());
404 }
405
406 #[test]
407 fn test_select_returns_at_most_tool_count_for_small_set() {
408 let preset = make_preset_with_threshold(0.0, vec![]);
409 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
410 let tools = make_tools(2);
411 selector.register_tools(&tools).unwrap();
412
413 let result = selector.select("operation task", 5).unwrap();
414 assert!(result.len() <= 2, "expected <= 2, got {}", result.len());
415 }
416
417 #[test]
418 fn test_fallback_to_defaults_on_low_confidence() {
419 let defaults = vec!["default_a".to_string(), "default_b".to_string()];
420 let preset = make_preset_with_threshold(1.0, defaults.clone());
422 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
423 let tools = make_tools(5);
424 selector.register_tools(&tools).unwrap();
425
426 let result = selector.select("completely unrelated xyz", 5).unwrap();
427 assert_eq!(result, defaults);
428 }
429
430 #[test]
431 fn test_update_tool_changes_embedding() {
432 let preset = make_preset_with_threshold(0.0, vec![]);
433 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
434 let tools = vec![ToolDefinition {
435 id: "t1".to_string(),
436 name: "T1".to_string(),
437 description: "alpha beta gamma".to_string(),
438 ..Default::default()
439 }];
440 selector.register_tools(&tools).unwrap();
441
442 let updated = ToolDefinition {
443 id: "t1".to_string(),
444 name: "T1".to_string(),
445 description: "delta epsilon zeta".to_string(),
446 ..Default::default()
447 };
448 selector.update_tool(&updated).unwrap();
449
450 let bag = selector.bags.get("t1").unwrap();
451 assert!(bag.contains("delta"));
452 assert!(!bag.contains("alpha"));
453 }
454
455 #[test]
458 fn test_tfidf_discriminative_ranking() {
459 let preset = make_preset_with_threshold(0.0, vec![]);
462 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
463
464 let tools = vec![
465 ToolDefinition {
466 id: "generic".to_string(),
467 name: "Generic".to_string(),
468 description: "this tool performs common operations on files and data".to_string(),
469 ..Default::default()
470 },
471 ToolDefinition {
472 id: "specific".to_string(),
473 name: "Specific".to_string(),
474 description: "this tool performs kubernetes pod deployment orchestration".to_string(),
475 ..Default::default()
476 },
477 ToolDefinition {
478 id: "other".to_string(),
479 name: "Other".to_string(),
480 description: "this tool handles database migration and schema updates".to_string(),
481 ..Default::default()
482 },
483 ];
484 selector.register_tools(&tools).unwrap();
485
486 let result = selector
489 .select("deploy kubernetes pods to the cluster", 5)
490 .unwrap();
491 assert_eq!(
492 result[0], "specific",
493 "TF-IDF should rank the tool with rare matching terms first"
494 );
495 }
496
497 #[test]
498 fn test_cosine_similarity_identical() {
499 let mut a = HashMap::new();
500 a.insert("hello".to_string(), 1.0);
501 a.insert("world".to_string(), 2.0);
502 let sim = cosine_similarity(&a, &a);
503 assert!((sim - 1.0).abs() < 1e-9);
504 }
505
506 #[test]
507 fn test_cosine_similarity_orthogonal() {
508 let mut a = HashMap::new();
509 a.insert("hello".to_string(), 1.0);
510 let mut b = HashMap::new();
511 b.insert("world".to_string(), 1.0);
512 let sim = cosine_similarity(&a, &b);
513 assert!(sim.abs() < 1e-9);
514 }
515
516 #[test]
517 fn test_cosine_similarity_empty() {
518 let a: TfIdfVector = HashMap::new();
519 let b: TfIdfVector = HashMap::new();
520 assert_eq!(cosine_similarity(&a, &b), 0.0);
521 }
522
523 #[test]
524 fn test_tfidf_vectors_populated() {
525 let preset = make_preset_with_threshold(0.0, vec![]);
526 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
527 let tools = make_tools(5);
528 selector.register_tools(&tools).unwrap();
529 assert_eq!(selector.tfidf_vectors.len(), 5);
530 assert_eq!(selector.total_docs, 5);
531 }
532
533 #[test]
534 fn test_update_tool_unregistered_returns_error() {
535 let preset = make_preset_with_threshold(0.0, vec![]);
536 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
537 let result = selector.update_tool(&ToolDefinition {
538 id: "nonexistent".to_string(),
539 name: "X".to_string(),
540 description: "desc".to_string(),
541 ..Default::default()
542 });
543 assert!(result.is_err());
544 }
545
546 #[test]
547 fn test_empty_tool_set_returns_defaults() {
548 let defaults = vec!["fallback".to_string()];
549 let preset = make_preset_with_threshold(0.0, defaults.clone());
550 let selector = ToolSelector::new(Path::new(""), &preset).unwrap();
551 let result = selector.select("anything", 5).unwrap();
552 assert_eq!(result, defaults);
553 }
554
555 fn arb_tool_count_and_intent() -> impl Strategy<Value = (usize, String)> {
562 (5usize..=20usize, "[a-z ]{5,40}".prop_map(|s| s.trim().to_string()))
563 }
564
565 fn arb_small_tool_count_and_intent() -> impl Strategy<Value = (usize, String)> {
567 (1usize..=4usize, "[a-z ]{5,40}".prop_map(|s| s.trim().to_string()))
568 }
569
570 proptest! {
571 #[test]
581 fn prop_tool_selection_cardinality_large(
582 (tool_count, intent) in arb_tool_count_and_intent()
583 ) {
584 let preset = make_preset_with_threshold(0.0, vec![]);
586 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
587 let tools = make_tools(tool_count);
588 selector.register_tools(&tools).unwrap();
589
590 let result = selector.select(&intent, 5).unwrap();
591
592 prop_assert!(
593 result.len() >= 3,
594 "expected >= 3 tools, got {} (tool_count={}, intent='{}')",
595 result.len(), tool_count, intent
596 );
597 prop_assert!(
598 result.len() <= 5,
599 "expected <= 5 tools, got {} (tool_count={}, intent='{}')",
600 result.len(), tool_count, intent
601 );
602 }
603
604 #[test]
609 fn prop_tool_selection_cardinality_small(
610 (tool_count, intent) in arb_small_tool_count_and_intent()
611 ) {
612 let preset = make_preset_with_threshold(0.0, vec![]);
613 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
614 let tools = make_tools(tool_count);
615 selector.register_tools(&tools).unwrap();
616
617 let result = selector.select(&intent, 5).unwrap();
618
619 prop_assert!(
620 result.len() <= tool_count,
621 "expected <= {} tools, got {} (intent='{}')",
622 tool_count, result.len(), intent
623 );
624 }
625 }
626}