Skip to main content

soul_coder/tools/
bash.rs

1//! Bash tool — execute shell commands with output truncation and timeout.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use serde_json::json;
7use tokio::sync::mpsc;
8
9use soul_core::error::SoulResult;
10use soul_core::tool::{Tool, ToolOutput};
11use soul_core::types::ToolDefinition;
12use soul_core::vexec::VirtualExecutor;
13
14use crate::truncate::{truncate_tail, MAX_BYTES};
15
16/// Maximum lines kept from bash output (tail).
17const BASH_MAX_LINES: usize = 50;
18
19/// Default command timeout in seconds.
20const DEFAULT_TIMEOUT: u64 = 120;
21
22pub struct BashTool {
23    executor: Arc<dyn VirtualExecutor>,
24    cwd: String,
25}
26
27impl BashTool {
28    pub fn new(executor: Arc<dyn VirtualExecutor>, cwd: impl Into<String>) -> Self {
29        Self {
30            executor,
31            cwd: cwd.into(),
32        }
33    }
34}
35
36/// Strip ANSI escape codes from output.
37fn strip_ansi(input: &str) -> String {
38    let mut result = String::with_capacity(input.len());
39    let mut chars = input.chars().peekable();
40
41    while let Some(ch) = chars.next() {
42        if ch == '\x1b' {
43            // Skip escape sequence
44            if let Some(&'[') = chars.peek() {
45                chars.next(); // consume '['
46                // Consume until a letter
47                while let Some(&c) = chars.peek() {
48                    chars.next();
49                    if c.is_ascii_alphabetic() {
50                        break;
51                    }
52                }
53            }
54        } else if ch == '\r' {
55            // Skip carriage returns
56        } else {
57            result.push(ch);
58        }
59    }
60
61    result
62}
63
64#[async_trait]
65impl Tool for BashTool {
66    fn name(&self) -> &str {
67        "bash"
68    }
69
70    fn definition(&self) -> ToolDefinition {
71        ToolDefinition {
72            name: "bash".into(),
73            description: "Execute a shell command. Returns stdout and stderr. Output is truncated to the last 50 lines.".into(),
74            input_schema: json!({
75                "type": "object",
76                "properties": {
77                    "command": {
78                        "type": "string",
79                        "description": "The shell command to execute"
80                    },
81                    "timeout": {
82                        "type": "integer",
83                        "description": "Timeout in seconds (default: 120)"
84                    }
85                },
86                "required": ["command"]
87            }),
88        }
89    }
90
91    async fn execute(
92        &self,
93        _call_id: &str,
94        arguments: serde_json::Value,
95        partial_tx: Option<mpsc::UnboundedSender<String>>,
96    ) -> SoulResult<ToolOutput> {
97        let command = arguments
98            .get("command")
99            .and_then(|v| v.as_str())
100            .unwrap_or("");
101
102        if command.is_empty() {
103            return Ok(ToolOutput::error("Missing required parameter: command"));
104        }
105
106        let timeout = arguments
107            .get("timeout")
108            .and_then(|v| v.as_u64())
109            .unwrap_or(DEFAULT_TIMEOUT);
110
111        let exec_result = self
112            .executor
113            .exec_shell(command, timeout, Some(&self.cwd))
114            .await;
115
116        match exec_result {
117            Ok(output) => {
118                // Stream partial output if channel available
119                if let Some(ref tx) = partial_tx {
120                    let _ = tx.send(output.stdout.clone());
121                }
122
123                // Combine stdout + stderr
124                let mut combined = strip_ansi(&output.stdout);
125                if !output.stderr.is_empty() {
126                    if !combined.is_empty() {
127                        combined.push('\n');
128                    }
129                    combined.push_str("[stderr]\n");
130                    combined.push_str(&strip_ansi(&output.stderr));
131                }
132
133                // Truncate from tail (errors/final output matter most)
134                let truncated = truncate_tail(&combined, BASH_MAX_LINES, MAX_BYTES);
135
136                let notice = truncated.truncation_notice();
137                let is_truncated = truncated.is_truncated();
138                let mut result = truncated.content;
139                if let Some(notice) = notice {
140                    result = format!("{}\n{}", notice, result);
141                }
142
143                if output.exit_code != 0 {
144                    result.push_str(&format!("\n[exit code: {}]", output.exit_code));
145                }
146
147                let tool_output = if output.success() {
148                    ToolOutput::success(result)
149                } else {
150                    ToolOutput::error(result)
151                };
152
153                Ok(tool_output.with_metadata(json!({
154                    "exit_code": output.exit_code,
155                    "truncated": is_truncated,
156                })))
157            }
158            Err(e) => Ok(ToolOutput::error(format!("Command failed: {}", e))),
159        }
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use soul_core::vexec::{ExecOutput, MockExecutor};
167
168    fn setup_ok(stdout: &str) -> BashTool {
169        let executor = Arc::new(MockExecutor::always_ok(stdout));
170        BashTool::new(executor as Arc<dyn VirtualExecutor>, "/project")
171    }
172
173    fn setup_with(responses: Vec<ExecOutput>) -> BashTool {
174        let executor = Arc::new(MockExecutor::new(responses));
175        BashTool::new(executor as Arc<dyn VirtualExecutor>, "/project")
176    }
177
178    #[tokio::test]
179    async fn execute_simple_command() {
180        let tool = setup_ok("hello world\n");
181        let result = tool
182            .execute("c1", json!({"command": "echo hello world"}), None)
183            .await
184            .unwrap();
185
186        assert!(!result.is_error);
187        assert!(result.content.contains("hello world"));
188    }
189
190    #[tokio::test]
191    async fn execute_with_error_exit() {
192        let tool = setup_with(vec![ExecOutput {
193            stdout: String::new(),
194            stderr: "command not found".into(),
195            exit_code: 127,
196        }]);
197
198        let result = tool
199            .execute("c2", json!({"command": "nonexistent"}), None)
200            .await
201            .unwrap();
202
203        assert!(result.is_error);
204        assert!(result.content.contains("command not found"));
205        assert!(result.content.contains("exit code: 127"));
206    }
207
208    #[tokio::test]
209    async fn execute_empty_command() {
210        let tool = setup_ok("");
211        let result = tool
212            .execute("c3", json!({"command": ""}), None)
213            .await
214            .unwrap();
215        assert!(result.is_error);
216        assert!(result.content.contains("Missing"));
217    }
218
219    #[tokio::test]
220    async fn strips_ansi() {
221        assert_eq!(strip_ansi("\x1b[31mred\x1b[0m"), "red");
222        assert_eq!(strip_ansi("no ansi"), "no ansi");
223        assert_eq!(strip_ansi("line\r\n"), "line\n");
224    }
225
226    #[tokio::test]
227    async fn stderr_included() {
228        let tool = setup_with(vec![ExecOutput {
229            stdout: "out\n".into(),
230            stderr: "warn\n".into(),
231            exit_code: 0,
232        }]);
233
234        let result = tool
235            .execute("c4", json!({"command": "test"}), None)
236            .await
237            .unwrap();
238
239        assert!(!result.is_error);
240        assert!(result.content.contains("out"));
241        assert!(result.content.contains("[stderr]"));
242        assert!(result.content.contains("warn"));
243    }
244
245    #[tokio::test]
246    async fn streaming_output() {
247        let tool = setup_ok("streamed\n");
248        let (tx, mut rx) = mpsc::unbounded_channel();
249
250        let result = tool
251            .execute("c5", json!({"command": "echo streamed"}), Some(tx))
252            .await
253            .unwrap();
254
255        assert!(!result.is_error);
256        let partial = rx.recv().await.unwrap();
257        assert_eq!(partial, "streamed\n");
258    }
259
260    #[tokio::test]
261    async fn tool_name_and_definition() {
262        let tool = setup_ok("");
263        assert_eq!(tool.name(), "bash");
264        let def = tool.definition();
265        assert_eq!(def.name, "bash");
266        assert!(def.input_schema["required"].as_array().unwrap().contains(&json!("command")));
267    }
268}