Skip to main content

safe_chains/
parse.rs

1pub fn split_outside_quotes(cmd: &str) -> Vec<String> {
2    let mut segments = Vec::new();
3    let mut current = String::new();
4    let mut in_single = false;
5    let mut in_double = false;
6    let mut escaped = false;
7    let mut chars = cmd.chars().peekable();
8
9    while let Some(c) = chars.next() {
10        if escaped {
11            current.push(c);
12            escaped = false;
13            continue;
14        }
15        if c == '\\' && !in_single {
16            escaped = true;
17            current.push(c);
18            continue;
19        }
20        if c == '\'' && !in_double {
21            in_single = !in_single;
22            current.push(c);
23            continue;
24        }
25        if c == '"' && !in_single {
26            in_double = !in_double;
27            current.push(c);
28            continue;
29        }
30        if !in_single && !in_double {
31            if c == '|' {
32                segments.push(current.clone());
33                current.clear();
34                continue;
35            }
36            if c == '&' && !current.ends_with('>') {
37                segments.push(current.clone());
38                current.clear();
39                if chars.peek() == Some(&'&') {
40                    chars.next();
41                }
42                continue;
43            }
44            if c == ';' || c == '\n' {
45                segments.push(current.clone());
46                current.clear();
47                continue;
48            }
49        }
50        current.push(c);
51    }
52    segments.push(current);
53    segments
54        .into_iter()
55        .map(|s| s.trim().to_string())
56        .filter(|s| !s.is_empty())
57        .collect()
58}
59
60pub fn tokenize(segment: &str) -> Option<Vec<String>> {
61    shell_words::split(segment).ok()
62}
63
64pub fn has_unsafe_shell_syntax(segment: &str) -> bool {
65    let mut in_single = false;
66    let mut in_double = false;
67    let mut escaped = false;
68    let chars: Vec<char> = segment.chars().collect();
69
70    for (i, &c) in chars.iter().enumerate() {
71        if escaped {
72            escaped = false;
73            continue;
74        }
75        if c == '\\' && !in_single {
76            escaped = true;
77            continue;
78        }
79        if c == '\'' && !in_double {
80            in_single = !in_single;
81            continue;
82        }
83        if c == '"' && !in_single {
84            in_double = !in_double;
85            continue;
86        }
87        if !in_single && !in_double {
88            if c == '>' || c == '<' {
89                let next = chars.get(i + 1);
90                if next == Some(&'&')
91                    && chars
92                        .get(i + 2)
93                        .is_some_and(|ch| ch.is_ascii_digit() || *ch == '-')
94                {
95                    continue;
96                }
97                if is_dev_null_target(&chars, i + 1, c) {
98                    continue;
99                }
100                return true;
101            }
102            if c == '`' {
103                return true;
104            }
105            if c == '$' && chars.get(i + 1) == Some(&'(') {
106                return true;
107            }
108        }
109    }
110    false
111}
112
113const DEV_NULL: [char; 9] = ['/', 'd', 'e', 'v', '/', 'n', 'u', 'l', 'l'];
114
115fn is_dev_null_target(chars: &[char], start: usize, redirect_char: char) -> bool {
116    let mut j = start;
117    if redirect_char == '>' && j < chars.len() && chars[j] == '>' {
118        j += 1;
119    }
120    while j < chars.len() && chars[j] == ' ' {
121        j += 1;
122    }
123    if j + DEV_NULL.len() > chars.len() {
124        return false;
125    }
126    if chars[j..j + DEV_NULL.len()] != DEV_NULL {
127        return false;
128    }
129    let end = j + DEV_NULL.len();
130    end >= chars.len() || chars[end].is_whitespace() || ";|&)".contains(chars[end])
131}
132
133pub fn has_flag(tokens: &[String], short: &str, long: Option<&str>) -> bool {
134    let short_char = short.trim_start_matches('-');
135    for token in &tokens[1..] {
136        if token == "--" {
137            return false;
138        }
139        if let Some(long_flag) = long
140            && (token == long_flag || token.starts_with(&format!("{long_flag}=")))
141        {
142            return true;
143        }
144        if token.starts_with('-') && !token.starts_with("--") && token[1..].contains(short_char) {
145            return true;
146        }
147    }
148    false
149}
150
151pub fn is_fd_redirect(token: &str) -> bool {
152    let bytes = token.as_bytes();
153    if bytes.len() < 3 {
154        return false;
155    }
156    let start = usize::from(bytes[0].is_ascii_digit());
157    bytes.get(start) == Some(&b'>')
158        && bytes.get(start + 1) == Some(&b'&')
159        && bytes[start + 2..].iter().all(|b| b.is_ascii_digit() || *b == b'-')
160}
161
162pub fn strip_fd_redirects(s: &str) -> String {
163    match tokenize(s) {
164        Some(tokens) => tokens
165            .into_iter()
166            .filter(|t| !is_fd_redirect(t))
167            .collect::<Vec<_>>()
168            .join(" "),
169        None => s.to_string(),
170    }
171}
172
173pub fn strip_env_prefix(segment: &str) -> &str {
174    let mut rest = segment;
175    loop {
176        let trimmed = rest.trim_start();
177        if trimmed.is_empty() {
178            return trimmed;
179        }
180        let bytes = trimmed.as_bytes();
181        if !bytes[0].is_ascii_uppercase() && bytes[0] != b'_' {
182            return trimmed;
183        }
184        if let Some(eq_pos) = trimmed.find('=') {
185            let key = &trimmed[..eq_pos];
186            let valid_key = key
187                .bytes()
188                .all(|b| b.is_ascii_uppercase() || b.is_ascii_digit() || b == b'_');
189            if !valid_key {
190                return trimmed;
191            }
192            if let Some(space_pos) = trimmed[eq_pos..].find(' ') {
193                rest = &trimmed[eq_pos + space_pos..];
194                continue;
195            }
196            return trimmed;
197        }
198        return trimmed;
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn split_pipe() {
208        assert_eq!(
209            split_outside_quotes("grep foo | head -5"),
210            vec!["grep foo", "head -5"]
211        );
212    }
213
214    #[test]
215    fn split_and() {
216        assert_eq!(
217            split_outside_quotes("ls && echo done"),
218            vec!["ls", "echo done"]
219        );
220    }
221
222    #[test]
223    fn split_semicolon() {
224        assert_eq!(
225            split_outside_quotes("ls; echo done"),
226            vec!["ls", "echo done"]
227        );
228    }
229
230    #[test]
231    fn split_preserves_quoted_pipes() {
232        assert_eq!(
233            split_outside_quotes("echo 'a | b' foo"),
234            vec!["echo 'a | b' foo"]
235        );
236    }
237
238    #[test]
239    fn split_background_operator() {
240        assert_eq!(
241            split_outside_quotes("cat file & rm -rf /"),
242            vec!["cat file", "rm -rf /"]
243        );
244    }
245
246    #[test]
247    fn split_newline() {
248        assert_eq!(
249            split_outside_quotes("echo foo\necho bar"),
250            vec!["echo foo", "echo bar"]
251        );
252    }
253
254    #[test]
255    fn unsafe_redirect() {
256        assert!(has_unsafe_shell_syntax("echo hello > file.txt"));
257    }
258
259    #[test]
260    fn safe_fd_redirect_stderr_to_stdout() {
261        assert!(!has_unsafe_shell_syntax("cargo clippy 2>&1"));
262    }
263
264    #[test]
265    fn safe_fd_redirect_close() {
266        assert!(!has_unsafe_shell_syntax("cmd 2>&-"));
267    }
268
269    #[test]
270    fn unsafe_redirect_ampersand_no_digit() {
271        assert!(has_unsafe_shell_syntax("echo hello >& file.txt"));
272    }
273
274    #[test]
275    fn unsafe_backtick() {
276        assert!(has_unsafe_shell_syntax("echo `rm -rf /`"));
277    }
278
279    #[test]
280    fn unsafe_command_substitution() {
281        assert!(has_unsafe_shell_syntax("echo $(rm -rf /)"));
282    }
283
284    #[test]
285    fn safe_quoted_dollar_paren() {
286        assert!(!has_unsafe_shell_syntax("echo '$(safe)' arg"));
287    }
288
289    #[test]
290    fn safe_quoted_redirect() {
291        assert!(!has_unsafe_shell_syntax("echo 'greater > than' test"));
292    }
293
294    #[test]
295    fn safe_no_special_chars() {
296        assert!(!has_unsafe_shell_syntax("grep pattern file"));
297    }
298
299    #[test]
300    fn safe_redirect_to_dev_null() {
301        assert!(!has_unsafe_shell_syntax("cmd >/dev/null"));
302    }
303
304    #[test]
305    fn safe_redirect_stderr_to_dev_null() {
306        assert!(!has_unsafe_shell_syntax("cmd 2>/dev/null"));
307    }
308
309    #[test]
310    fn safe_redirect_append_to_dev_null() {
311        assert!(!has_unsafe_shell_syntax("cmd >>/dev/null"));
312    }
313
314    #[test]
315    fn safe_redirect_space_dev_null() {
316        assert!(!has_unsafe_shell_syntax("cmd > /dev/null"));
317    }
318
319    #[test]
320    fn safe_redirect_input_dev_null() {
321        assert!(!has_unsafe_shell_syntax("cmd < /dev/null"));
322    }
323
324    #[test]
325    fn safe_redirect_both_dev_null() {
326        assert!(!has_unsafe_shell_syntax("cmd 2>/dev/null"));
327    }
328
329    #[test]
330    fn unsafe_redirect_dev_null_prefix() {
331        assert!(has_unsafe_shell_syntax("cmd > /dev/nullicious"));
332    }
333
334    #[test]
335    fn unsafe_redirect_dev_null_path_traversal() {
336        assert!(has_unsafe_shell_syntax("cmd > /dev/null/../etc/passwd"));
337    }
338
339    #[test]
340    fn unsafe_redirect_dev_null_subpath() {
341        assert!(has_unsafe_shell_syntax("cmd > /dev/null/foo"));
342    }
343
344    #[test]
345    fn unsafe_redirect_to_file() {
346        assert!(has_unsafe_shell_syntax("cmd > output.txt"));
347    }
348
349    #[test]
350    fn has_flag_short() {
351        let tokens: Vec<String> = vec!["sed", "-i", "s/foo/bar/"]
352            .into_iter()
353            .map(String::from)
354            .collect();
355        assert!(has_flag(&tokens, "-i", Some("--in-place")));
356    }
357
358    #[test]
359    fn has_flag_long_with_eq() {
360        let tokens: Vec<String> = vec!["sed", "--in-place=.bak", "s/foo/bar/"]
361            .into_iter()
362            .map(String::from)
363            .collect();
364        assert!(has_flag(&tokens, "-i", Some("--in-place")));
365    }
366
367    #[test]
368    fn has_flag_combined_short() {
369        let tokens: Vec<String> = vec!["sed", "-ni", "s/foo/bar/p"]
370            .into_iter()
371            .map(String::from)
372            .collect();
373        assert!(has_flag(&tokens, "-i", Some("--in-place")));
374    }
375
376    #[test]
377    fn has_flag_stops_at_double_dash() {
378        let tokens: Vec<String> = vec!["cmd", "--", "-i"]
379            .into_iter()
380            .map(String::from)
381            .collect();
382        assert!(!has_flag(&tokens, "-i", Some("--in-place")));
383    }
384
385    #[test]
386    fn strip_single_env_var() {
387        assert_eq!(strip_env_prefix("RACK_ENV=test bundle exec rspec"), "bundle exec rspec");
388    }
389
390    #[test]
391    fn strip_multiple_env_vars() {
392        assert_eq!(
393            strip_env_prefix("RACK_ENV=test RAILS_ENV=test bundle exec rspec"),
394            "bundle exec rspec"
395        );
396    }
397
398    #[test]
399    fn strip_no_env_var() {
400        assert_eq!(strip_env_prefix("bundle exec rspec"), "bundle exec rspec");
401    }
402
403    #[test]
404    fn tokenize_simple() {
405        assert_eq!(
406            tokenize("grep foo file.txt"),
407            Some(vec!["grep".to_string(), "foo".to_string(), "file.txt".to_string()])
408        );
409    }
410
411    #[test]
412    fn tokenize_quoted() {
413        assert_eq!(
414            tokenize("echo 'hello world'"),
415            Some(vec!["echo".to_string(), "hello world".to_string()])
416        );
417    }
418}