Skip to main content

sqz_engine/
tool_selector.rs

1use std::collections::{HashMap, HashSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::error::{Result, SqzError};
6use crate::preset::Preset;
7use crate::types::ToolId;
8
9/// A tool definition with an id, name, and description.
10#[derive(Debug, Clone, Serialize, Deserialize, Default)]
11pub struct ToolDefinition {
12    pub id: ToolId,
13    pub name: String,
14    pub description: String,
15    /// JSON Schema for the tool's input parameters.
16    #[serde(default)]
17    pub input_schema: serde_json::Value,
18    /// JSON Schema describing the structure of the compressed output.
19    #[serde(default)]
20    pub output_schema: serde_json::Value,
21    /// Description of what sqz does to this tool's output.
22    #[serde(default)]
23    pub compression_transforms: Vec<String>,
24}
25
26/// Bag-of-words representation: a set of lowercase word tokens.
27type BagOfWords = HashSet<String>;
28
29/// Tokenize a string into a bag of lowercase words, splitting on whitespace and punctuation.
30fn tokenize(text: &str) -> BagOfWords {
31    text.split(|c: char| !c.is_alphanumeric())
32        .filter(|s| !s.is_empty())
33        .map(|s| s.to_lowercase())
34        .collect()
35}
36
37/// Tokenize into a frequency map (term → count) for TF-IDF.
38fn tokenize_tf(text: &str) -> HashMap<String, u32> {
39    let mut freq = HashMap::new();
40    for word in text.split(|c: char| !c.is_alphanumeric()).filter(|s| !s.is_empty()) {
41        *freq.entry(word.to_lowercase()).or_insert(0) += 1;
42    }
43    freq
44}
45
46/// Jaccard similarity between two bags of words: |A ∩ B| / |A ∪ B|.
47/// Returns 0.0 if both sets are empty.
48fn jaccard(a: &BagOfWords, b: &BagOfWords) -> f64 {
49    if a.is_empty() && b.is_empty() {
50        return 0.0;
51    }
52    let intersection = a.intersection(b).count() as f64;
53    let union = a.union(b).count() as f64;
54    if union == 0.0 {
55        0.0
56    } else {
57        intersection / union
58    }
59}
60
61// ── TF-IDF + Cosine Similarity ────────────────────────────────────────────────
62
63/// Sparse TF-IDF vector: term → weight.
64type TfIdfVector = HashMap<String, f64>;
65
66/// Compute TF-IDF weight for a term in a document.
67///
68/// TF(t,d) = count(t,d) / |d|
69/// IDF(t) = ln(N / DF(t))
70/// TF-IDF(t,d) = TF(t,d) × IDF(t)
71fn compute_tfidf(
72    term_freq: &HashMap<String, u32>,
73    doc_freq: &HashMap<String, u32>,
74    total_docs: u32,
75) -> TfIdfVector {
76    let doc_len: u32 = term_freq.values().sum();
77    if doc_len == 0 {
78        return HashMap::new();
79    }
80
81    let mut vector = HashMap::new();
82    for (term, &count) in term_freq {
83        let tf = count as f64 / doc_len as f64;
84        let df = doc_freq.get(term).copied().unwrap_or(1).max(1);
85        let idf = (total_docs as f64 / df as f64).ln();
86        let weight = tf * idf;
87        if weight > 0.0 {
88            vector.insert(term.clone(), weight);
89        }
90    }
91    vector
92}
93
94/// Cosine similarity between two sparse TF-IDF vectors.
95///
96/// cosine(a, b) = (a · b) / (||a|| × ||b||)
97fn cosine_similarity(a: &TfIdfVector, b: &TfIdfVector) -> f64 {
98    if a.is_empty() || b.is_empty() {
99        return 0.0;
100    }
101
102    let dot: f64 = a
103        .iter()
104        .filter_map(|(term, &wa)| b.get(term).map(|&wb| wa * wb))
105        .sum();
106
107    let norm_a: f64 = a.values().map(|w| w * w).sum::<f64>().sqrt();
108    let norm_b: f64 = b.values().map(|w| w * w).sum::<f64>().sqrt();
109
110    if norm_a == 0.0 || norm_b == 0.0 {
111        0.0
112    } else {
113        dot / (norm_a * norm_b)
114    }
115}
116
117// ── ToolSelector ──────────────────────────────────────────────────────────────
118
119/// Selects 3–5 relevant tools per task using TF-IDF + cosine similarity,
120/// with Jaccard as a fallback for very short descriptions.
121///
122/// # Selection rules
123/// - Compute TF-IDF vectors for each tool description at registration time.
124/// - At query time, compute the TF-IDF vector for the intent and score via cosine.
125/// - Sort tools by descending similarity score.
126/// - Return between 3 and min(5, tool_count) tools.
127/// - If no tool has similarity > threshold, return the default tool set instead.
128pub struct ToolSelector {
129    /// Bag-of-words for each registered tool description (kept for backward compat).
130    bags: HashMap<ToolId, BagOfWords>,
131    /// TF-IDF vectors for each registered tool description.
132    tfidf_vectors: HashMap<ToolId, TfIdfVector>,
133    /// Document frequency: term → number of tools containing that term.
134    doc_freq: HashMap<String, u32>,
135    /// Total number of registered tools (for IDF computation).
136    total_docs: u32,
137    /// Ordered list of registered tool ids (preserves insertion order for determinism).
138    tool_ids: Vec<ToolId>,
139    /// Similarity threshold below which we fall back to defaults.
140    threshold: f64,
141    /// Default tool ids returned when confidence is low.
142    default_tools: Vec<ToolId>,
143    /// Raw term frequencies per tool (needed for recomputing TF-IDF on updates).
144    term_freqs: HashMap<ToolId, HashMap<String, u32>>,
145}
146
147impl ToolSelector {
148    /// Create a new `ToolSelector` from a `Preset`.
149    ///
150    /// `model_path` is accepted for API compatibility but is unused.
151    pub fn new(_model_path: &std::path::Path, preset: &Preset) -> Result<Self> {
152        let threshold = preset.tool_selection.similarity_threshold;
153        let default_tools = preset.tool_selection.default_tools.clone();
154        Ok(Self {
155            bags: HashMap::new(),
156            tfidf_vectors: HashMap::new(),
157            doc_freq: HashMap::new(),
158            total_docs: 0,
159            tool_ids: Vec::new(),
160            threshold,
161            default_tools,
162            term_freqs: HashMap::new(),
163        })
164    }
165
166    /// Register a slice of tools, computing TF-IDF vectors for each description.
167    pub fn register_tools(&mut self, tools: &[ToolDefinition]) -> Result<()> {
168        // First pass: collect term frequencies and update doc_freq
169        for tool in tools {
170            let bag = tokenize(&tool.description);
171            let tf = tokenize_tf(&tool.description);
172
173            // Update document frequency for new terms
174            if !self.bags.contains_key(&tool.id) {
175                self.tool_ids.push(tool.id.clone());
176                self.total_docs += 1;
177                for term in tf.keys() {
178                    *self.doc_freq.entry(term.clone()).or_insert(0) += 1;
179                }
180            } else {
181                // Tool already registered — update doc_freq (remove old, add new)
182                if let Some(old_tf) = self.term_freqs.get(&tool.id) {
183                    for term in old_tf.keys() {
184                        if let Some(count) = self.doc_freq.get_mut(term) {
185                            *count = count.saturating_sub(1);
186                        }
187                    }
188                }
189                for term in tf.keys() {
190                    *self.doc_freq.entry(term.clone()).or_insert(0) += 1;
191                }
192            }
193
194            self.bags.insert(tool.id.clone(), bag);
195            self.term_freqs.insert(tool.id.clone(), tf);
196        }
197
198        // Second pass: recompute all TF-IDF vectors (IDF changed)
199        self.recompute_tfidf();
200        Ok(())
201    }
202
203    /// Recompute TF-IDF vectors for all registered tools.
204    fn recompute_tfidf(&mut self) {
205        for id in &self.tool_ids {
206            if let Some(tf) = self.term_freqs.get(id) {
207                let vector = compute_tfidf(tf, &self.doc_freq, self.total_docs);
208                self.tfidf_vectors.insert(id.clone(), vector);
209            }
210        }
211    }
212
213    /// Select between 3 and min(5, tool_count) tools whose descriptions best match `intent`.
214    ///
215    /// Uses TF-IDF + cosine similarity for scoring. Falls back to Jaccard for
216    /// very short intents (< 3 words) where TF-IDF has insufficient signal.
217    ///
218    /// Returns the default tool set when no tool exceeds the similarity threshold.
219    pub fn select(&self, intent: &str, max_tools: usize) -> Result<Vec<ToolId>> {
220        let tool_count = self.tool_ids.len();
221        if tool_count == 0 {
222            return Ok(self.default_tools.clone());
223        }
224
225        let intent_words: Vec<&str> = intent
226            .split(|c: char| !c.is_alphanumeric())
227            .filter(|s| !s.is_empty())
228            .collect();
229
230        // Use TF-IDF + cosine for intents with enough signal, Jaccard for short ones
231        let use_tfidf = intent_words.len() >= 3 && self.total_docs >= 2;
232
233        let mut scored: Vec<(f64, &ToolId)> = if use_tfidf {
234            let intent_tf = tokenize_tf(intent);
235            let intent_vector = compute_tfidf(&intent_tf, &self.doc_freq, self.total_docs);
236
237            self.tool_ids
238                .iter()
239                .map(|id| {
240                    let score = self
241                        .tfidf_vectors
242                        .get(id)
243                        .map(|v| cosine_similarity(&intent_vector, v))
244                        .unwrap_or(0.0);
245                    (score, id)
246                })
247                .collect()
248        } else {
249            // Fallback to Jaccard for short intents
250            let intent_bag = tokenize(intent);
251            self.tool_ids
252                .iter()
253                .map(|id| {
254                    let bag = self.bags.get(id).expect("bag must exist for registered tool");
255                    let score = jaccard(&intent_bag, bag);
256                    (score, id)
257                })
258                .collect()
259        };
260
261        // Sort descending by score, then ascending by id for determinism on ties.
262        scored.sort_by(|a, b| {
263            b.0.partial_cmp(&a.0)
264                .unwrap_or(std::cmp::Ordering::Equal)
265                .then_with(|| a.1.cmp(b.1))
266        });
267
268        // Check whether any tool exceeds the threshold.
269        let best_score = scored.first().map(|(s, _)| *s).unwrap_or(0.0);
270        if best_score < self.threshold {
271            return Ok(self.default_tools.clone());
272        }
273
274        // Cardinality: return between 3 and min(max_tools, 5, tool_count) tools.
275        let upper = max_tools.min(5).min(tool_count);
276        let lower = 3_usize.min(tool_count);
277        let count = upper.max(lower);
278
279        let result = scored
280            .into_iter()
281            .take(count)
282            .map(|(_, id)| id.clone())
283            .collect();
284
285        Ok(result)
286    }
287
288    /// Re-embed a single tool on description change.
289    pub fn update_tool(&mut self, tool: &ToolDefinition) -> Result<()> {
290        if !self.bags.contains_key(&tool.id) {
291            return Err(SqzError::Other(format!(
292                "tool '{}' is not registered; use register_tools first",
293                tool.id
294            )));
295        }
296
297        // Update doc_freq: remove old terms, add new
298        if let Some(old_tf) = self.term_freqs.get(&tool.id) {
299            for term in old_tf.keys() {
300                if let Some(count) = self.doc_freq.get_mut(term) {
301                    *count = count.saturating_sub(1);
302                }
303            }
304        }
305
306        let bag = tokenize(&tool.description);
307        let tf = tokenize_tf(&tool.description);
308        for term in tf.keys() {
309            *self.doc_freq.entry(term.clone()).or_insert(0) += 1;
310        }
311
312        self.bags.insert(tool.id.clone(), bag);
313        self.term_freqs.insert(tool.id.clone(), tf);
314
315        // Recompute all vectors since IDF may have changed
316        self.recompute_tfidf();
317        Ok(())
318    }
319}
320
321// ---------------------------------------------------------------------------
322// Tests
323// ---------------------------------------------------------------------------
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328    use proptest::prelude::*;
329    use std::path::Path;
330
331    fn make_preset_with_threshold(threshold: f64, default_tools: Vec<String>) -> Preset {
332        let mut p = Preset::default();
333        p.tool_selection.similarity_threshold = threshold;
334        p.tool_selection.default_tools = default_tools;
335        p
336    }
337
338    fn make_tools(n: usize) -> Vec<ToolDefinition> {
339        (0..n)
340            .map(|i| ToolDefinition {
341                id: format!("tool_{i}"),
342                name: format!("Tool {i}"),
343                description: format!(
344                    "This tool performs operation number {i} for task category alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu nu xi omicron pi rho sigma tau upsilon phi chi psi omega {i}"
345                ),
346                ..Default::default()
347            })
348            .collect()
349    }
350
351    // -----------------------------------------------------------------------
352    // Unit tests
353    // -----------------------------------------------------------------------
354
355    #[test]
356    fn test_tokenize_basic() {
357        let bag = tokenize("hello world foo");
358        assert!(bag.contains("hello"));
359        assert!(bag.contains("world"));
360        assert!(bag.contains("foo"));
361    }
362
363    #[test]
364    fn test_tokenize_punctuation() {
365        let bag = tokenize("read_file: reads a file.");
366        assert!(bag.contains("read"));
367        assert!(bag.contains("file"));
368        assert!(bag.contains("reads"));
369        assert!(bag.contains("a"));
370    }
371
372    #[test]
373    fn test_jaccard_identical() {
374        let a = tokenize("read file");
375        let b = tokenize("read file");
376        assert!((jaccard(&a, &b) - 1.0).abs() < 1e-9);
377    }
378
379    #[test]
380    fn test_jaccard_disjoint() {
381        let a = tokenize("alpha beta");
382        let b = tokenize("gamma delta");
383        assert!((jaccard(&a, &b)).abs() < 1e-9);
384    }
385
386    #[test]
387    fn test_select_returns_between_3_and_5_for_large_set() {
388        let preset = make_preset_with_threshold(0.0, vec![]);
389        let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
390        let tools = make_tools(10);
391        selector.register_tools(&tools).unwrap();
392
393        let result = selector.select("operation task alpha beta", 5).unwrap();
394        assert!(result.len() >= 3, "expected >= 3, got {}", result.len());
395        assert!(result.len() <= 5, "expected <= 5, got {}", result.len());
396    }
397
398    #[test]
399    fn test_select_returns_at_most_tool_count_for_small_set() {
400        let preset = make_preset_with_threshold(0.0, vec![]);
401        let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
402        let tools = make_tools(2);
403        selector.register_tools(&tools).unwrap();
404
405        let result = selector.select("operation task", 5).unwrap();
406        assert!(result.len() <= 2, "expected <= 2, got {}", result.len());
407    }
408
409    #[test]
410    fn test_fallback_to_defaults_on_low_confidence() {
411        let defaults = vec!["default_a".to_string(), "default_b".to_string()];
412        // threshold = 1.0 means nothing will ever exceed it (Jaccard < 1 unless identical)
413        let preset = make_preset_with_threshold(1.0, defaults.clone());
414        let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
415        let tools = make_tools(5);
416        selector.register_tools(&tools).unwrap();
417
418        let result = selector.select("completely unrelated xyz", 5).unwrap();
419        assert_eq!(result, defaults);
420    }
421
422    #[test]
423    fn test_update_tool_changes_embedding() {
424        let preset = make_preset_with_threshold(0.0, vec![]);
425        let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
426        let tools = vec![ToolDefinition {
427            id: "t1".to_string(),
428            name: "T1".to_string(),
429            description: "alpha beta gamma".to_string(),
430            ..Default::default()
431        }];
432        selector.register_tools(&tools).unwrap();
433
434        let updated = ToolDefinition {
435            id: "t1".to_string(),
436            name: "T1".to_string(),
437            description: "delta epsilon zeta".to_string(),
438            ..Default::default()
439        };
440        selector.update_tool(&updated).unwrap();
441
442        let bag = selector.bags.get("t1").unwrap();
443        assert!(bag.contains("delta"));
444        assert!(!bag.contains("alpha"));
445    }
446
447    // ── TF-IDF specific tests ─────────────────────────────────────────────
448
449    #[test]
450    fn test_tfidf_discriminative_ranking() {
451        // TF-IDF should rank a tool with a rare matching term higher than
452        // one with only common terms.
453        let preset = make_preset_with_threshold(0.0, vec![]);
454        let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
455
456        let tools = vec![
457            ToolDefinition {
458                id: "generic".to_string(),
459                name: "Generic".to_string(),
460                description: "this tool performs common operations on files and data".to_string(),
461                ..Default::default()
462            },
463            ToolDefinition {
464                id: "specific".to_string(),
465                name: "Specific".to_string(),
466                description: "this tool performs kubernetes pod deployment orchestration".to_string(),
467                ..Default::default()
468            },
469            ToolDefinition {
470                id: "other".to_string(),
471                name: "Other".to_string(),
472                description: "this tool handles database migration and schema updates".to_string(),
473                ..Default::default()
474            },
475        ];
476        selector.register_tools(&tools).unwrap();
477
478        // "kubernetes deployment" should rank "specific" highest because
479        // "kubernetes" and "deployment" are rare (discriminative) terms
480        let result = selector
481            .select("deploy kubernetes pods to the cluster", 5)
482            .unwrap();
483        assert_eq!(
484            result[0], "specific",
485            "TF-IDF should rank the tool with rare matching terms first"
486        );
487    }
488
489    #[test]
490    fn test_cosine_similarity_identical() {
491        let mut a = HashMap::new();
492        a.insert("hello".to_string(), 1.0);
493        a.insert("world".to_string(), 2.0);
494        let sim = cosine_similarity(&a, &a);
495        assert!((sim - 1.0).abs() < 1e-9);
496    }
497
498    #[test]
499    fn test_cosine_similarity_orthogonal() {
500        let mut a = HashMap::new();
501        a.insert("hello".to_string(), 1.0);
502        let mut b = HashMap::new();
503        b.insert("world".to_string(), 1.0);
504        let sim = cosine_similarity(&a, &b);
505        assert!(sim.abs() < 1e-9);
506    }
507
508    #[test]
509    fn test_cosine_similarity_empty() {
510        let a: TfIdfVector = HashMap::new();
511        let b: TfIdfVector = HashMap::new();
512        assert_eq!(cosine_similarity(&a, &b), 0.0);
513    }
514
515    #[test]
516    fn test_tfidf_vectors_populated() {
517        let preset = make_preset_with_threshold(0.0, vec![]);
518        let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
519        let tools = make_tools(5);
520        selector.register_tools(&tools).unwrap();
521        assert_eq!(selector.tfidf_vectors.len(), 5);
522        assert_eq!(selector.total_docs, 5);
523    }
524
525    #[test]
526    fn test_update_tool_unregistered_returns_error() {
527        let preset = make_preset_with_threshold(0.0, vec![]);
528        let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
529        let result = selector.update_tool(&ToolDefinition {
530            id: "nonexistent".to_string(),
531            name: "X".to_string(),
532            description: "desc".to_string(),
533            ..Default::default()
534        });
535        assert!(result.is_err());
536    }
537
538    #[test]
539    fn test_empty_tool_set_returns_defaults() {
540        let defaults = vec!["fallback".to_string()];
541        let preset = make_preset_with_threshold(0.0, defaults.clone());
542        let selector = ToolSelector::new(Path::new(""), &preset).unwrap();
543        let result = selector.select("anything", 5).unwrap();
544        assert_eq!(result, defaults);
545    }
546
547    // -----------------------------------------------------------------------
548    // Property 2: Tool selection cardinality
549    // Validates: Requirements 3.1, 26.3
550    // -----------------------------------------------------------------------
551
552    /// Strategy: generate a tool count in [5, 20] and an intent string.
553    fn arb_tool_count_and_intent() -> impl Strategy<Value = (usize, String)> {
554        (5usize..=20usize, "[a-z ]{5,40}".prop_map(|s| s.trim().to_string()))
555    }
556
557    /// Strategy: generate a small tool count in [1, 4] and an intent string.
558    fn arb_small_tool_count_and_intent() -> impl Strategy<Value = (usize, String)> {
559        (1usize..=4usize, "[a-z ]{5,40}".prop_map(|s| s.trim().to_string()))
560    }
561
562    proptest! {
563        /// **Validates: Requirements 3.1, 26.3**
564        ///
565        /// Property 2: Tool selection cardinality.
566        ///
567        /// For any intent string and any tool set of size >= 5, the ToolSelector
568        /// SHALL return between 3 and 5 tools (inclusive).
569        ///
570        /// For tool sets smaller than 5, it SHALL return at most the size of the
571        /// tool set.
572        #[test]
573        fn prop_tool_selection_cardinality_large(
574            (tool_count, intent) in arb_tool_count_and_intent()
575        ) {
576            // Use threshold = 0.0 so all tools are eligible (no fallback to defaults).
577            let preset = make_preset_with_threshold(0.0, vec![]);
578            let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
579            let tools = make_tools(tool_count);
580            selector.register_tools(&tools).unwrap();
581
582            let result = selector.select(&intent, 5).unwrap();
583
584            prop_assert!(
585                result.len() >= 3,
586                "expected >= 3 tools, got {} (tool_count={}, intent='{}')",
587                result.len(), tool_count, intent
588            );
589            prop_assert!(
590                result.len() <= 5,
591                "expected <= 5 tools, got {} (tool_count={}, intent='{}')",
592                result.len(), tool_count, intent
593            );
594        }
595
596        /// **Validates: Requirements 3.1, 26.3**
597        ///
598        /// Property 2b: For tool sets smaller than 5, ToolSelector returns at most
599        /// the size of the tool set.
600        #[test]
601        fn prop_tool_selection_cardinality_small(
602            (tool_count, intent) in arb_small_tool_count_and_intent()
603        ) {
604            let preset = make_preset_with_threshold(0.0, vec![]);
605            let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
606            let tools = make_tools(tool_count);
607            selector.register_tools(&tools).unwrap();
608
609            let result = selector.select(&intent, 5).unwrap();
610
611            prop_assert!(
612                result.len() <= tool_count,
613                "expected <= {} tools, got {} (intent='{}')",
614                tool_count, result.len(), intent
615            );
616        }
617    }
618}