1use std::collections::{HashMap, HashSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::error::{Result, SqzError};
6use crate::preset::Preset;
7use crate::types::ToolId;
8
9#[derive(Debug, Clone, Serialize, Deserialize, Default)]
11pub struct ToolDefinition {
12 pub id: ToolId,
13 pub name: String,
14 pub description: String,
15 #[serde(default)]
17 pub input_schema: serde_json::Value,
18 #[serde(default)]
20 pub output_schema: serde_json::Value,
21 #[serde(default)]
23 pub compression_transforms: Vec<String>,
24}
25
26type BagOfWords = HashSet<String>;
28
29fn tokenize(text: &str) -> BagOfWords {
31 text.split(|c: char| !c.is_alphanumeric())
32 .filter(|s| !s.is_empty())
33 .map(|s| s.to_lowercase())
34 .collect()
35}
36
37fn tokenize_tf(text: &str) -> HashMap<String, u32> {
39 let mut freq = HashMap::new();
40 for word in text.split(|c: char| !c.is_alphanumeric()).filter(|s| !s.is_empty()) {
41 *freq.entry(word.to_lowercase()).or_insert(0) += 1;
42 }
43 freq
44}
45
46fn jaccard(a: &BagOfWords, b: &BagOfWords) -> f64 {
49 if a.is_empty() && b.is_empty() {
50 return 0.0;
51 }
52 let intersection = a.intersection(b).count() as f64;
53 let union = a.union(b).count() as f64;
54 if union == 0.0 {
55 0.0
56 } else {
57 intersection / union
58 }
59}
60
61type TfIdfVector = HashMap<String, f64>;
65
66fn compute_tfidf(
72 term_freq: &HashMap<String, u32>,
73 doc_freq: &HashMap<String, u32>,
74 total_docs: u32,
75) -> TfIdfVector {
76 let doc_len: u32 = term_freq.values().sum();
77 if doc_len == 0 {
78 return HashMap::new();
79 }
80
81 let mut vector = HashMap::new();
82 for (term, &count) in term_freq {
83 let tf = count as f64 / doc_len as f64;
84 let df = doc_freq.get(term).copied().unwrap_or(1).max(1);
85 let idf = (total_docs as f64 / df as f64).ln();
86 let weight = tf * idf;
87 if weight > 0.0 {
88 vector.insert(term.clone(), weight);
89 }
90 }
91 vector
92}
93
94fn cosine_similarity(a: &TfIdfVector, b: &TfIdfVector) -> f64 {
98 if a.is_empty() || b.is_empty() {
99 return 0.0;
100 }
101
102 let dot: f64 = a
103 .iter()
104 .filter_map(|(term, &wa)| b.get(term).map(|&wb| wa * wb))
105 .sum();
106
107 let norm_a: f64 = a.values().map(|w| w * w).sum::<f64>().sqrt();
108 let norm_b: f64 = b.values().map(|w| w * w).sum::<f64>().sqrt();
109
110 if norm_a == 0.0 || norm_b == 0.0 {
111 0.0
112 } else {
113 dot / (norm_a * norm_b)
114 }
115}
116
117pub struct ToolSelector {
129 bags: HashMap<ToolId, BagOfWords>,
131 tfidf_vectors: HashMap<ToolId, TfIdfVector>,
133 doc_freq: HashMap<String, u32>,
135 total_docs: u32,
137 tool_ids: Vec<ToolId>,
139 threshold: f64,
141 default_tools: Vec<ToolId>,
143 term_freqs: HashMap<ToolId, HashMap<String, u32>>,
145}
146
147impl ToolSelector {
148 pub fn new(_model_path: &std::path::Path, preset: &Preset) -> Result<Self> {
152 let threshold = preset.tool_selection.similarity_threshold;
153 let default_tools = preset.tool_selection.default_tools.clone();
154 Ok(Self {
155 bags: HashMap::new(),
156 tfidf_vectors: HashMap::new(),
157 doc_freq: HashMap::new(),
158 total_docs: 0,
159 tool_ids: Vec::new(),
160 threshold,
161 default_tools,
162 term_freqs: HashMap::new(),
163 })
164 }
165
166 pub fn register_tools(&mut self, tools: &[ToolDefinition]) -> Result<()> {
168 for tool in tools {
170 let bag = tokenize(&tool.description);
171 let tf = tokenize_tf(&tool.description);
172
173 if !self.bags.contains_key(&tool.id) {
175 self.tool_ids.push(tool.id.clone());
176 self.total_docs += 1;
177 for term in tf.keys() {
178 *self.doc_freq.entry(term.clone()).or_insert(0) += 1;
179 }
180 } else {
181 if let Some(old_tf) = self.term_freqs.get(&tool.id) {
183 for term in old_tf.keys() {
184 if let Some(count) = self.doc_freq.get_mut(term) {
185 *count = count.saturating_sub(1);
186 }
187 }
188 }
189 for term in tf.keys() {
190 *self.doc_freq.entry(term.clone()).or_insert(0) += 1;
191 }
192 }
193
194 self.bags.insert(tool.id.clone(), bag);
195 self.term_freqs.insert(tool.id.clone(), tf);
196 }
197
198 self.recompute_tfidf();
200 Ok(())
201 }
202
203 fn recompute_tfidf(&mut self) {
205 for id in &self.tool_ids {
206 if let Some(tf) = self.term_freqs.get(id) {
207 let vector = compute_tfidf(tf, &self.doc_freq, self.total_docs);
208 self.tfidf_vectors.insert(id.clone(), vector);
209 }
210 }
211 }
212
213 pub fn select(&self, intent: &str, max_tools: usize) -> Result<Vec<ToolId>> {
220 let tool_count = self.tool_ids.len();
221 if tool_count == 0 {
222 return Ok(self.default_tools.clone());
223 }
224
225 let intent_words: Vec<&str> = intent
226 .split(|c: char| !c.is_alphanumeric())
227 .filter(|s| !s.is_empty())
228 .collect();
229
230 let use_tfidf = intent_words.len() >= 3 && self.total_docs >= 2;
232
233 let mut scored: Vec<(f64, &ToolId)> = if use_tfidf {
234 let intent_tf = tokenize_tf(intent);
235 let intent_vector = compute_tfidf(&intent_tf, &self.doc_freq, self.total_docs);
236
237 self.tool_ids
238 .iter()
239 .map(|id| {
240 let score = self
241 .tfidf_vectors
242 .get(id)
243 .map(|v| cosine_similarity(&intent_vector, v))
244 .unwrap_or(0.0);
245 (score, id)
246 })
247 .collect()
248 } else {
249 let intent_bag = tokenize(intent);
251 self.tool_ids
252 .iter()
253 .map(|id| {
254 let bag = self.bags.get(id).expect("bag must exist for registered tool");
255 let score = jaccard(&intent_bag, bag);
256 (score, id)
257 })
258 .collect()
259 };
260
261 scored.sort_by(|a, b| {
263 b.0.partial_cmp(&a.0)
264 .unwrap_or(std::cmp::Ordering::Equal)
265 .then_with(|| a.1.cmp(b.1))
266 });
267
268 let best_score = scored.first().map(|(s, _)| *s).unwrap_or(0.0);
270 if best_score < self.threshold {
271 return Ok(self.default_tools.clone());
272 }
273
274 let upper = max_tools.min(5).min(tool_count);
276 let lower = 3_usize.min(tool_count);
277 let count = upper.max(lower);
278
279 let result = scored
280 .into_iter()
281 .take(count)
282 .map(|(_, id)| id.clone())
283 .collect();
284
285 Ok(result)
286 }
287
288 pub fn update_tool(&mut self, tool: &ToolDefinition) -> Result<()> {
290 if !self.bags.contains_key(&tool.id) {
291 return Err(SqzError::Other(format!(
292 "tool '{}' is not registered; use register_tools first",
293 tool.id
294 )));
295 }
296
297 if let Some(old_tf) = self.term_freqs.get(&tool.id) {
299 for term in old_tf.keys() {
300 if let Some(count) = self.doc_freq.get_mut(term) {
301 *count = count.saturating_sub(1);
302 }
303 }
304 }
305
306 let bag = tokenize(&tool.description);
307 let tf = tokenize_tf(&tool.description);
308 for term in tf.keys() {
309 *self.doc_freq.entry(term.clone()).or_insert(0) += 1;
310 }
311
312 self.bags.insert(tool.id.clone(), bag);
313 self.term_freqs.insert(tool.id.clone(), tf);
314
315 self.recompute_tfidf();
317 Ok(())
318 }
319}
320
321#[cfg(test)]
326mod tests {
327 use super::*;
328 use proptest::prelude::*;
329 use std::path::Path;
330
331 fn make_preset_with_threshold(threshold: f64, default_tools: Vec<String>) -> Preset {
332 let mut p = Preset::default();
333 p.tool_selection.similarity_threshold = threshold;
334 p.tool_selection.default_tools = default_tools;
335 p
336 }
337
338 fn make_tools(n: usize) -> Vec<ToolDefinition> {
339 (0..n)
340 .map(|i| ToolDefinition {
341 id: format!("tool_{i}"),
342 name: format!("Tool {i}"),
343 description: format!(
344 "This tool performs operation number {i} for task category alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu nu xi omicron pi rho sigma tau upsilon phi chi psi omega {i}"
345 ),
346 ..Default::default()
347 })
348 .collect()
349 }
350
351 #[test]
356 fn test_tokenize_basic() {
357 let bag = tokenize("hello world foo");
358 assert!(bag.contains("hello"));
359 assert!(bag.contains("world"));
360 assert!(bag.contains("foo"));
361 }
362
363 #[test]
364 fn test_tokenize_punctuation() {
365 let bag = tokenize("read_file: reads a file.");
366 assert!(bag.contains("read"));
367 assert!(bag.contains("file"));
368 assert!(bag.contains("reads"));
369 assert!(bag.contains("a"));
370 }
371
372 #[test]
373 fn test_jaccard_identical() {
374 let a = tokenize("read file");
375 let b = tokenize("read file");
376 assert!((jaccard(&a, &b) - 1.0).abs() < 1e-9);
377 }
378
379 #[test]
380 fn test_jaccard_disjoint() {
381 let a = tokenize("alpha beta");
382 let b = tokenize("gamma delta");
383 assert!((jaccard(&a, &b)).abs() < 1e-9);
384 }
385
386 #[test]
387 fn test_select_returns_between_3_and_5_for_large_set() {
388 let preset = make_preset_with_threshold(0.0, vec![]);
389 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
390 let tools = make_tools(10);
391 selector.register_tools(&tools).unwrap();
392
393 let result = selector.select("operation task alpha beta", 5).unwrap();
394 assert!(result.len() >= 3, "expected >= 3, got {}", result.len());
395 assert!(result.len() <= 5, "expected <= 5, got {}", result.len());
396 }
397
398 #[test]
399 fn test_select_returns_at_most_tool_count_for_small_set() {
400 let preset = make_preset_with_threshold(0.0, vec![]);
401 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
402 let tools = make_tools(2);
403 selector.register_tools(&tools).unwrap();
404
405 let result = selector.select("operation task", 5).unwrap();
406 assert!(result.len() <= 2, "expected <= 2, got {}", result.len());
407 }
408
409 #[test]
410 fn test_fallback_to_defaults_on_low_confidence() {
411 let defaults = vec!["default_a".to_string(), "default_b".to_string()];
412 let preset = make_preset_with_threshold(1.0, defaults.clone());
414 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
415 let tools = make_tools(5);
416 selector.register_tools(&tools).unwrap();
417
418 let result = selector.select("completely unrelated xyz", 5).unwrap();
419 assert_eq!(result, defaults);
420 }
421
422 #[test]
423 fn test_update_tool_changes_embedding() {
424 let preset = make_preset_with_threshold(0.0, vec![]);
425 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
426 let tools = vec![ToolDefinition {
427 id: "t1".to_string(),
428 name: "T1".to_string(),
429 description: "alpha beta gamma".to_string(),
430 ..Default::default()
431 }];
432 selector.register_tools(&tools).unwrap();
433
434 let updated = ToolDefinition {
435 id: "t1".to_string(),
436 name: "T1".to_string(),
437 description: "delta epsilon zeta".to_string(),
438 ..Default::default()
439 };
440 selector.update_tool(&updated).unwrap();
441
442 let bag = selector.bags.get("t1").unwrap();
443 assert!(bag.contains("delta"));
444 assert!(!bag.contains("alpha"));
445 }
446
447 #[test]
450 fn test_tfidf_discriminative_ranking() {
451 let preset = make_preset_with_threshold(0.0, vec![]);
454 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
455
456 let tools = vec![
457 ToolDefinition {
458 id: "generic".to_string(),
459 name: "Generic".to_string(),
460 description: "this tool performs common operations on files and data".to_string(),
461 ..Default::default()
462 },
463 ToolDefinition {
464 id: "specific".to_string(),
465 name: "Specific".to_string(),
466 description: "this tool performs kubernetes pod deployment orchestration".to_string(),
467 ..Default::default()
468 },
469 ToolDefinition {
470 id: "other".to_string(),
471 name: "Other".to_string(),
472 description: "this tool handles database migration and schema updates".to_string(),
473 ..Default::default()
474 },
475 ];
476 selector.register_tools(&tools).unwrap();
477
478 let result = selector
481 .select("deploy kubernetes pods to the cluster", 5)
482 .unwrap();
483 assert_eq!(
484 result[0], "specific",
485 "TF-IDF should rank the tool with rare matching terms first"
486 );
487 }
488
489 #[test]
490 fn test_cosine_similarity_identical() {
491 let mut a = HashMap::new();
492 a.insert("hello".to_string(), 1.0);
493 a.insert("world".to_string(), 2.0);
494 let sim = cosine_similarity(&a, &a);
495 assert!((sim - 1.0).abs() < 1e-9);
496 }
497
498 #[test]
499 fn test_cosine_similarity_orthogonal() {
500 let mut a = HashMap::new();
501 a.insert("hello".to_string(), 1.0);
502 let mut b = HashMap::new();
503 b.insert("world".to_string(), 1.0);
504 let sim = cosine_similarity(&a, &b);
505 assert!(sim.abs() < 1e-9);
506 }
507
508 #[test]
509 fn test_cosine_similarity_empty() {
510 let a: TfIdfVector = HashMap::new();
511 let b: TfIdfVector = HashMap::new();
512 assert_eq!(cosine_similarity(&a, &b), 0.0);
513 }
514
515 #[test]
516 fn test_tfidf_vectors_populated() {
517 let preset = make_preset_with_threshold(0.0, vec![]);
518 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
519 let tools = make_tools(5);
520 selector.register_tools(&tools).unwrap();
521 assert_eq!(selector.tfidf_vectors.len(), 5);
522 assert_eq!(selector.total_docs, 5);
523 }
524
525 #[test]
526 fn test_update_tool_unregistered_returns_error() {
527 let preset = make_preset_with_threshold(0.0, vec![]);
528 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
529 let result = selector.update_tool(&ToolDefinition {
530 id: "nonexistent".to_string(),
531 name: "X".to_string(),
532 description: "desc".to_string(),
533 ..Default::default()
534 });
535 assert!(result.is_err());
536 }
537
538 #[test]
539 fn test_empty_tool_set_returns_defaults() {
540 let defaults = vec!["fallback".to_string()];
541 let preset = make_preset_with_threshold(0.0, defaults.clone());
542 let selector = ToolSelector::new(Path::new(""), &preset).unwrap();
543 let result = selector.select("anything", 5).unwrap();
544 assert_eq!(result, defaults);
545 }
546
547 fn arb_tool_count_and_intent() -> impl Strategy<Value = (usize, String)> {
554 (5usize..=20usize, "[a-z ]{5,40}".prop_map(|s| s.trim().to_string()))
555 }
556
557 fn arb_small_tool_count_and_intent() -> impl Strategy<Value = (usize, String)> {
559 (1usize..=4usize, "[a-z ]{5,40}".prop_map(|s| s.trim().to_string()))
560 }
561
562 proptest! {
563 #[test]
573 fn prop_tool_selection_cardinality_large(
574 (tool_count, intent) in arb_tool_count_and_intent()
575 ) {
576 let preset = make_preset_with_threshold(0.0, vec![]);
578 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
579 let tools = make_tools(tool_count);
580 selector.register_tools(&tools).unwrap();
581
582 let result = selector.select(&intent, 5).unwrap();
583
584 prop_assert!(
585 result.len() >= 3,
586 "expected >= 3 tools, got {} (tool_count={}, intent='{}')",
587 result.len(), tool_count, intent
588 );
589 prop_assert!(
590 result.len() <= 5,
591 "expected <= 5 tools, got {} (tool_count={}, intent='{}')",
592 result.len(), tool_count, intent
593 );
594 }
595
596 #[test]
601 fn prop_tool_selection_cardinality_small(
602 (tool_count, intent) in arb_small_tool_count_and_intent()
603 ) {
604 let preset = make_preset_with_threshold(0.0, vec![]);
605 let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
606 let tools = make_tools(tool_count);
607 selector.register_tools(&tools).unwrap();
608
609 let result = selector.select(&intent, 5).unwrap();
610
611 prop_assert!(
612 result.len() <= tool_count,
613 "expected <= {} tools, got {} (intent='{}')",
614 tool_count, result.len(), intent
615 );
616 }
617 }
618}