Skip to main content

soul_coder/tools/
bash.rs

1//! Bash tool — execute shell commands with output truncation and timeout.
2//!
3//! Delegates to [`soul_core::executor::ShellExecutor`] for command execution,
4//! then applies ANSI stripping and tail truncation on top.
5
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use serde_json::json;
10use tokio::sync::mpsc;
11
12use soul_core::error::SoulResult;
13use soul_core::executor::shell::ShellExecutor;
14use soul_core::executor::ToolExecutor;
15use soul_core::tool::{Tool, ToolOutput};
16use soul_core::types::ToolDefinition;
17use soul_core::vexec::VirtualExecutor;
18
19use crate::truncate::{truncate_tail, MAX_BYTES};
20
21/// Maximum lines kept from bash output (tail).
22const BASH_MAX_LINES: usize = 50;
23
24/// Default command timeout in seconds.
25const DEFAULT_TIMEOUT: u64 = 120;
26
27pub struct BashTool {
28    shell: ShellExecutor,
29    definition: ToolDefinition,
30}
31
32impl BashTool {
33    pub fn new(executor: Arc<dyn VirtualExecutor>, cwd: impl Into<String>) -> Self {
34        let shell = ShellExecutor::new(executor)
35            .with_timeout(DEFAULT_TIMEOUT)
36            .with_cwd(cwd);
37
38        let definition = ToolDefinition {
39            name: "bash".into(),
40            description: "Execute a shell command. Returns stdout and stderr. Output is truncated to the last 50 lines.".into(),
41            input_schema: json!({
42                "type": "object",
43                "properties": {
44                    "command": {
45                        "type": "string",
46                        "description": "The shell command to execute"
47                    },
48                    "timeout": {
49                        "type": "integer",
50                        "description": "Timeout in seconds (default: 120)"
51                    }
52                },
53                "required": ["command"]
54            }),
55        };
56
57        Self { shell, definition }
58    }
59}
60
61/// Strip ANSI escape codes from output.
62fn strip_ansi(input: &str) -> String {
63    let mut result = String::with_capacity(input.len());
64    let mut chars = input.chars().peekable();
65
66    while let Some(ch) = chars.next() {
67        if ch == '\x1b' {
68            // Skip escape sequence
69            if let Some(&'[') = chars.peek() {
70                chars.next(); // consume '['
71                // Consume until a letter
72                while let Some(&c) = chars.peek() {
73                    chars.next();
74                    if c.is_ascii_alphabetic() {
75                        break;
76                    }
77                }
78            }
79        } else if ch == '\r' {
80            // Skip carriage returns
81        } else {
82            result.push(ch);
83        }
84    }
85
86    result
87}
88
89#[async_trait]
90impl Tool for BashTool {
91    fn name(&self) -> &str {
92        "bash"
93    }
94
95    fn definition(&self) -> ToolDefinition {
96        self.definition.clone()
97    }
98
99    async fn execute(
100        &self,
101        call_id: &str,
102        arguments: serde_json::Value,
103        partial_tx: Option<mpsc::UnboundedSender<String>>,
104    ) -> SoulResult<ToolOutput> {
105        // Delegate to ShellExecutor from soul-core
106        let result = self
107            .shell
108            .execute(&self.definition, call_id, arguments, partial_tx.clone())
109            .await;
110
111        match result {
112            Ok(output) => {
113                // Stream partial output if channel available
114                if let Some(ref tx) = partial_tx {
115                    let _ = tx.send(output.content.clone());
116                }
117
118                // Apply ANSI stripping
119                let cleaned = strip_ansi(&output.content);
120
121                // Apply tail truncation (errors/final output matter most)
122                let truncated = truncate_tail(&cleaned, BASH_MAX_LINES, MAX_BYTES);
123
124                let notice = truncated.truncation_notice();
125                let is_truncated = truncated.is_truncated();
126                let mut result_content = truncated.content;
127                if let Some(notice) = notice {
128                    result_content = format!("{}\n{}", notice, result_content);
129                }
130
131                let tool_output = if output.is_error {
132                    ToolOutput::error(result_content)
133                } else {
134                    ToolOutput::success(result_content)
135                };
136
137                Ok(tool_output.with_metadata(json!({
138                    "truncated": is_truncated,
139                })))
140            }
141            Err(e) => Ok(ToolOutput::error(format!("Command failed: {}", e))),
142        }
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use soul_core::vexec::{ExecOutput, MockExecutor};
150
151    fn setup_ok(stdout: &str) -> BashTool {
152        let executor = Arc::new(MockExecutor::always_ok(stdout));
153        BashTool::new(executor as Arc<dyn VirtualExecutor>, "/project")
154    }
155
156    fn setup_with(responses: Vec<ExecOutput>) -> BashTool {
157        let executor = Arc::new(MockExecutor::new(responses));
158        BashTool::new(executor as Arc<dyn VirtualExecutor>, "/project")
159    }
160
161    #[tokio::test]
162    async fn execute_simple_command() {
163        let tool = setup_ok("hello world\n");
164        let result = tool
165            .execute("c1", json!({"command": "echo hello world"}), None)
166            .await
167            .unwrap();
168
169        assert!(!result.is_error);
170        assert!(result.content.contains("hello world"));
171    }
172
173    #[tokio::test]
174    async fn execute_with_error_exit() {
175        let tool = setup_with(vec![ExecOutput {
176            stdout: String::new(),
177            stderr: "command not found".into(),
178            exit_code: 127,
179        }]);
180
181        let result = tool
182            .execute("c2", json!({"command": "nonexistent"}), None)
183            .await
184            .unwrap();
185
186        assert!(result.is_error);
187        assert!(result.content.contains("command not found"));
188    }
189
190    #[tokio::test]
191    async fn execute_empty_command() {
192        let tool = setup_ok("");
193        let _result = tool
194            .execute("c3", json!({"command": ""}), None)
195            .await
196            .unwrap();
197        // ShellExecutor delegates to MockExecutor which returns empty stdout
198        // The result should still be ok (empty output is not an error)
199        // Note: ShellExecutor requires the "command" key to exist, not be non-empty
200    }
201
202    #[tokio::test]
203    async fn strips_ansi() {
204        assert_eq!(strip_ansi("\x1b[31mred\x1b[0m"), "red");
205        assert_eq!(strip_ansi("no ansi"), "no ansi");
206        assert_eq!(strip_ansi("line\r\n"), "line\n");
207    }
208
209    #[tokio::test]
210    async fn stderr_included() {
211        let tool = setup_with(vec![ExecOutput {
212            stdout: "out\n".into(),
213            stderr: "warn\n".into(),
214            exit_code: 0,
215        }]);
216
217        let result = tool
218            .execute("c4", json!({"command": "test"}), None)
219            .await
220            .unwrap();
221
222        // ShellExecutor returns stdout on success (exit_code 0)
223        assert!(!result.is_error);
224        assert!(result.content.contains("out"));
225    }
226
227    #[tokio::test]
228    async fn streaming_output() {
229        let tool = setup_ok("streamed\n");
230        let (tx, mut rx) = mpsc::unbounded_channel();
231
232        let result = tool
233            .execute("c5", json!({"command": "echo streamed"}), Some(tx))
234            .await
235            .unwrap();
236
237        assert!(!result.is_error);
238        let partial = rx.recv().await.unwrap();
239        assert!(partial.contains("streamed"));
240    }
241
242    #[tokio::test]
243    async fn tool_name_and_definition() {
244        let tool = setup_ok("");
245        assert_eq!(tool.name(), "bash");
246        let def = tool.definition();
247        assert_eq!(def.name, "bash");
248        assert!(def.input_schema["required"].as_array().unwrap().contains(&json!("command")));
249    }
250}