Skip to main content

statespace_tool_runtime/
spec.rs

1//! Tool specification parsing and validation.
2//!
3//! ```yaml
4//! tools:
5//!   - [ls]                                 # Simple command, extra args allowed
6//!   - [cat, { }]                           # Placeholder accepts any value
7//!   - [cat, { regex: ".*\\.md$" }]         # Regex-constrained placeholder
8//!   - [psql, -c, { regex: "^SELECT" }, ;]  # Trailing ; disables extra args
9//! ```
10
11use fancy_regex::Regex;
12
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum ToolPart {
15    Literal(String),
16    Placeholder { regex: Option<CompiledRegex> },
17}
18
19#[derive(Debug, Clone)]
20pub struct CompiledRegex {
21    pub pattern: String,
22    pub regex: Regex,
23}
24
25impl PartialEq for CompiledRegex {
26    fn eq(&self, other: &Self) -> bool {
27        self.pattern == other.pattern
28    }
29}
30
31impl Eq for CompiledRegex {}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct ToolSpec {
35    pub parts: Vec<ToolPart>,
36    pub options_disabled: bool,
37}
38
39#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
40#[non_exhaustive]
41pub enum SpecError {
42    #[error("invalid regex pattern '{pattern}': {message}")]
43    InvalidRegex { pattern: String, message: String },
44    #[error("empty tool specification")]
45    EmptySpec,
46    #[error("invalid tool part: {0}")]
47    InvalidPart(String),
48}
49
50pub type SpecResult<T> = Result<T, SpecError>;
51
52impl ToolSpec {
53    /// # Errors
54    ///
55    /// Returns `SpecError` when the tool specification is empty or invalid.
56    pub fn parse(raw: &[serde_json::Value]) -> SpecResult<Self> {
57        if raw.is_empty() {
58            return Err(SpecError::EmptySpec);
59        }
60
61        let options_disabled = raw.last().is_some_and(|v| v.as_str() == Some(";"));
62
63        let parts = raw
64            .iter()
65            .filter(|v| v.as_str() != Some(";"))
66            .map(Self::parse_part)
67            .collect::<SpecResult<Vec<_>>>()?;
68
69        if parts.is_empty() {
70            return Err(SpecError::EmptySpec);
71        }
72
73        Ok(Self {
74            parts,
75            options_disabled,
76        })
77    }
78
79    fn parse_part(value: &serde_json::Value) -> SpecResult<ToolPart> {
80        match value {
81            serde_json::Value::String(s) => Ok(ToolPart::Literal(s.clone())),
82
83            serde_json::Value::Object(obj) => {
84                if obj.is_empty() {
85                    return Ok(ToolPart::Placeholder { regex: None });
86                }
87
88                if let Some(pattern) = obj.get("regex").and_then(|v| v.as_str()) {
89                    let regex = Regex::new(pattern).map_err(|e| SpecError::InvalidRegex {
90                        pattern: pattern.to_string(),
91                        message: e.to_string(),
92                    })?;
93                    return Ok(ToolPart::Placeholder {
94                        regex: Some(CompiledRegex {
95                            pattern: pattern.to_string(),
96                            regex,
97                        }),
98                    });
99                }
100
101                Err(SpecError::InvalidPart(format!(
102                    "unknown object keys: {:?}",
103                    obj.keys().collect::<Vec<_>>()
104                )))
105            }
106
107            _ => Err(SpecError::InvalidPart(format!(
108                "expected string or object, got: {value}"
109            ))),
110        }
111    }
112}
113
114#[must_use]
115pub fn is_valid_tool_call(command: &[String], specs: &[ToolSpec]) -> bool {
116    if command.is_empty() {
117        return false;
118    }
119    find_matching_spec(command, specs).is_some()
120}
121
122#[must_use]
123pub fn find_matching_spec<'a>(command: &[String], specs: &'a [ToolSpec]) -> Option<&'a ToolSpec> {
124    specs.iter().find(|spec| matches_spec(command, spec))
125}
126
127fn matches_spec(command: &[String], spec: &ToolSpec) -> bool {
128    if command.len() < spec.parts.len() {
129        return false;
130    }
131
132    if command.len() > spec.parts.len() && spec.options_disabled {
133        return false;
134    }
135
136    for (i, part) in spec.parts.iter().enumerate() {
137        let cmd_part = &command[i];
138
139        match part {
140            ToolPart::Literal(lit) => {
141                if cmd_part != lit {
142                    return false;
143                }
144            }
145            ToolPart::Placeholder { regex: None } => {}
146            ToolPart::Placeholder {
147                regex: Some(compiled),
148            } => {
149                if !compiled.regex.is_match(cmd_part).unwrap_or(false) {
150                    return false;
151                }
152            }
153        }
154    }
155
156    true
157}
158
159#[cfg(test)]
160#[allow(clippy::unwrap_used)]
161mod tests {
162    use super::*;
163
164    fn make_spec(parts: Vec<ToolPart>, options_disabled: bool) -> ToolSpec {
165        ToolSpec {
166            parts,
167            options_disabled,
168        }
169    }
170
171    fn lit(s: &str) -> ToolPart {
172        ToolPart::Literal(s.to_string())
173    }
174
175    fn placeholder() -> ToolPart {
176        ToolPart::Placeholder { regex: None }
177    }
178
179    fn regex_placeholder(pattern: &str) -> ToolPart {
180        ToolPart::Placeholder {
181            regex: Some(CompiledRegex {
182                pattern: pattern.to_string(),
183                regex: Regex::new(pattern).unwrap(),
184            }),
185        }
186    }
187
188    #[test]
189    fn validate_simple_match() {
190        let specs = vec![make_spec(vec![lit("ls")], false)];
191        assert!(is_valid_tool_call(&["ls".to_string()], &specs));
192    }
193
194    #[test]
195    fn validate_with_extra_args_allowed() {
196        let specs = vec![make_spec(vec![lit("ls")], false)];
197        assert!(is_valid_tool_call(
198            &["ls".to_string(), "-la".to_string()],
199            &specs
200        ));
201    }
202
203    #[test]
204    fn validate_with_extra_args_disabled() {
205        let specs = vec![make_spec(vec![lit("ls")], true)];
206        assert!(!is_valid_tool_call(
207            &["ls".to_string(), "-la".to_string()],
208            &specs
209        ));
210    }
211
212    #[test]
213    fn validate_placeholder_matches_any() {
214        let specs = vec![make_spec(vec![lit("cat"), placeholder()], false)];
215
216        assert!(is_valid_tool_call(
217            &["cat".to_string(), "file.txt".to_string()],
218            &specs
219        ));
220        assert!(is_valid_tool_call(
221            &["cat".to_string(), "anything".to_string()],
222            &specs
223        ));
224    }
225
226    #[test]
227    fn validate_regex_placeholder() {
228        let specs = vec![make_spec(
229            vec![lit("cat"), regex_placeholder(r".*\.md$")],
230            false,
231        )];
232
233        assert!(is_valid_tool_call(
234            &["cat".to_string(), "README.md".to_string()],
235            &specs
236        ));
237        assert!(!is_valid_tool_call(
238            &["cat".to_string(), "README.txt".to_string()],
239            &specs
240        ));
241    }
242
243    #[test]
244    fn validate_regex_with_options_disabled() {
245        let specs = vec![make_spec(
246            vec![lit("cat"), regex_placeholder(r".*\.md$")],
247            true,
248        )];
249
250        assert!(is_valid_tool_call(
251            &["cat".to_string(), "file.md".to_string()],
252            &specs
253        ));
254
255        assert!(!is_valid_tool_call(
256            &["cat".to_string(), "file.md".to_string(), "-n".to_string()],
257            &specs
258        ));
259
260        assert!(!is_valid_tool_call(
261            &["cat".to_string(), "file.txt".to_string()],
262            &specs
263        ));
264    }
265
266    #[test]
267    fn validate_complex_psql_spec() {
268        let specs = vec![make_spec(
269            vec![lit("psql"), lit("-c"), regex_placeholder("^SELECT")],
270            true,
271        )];
272
273        assert!(is_valid_tool_call(
274            &[
275                "psql".to_string(),
276                "-c".to_string(),
277                "SELECT * FROM users".to_string()
278            ],
279            &specs
280        ));
281
282        assert!(!is_valid_tool_call(
283            &[
284                "psql".to_string(),
285                "-c".to_string(),
286                "INSERT INTO users VALUES (1)".to_string()
287            ],
288            &specs
289        ));
290
291        assert!(!is_valid_tool_call(
292            &[
293                "psql".to_string(),
294                "-c".to_string(),
295                "SELECT 1".to_string(),
296                "--extra".to_string()
297            ],
298            &specs
299        ));
300    }
301
302    #[test]
303    fn validate_empty_command() {
304        let specs = vec![make_spec(vec![lit("ls")], false)];
305        assert!(!is_valid_tool_call(&[], &specs));
306    }
307
308    #[test]
309    fn validate_placeholder_is_required() {
310        let specs = vec![make_spec(vec![lit("ls"), placeholder()], false)];
311
312        assert!(!is_valid_tool_call(&["ls".into()], &specs));
313        assert!(is_valid_tool_call(&["ls".into(), "dir".into()], &specs));
314        assert!(is_valid_tool_call(
315            &["ls".into(), "dir".into(), "-la".into()],
316            &specs
317        ));
318    }
319
320    #[test]
321    fn validate_multiple_specs() {
322        let specs = vec![
323            make_spec(vec![lit("ls")], false),
324            make_spec(vec![lit("cat"), placeholder()], false),
325        ];
326
327        assert!(is_valid_tool_call(&["ls".to_string()], &specs));
328        assert!(is_valid_tool_call(
329            &["cat".to_string(), "file.txt".to_string()],
330            &specs
331        ));
332        assert!(!is_valid_tool_call(&["rm".to_string()], &specs));
333    }
334}