Skip to main content

sqz_engine/
tool_selector.rs

1use std::collections::{HashMap, HashSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::error::{Result, SqzError};
6use crate::preset::Preset;
7use crate::types::ToolId;
8
9/// A tool definition with an id, name, and description.
10#[derive(Debug, Clone, Serialize, Deserialize, Default)]
11pub struct ToolDefinition {
12    pub id: ToolId,
13    pub name: String,
14    pub description: String,
15    /// JSON Schema for the tool's input parameters.
16    #[serde(default)]
17    pub input_schema: serde_json::Value,
18    /// Optional JSON Schema describing the structure of a tool's structured
19    /// output (see MCP spec `outputSchema`).
20    ///
21    /// Per the MCP 2025-06-18 specification, when present this schema's root
22    /// `type` MUST be `"object"` and servers MUST return `structuredContent`
23    /// matching the schema. Leave as `Value::Null` (the default) when the
24    /// tool returns plain text via the `content` field — strict clients such
25    /// as OpenCode validate this and will disable the whole server if a
26    /// scalar-typed schema is advertised (reported in issue #5).
27    #[serde(default)]
28    pub output_schema: serde_json::Value,
29    /// Description of what sqz does to this tool's output.
30    #[serde(default)]
31    pub compression_transforms: Vec<String>,
32}
33
34/// Bag-of-words representation: a set of lowercase word tokens.
35type BagOfWords = HashSet<String>;
36
37/// Tokenize a string into a bag of lowercase words, splitting on whitespace and punctuation.
38fn tokenize(text: &str) -> BagOfWords {
39    text.split(|c: char| !c.is_alphanumeric())
40        .filter(|s| !s.is_empty())
41        .map(|s| s.to_lowercase())
42        .collect()
43}
44
45/// Tokenize into a frequency map (term → count) for TF-IDF.
46fn tokenize_tf(text: &str) -> HashMap<String, u32> {
47    let mut freq = HashMap::new();
48    for word in text.split(|c: char| !c.is_alphanumeric()).filter(|s| !s.is_empty()) {
49        *freq.entry(word.to_lowercase()).or_insert(0) += 1;
50    }
51    freq
52}
53
54/// Jaccard similarity between two bags of words: |A ∩ B| / |A ∪ B|.
55/// Returns 0.0 if both sets are empty.
56fn jaccard(a: &BagOfWords, b: &BagOfWords) -> f64 {
57    if a.is_empty() && b.is_empty() {
58        return 0.0;
59    }
60    let intersection = a.intersection(b).count() as f64;
61    let union = a.union(b).count() as f64;
62    if union == 0.0 {
63        0.0
64    } else {
65        intersection / union
66    }
67}
68
69// ── TF-IDF + Cosine Similarity ────────────────────────────────────────────────
70
71/// Sparse TF-IDF vector: term → weight.
72type TfIdfVector = HashMap<String, f64>;
73
74/// Compute TF-IDF weight for a term in a document.
75///
76/// TF(t,d) = count(t,d) / |d|
77/// IDF(t) = ln(N / DF(t))
78/// TF-IDF(t,d) = TF(t,d) × IDF(t)
79fn compute_tfidf(
80    term_freq: &HashMap<String, u32>,
81    doc_freq: &HashMap<String, u32>,
82    total_docs: u32,
83) -> TfIdfVector {
84    let doc_len: u32 = term_freq.values().sum();
85    if doc_len == 0 {
86        return HashMap::new();
87    }
88
89    let mut vector = HashMap::new();
90    for (term, &count) in term_freq {
91        let tf = count as f64 / doc_len as f64;
92        let df = doc_freq.get(term).copied().unwrap_or(1).max(1);
93        let idf = (total_docs as f64 / df as f64).ln();
94        let weight = tf * idf;
95        if weight > 0.0 {
96            vector.insert(term.clone(), weight);
97        }
98    }
99    vector
100}
101
102/// Cosine similarity between two sparse TF-IDF vectors.
103///
104/// cosine(a, b) = (a · b) / (||a|| × ||b||)
105fn cosine_similarity(a: &TfIdfVector, b: &TfIdfVector) -> f64 {
106    if a.is_empty() || b.is_empty() {
107        return 0.0;
108    }
109
110    let dot: f64 = a
111        .iter()
112        .filter_map(|(term, &wa)| b.get(term).map(|&wb| wa * wb))
113        .sum();
114
115    let norm_a: f64 = a.values().map(|w| w * w).sum::<f64>().sqrt();
116    let norm_b: f64 = b.values().map(|w| w * w).sum::<f64>().sqrt();
117
118    if norm_a == 0.0 || norm_b == 0.0 {
119        0.0
120    } else {
121        dot / (norm_a * norm_b)
122    }
123}
124
125// ── ToolSelector ──────────────────────────────────────────────────────────────
126
127/// Selects 3–5 relevant tools per task using TF-IDF + cosine similarity,
128/// with Jaccard as a fallback for very short descriptions.
129///
130/// # Selection rules
131/// - Compute TF-IDF vectors for each tool description at registration time.
132/// - At query time, compute the TF-IDF vector for the intent and score via cosine.
133/// - Sort tools by descending similarity score.
134/// - Return between 3 and min(5, tool_count) tools.
135/// - If no tool has similarity > threshold, return the default tool set instead.
136pub struct ToolSelector {
137    /// Bag-of-words for each registered tool description (kept for backward compat).
138    bags: HashMap<ToolId, BagOfWords>,
139    /// TF-IDF vectors for each registered tool description.
140    tfidf_vectors: HashMap<ToolId, TfIdfVector>,
141    /// Document frequency: term → number of tools containing that term.
142    doc_freq: HashMap<String, u32>,
143    /// Total number of registered tools (for IDF computation).
144    total_docs: u32,
145    /// Ordered list of registered tool ids (preserves insertion order for determinism).
146    tool_ids: Vec<ToolId>,
147    /// Similarity threshold below which we fall back to defaults.
148    threshold: f64,
149    /// Default tool ids returned when confidence is low.
150    default_tools: Vec<ToolId>,
151    /// Raw term frequencies per tool (needed for recomputing TF-IDF on updates).
152    term_freqs: HashMap<ToolId, HashMap<String, u32>>,
153}
154
155impl ToolSelector {
156    /// Create a new `ToolSelector` from a `Preset`.
157    ///
158    /// `model_path` is accepted for API compatibility but is unused.
159    pub fn new(_model_path: &std::path::Path, preset: &Preset) -> Result<Self> {
160        let threshold = preset.tool_selection.similarity_threshold;
161        let default_tools = preset.tool_selection.default_tools.clone();
162        Ok(Self {
163            bags: HashMap::new(),
164            tfidf_vectors: HashMap::new(),
165            doc_freq: HashMap::new(),
166            total_docs: 0,
167            tool_ids: Vec::new(),
168            threshold,
169            default_tools,
170            term_freqs: HashMap::new(),
171        })
172    }
173
174    /// Register a slice of tools, computing TF-IDF vectors for each description.
175    pub fn register_tools(&mut self, tools: &[ToolDefinition]) -> Result<()> {
176        // First pass: collect term frequencies and update doc_freq
177        for tool in tools {
178            let bag = tokenize(&tool.description);
179            let tf = tokenize_tf(&tool.description);
180
181            // Update document frequency for new terms
182            if !self.bags.contains_key(&tool.id) {
183                self.tool_ids.push(tool.id.clone());
184                self.total_docs += 1;
185                for term in tf.keys() {
186                    *self.doc_freq.entry(term.clone()).or_insert(0) += 1;
187                }
188            } else {
189                // Tool already registered — update doc_freq (remove old, add new)
190                if let Some(old_tf) = self.term_freqs.get(&tool.id) {
191                    for term in old_tf.keys() {
192                        if let Some(count) = self.doc_freq.get_mut(term) {
193                            *count = count.saturating_sub(1);
194                        }
195                    }
196                }
197                for term in tf.keys() {
198                    *self.doc_freq.entry(term.clone()).or_insert(0) += 1;
199                }
200            }
201
202            self.bags.insert(tool.id.clone(), bag);
203            self.term_freqs.insert(tool.id.clone(), tf);
204        }
205
206        // Second pass: recompute all TF-IDF vectors (IDF changed)
207        self.recompute_tfidf();
208        Ok(())
209    }
210
211    /// Recompute TF-IDF vectors for all registered tools.
212    fn recompute_tfidf(&mut self) {
213        for id in &self.tool_ids {
214            if let Some(tf) = self.term_freqs.get(id) {
215                let vector = compute_tfidf(tf, &self.doc_freq, self.total_docs);
216                self.tfidf_vectors.insert(id.clone(), vector);
217            }
218        }
219    }
220
221    /// Select between 3 and min(5, tool_count) tools whose descriptions best match `intent`.
222    ///
223    /// Uses TF-IDF + cosine similarity for scoring. Falls back to Jaccard for
224    /// very short intents (< 3 words) where TF-IDF has insufficient signal.
225    ///
226    /// Returns the default tool set when no tool exceeds the similarity threshold.
227    pub fn select(&self, intent: &str, max_tools: usize) -> Result<Vec<ToolId>> {
228        let tool_count = self.tool_ids.len();
229        if tool_count == 0 {
230            return Ok(self.default_tools.clone());
231        }
232
233        let intent_words: Vec<&str> = intent
234            .split(|c: char| !c.is_alphanumeric())
235            .filter(|s| !s.is_empty())
236            .collect();
237
238        // Use TF-IDF + cosine for intents with enough signal, Jaccard for short ones
239        let use_tfidf = intent_words.len() >= 3 && self.total_docs >= 2;
240
241        let mut scored: Vec<(f64, &ToolId)> = if use_tfidf {
242            let intent_tf = tokenize_tf(intent);
243            let intent_vector = compute_tfidf(&intent_tf, &self.doc_freq, self.total_docs);
244
245            self.tool_ids
246                .iter()
247                .map(|id| {
248                    let score = self
249                        .tfidf_vectors
250                        .get(id)
251                        .map(|v| cosine_similarity(&intent_vector, v))
252                        .unwrap_or(0.0);
253                    (score, id)
254                })
255                .collect()
256        } else {
257            // Fallback to Jaccard for short intents
258            let intent_bag = tokenize(intent);
259            self.tool_ids
260                .iter()
261                .map(|id| {
262                    let bag = self.bags.get(id).expect("bag must exist for registered tool");
263                    let score = jaccard(&intent_bag, bag);
264                    (score, id)
265                })
266                .collect()
267        };
268
269        // Sort descending by score, then ascending by id for determinism on ties.
270        scored.sort_by(|a, b| {
271            b.0.partial_cmp(&a.0)
272                .unwrap_or(std::cmp::Ordering::Equal)
273                .then_with(|| a.1.cmp(b.1))
274        });
275
276        // Check whether any tool exceeds the threshold.
277        let best_score = scored.first().map(|(s, _)| *s).unwrap_or(0.0);
278        if best_score < self.threshold {
279            return Ok(self.default_tools.clone());
280        }
281
282        // Cardinality: return between 3 and min(max_tools, 5, tool_count) tools.
283        let upper = max_tools.min(5).min(tool_count);
284        let lower = 3_usize.min(tool_count);
285        let count = upper.max(lower);
286
287        let result = scored
288            .into_iter()
289            .take(count)
290            .map(|(_, id)| id.clone())
291            .collect();
292
293        Ok(result)
294    }
295
296    /// Re-embed a single tool on description change.
297    pub fn update_tool(&mut self, tool: &ToolDefinition) -> Result<()> {
298        if !self.bags.contains_key(&tool.id) {
299            return Err(SqzError::Other(format!(
300                "tool '{}' is not registered; use register_tools first",
301                tool.id
302            )));
303        }
304
305        // Update doc_freq: remove old terms, add new
306        if let Some(old_tf) = self.term_freqs.get(&tool.id) {
307            for term in old_tf.keys() {
308                if let Some(count) = self.doc_freq.get_mut(term) {
309                    *count = count.saturating_sub(1);
310                }
311            }
312        }
313
314        let bag = tokenize(&tool.description);
315        let tf = tokenize_tf(&tool.description);
316        for term in tf.keys() {
317            *self.doc_freq.entry(term.clone()).or_insert(0) += 1;
318        }
319
320        self.bags.insert(tool.id.clone(), bag);
321        self.term_freqs.insert(tool.id.clone(), tf);
322
323        // Recompute all vectors since IDF may have changed
324        self.recompute_tfidf();
325        Ok(())
326    }
327}
328
329// ---------------------------------------------------------------------------
330// Tests
331// ---------------------------------------------------------------------------
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336    use proptest::prelude::*;
337    use std::path::Path;
338
339    fn make_preset_with_threshold(threshold: f64, default_tools: Vec<String>) -> Preset {
340        let mut p = Preset::default();
341        p.tool_selection.similarity_threshold = threshold;
342        p.tool_selection.default_tools = default_tools;
343        p
344    }
345
346    fn make_tools(n: usize) -> Vec<ToolDefinition> {
347        (0..n)
348            .map(|i| ToolDefinition {
349                id: format!("tool_{i}"),
350                name: format!("Tool {i}"),
351                description: format!(
352                    "This tool performs operation number {i} for task category alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu nu xi omicron pi rho sigma tau upsilon phi chi psi omega {i}"
353                ),
354                ..Default::default()
355            })
356            .collect()
357    }
358
359    // -----------------------------------------------------------------------
360    // Unit tests
361    // -----------------------------------------------------------------------
362
363    #[test]
364    fn test_tokenize_basic() {
365        let bag = tokenize("hello world foo");
366        assert!(bag.contains("hello"));
367        assert!(bag.contains("world"));
368        assert!(bag.contains("foo"));
369    }
370
371    #[test]
372    fn test_tokenize_punctuation() {
373        let bag = tokenize("read_file: reads a file.");
374        assert!(bag.contains("read"));
375        assert!(bag.contains("file"));
376        assert!(bag.contains("reads"));
377        assert!(bag.contains("a"));
378    }
379
380    #[test]
381    fn test_jaccard_identical() {
382        let a = tokenize("read file");
383        let b = tokenize("read file");
384        assert!((jaccard(&a, &b) - 1.0).abs() < 1e-9);
385    }
386
387    #[test]
388    fn test_jaccard_disjoint() {
389        let a = tokenize("alpha beta");
390        let b = tokenize("gamma delta");
391        assert!((jaccard(&a, &b)).abs() < 1e-9);
392    }
393
394    #[test]
395    fn test_select_returns_between_3_and_5_for_large_set() {
396        let preset = make_preset_with_threshold(0.0, vec![]);
397        let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
398        let tools = make_tools(10);
399        selector.register_tools(&tools).unwrap();
400
401        let result = selector.select("operation task alpha beta", 5).unwrap();
402        assert!(result.len() >= 3, "expected >= 3, got {}", result.len());
403        assert!(result.len() <= 5, "expected <= 5, got {}", result.len());
404    }
405
406    #[test]
407    fn test_select_returns_at_most_tool_count_for_small_set() {
408        let preset = make_preset_with_threshold(0.0, vec![]);
409        let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
410        let tools = make_tools(2);
411        selector.register_tools(&tools).unwrap();
412
413        let result = selector.select("operation task", 5).unwrap();
414        assert!(result.len() <= 2, "expected <= 2, got {}", result.len());
415    }
416
417    #[test]
418    fn test_fallback_to_defaults_on_low_confidence() {
419        let defaults = vec!["default_a".to_string(), "default_b".to_string()];
420        // threshold = 1.0 means nothing will ever exceed it (Jaccard < 1 unless identical)
421        let preset = make_preset_with_threshold(1.0, defaults.clone());
422        let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
423        let tools = make_tools(5);
424        selector.register_tools(&tools).unwrap();
425
426        let result = selector.select("completely unrelated xyz", 5).unwrap();
427        assert_eq!(result, defaults);
428    }
429
430    #[test]
431    fn test_update_tool_changes_embedding() {
432        let preset = make_preset_with_threshold(0.0, vec![]);
433        let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
434        let tools = vec![ToolDefinition {
435            id: "t1".to_string(),
436            name: "T1".to_string(),
437            description: "alpha beta gamma".to_string(),
438            ..Default::default()
439        }];
440        selector.register_tools(&tools).unwrap();
441
442        let updated = ToolDefinition {
443            id: "t1".to_string(),
444            name: "T1".to_string(),
445            description: "delta epsilon zeta".to_string(),
446            ..Default::default()
447        };
448        selector.update_tool(&updated).unwrap();
449
450        let bag = selector.bags.get("t1").unwrap();
451        assert!(bag.contains("delta"));
452        assert!(!bag.contains("alpha"));
453    }
454
455    // ── TF-IDF specific tests ─────────────────────────────────────────────
456
457    #[test]
458    fn test_tfidf_discriminative_ranking() {
459        // TF-IDF should rank a tool with a rare matching term higher than
460        // one with only common terms.
461        let preset = make_preset_with_threshold(0.0, vec![]);
462        let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
463
464        let tools = vec![
465            ToolDefinition {
466                id: "generic".to_string(),
467                name: "Generic".to_string(),
468                description: "this tool performs common operations on files and data".to_string(),
469                ..Default::default()
470            },
471            ToolDefinition {
472                id: "specific".to_string(),
473                name: "Specific".to_string(),
474                description: "this tool performs kubernetes pod deployment orchestration".to_string(),
475                ..Default::default()
476            },
477            ToolDefinition {
478                id: "other".to_string(),
479                name: "Other".to_string(),
480                description: "this tool handles database migration and schema updates".to_string(),
481                ..Default::default()
482            },
483        ];
484        selector.register_tools(&tools).unwrap();
485
486        // "kubernetes deployment" should rank "specific" highest because
487        // "kubernetes" and "deployment" are rare (discriminative) terms
488        let result = selector
489            .select("deploy kubernetes pods to the cluster", 5)
490            .unwrap();
491        assert_eq!(
492            result[0], "specific",
493            "TF-IDF should rank the tool with rare matching terms first"
494        );
495    }
496
497    #[test]
498    fn test_cosine_similarity_identical() {
499        let mut a = HashMap::new();
500        a.insert("hello".to_string(), 1.0);
501        a.insert("world".to_string(), 2.0);
502        let sim = cosine_similarity(&a, &a);
503        assert!((sim - 1.0).abs() < 1e-9);
504    }
505
506    #[test]
507    fn test_cosine_similarity_orthogonal() {
508        let mut a = HashMap::new();
509        a.insert("hello".to_string(), 1.0);
510        let mut b = HashMap::new();
511        b.insert("world".to_string(), 1.0);
512        let sim = cosine_similarity(&a, &b);
513        assert!(sim.abs() < 1e-9);
514    }
515
516    #[test]
517    fn test_cosine_similarity_empty() {
518        let a: TfIdfVector = HashMap::new();
519        let b: TfIdfVector = HashMap::new();
520        assert_eq!(cosine_similarity(&a, &b), 0.0);
521    }
522
523    #[test]
524    fn test_tfidf_vectors_populated() {
525        let preset = make_preset_with_threshold(0.0, vec![]);
526        let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
527        let tools = make_tools(5);
528        selector.register_tools(&tools).unwrap();
529        assert_eq!(selector.tfidf_vectors.len(), 5);
530        assert_eq!(selector.total_docs, 5);
531    }
532
533    #[test]
534    fn test_update_tool_unregistered_returns_error() {
535        let preset = make_preset_with_threshold(0.0, vec![]);
536        let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
537        let result = selector.update_tool(&ToolDefinition {
538            id: "nonexistent".to_string(),
539            name: "X".to_string(),
540            description: "desc".to_string(),
541            ..Default::default()
542        });
543        assert!(result.is_err());
544    }
545
546    #[test]
547    fn test_empty_tool_set_returns_defaults() {
548        let defaults = vec!["fallback".to_string()];
549        let preset = make_preset_with_threshold(0.0, defaults.clone());
550        let selector = ToolSelector::new(Path::new(""), &preset).unwrap();
551        let result = selector.select("anything", 5).unwrap();
552        assert_eq!(result, defaults);
553    }
554
555    // -----------------------------------------------------------------------
556    // Property 2: Tool selection cardinality
557    // Validates: Requirements 3.1, 26.3
558    // -----------------------------------------------------------------------
559
560    /// Strategy: generate a tool count in [5, 20] and an intent string.
561    fn arb_tool_count_and_intent() -> impl Strategy<Value = (usize, String)> {
562        (5usize..=20usize, "[a-z ]{5,40}".prop_map(|s| s.trim().to_string()))
563    }
564
565    /// Strategy: generate a small tool count in [1, 4] and an intent string.
566    fn arb_small_tool_count_and_intent() -> impl Strategy<Value = (usize, String)> {
567        (1usize..=4usize, "[a-z ]{5,40}".prop_map(|s| s.trim().to_string()))
568    }
569
570    proptest! {
571        /// **Validates: Requirements 3.1, 26.3**
572        ///
573        /// Property 2: Tool selection cardinality.
574        ///
575        /// For any intent string and any tool set of size >= 5, the ToolSelector
576        /// SHALL return between 3 and 5 tools (inclusive).
577        ///
578        /// For tool sets smaller than 5, it SHALL return at most the size of the
579        /// tool set.
580        #[test]
581        fn prop_tool_selection_cardinality_large(
582            (tool_count, intent) in arb_tool_count_and_intent()
583        ) {
584            // Use threshold = 0.0 so all tools are eligible (no fallback to defaults).
585            let preset = make_preset_with_threshold(0.0, vec![]);
586            let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
587            let tools = make_tools(tool_count);
588            selector.register_tools(&tools).unwrap();
589
590            let result = selector.select(&intent, 5).unwrap();
591
592            prop_assert!(
593                result.len() >= 3,
594                "expected >= 3 tools, got {} (tool_count={}, intent='{}')",
595                result.len(), tool_count, intent
596            );
597            prop_assert!(
598                result.len() <= 5,
599                "expected <= 5 tools, got {} (tool_count={}, intent='{}')",
600                result.len(), tool_count, intent
601            );
602        }
603
604        /// **Validates: Requirements 3.1, 26.3**
605        ///
606        /// Property 2b: For tool sets smaller than 5, ToolSelector returns at most
607        /// the size of the tool set.
608        #[test]
609        fn prop_tool_selection_cardinality_small(
610            (tool_count, intent) in arb_small_tool_count_and_intent()
611        ) {
612            let preset = make_preset_with_threshold(0.0, vec![]);
613            let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
614            let tools = make_tools(tool_count);
615            selector.register_tools(&tools).unwrap();
616
617            let result = selector.select(&intent, 5).unwrap();
618
619            prop_assert!(
620                result.len() <= tool_count,
621                "expected <= {} tools, got {} (intent='{}')",
622                tool_count, result.len(), intent
623            );
624        }
625    }
626}