winx_code_agent/utils/
bash_parser.rs1use tree_sitter::{Node, Parser};
2
3use crate::errors::{Result, WinxError};
4
5pub fn assert_single_statement(command: &str) -> Result<()> {
6 let trimmed = command.trim();
7 if trimmed.is_empty() {
8 return Ok(());
9 }
10 if trimmed.contains('\0') {
11 return Err(WinxError::CommandExecutionError(
12 "Command contains a NUL byte. JSON escape \\u0000 becomes an actual NUL before bash sees it; write \\\\0 or \\\\x00 in the command string instead.".to_string(),
13 ));
14 }
15
16 let mut parser = Parser::new();
17 let language: tree_sitter::Language = tree_sitter_bash::LANGUAGE.into();
18 parser.set_language(&language).map_err(|error| {
19 WinxError::CommandExecutionError(format!("Failed to load bash parser: {error}"))
20 })?;
21
22 let tree = parser.parse(trimmed, None).ok_or_else(|| {
23 WinxError::CommandExecutionError("Failed to parse bash command".to_string())
24 })?;
25 let root = tree.root_node();
26
27 if root.has_error() {
28 if bash_accepts_syntax(trimmed) {
29 return Ok(());
30 }
31
32 return Err(WinxError::CommandExecutionError(
33 "Command contains invalid bash syntax. If this is a complex script, pass it as multiline bash, avoid NUL bytes, or set allow_multi=true after verifying the quoting.".to_string(),
34 ));
35 }
36
37 let statement_count = top_level_statement_count(trimmed, root);
38
39 if statement_count > 1 && !trimmed.contains('\n') {
40 return Err(WinxError::CommandExecutionError(
41 "Command should contain a single top-level bash statement. For deliberate scripts, split statements across lines or set allow_multi=true.".to_string(),
42 ));
43 }
44
45 Ok(())
46}
47
48fn bash_accepts_syntax(command: &str) -> bool {
49 std::process::Command::new("bash")
50 .arg("-n")
51 .arg("-c")
52 .arg(command)
53 .status()
54 .is_ok_and(|status| status.success())
55}
56
57#[derive(Debug, Clone)]
58struct StatementNode {
59 kind: String,
60 start_byte: usize,
61 end_byte: usize,
62}
63
64fn top_level_statement_count(source: &str, root: Node<'_>) -> usize {
65 let mut statements = Vec::new();
66 collect_statement_nodes(root, &mut statements);
67
68 statements
69 .iter()
70 .filter(|stmt| stmt.kind != "comment")
71 .filter(|stmt| !statements.iter().any(|other| is_contained_statement(source, stmt, other)))
72 .count()
73}
74
75fn collect_statement_nodes(node: Node<'_>, statements: &mut Vec<StatementNode>) {
76 if is_statement_node(node.kind()) {
77 statements.push(StatementNode {
78 kind: node.kind().to_string(),
79 start_byte: node.start_byte(),
80 end_byte: node.end_byte(),
81 });
82 }
83
84 for index in 0..node.named_child_count() as u32 {
85 if let Some(child) = node.named_child(index) {
86 collect_statement_nodes(child, statements);
87 }
88 }
89}
90
91fn is_statement_node(kind: &str) -> bool {
92 matches!(
93 kind,
94 "command"
95 | "variable_assignment"
96 | "declaration_command"
97 | "unset_command"
98 | "comment"
99 | "for_statement"
100 | "c_style_for_statement"
101 | "while_statement"
102 | "if_statement"
103 | "case_statement"
104 | "function_definition"
105 | "pipeline"
106 | "list"
107 | "compound_statement"
108 | "subshell"
109 | "redirected_statement"
110 )
111}
112
113fn is_contained_statement(source: &str, stmt: &StatementNode, other: &StatementNode) -> bool {
114 if stmt.start_byte == other.start_byte
115 && stmt.end_byte == other.end_byte
116 && stmt.kind == other.kind
117 {
118 return false;
119 }
120
121 let other_text = &source[other.start_byte..other.end_byte];
122 if other.kind == "list" && other_text.contains(';') {
123 return false;
124 }
125
126 other.start_byte <= stmt.start_byte
127 && other.end_byte >= stmt.end_byte
128 && other.end_byte - other.start_byte > stmt.end_byte - stmt.start_byte
129 && other_text.contains(&source[stmt.start_byte..stmt.end_byte])
130}
131
132#[cfg(test)]
133mod tests {
134 use super::assert_single_statement;
135
136 #[test]
137 fn accepts_shell_chains_as_single_statement() {
138 assert!(assert_single_statement("cargo test && cargo clippy").is_ok());
139 }
140
141 #[test]
142 fn accepts_heredocs_as_single_statement() {
143 let command = "cat <<'EOF'\nhello\nEOF";
144 assert!(assert_single_statement(command).is_ok());
145 }
146
147 #[test]
148 fn accepts_for_loop_as_single_compound_statement() {
149 assert!(assert_single_statement("for i in 1 2 3; do echo tick; sleep 1; done").is_ok());
150 }
151
152 #[test]
153 fn rejects_semicolon_separated_top_level_statements() {
154 assert!(assert_single_statement("pwd; ls").is_err());
155 }
156
157 #[test]
158 fn accepts_multiline_scripts() {
159 assert!(assert_single_statement("pwd\nls").is_ok());
160 }
161
162 #[test]
163 fn accepts_bash_lc_script_when_tree_sitter_reports_error() {
164 let command = "bash -lc 'printf \"%s\\n\" \"-- drm connectors --\"; for s in /sys/class/drm/card*-*/status; do [ -e \"$s\" ] || continue; c=${s%/status}; printf \"%s: %s\" \"${c##*/}\" \"$(cat \"$s\")\"; done'";
165 assert!(assert_single_statement(command).is_ok());
166 }
167
168 #[test]
169 fn rejects_nul_with_actionable_message() {
170 let error = match assert_single_statement("printf '\0'") {
171 Ok(()) => String::new(),
172 Err(error) => error.to_string(),
173 };
174 assert!(error.contains("NUL byte"));
175 assert!(error.contains("\\\\x00"));
176 }
177}