Skip to main content

sqz_engine/
tool_selector.rs

1use std::collections::{HashMap, HashSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::error::{Result, SqzError};
6use crate::preset::Preset;
7use crate::types::ToolId;
8
9/// A tool definition with an id, name, and description.
10#[derive(Debug, Clone, Serialize, Deserialize, Default)]
11pub struct ToolDefinition {
12    pub id: ToolId,
13    pub name: String,
14    pub description: String,
15    /// JSON Schema for the tool's input parameters.
16    #[serde(default)]
17    pub input_schema: serde_json::Value,
18    /// JSON Schema describing the structure of the compressed output.
19    #[serde(default)]
20    pub output_schema: serde_json::Value,
21    /// Description of what sqz does to this tool's output.
22    #[serde(default)]
23    pub compression_transforms: Vec<String>,
24}
25
26/// Bag-of-words representation: a set of lowercase word tokens.
27type BagOfWords = HashSet<String>;
28
29/// Tokenize a string into a bag of lowercase words, splitting on whitespace and punctuation.
30fn tokenize(text: &str) -> BagOfWords {
31    text.split(|c: char| !c.is_alphanumeric())
32        .filter(|s| !s.is_empty())
33        .map(|s| s.to_lowercase())
34        .collect()
35}
36
37/// Jaccard similarity between two bags of words: |A ∩ B| / |A ∪ B|.
38/// Returns 0.0 if both sets are empty.
39fn jaccard(a: &BagOfWords, b: &BagOfWords) -> f64 {
40    if a.is_empty() && b.is_empty() {
41        return 0.0;
42    }
43    let intersection = a.intersection(b).count() as f64;
44    let union = a.union(b).count() as f64;
45    if union == 0.0 {
46        0.0
47    } else {
48        intersection / union
49    }
50}
51
52/// Selects 3–5 relevant tools per task using TF-IDF-style word-overlap (Jaccard) similarity.
53///
54/// # Selection rules
55/// - Compute Jaccard similarity between the intent query and each tool description.
56/// - Sort tools by descending similarity score.
57/// - Return between 3 and min(5, tool_count) tools.
58/// - If no tool has similarity > threshold, return the default tool set instead.
59pub struct ToolSelector {
60    /// Bag-of-words for each registered tool description.
61    bags: HashMap<ToolId, BagOfWords>,
62    /// Ordered list of registered tool ids (preserves insertion order for determinism).
63    tool_ids: Vec<ToolId>,
64    /// Similarity threshold below which we fall back to defaults.
65    threshold: f64,
66    /// Default tool ids returned when confidence is low.
67    default_tools: Vec<ToolId>,
68}
69
70impl ToolSelector {
71    /// Create a new `ToolSelector` from a `Preset`.
72    ///
73    /// `model_path` is accepted for API compatibility but is unused because we use
74    /// a TF-IDF/Jaccard approach rather than a neural embedding model.
75    pub fn new(_model_path: &std::path::Path, preset: &Preset) -> Result<Self> {
76        let threshold = preset.tool_selection.similarity_threshold;
77        let default_tools = preset.tool_selection.default_tools.clone();
78        Ok(Self {
79            bags: HashMap::new(),
80            tool_ids: Vec::new(),
81            threshold,
82            default_tools,
83        })
84    }
85
86    /// Register a slice of tools, computing bag-of-words for each description.
87    pub fn register_tools(&mut self, tools: &[ToolDefinition]) -> Result<()> {
88        for tool in tools {
89            let bag = tokenize(&tool.description);
90            if !self.bags.contains_key(&tool.id) {
91                self.tool_ids.push(tool.id.clone());
92            }
93            self.bags.insert(tool.id.clone(), bag);
94        }
95        Ok(())
96    }
97
98    /// Select between 3 and min(5, tool_count) tools whose descriptions best match `intent`.
99    ///
100    /// Returns the default tool set when no tool exceeds the similarity threshold.
101    pub fn select(&self, intent: &str, max_tools: usize) -> Result<Vec<ToolId>> {
102        let tool_count = self.tool_ids.len();
103        if tool_count == 0 {
104            return Ok(self.default_tools.clone());
105        }
106
107        let intent_bag = tokenize(intent);
108
109        // Score every registered tool.
110        let mut scored: Vec<(f64, &ToolId)> = self
111            .tool_ids
112            .iter()
113            .map(|id| {
114                let bag = self.bags.get(id).expect("bag must exist for registered tool");
115                let score = jaccard(&intent_bag, bag);
116                (score, id)
117            })
118            .collect();
119
120        // Sort descending by score, then ascending by id for determinism on ties.
121        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
122            .then_with(|| a.1.cmp(b.1)));
123
124        // Check whether any tool exceeds the threshold.
125        // We use strict less-than so that a score equal to the threshold is still
126        // considered "confident enough" (threshold = 0.0 means always select).
127        let best_score = scored.first().map(|(s, _)| *s).unwrap_or(0.0);
128        if best_score < self.threshold {
129            return Ok(self.default_tools.clone());
130        }
131
132        // Cardinality: return between 3 and min(max_tools, 5, tool_count) tools.
133        let upper = max_tools.min(5).min(tool_count);
134        let lower = 3_usize.min(tool_count);
135        let count = upper.max(lower);
136
137        let result = scored
138            .into_iter()
139            .take(count)
140            .map(|(_, id)| id.clone())
141            .collect();
142
143        Ok(result)
144    }
145
146    /// Re-embed a single tool on description change.
147    pub fn update_tool(&mut self, tool: &ToolDefinition) -> Result<()> {
148        if !self.bags.contains_key(&tool.id) {
149            return Err(SqzError::Other(format!(
150                "tool '{}' is not registered; use register_tools first",
151                tool.id
152            )));
153        }
154        let bag = tokenize(&tool.description);
155        self.bags.insert(tool.id.clone(), bag);
156        Ok(())
157    }
158}
159
160// ---------------------------------------------------------------------------
161// Tests
162// ---------------------------------------------------------------------------
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use proptest::prelude::*;
168    use std::path::Path;
169
170    fn make_preset_with_threshold(threshold: f64, default_tools: Vec<String>) -> Preset {
171        let mut p = Preset::default();
172        p.tool_selection.similarity_threshold = threshold;
173        p.tool_selection.default_tools = default_tools;
174        p
175    }
176
177    fn make_tools(n: usize) -> Vec<ToolDefinition> {
178        (0..n)
179            .map(|i| ToolDefinition {
180                id: format!("tool_{i}"),
181                name: format!("Tool {i}"),
182                description: format!(
183                    "This tool performs operation number {i} for task category alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu nu xi omicron pi rho sigma tau upsilon phi chi psi omega {i}"
184                ),
185                ..Default::default()
186            })
187            .collect()
188    }
189
190    // -----------------------------------------------------------------------
191    // Unit tests
192    // -----------------------------------------------------------------------
193
194    #[test]
195    fn test_tokenize_basic() {
196        let bag = tokenize("hello world foo");
197        assert!(bag.contains("hello"));
198        assert!(bag.contains("world"));
199        assert!(bag.contains("foo"));
200    }
201
202    #[test]
203    fn test_tokenize_punctuation() {
204        let bag = tokenize("read_file: reads a file.");
205        assert!(bag.contains("read"));
206        assert!(bag.contains("file"));
207        assert!(bag.contains("reads"));
208        assert!(bag.contains("a"));
209    }
210
211    #[test]
212    fn test_jaccard_identical() {
213        let a = tokenize("read file");
214        let b = tokenize("read file");
215        assert!((jaccard(&a, &b) - 1.0).abs() < 1e-9);
216    }
217
218    #[test]
219    fn test_jaccard_disjoint() {
220        let a = tokenize("alpha beta");
221        let b = tokenize("gamma delta");
222        assert!((jaccard(&a, &b)).abs() < 1e-9);
223    }
224
225    #[test]
226    fn test_select_returns_between_3_and_5_for_large_set() {
227        let preset = make_preset_with_threshold(0.0, vec![]);
228        let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
229        let tools = make_tools(10);
230        selector.register_tools(&tools).unwrap();
231
232        let result = selector.select("operation task alpha beta", 5).unwrap();
233        assert!(result.len() >= 3, "expected >= 3, got {}", result.len());
234        assert!(result.len() <= 5, "expected <= 5, got {}", result.len());
235    }
236
237    #[test]
238    fn test_select_returns_at_most_tool_count_for_small_set() {
239        let preset = make_preset_with_threshold(0.0, vec![]);
240        let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
241        let tools = make_tools(2);
242        selector.register_tools(&tools).unwrap();
243
244        let result = selector.select("operation task", 5).unwrap();
245        assert!(result.len() <= 2, "expected <= 2, got {}", result.len());
246    }
247
248    #[test]
249    fn test_fallback_to_defaults_on_low_confidence() {
250        let defaults = vec!["default_a".to_string(), "default_b".to_string()];
251        // threshold = 1.0 means nothing will ever exceed it (Jaccard < 1 unless identical)
252        let preset = make_preset_with_threshold(1.0, defaults.clone());
253        let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
254        let tools = make_tools(5);
255        selector.register_tools(&tools).unwrap();
256
257        let result = selector.select("completely unrelated xyz", 5).unwrap();
258        assert_eq!(result, defaults);
259    }
260
261    #[test]
262    fn test_update_tool_changes_embedding() {
263        let preset = make_preset_with_threshold(0.0, vec![]);
264        let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
265        let tools = vec![ToolDefinition {
266            id: "t1".to_string(),
267            name: "T1".to_string(),
268            description: "alpha beta gamma".to_string(),
269            ..Default::default()
270        }];
271        selector.register_tools(&tools).unwrap();
272
273        let updated = ToolDefinition {
274            id: "t1".to_string(),
275            name: "T1".to_string(),
276            description: "delta epsilon zeta".to_string(),
277            ..Default::default()
278        };
279        selector.update_tool(&updated).unwrap();
280
281        let bag = selector.bags.get("t1").unwrap();
282        assert!(bag.contains("delta"));
283        assert!(!bag.contains("alpha"));
284    }
285
286    #[test]
287    fn test_update_tool_unregistered_returns_error() {
288        let preset = make_preset_with_threshold(0.0, vec![]);
289        let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
290        let result = selector.update_tool(&ToolDefinition {
291            id: "nonexistent".to_string(),
292            name: "X".to_string(),
293            description: "desc".to_string(),
294            ..Default::default()
295        });
296        assert!(result.is_err());
297    }
298
299    #[test]
300    fn test_empty_tool_set_returns_defaults() {
301        let defaults = vec!["fallback".to_string()];
302        let preset = make_preset_with_threshold(0.0, defaults.clone());
303        let selector = ToolSelector::new(Path::new(""), &preset).unwrap();
304        let result = selector.select("anything", 5).unwrap();
305        assert_eq!(result, defaults);
306    }
307
308    // -----------------------------------------------------------------------
309    // Property 2: Tool selection cardinality
310    // Validates: Requirements 3.1, 26.3
311    // -----------------------------------------------------------------------
312
313    /// Strategy: generate a tool count in [5, 20] and an intent string.
314    fn arb_tool_count_and_intent() -> impl Strategy<Value = (usize, String)> {
315        (5usize..=20usize, "[a-z ]{5,40}".prop_map(|s| s.trim().to_string()))
316    }
317
318    /// Strategy: generate a small tool count in [1, 4] and an intent string.
319    fn arb_small_tool_count_and_intent() -> impl Strategy<Value = (usize, String)> {
320        (1usize..=4usize, "[a-z ]{5,40}".prop_map(|s| s.trim().to_string()))
321    }
322
323    proptest! {
324        /// **Validates: Requirements 3.1, 26.3**
325        ///
326        /// Property 2: Tool selection cardinality.
327        ///
328        /// For any intent string and any tool set of size >= 5, the ToolSelector
329        /// SHALL return between 3 and 5 tools (inclusive).
330        ///
331        /// For tool sets smaller than 5, it SHALL return at most the size of the
332        /// tool set.
333        #[test]
334        fn prop_tool_selection_cardinality_large(
335            (tool_count, intent) in arb_tool_count_and_intent()
336        ) {
337            // Use threshold = 0.0 so all tools are eligible (no fallback to defaults).
338            let preset = make_preset_with_threshold(0.0, vec![]);
339            let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
340            let tools = make_tools(tool_count);
341            selector.register_tools(&tools).unwrap();
342
343            let result = selector.select(&intent, 5).unwrap();
344
345            prop_assert!(
346                result.len() >= 3,
347                "expected >= 3 tools, got {} (tool_count={}, intent='{}')",
348                result.len(), tool_count, intent
349            );
350            prop_assert!(
351                result.len() <= 5,
352                "expected <= 5 tools, got {} (tool_count={}, intent='{}')",
353                result.len(), tool_count, intent
354            );
355        }
356
357        /// **Validates: Requirements 3.1, 26.3**
358        ///
359        /// Property 2b: For tool sets smaller than 5, ToolSelector returns at most
360        /// the size of the tool set.
361        #[test]
362        fn prop_tool_selection_cardinality_small(
363            (tool_count, intent) in arb_small_tool_count_and_intent()
364        ) {
365            let preset = make_preset_with_threshold(0.0, vec![]);
366            let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
367            let tools = make_tools(tool_count);
368            selector.register_tools(&tools).unwrap();
369
370            let result = selector.select(&intent, 5).unwrap();
371
372            prop_assert!(
373                result.len() <= tool_count,
374                "expected <= {} tools, got {} (intent='{}')",
375                tool_count, result.len(), intent
376            );
377        }
378    }
379}