Skip to main content

winx_code_agent/utils/
bash_parser.rs

1use tree_sitter::{Node, Parser};
2
3use crate::errors::{Result, WinxError};
4
5pub fn assert_single_statement(command: &str) -> Result<()> {
6    let trimmed = command.trim();
7    if trimmed.is_empty() {
8        return Ok(());
9    }
10
11    let mut parser = Parser::new();
12    let language: tree_sitter::Language = tree_sitter_bash::LANGUAGE.into();
13    parser.set_language(&language).map_err(|error| {
14        WinxError::CommandExecutionError(format!("Failed to load bash parser: {error}"))
15    })?;
16
17    let tree = parser.parse(trimmed, None).ok_or_else(|| {
18        WinxError::CommandExecutionError("Failed to parse bash command".to_string())
19    })?;
20    let root = tree.root_node();
21
22    if root.has_error() {
23        return Err(WinxError::CommandExecutionError(
24            "Command contains invalid bash syntax.".to_string(),
25        ));
26    }
27
28    let statement_count = top_level_statement_count(trimmed, root);
29
30    if statement_count > 1 {
31        return Err(WinxError::CommandExecutionError(
32            "Command should contain a single top-level bash statement.".to_string(),
33        ));
34    }
35
36    Ok(())
37}
38
39#[derive(Debug, Clone)]
40struct StatementNode {
41    kind: String,
42    start_byte: usize,
43    end_byte: usize,
44}
45
46fn top_level_statement_count(source: &str, root: Node<'_>) -> usize {
47    let mut statements = Vec::new();
48    collect_statement_nodes(root, &mut statements);
49
50    statements
51        .iter()
52        .filter(|stmt| stmt.kind != "comment")
53        .filter(|stmt| !statements.iter().any(|other| is_contained_statement(source, stmt, other)))
54        .count()
55}
56
57fn collect_statement_nodes(node: Node<'_>, statements: &mut Vec<StatementNode>) {
58    if is_statement_node(node.kind()) {
59        statements.push(StatementNode {
60            kind: node.kind().to_string(),
61            start_byte: node.start_byte(),
62            end_byte: node.end_byte(),
63        });
64    }
65
66    for index in 0..node.named_child_count() as u32 {
67        if let Some(child) = node.named_child(index) {
68            collect_statement_nodes(child, statements);
69        }
70    }
71}
72
73fn is_statement_node(kind: &str) -> bool {
74    matches!(
75        kind,
76        "command"
77            | "variable_assignment"
78            | "declaration_command"
79            | "unset_command"
80            | "comment"
81            | "for_statement"
82            | "c_style_for_statement"
83            | "while_statement"
84            | "if_statement"
85            | "case_statement"
86            | "function_definition"
87            | "pipeline"
88            | "list"
89            | "compound_statement"
90            | "subshell"
91            | "redirected_statement"
92    )
93}
94
95fn is_contained_statement(source: &str, stmt: &StatementNode, other: &StatementNode) -> bool {
96    if stmt.start_byte == other.start_byte
97        && stmt.end_byte == other.end_byte
98        && stmt.kind == other.kind
99    {
100        return false;
101    }
102
103    let other_text = &source[other.start_byte..other.end_byte];
104    if other.kind == "list" && other_text.contains(';') {
105        return false;
106    }
107
108    other.start_byte <= stmt.start_byte
109        && other.end_byte >= stmt.end_byte
110        && other.end_byte - other.start_byte > stmt.end_byte - stmt.start_byte
111        && other_text.contains(&source[stmt.start_byte..stmt.end_byte])
112}
113
114#[cfg(test)]
115mod tests {
116    use super::assert_single_statement;
117
118    #[test]
119    fn accepts_shell_chains_as_single_statement() {
120        assert!(assert_single_statement("cargo test && cargo clippy").is_ok());
121    }
122
123    #[test]
124    fn accepts_heredocs_as_single_statement() {
125        let command = "cat <<'EOF'\nhello\nEOF";
126        assert!(assert_single_statement(command).is_ok());
127    }
128
129    #[test]
130    fn accepts_for_loop_as_single_compound_statement() {
131        assert!(assert_single_statement("for i in 1 2 3; do echo tick; sleep 1; done").is_ok());
132    }
133
134    #[test]
135    fn rejects_semicolon_separated_top_level_statements() {
136        assert!(assert_single_statement("pwd; ls").is_err());
137    }
138
139    #[test]
140    fn rejects_multiple_top_level_statements() {
141        assert!(assert_single_statement("pwd\nls").is_err());
142    }
143}