Skip to main content

sqz_engine/
tool_selector.rs

1use std::collections::{HashMap, HashSet};
2
3use crate::error::{Result, SqzError};
4use crate::preset::Preset;
5use crate::types::ToolId;
6
7/// A tool definition with an id, name, and description.
8#[derive(Debug, Clone)]
9pub struct ToolDefinition {
10    pub id: ToolId,
11    pub name: String,
12    pub description: String,
13}
14
15/// Bag-of-words representation: a set of lowercase word tokens.
16type BagOfWords = HashSet<String>;
17
18/// Tokenize a string into a bag of lowercase words, splitting on whitespace and punctuation.
19fn tokenize(text: &str) -> BagOfWords {
20    text.split(|c: char| !c.is_alphanumeric())
21        .filter(|s| !s.is_empty())
22        .map(|s| s.to_lowercase())
23        .collect()
24}
25
26/// Jaccard similarity between two bags of words: |A ∩ B| / |A ∪ B|.
27/// Returns 0.0 if both sets are empty.
28fn jaccard(a: &BagOfWords, b: &BagOfWords) -> f64 {
29    if a.is_empty() && b.is_empty() {
30        return 0.0;
31    }
32    let intersection = a.intersection(b).count() as f64;
33    let union = a.union(b).count() as f64;
34    if union == 0.0 {
35        0.0
36    } else {
37        intersection / union
38    }
39}
40
41/// Selects 3–5 relevant tools per task using TF-IDF-style word-overlap (Jaccard) similarity.
42///
43/// # Selection rules
44/// - Compute Jaccard similarity between the intent query and each tool description.
45/// - Sort tools by descending similarity score.
46/// - Return between 3 and min(5, tool_count) tools.
47/// - If no tool has similarity > threshold, return the default tool set instead.
48pub struct ToolSelector {
49    /// Bag-of-words for each registered tool description.
50    bags: HashMap<ToolId, BagOfWords>,
51    /// Ordered list of registered tool ids (preserves insertion order for determinism).
52    tool_ids: Vec<ToolId>,
53    /// Similarity threshold below which we fall back to defaults.
54    threshold: f64,
55    /// Default tool ids returned when confidence is low.
56    default_tools: Vec<ToolId>,
57}
58
59impl ToolSelector {
60    /// Create a new `ToolSelector` from a `Preset`.
61    ///
62    /// `model_path` is accepted for API compatibility but is unused because we use
63    /// a TF-IDF/Jaccard approach rather than a neural embedding model.
64    pub fn new(_model_path: &std::path::Path, preset: &Preset) -> Result<Self> {
65        let threshold = preset.tool_selection.similarity_threshold;
66        let default_tools = preset.tool_selection.default_tools.clone();
67        Ok(Self {
68            bags: HashMap::new(),
69            tool_ids: Vec::new(),
70            threshold,
71            default_tools,
72        })
73    }
74
75    /// Register a slice of tools, computing bag-of-words for each description.
76    pub fn register_tools(&mut self, tools: &[ToolDefinition]) -> Result<()> {
77        for tool in tools {
78            let bag = tokenize(&tool.description);
79            if !self.bags.contains_key(&tool.id) {
80                self.tool_ids.push(tool.id.clone());
81            }
82            self.bags.insert(tool.id.clone(), bag);
83        }
84        Ok(())
85    }
86
87    /// Select between 3 and min(5, tool_count) tools whose descriptions best match `intent`.
88    ///
89    /// Returns the default tool set when no tool exceeds the similarity threshold.
90    pub fn select(&self, intent: &str, max_tools: usize) -> Result<Vec<ToolId>> {
91        let tool_count = self.tool_ids.len();
92        if tool_count == 0 {
93            return Ok(self.default_tools.clone());
94        }
95
96        let intent_bag = tokenize(intent);
97
98        // Score every registered tool.
99        let mut scored: Vec<(f64, &ToolId)> = self
100            .tool_ids
101            .iter()
102            .map(|id| {
103                let bag = self.bags.get(id).expect("bag must exist for registered tool");
104                let score = jaccard(&intent_bag, bag);
105                (score, id)
106            })
107            .collect();
108
109        // Sort descending by score, then ascending by id for determinism on ties.
110        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
111            .then_with(|| a.1.cmp(b.1)));
112
113        // Check whether any tool exceeds the threshold.
114        // We use strict less-than so that a score equal to the threshold is still
115        // considered "confident enough" (threshold = 0.0 means always select).
116        let best_score = scored.first().map(|(s, _)| *s).unwrap_or(0.0);
117        if best_score < self.threshold {
118            return Ok(self.default_tools.clone());
119        }
120
121        // Cardinality: return between 3 and min(max_tools, 5, tool_count) tools.
122        let upper = max_tools.min(5).min(tool_count);
123        let lower = 3_usize.min(tool_count);
124        let count = upper.max(lower);
125
126        let result = scored
127            .into_iter()
128            .take(count)
129            .map(|(_, id)| id.clone())
130            .collect();
131
132        Ok(result)
133    }
134
135    /// Re-embed a single tool on description change.
136    pub fn update_tool(&mut self, tool: &ToolDefinition) -> Result<()> {
137        if !self.bags.contains_key(&tool.id) {
138            return Err(SqzError::Other(format!(
139                "tool '{}' is not registered; use register_tools first",
140                tool.id
141            )));
142        }
143        let bag = tokenize(&tool.description);
144        self.bags.insert(tool.id.clone(), bag);
145        Ok(())
146    }
147}
148
149// ---------------------------------------------------------------------------
150// Tests
151// ---------------------------------------------------------------------------
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use proptest::prelude::*;
157    use std::path::Path;
158
159    fn make_preset_with_threshold(threshold: f64, default_tools: Vec<String>) -> Preset {
160        let mut p = Preset::default();
161        p.tool_selection.similarity_threshold = threshold;
162        p.tool_selection.default_tools = default_tools;
163        p
164    }
165
166    fn make_tools(n: usize) -> Vec<ToolDefinition> {
167        (0..n)
168            .map(|i| ToolDefinition {
169                id: format!("tool_{i}"),
170                name: format!("Tool {i}"),
171                description: format!(
172                    "This tool performs operation number {i} for task category alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu nu xi omicron pi rho sigma tau upsilon phi chi psi omega {i}"
173                ),
174            })
175            .collect()
176    }
177
178    // -----------------------------------------------------------------------
179    // Unit tests
180    // -----------------------------------------------------------------------
181
182    #[test]
183    fn test_tokenize_basic() {
184        let bag = tokenize("hello world foo");
185        assert!(bag.contains("hello"));
186        assert!(bag.contains("world"));
187        assert!(bag.contains("foo"));
188    }
189
190    #[test]
191    fn test_tokenize_punctuation() {
192        let bag = tokenize("read_file: reads a file.");
193        assert!(bag.contains("read"));
194        assert!(bag.contains("file"));
195        assert!(bag.contains("reads"));
196        assert!(bag.contains("a"));
197    }
198
199    #[test]
200    fn test_jaccard_identical() {
201        let a = tokenize("read file");
202        let b = tokenize("read file");
203        assert!((jaccard(&a, &b) - 1.0).abs() < 1e-9);
204    }
205
206    #[test]
207    fn test_jaccard_disjoint() {
208        let a = tokenize("alpha beta");
209        let b = tokenize("gamma delta");
210        assert!((jaccard(&a, &b)).abs() < 1e-9);
211    }
212
213    #[test]
214    fn test_select_returns_between_3_and_5_for_large_set() {
215        let preset = make_preset_with_threshold(0.0, vec![]);
216        let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
217        let tools = make_tools(10);
218        selector.register_tools(&tools).unwrap();
219
220        let result = selector.select("operation task alpha beta", 5).unwrap();
221        assert!(result.len() >= 3, "expected >= 3, got {}", result.len());
222        assert!(result.len() <= 5, "expected <= 5, got {}", result.len());
223    }
224
225    #[test]
226    fn test_select_returns_at_most_tool_count_for_small_set() {
227        let preset = make_preset_with_threshold(0.0, vec![]);
228        let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
229        let tools = make_tools(2);
230        selector.register_tools(&tools).unwrap();
231
232        let result = selector.select("operation task", 5).unwrap();
233        assert!(result.len() <= 2, "expected <= 2, got {}", result.len());
234    }
235
236    #[test]
237    fn test_fallback_to_defaults_on_low_confidence() {
238        let defaults = vec!["default_a".to_string(), "default_b".to_string()];
239        // threshold = 1.0 means nothing will ever exceed it (Jaccard < 1 unless identical)
240        let preset = make_preset_with_threshold(1.0, defaults.clone());
241        let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
242        let tools = make_tools(5);
243        selector.register_tools(&tools).unwrap();
244
245        let result = selector.select("completely unrelated xyz", 5).unwrap();
246        assert_eq!(result, defaults);
247    }
248
249    #[test]
250    fn test_update_tool_changes_embedding() {
251        let preset = make_preset_with_threshold(0.0, vec![]);
252        let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
253        let tools = vec![ToolDefinition {
254            id: "t1".to_string(),
255            name: "T1".to_string(),
256            description: "alpha beta gamma".to_string(),
257        }];
258        selector.register_tools(&tools).unwrap();
259
260        let updated = ToolDefinition {
261            id: "t1".to_string(),
262            name: "T1".to_string(),
263            description: "delta epsilon zeta".to_string(),
264        };
265        selector.update_tool(&updated).unwrap();
266
267        let bag = selector.bags.get("t1").unwrap();
268        assert!(bag.contains("delta"));
269        assert!(!bag.contains("alpha"));
270    }
271
272    #[test]
273    fn test_update_tool_unregistered_returns_error() {
274        let preset = make_preset_with_threshold(0.0, vec![]);
275        let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
276        let result = selector.update_tool(&ToolDefinition {
277            id: "nonexistent".to_string(),
278            name: "X".to_string(),
279            description: "desc".to_string(),
280        });
281        assert!(result.is_err());
282    }
283
284    #[test]
285    fn test_empty_tool_set_returns_defaults() {
286        let defaults = vec!["fallback".to_string()];
287        let preset = make_preset_with_threshold(0.0, defaults.clone());
288        let selector = ToolSelector::new(Path::new(""), &preset).unwrap();
289        let result = selector.select("anything", 5).unwrap();
290        assert_eq!(result, defaults);
291    }
292
293    // -----------------------------------------------------------------------
294    // Property 2: Tool selection cardinality
295    // Validates: Requirements 3.1, 26.3
296    // -----------------------------------------------------------------------
297
298    /// Strategy: generate a tool count in [5, 20] and an intent string.
299    fn arb_tool_count_and_intent() -> impl Strategy<Value = (usize, String)> {
300        (5usize..=20usize, "[a-z ]{5,40}".prop_map(|s| s.trim().to_string()))
301    }
302
303    /// Strategy: generate a small tool count in [1, 4] and an intent string.
304    fn arb_small_tool_count_and_intent() -> impl Strategy<Value = (usize, String)> {
305        (1usize..=4usize, "[a-z ]{5,40}".prop_map(|s| s.trim().to_string()))
306    }
307
308    proptest! {
309        /// **Validates: Requirements 3.1, 26.3**
310        ///
311        /// Property 2: Tool selection cardinality.
312        ///
313        /// For any intent string and any tool set of size >= 5, the ToolSelector
314        /// SHALL return between 3 and 5 tools (inclusive).
315        ///
316        /// For tool sets smaller than 5, it SHALL return at most the size of the
317        /// tool set.
318        #[test]
319        fn prop_tool_selection_cardinality_large(
320            (tool_count, intent) in arb_tool_count_and_intent()
321        ) {
322            // Use threshold = 0.0 so all tools are eligible (no fallback to defaults).
323            let preset = make_preset_with_threshold(0.0, vec![]);
324            let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
325            let tools = make_tools(tool_count);
326            selector.register_tools(&tools).unwrap();
327
328            let result = selector.select(&intent, 5).unwrap();
329
330            prop_assert!(
331                result.len() >= 3,
332                "expected >= 3 tools, got {} (tool_count={}, intent='{}')",
333                result.len(), tool_count, intent
334            );
335            prop_assert!(
336                result.len() <= 5,
337                "expected <= 5 tools, got {} (tool_count={}, intent='{}')",
338                result.len(), tool_count, intent
339            );
340        }
341
342        /// **Validates: Requirements 3.1, 26.3**
343        ///
344        /// Property 2b: For tool sets smaller than 5, ToolSelector returns at most
345        /// the size of the tool set.
346        #[test]
347        fn prop_tool_selection_cardinality_small(
348            (tool_count, intent) in arb_small_tool_count_and_intent()
349        ) {
350            let preset = make_preset_with_threshold(0.0, vec![]);
351            let mut selector = ToolSelector::new(Path::new(""), &preset).unwrap();
352            let tools = make_tools(tool_count);
353            selector.register_tools(&tools).unwrap();
354
355            let result = selector.select(&intent, 5).unwrap();
356
357            prop_assert!(
358                result.len() <= tool_count,
359                "expected <= {} tools, got {} (intent='{}')",
360                tool_count, result.len(), intent
361            );
362        }
363    }
364}