Skip to main content

winx_code_agent/utils/
bash_parser.rs

1use tree_sitter::{Node, Parser};
2
3use crate::errors::{Result, WinxError};
4
5pub fn assert_single_statement(command: &str) -> Result<()> {
6    let trimmed = command.trim();
7    if trimmed.is_empty() {
8        return Ok(());
9    }
10    if trimmed.contains('\0') {
11        return Err(WinxError::CommandExecutionError(
12            "Command contains a NUL byte. JSON escape \\u0000 becomes an actual NUL before bash sees it; write \\\\0 or \\\\x00 in the command string instead.".to_string(),
13        ));
14    }
15
16    let mut parser = Parser::new();
17    let language: tree_sitter::Language = tree_sitter_bash::LANGUAGE.into();
18    parser.set_language(&language).map_err(|error| {
19        WinxError::CommandExecutionError(format!("Failed to load bash parser: {error}"))
20    })?;
21
22    let tree = parser.parse(trimmed, None).ok_or_else(|| {
23        WinxError::CommandExecutionError("Failed to parse bash command".to_string())
24    })?;
25    let root = tree.root_node();
26
27    if root.has_error() {
28        if bash_accepts_syntax(trimmed) {
29            return Ok(());
30        }
31
32        return Err(WinxError::CommandExecutionError(
33            "Command contains invalid bash syntax. If this is a complex script, pass it as multiline bash, avoid NUL bytes, or set allow_multi=true after verifying the quoting.".to_string(),
34        ));
35    }
36
37    let statement_count = top_level_statement_count(trimmed, root);
38
39    if statement_count > 1 && !trimmed.contains('\n') {
40        return Err(WinxError::CommandExecutionError(
41            "Command should contain a single top-level bash statement. For deliberate scripts, split statements across lines or set allow_multi=true.".to_string(),
42        ));
43    }
44
45    Ok(())
46}
47
48fn bash_accepts_syntax(command: &str) -> bool {
49    std::process::Command::new("bash")
50        .arg("-n")
51        .arg("-c")
52        .arg(command)
53        .status()
54        .is_ok_and(|status| status.success())
55}
56
57#[derive(Debug, Clone)]
58struct StatementNode {
59    kind: String,
60    start_byte: usize,
61    end_byte: usize,
62}
63
64fn top_level_statement_count(source: &str, root: Node<'_>) -> usize {
65    let mut statements = Vec::new();
66    collect_statement_nodes(root, &mut statements);
67
68    statements
69        .iter()
70        .filter(|stmt| stmt.kind != "comment")
71        .filter(|stmt| !statements.iter().any(|other| is_contained_statement(source, stmt, other)))
72        .count()
73}
74
75fn collect_statement_nodes(node: Node<'_>, statements: &mut Vec<StatementNode>) {
76    if is_statement_node(node.kind()) {
77        statements.push(StatementNode {
78            kind: node.kind().to_string(),
79            start_byte: node.start_byte(),
80            end_byte: node.end_byte(),
81        });
82    }
83
84    for index in 0..node.named_child_count() as u32 {
85        if let Some(child) = node.named_child(index) {
86            collect_statement_nodes(child, statements);
87        }
88    }
89}
90
91fn is_statement_node(kind: &str) -> bool {
92    matches!(
93        kind,
94        "command"
95            | "variable_assignment"
96            | "declaration_command"
97            | "unset_command"
98            | "comment"
99            | "for_statement"
100            | "c_style_for_statement"
101            | "while_statement"
102            | "if_statement"
103            | "case_statement"
104            | "function_definition"
105            | "pipeline"
106            | "list"
107            | "compound_statement"
108            | "subshell"
109            | "redirected_statement"
110    )
111}
112
113fn is_contained_statement(source: &str, stmt: &StatementNode, other: &StatementNode) -> bool {
114    if stmt.start_byte == other.start_byte
115        && stmt.end_byte == other.end_byte
116        && stmt.kind == other.kind
117    {
118        return false;
119    }
120
121    let other_text = &source[other.start_byte..other.end_byte];
122    if other.kind == "list" && other_text.contains(';') {
123        return false;
124    }
125
126    other.start_byte <= stmt.start_byte
127        && other.end_byte >= stmt.end_byte
128        && other.end_byte - other.start_byte > stmt.end_byte - stmt.start_byte
129        && other_text.contains(&source[stmt.start_byte..stmt.end_byte])
130}
131
132#[cfg(test)]
133mod tests {
134    use super::assert_single_statement;
135
136    #[test]
137    fn accepts_shell_chains_as_single_statement() {
138        assert!(assert_single_statement("cargo test && cargo clippy").is_ok());
139    }
140
141    #[test]
142    fn accepts_heredocs_as_single_statement() {
143        let command = "cat <<'EOF'\nhello\nEOF";
144        assert!(assert_single_statement(command).is_ok());
145    }
146
147    #[test]
148    fn accepts_for_loop_as_single_compound_statement() {
149        assert!(assert_single_statement("for i in 1 2 3; do echo tick; sleep 1; done").is_ok());
150    }
151
152    #[test]
153    fn rejects_semicolon_separated_top_level_statements() {
154        assert!(assert_single_statement("pwd; ls").is_err());
155    }
156
157    #[test]
158    fn accepts_multiline_scripts() {
159        assert!(assert_single_statement("pwd\nls").is_ok());
160    }
161
162    #[test]
163    fn accepts_bash_lc_script_when_tree_sitter_reports_error() {
164        let command = "bash -lc 'printf \"%s\\n\" \"-- drm connectors --\"; for s in /sys/class/drm/card*-*/status; do [ -e \"$s\" ] || continue; c=${s%/status}; printf \"%s: %s\" \"${c##*/}\" \"$(cat \"$s\")\"; done'";
165        assert!(assert_single_statement(command).is_ok());
166    }
167
168    #[test]
169    fn rejects_nul_with_actionable_message() {
170        let error = match assert_single_statement("printf '\0'") {
171            Ok(()) => String::new(),
172            Err(error) => error.to_string(),
173        };
174        assert!(error.contains("NUL byte"));
175        assert!(error.contains("\\\\x00"));
176    }
177}