1use crate::error::Error;
4use serde::{Deserialize, Serialize};
5use std::fmt;
6use std::str::FromStr;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
9#[serde(rename_all = "UPPERCASE")]
10#[non_exhaustive]
11pub enum HttpMethod {
12 #[default]
13 Get,
14 Post,
15 Put,
16 Patch,
17 Delete,
18 Head,
19 Options,
20}
21
22impl HttpMethod {
23 #[must_use]
24 pub const fn as_str(&self) -> &'static str {
25 match self {
26 Self::Get => "GET",
27 Self::Post => "POST",
28 Self::Put => "PUT",
29 Self::Patch => "PATCH",
30 Self::Delete => "DELETE",
31 Self::Head => "HEAD",
32 Self::Options => "OPTIONS",
33 }
34 }
35}
36
37impl fmt::Display for HttpMethod {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 write!(f, "{}", self.as_str())
40 }
41}
42
43impl FromStr for HttpMethod {
44 type Err = Error;
45
46 fn from_str(s: &str) -> Result<Self, Self::Err> {
47 match s.to_uppercase().as_str() {
48 "GET" => Ok(Self::Get),
49 "POST" => Ok(Self::Post),
50 "PUT" => Ok(Self::Put),
51 "PATCH" => Ok(Self::Patch),
52 "DELETE" => Ok(Self::Delete),
53 "HEAD" => Ok(Self::Head),
54 "OPTIONS" => Ok(Self::Options),
55 _ => Err(Error::InvalidCommand(format!("Unknown HTTP method: {s}"))),
56 }
57 }
58}
59
60#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
61#[serde(tag = "type", rename_all = "lowercase")]
62#[non_exhaustive]
63pub enum BuiltinTool {
64 Glob { pattern: String },
65 Curl { url: String, method: HttpMethod },
66 Exec { command: String, args: Vec<String> },
67}
68
69impl BuiltinTool {
70 pub fn from_command(command: &[String]) -> Result<Self, Error> {
74 if command.is_empty() {
75 return Err(Error::InvalidCommand("Command cannot be empty".to_string()));
76 }
77
78 match command[0].as_str() {
79 "glob" => {
80 if command.len() < 2 {
81 return Err(Error::InvalidCommand(
82 "glob requires a pattern argument".to_string(),
83 ));
84 }
85 Ok(Self::Glob {
86 pattern: command[1].clone(),
87 })
88 }
89 "curl" => Self::parse_curl(&command[1..]),
90 cmd => Ok(Self::Exec {
91 command: cmd.to_string(),
92 args: command[1..].to_vec(),
93 }),
94 }
95 }
96
97 fn parse_curl(args: &[String]) -> Result<Self, Error> {
98 #[derive(Debug)]
99 struct CurlArgs {
100 url: Option<String>,
101 method: Option<String>,
102 }
103
104 let parsed = args.iter().try_fold(
105 (
106 CurlArgs {
107 url: None,
108 method: None,
109 },
110 None::<&str>,
111 ),
112 |(mut acc, expecting_value), arg| match expecting_value {
113 Some("-X" | "--request") => {
114 acc.method = Some(arg.clone());
115 Ok((acc, None))
116 }
117 Some(flag) => Err(Error::InvalidCommand(format!("Unknown flag: {flag}"))),
118 None if arg == "-X" || arg == "--request" => Ok((acc, Some(arg.as_str()))),
119 None if !arg.starts_with('-') && acc.url.is_none() => {
120 acc.url = Some(arg.clone());
121 Ok((acc, None))
122 }
123 None if arg.starts_with('-') => {
124 Err(Error::InvalidCommand(format!("Unknown flag: {arg}")))
125 }
126 None => Ok((acc, None)),
127 },
128 );
129
130 let (args, expecting) = parsed?;
131
132 if let Some(flag) = expecting {
133 return Err(Error::InvalidCommand(format!(
134 "{flag} requires a method argument"
135 )));
136 }
137
138 let url = args
139 .url
140 .ok_or_else(|| Error::InvalidCommand("curl requires a URL argument".to_string()))?;
141
142 let method = match args.method {
143 Some(m) => m.parse()?,
144 None => HttpMethod::default(),
145 };
146
147 Ok(Self::Curl { url, method })
148 }
149
150 #[must_use]
151 pub const fn name(&self) -> &'static str {
152 match self {
153 Self::Glob { .. } => "glob",
154 Self::Curl { .. } => "curl",
155 Self::Exec { .. } => "exec",
156 }
157 }
158
159 pub const fn requires_egress(&self) -> bool {
160 matches!(self, Self::Curl { .. })
161 }
162
163 pub fn is_free_tier_allowed(&self) -> bool {
164 match self {
165 Self::Glob { .. } => true,
166 Self::Curl { .. } => false,
167 Self::Exec { command, .. } => FREE_TIER_COMMAND_ALLOWLIST.contains(&command.as_str()),
168 }
169 }
170}
171
172pub const FREE_TIER_COMMAND_ALLOWLIST: &[&str] = &[
173 "cat",
174 "head",
175 "tail",
176 "less",
177 "more",
178 "wc",
179 "sort",
180 "uniq",
181 "cut",
182 "paste",
183 "tr",
184 "tee",
185 "split",
186 "csplit",
187 "ls",
188 "stat",
189 "file",
190 "du",
191 "df",
192 "find",
193 "which",
194 "whereis",
195 "cp",
196 "mv",
197 "rm",
198 "mkdir",
199 "rmdir",
200 "touch",
201 "ln",
202 "grep",
203 "egrep",
204 "fgrep",
205 "sed",
206 "awk",
207 "diff",
208 "comm",
209 "cmp",
210 "jq",
211 "tar",
212 "gzip",
213 "gunzip",
214 "zcat",
215 "bzip2",
216 "bunzip2",
217 "xz",
218 "unxz",
219 "echo",
220 "printf",
221 "true",
222 "false",
223 "yes",
224 "date",
225 "cal",
226 "env",
227 "printenv",
228 "basename",
229 "dirname",
230 "realpath",
231 "readlink",
232 "pwd",
233 "id",
234 "whoami",
235 "uname",
236 "hostname",
237 "md5sum",
238 "sha256sum",
239 "base64",
240 "xxd",
241 "hexdump",
242 "od",
243];
244
245#[cfg(test)]
246#[allow(clippy::unwrap_used)]
247mod tests {
248 use super::*;
249
250 #[test]
251 fn test_builtin_tool_name() {
252 let exec = BuiltinTool::Exec {
253 command: "ls".to_string(),
254 args: vec![],
255 };
256 assert_eq!(exec.name(), "exec");
257 }
258
259 #[test]
260 fn test_from_command_ls() {
261 let tool = BuiltinTool::from_command(&["ls".to_string(), "docs/".to_string()]).unwrap();
262 assert!(matches!(
263 tool,
264 BuiltinTool::Exec { command, args } if command == "ls" && args == vec!["docs/"]
265 ));
266 }
267
268 #[test]
269 fn test_from_command_cat() {
270 let tool = BuiltinTool::from_command(&["cat".to_string(), "file.md".to_string()]).unwrap();
271 assert!(matches!(
272 tool,
273 BuiltinTool::Exec { command, args } if command == "cat" && args == vec!["file.md"]
274 ));
275
276 let tool = BuiltinTool::from_command(&["cat".to_string()]).unwrap();
277 assert!(matches!(
278 tool,
279 BuiltinTool::Exec { command, args } if command == "cat" && args.is_empty()
280 ));
281 }
282
283 #[test]
284 fn test_from_command_glob() {
285 let tool = BuiltinTool::from_command(&["glob".to_string(), "*.md".to_string()]).unwrap();
286 assert_eq!(
287 tool,
288 BuiltinTool::Glob {
289 pattern: "*.md".to_string()
290 }
291 );
292 }
293
294 #[test]
295 fn test_from_command_curl() {
296 let tool =
297 BuiltinTool::from_command(&["curl".to_string(), "https://api.github.com".to_string()])
298 .unwrap();
299 assert_eq!(
300 tool,
301 BuiltinTool::Curl {
302 url: "https://api.github.com".to_string(),
303 method: HttpMethod::Get,
304 }
305 );
306
307 let tool = BuiltinTool::from_command(&[
308 "curl".to_string(),
309 "-X".to_string(),
310 "POST".to_string(),
311 "https://api.github.com".to_string(),
312 ])
313 .unwrap();
314 assert_eq!(
315 tool,
316 BuiltinTool::Curl {
317 url: "https://api.github.com".to_string(),
318 method: HttpMethod::Post,
319 }
320 );
321 }
322
323 #[test]
324 fn test_from_command_custom() {
325 let tool = BuiltinTool::from_command(&["jq".to_string(), ".".to_string()]).unwrap();
326 assert!(matches!(
327 tool,
328 BuiltinTool::Exec { command, args } if command == "jq" && args == vec!["."]
329 ));
330
331 let tool =
332 BuiltinTool::from_command(&["node".to_string(), "script.js".to_string()]).unwrap();
333 assert!(matches!(
334 tool,
335 BuiltinTool::Exec { command, args } if command == "node" && args == vec!["script.js"]
336 ));
337 }
338
339 #[test]
340 fn test_http_method_parsing() {
341 assert_eq!("GET".parse::<HttpMethod>().unwrap(), HttpMethod::Get);
342 assert_eq!("post".parse::<HttpMethod>().unwrap(), HttpMethod::Post);
343 assert!("INVALID".parse::<HttpMethod>().is_err());
344 }
345
346 #[test]
347 fn test_is_free_tier_allowed_glob() {
348 let tool = BuiltinTool::Glob {
349 pattern: "*.md".to_string(),
350 };
351 assert!(tool.is_free_tier_allowed());
352 }
353
354 #[test]
355 fn test_is_free_tier_allowed_curl_blocked() {
356 let tool = BuiltinTool::Curl {
357 url: "https://example.com".to_string(),
358 method: HttpMethod::Get,
359 };
360 assert!(!tool.is_free_tier_allowed());
361 }
362
363 #[test]
364 fn test_is_free_tier_allowed_allowlisted_commands() {
365 for cmd in ["cat", "ls", "grep", "sed", "awk", "jq", "head", "tail"] {
366 let tool = BuiltinTool::Exec {
367 command: cmd.to_string(),
368 args: vec![],
369 };
370 assert!(tool.is_free_tier_allowed(), "{cmd} should be allowed");
371 }
372 }
373
374 #[test]
375 fn test_is_free_tier_blocked_dangerous_commands() {
376 for cmd in [
377 "wget", "nc", "ssh", "node", "ruby", "curl", "apt", "pip", "npm",
378 ] {
379 let tool = BuiltinTool::Exec {
380 command: cmd.to_string(),
381 args: vec![],
382 };
383 assert!(!tool.is_free_tier_allowed(), "{cmd} should be blocked");
384 }
385 }
386
387 #[test]
388 fn test_requires_egress() {
389 assert!(
390 BuiltinTool::Curl {
391 url: "https://example.com".to_string(),
392 method: HttpMethod::Get,
393 }
394 .requires_egress()
395 );
396
397 assert!(
398 !BuiltinTool::Glob {
399 pattern: "*.md".to_string(),
400 }
401 .requires_egress()
402 );
403
404 assert!(
405 !BuiltinTool::Exec {
406 command: "ls".to_string(),
407 args: vec![],
408 }
409 .requires_egress()
410 );
411 }
412}