winx_code_agent/utils/
bash_parser.rs1use tree_sitter::{Node, Parser};
2
3use crate::errors::{Result, WinxError};
4
5pub fn assert_single_statement(command: &str) -> Result<()> {
6 let trimmed = command.trim();
7 if trimmed.is_empty() {
8 return Ok(());
9 }
10
11 let mut parser = Parser::new();
12 let language: tree_sitter::Language = tree_sitter_bash::LANGUAGE.into();
13 parser.set_language(&language).map_err(|error| {
14 WinxError::CommandExecutionError(format!("Failed to load bash parser: {error}"))
15 })?;
16
17 let tree = parser.parse(trimmed, None).ok_or_else(|| {
18 WinxError::CommandExecutionError("Failed to parse bash command".to_string())
19 })?;
20 let root = tree.root_node();
21
22 if root.has_error() {
23 return Err(WinxError::CommandExecutionError(
24 "Command contains invalid bash syntax.".to_string(),
25 ));
26 }
27
28 let statement_count = top_level_statement_count(trimmed, root);
29
30 if statement_count > 1 {
31 return Err(WinxError::CommandExecutionError(
32 "Command should contain a single top-level bash statement.".to_string(),
33 ));
34 }
35
36 Ok(())
37}
38
39#[derive(Debug, Clone)]
40struct StatementNode {
41 kind: String,
42 start_byte: usize,
43 end_byte: usize,
44}
45
46fn top_level_statement_count(source: &str, root: Node<'_>) -> usize {
47 let mut statements = Vec::new();
48 collect_statement_nodes(root, &mut statements);
49
50 statements
51 .iter()
52 .filter(|stmt| stmt.kind != "comment")
53 .filter(|stmt| !statements.iter().any(|other| is_contained_statement(source, stmt, other)))
54 .count()
55}
56
57fn collect_statement_nodes(node: Node<'_>, statements: &mut Vec<StatementNode>) {
58 if is_statement_node(node.kind()) {
59 statements.push(StatementNode {
60 kind: node.kind().to_string(),
61 start_byte: node.start_byte(),
62 end_byte: node.end_byte(),
63 });
64 }
65
66 for index in 0..node.named_child_count() as u32 {
67 if let Some(child) = node.named_child(index) {
68 collect_statement_nodes(child, statements);
69 }
70 }
71}
72
73fn is_statement_node(kind: &str) -> bool {
74 matches!(
75 kind,
76 "command"
77 | "variable_assignment"
78 | "declaration_command"
79 | "unset_command"
80 | "comment"
81 | "for_statement"
82 | "c_style_for_statement"
83 | "while_statement"
84 | "if_statement"
85 | "case_statement"
86 | "function_definition"
87 | "pipeline"
88 | "list"
89 | "compound_statement"
90 | "subshell"
91 | "redirected_statement"
92 )
93}
94
95fn is_contained_statement(source: &str, stmt: &StatementNode, other: &StatementNode) -> bool {
96 if stmt.start_byte == other.start_byte
97 && stmt.end_byte == other.end_byte
98 && stmt.kind == other.kind
99 {
100 return false;
101 }
102
103 let other_text = &source[other.start_byte..other.end_byte];
104 if other.kind == "list" && other_text.contains(';') {
105 return false;
106 }
107
108 other.start_byte <= stmt.start_byte
109 && other.end_byte >= stmt.end_byte
110 && other.end_byte - other.start_byte > stmt.end_byte - stmt.start_byte
111 && other_text.contains(&source[stmt.start_byte..stmt.end_byte])
112}
113
114#[cfg(test)]
115mod tests {
116 use super::assert_single_statement;
117
118 #[test]
119 fn accepts_shell_chains_as_single_statement() {
120 assert!(assert_single_statement("cargo test && cargo clippy").is_ok());
121 }
122
123 #[test]
124 fn accepts_heredocs_as_single_statement() {
125 let command = "cat <<'EOF'\nhello\nEOF";
126 assert!(assert_single_statement(command).is_ok());
127 }
128
129 #[test]
130 fn accepts_for_loop_as_single_compound_statement() {
131 assert!(assert_single_statement("for i in 1 2 3; do echo tick; sleep 1; done").is_ok());
132 }
133
134 #[test]
135 fn rejects_semicolon_separated_top_level_statements() {
136 assert!(assert_single_statement("pwd; ls").is_err());
137 }
138
139 #[test]
140 fn rejects_multiple_top_level_statements() {
141 assert!(assert_single_statement("pwd\nls").is_err());
142 }
143}