Skip to main content

statespace_tool_runtime/
tools.rs

1//! Tool domain models
2
3use crate::error::Error;
4use serde::{Deserialize, Serialize};
5use std::fmt;
6use std::str::FromStr;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
9#[serde(rename_all = "UPPERCASE")]
10#[non_exhaustive]
11pub enum HttpMethod {
12    #[default]
13    Get,
14    Post,
15    Put,
16    Patch,
17    Delete,
18    Head,
19    Options,
20}
21
22impl HttpMethod {
23    #[must_use]
24    pub const fn as_str(&self) -> &'static str {
25        match self {
26            Self::Get => "GET",
27            Self::Post => "POST",
28            Self::Put => "PUT",
29            Self::Patch => "PATCH",
30            Self::Delete => "DELETE",
31            Self::Head => "HEAD",
32            Self::Options => "OPTIONS",
33        }
34    }
35}
36
37impl fmt::Display for HttpMethod {
38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39        write!(f, "{}", self.as_str())
40    }
41}
42
43impl FromStr for HttpMethod {
44    type Err = Error;
45
46    fn from_str(s: &str) -> Result<Self, Self::Err> {
47        match s.to_uppercase().as_str() {
48            "GET" => Ok(Self::Get),
49            "POST" => Ok(Self::Post),
50            "PUT" => Ok(Self::Put),
51            "PATCH" => Ok(Self::Patch),
52            "DELETE" => Ok(Self::Delete),
53            "HEAD" => Ok(Self::Head),
54            "OPTIONS" => Ok(Self::Options),
55            _ => Err(Error::InvalidCommand(format!("Unknown HTTP method: {s}"))),
56        }
57    }
58}
59
60#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
61#[serde(tag = "type", rename_all = "lowercase")]
62#[non_exhaustive]
63pub enum BuiltinTool {
64    Glob { pattern: String },
65    Curl { url: String, method: HttpMethod },
66    Exec { command: String, args: Vec<String> },
67}
68
69impl BuiltinTool {
70    /// # Errors
71    ///
72    /// Returns an error when the command is empty or malformed.
73    pub fn from_command(command: &[String]) -> Result<Self, Error> {
74        if command.is_empty() {
75            return Err(Error::InvalidCommand("Command cannot be empty".to_string()));
76        }
77
78        match command[0].as_str() {
79            "glob" => {
80                if command.len() < 2 {
81                    return Err(Error::InvalidCommand(
82                        "glob requires a pattern argument".to_string(),
83                    ));
84                }
85                Ok(Self::Glob {
86                    pattern: command[1].clone(),
87                })
88            }
89            "curl" => Self::parse_curl(&command[1..]),
90            cmd => Ok(Self::Exec {
91                command: cmd.to_string(),
92                args: command[1..].to_vec(),
93            }),
94        }
95    }
96
97    fn parse_curl(args: &[String]) -> Result<Self, Error> {
98        #[derive(Debug)]
99        struct CurlArgs {
100            url: Option<String>,
101            method: Option<String>,
102        }
103
104        let parsed = args.iter().try_fold(
105            (
106                CurlArgs {
107                    url: None,
108                    method: None,
109                },
110                None::<&str>,
111            ),
112            |(mut acc, expecting_value), arg| match expecting_value {
113                Some("-X" | "--request") => {
114                    acc.method = Some(arg.clone());
115                    Ok((acc, None))
116                }
117                Some(flag) => Err(Error::InvalidCommand(format!("Unknown flag: {flag}"))),
118                None if arg == "-X" || arg == "--request" => Ok((acc, Some(arg.as_str()))),
119                None if !arg.starts_with('-') && acc.url.is_none() => {
120                    acc.url = Some(arg.clone());
121                    Ok((acc, None))
122                }
123                None if arg.starts_with('-') => {
124                    Err(Error::InvalidCommand(format!("Unknown flag: {arg}")))
125                }
126                None => Ok((acc, None)),
127            },
128        );
129
130        let (args, expecting) = parsed?;
131
132        if let Some(flag) = expecting {
133            return Err(Error::InvalidCommand(format!(
134                "{flag} requires a method argument"
135            )));
136        }
137
138        let url = args
139            .url
140            .ok_or_else(|| Error::InvalidCommand("curl requires a URL argument".to_string()))?;
141
142        let method = match args.method {
143            Some(m) => m.parse()?,
144            None => HttpMethod::default(),
145        };
146
147        Ok(Self::Curl { url, method })
148    }
149
150    #[must_use]
151    pub const fn name(&self) -> &'static str {
152        match self {
153            Self::Glob { .. } => "glob",
154            Self::Curl { .. } => "curl",
155            Self::Exec { .. } => "exec",
156        }
157    }
158
159    pub const fn requires_egress(&self) -> bool {
160        matches!(self, Self::Curl { .. })
161    }
162
163    pub fn is_free_tier_allowed(&self) -> bool {
164        match self {
165            Self::Glob { .. } => true,
166            Self::Curl { .. } => false,
167            Self::Exec { command, .. } => FREE_TIER_COMMAND_ALLOWLIST.contains(&command.as_str()),
168        }
169    }
170}
171
172pub const FREE_TIER_COMMAND_ALLOWLIST: &[&str] = &[
173    "cat",
174    "head",
175    "tail",
176    "less",
177    "more",
178    "wc",
179    "sort",
180    "uniq",
181    "cut",
182    "paste",
183    "tr",
184    "tee",
185    "split",
186    "csplit",
187    "ls",
188    "stat",
189    "file",
190    "du",
191    "df",
192    "find",
193    "which",
194    "whereis",
195    "cp",
196    "mv",
197    "rm",
198    "mkdir",
199    "rmdir",
200    "touch",
201    "ln",
202    "grep",
203    "egrep",
204    "fgrep",
205    "sed",
206    "awk",
207    "diff",
208    "comm",
209    "cmp",
210    "jq",
211    "tar",
212    "gzip",
213    "gunzip",
214    "zcat",
215    "bzip2",
216    "bunzip2",
217    "xz",
218    "unxz",
219    "echo",
220    "printf",
221    "true",
222    "false",
223    "yes",
224    "date",
225    "cal",
226    "env",
227    "printenv",
228    "basename",
229    "dirname",
230    "realpath",
231    "readlink",
232    "pwd",
233    "id",
234    "whoami",
235    "uname",
236    "hostname",
237    "md5sum",
238    "sha256sum",
239    "base64",
240    "xxd",
241    "hexdump",
242    "od",
243];
244
245#[cfg(test)]
246#[allow(clippy::unwrap_used)]
247mod tests {
248    use super::*;
249
250    #[test]
251    fn test_builtin_tool_name() {
252        let exec = BuiltinTool::Exec {
253            command: "ls".to_string(),
254            args: vec![],
255        };
256        assert_eq!(exec.name(), "exec");
257    }
258
259    #[test]
260    fn test_from_command_ls() {
261        let tool = BuiltinTool::from_command(&["ls".to_string(), "docs/".to_string()]).unwrap();
262        assert!(matches!(
263            tool,
264            BuiltinTool::Exec { command, args } if command == "ls" && args == vec!["docs/"]
265        ));
266    }
267
268    #[test]
269    fn test_from_command_cat() {
270        let tool = BuiltinTool::from_command(&["cat".to_string(), "file.md".to_string()]).unwrap();
271        assert!(matches!(
272            tool,
273            BuiltinTool::Exec { command, args } if command == "cat" && args == vec!["file.md"]
274        ));
275
276        let tool = BuiltinTool::from_command(&["cat".to_string()]).unwrap();
277        assert!(matches!(
278            tool,
279            BuiltinTool::Exec { command, args } if command == "cat" && args.is_empty()
280        ));
281    }
282
283    #[test]
284    fn test_from_command_glob() {
285        let tool = BuiltinTool::from_command(&["glob".to_string(), "*.md".to_string()]).unwrap();
286        assert_eq!(
287            tool,
288            BuiltinTool::Glob {
289                pattern: "*.md".to_string()
290            }
291        );
292    }
293
294    #[test]
295    fn test_from_command_curl() {
296        let tool =
297            BuiltinTool::from_command(&["curl".to_string(), "https://api.github.com".to_string()])
298                .unwrap();
299        assert_eq!(
300            tool,
301            BuiltinTool::Curl {
302                url: "https://api.github.com".to_string(),
303                method: HttpMethod::Get,
304            }
305        );
306
307        let tool = BuiltinTool::from_command(&[
308            "curl".to_string(),
309            "-X".to_string(),
310            "POST".to_string(),
311            "https://api.github.com".to_string(),
312        ])
313        .unwrap();
314        assert_eq!(
315            tool,
316            BuiltinTool::Curl {
317                url: "https://api.github.com".to_string(),
318                method: HttpMethod::Post,
319            }
320        );
321    }
322
323    #[test]
324    fn test_from_command_custom() {
325        let tool = BuiltinTool::from_command(&["jq".to_string(), ".".to_string()]).unwrap();
326        assert!(matches!(
327            tool,
328            BuiltinTool::Exec { command, args } if command == "jq" && args == vec!["."]
329        ));
330
331        let tool =
332            BuiltinTool::from_command(&["node".to_string(), "script.js".to_string()]).unwrap();
333        assert!(matches!(
334            tool,
335            BuiltinTool::Exec { command, args } if command == "node" && args == vec!["script.js"]
336        ));
337    }
338
339    #[test]
340    fn test_http_method_parsing() {
341        assert_eq!("GET".parse::<HttpMethod>().unwrap(), HttpMethod::Get);
342        assert_eq!("post".parse::<HttpMethod>().unwrap(), HttpMethod::Post);
343        assert!("INVALID".parse::<HttpMethod>().is_err());
344    }
345
346    #[test]
347    fn test_is_free_tier_allowed_glob() {
348        let tool = BuiltinTool::Glob {
349            pattern: "*.md".to_string(),
350        };
351        assert!(tool.is_free_tier_allowed());
352    }
353
354    #[test]
355    fn test_is_free_tier_allowed_curl_blocked() {
356        let tool = BuiltinTool::Curl {
357            url: "https://example.com".to_string(),
358            method: HttpMethod::Get,
359        };
360        assert!(!tool.is_free_tier_allowed());
361    }
362
363    #[test]
364    fn test_is_free_tier_allowed_allowlisted_commands() {
365        for cmd in ["cat", "ls", "grep", "sed", "awk", "jq", "head", "tail"] {
366            let tool = BuiltinTool::Exec {
367                command: cmd.to_string(),
368                args: vec![],
369            };
370            assert!(tool.is_free_tier_allowed(), "{cmd} should be allowed");
371        }
372    }
373
374    #[test]
375    fn test_is_free_tier_blocked_dangerous_commands() {
376        for cmd in [
377            "wget", "nc", "ssh", "node", "ruby", "curl", "apt", "pip", "npm",
378        ] {
379            let tool = BuiltinTool::Exec {
380                command: cmd.to_string(),
381                args: vec![],
382            };
383            assert!(!tool.is_free_tier_allowed(), "{cmd} should be blocked");
384        }
385    }
386
387    #[test]
388    fn test_requires_egress() {
389        assert!(
390            BuiltinTool::Curl {
391                url: "https://example.com".to_string(),
392                method: HttpMethod::Get,
393            }
394            .requires_egress()
395        );
396
397        assert!(
398            !BuiltinTool::Glob {
399                pattern: "*.md".to_string(),
400            }
401            .requires_egress()
402        );
403
404        assert!(
405            !BuiltinTool::Exec {
406                command: "ls".to_string(),
407                args: vec![],
408            }
409            .requires_egress()
410        );
411    }
412}