Skip to main content

saorsa_agent/tools/
bash.rs

1//! Bash tool for executing shell commands.
2
3use std::path::PathBuf;
4use std::time::Duration;
5
6use tracing::debug;
7
8use crate::error::{Result, SaorsaAgentError};
9use crate::tool::Tool;
10
11/// Default command timeout in seconds.
12const DEFAULT_TIMEOUT_SECS: u64 = 120;
13
14/// Maximum output length in bytes before truncation.
15const MAX_OUTPUT_BYTES: usize = 100_000;
16
17/// Tool for executing bash commands.
18pub struct BashTool {
19    /// Working directory for commands.
20    working_dir: PathBuf,
21    /// Command timeout.
22    timeout: Duration,
23}
24
25impl BashTool {
26    /// Create a new bash tool with the given working directory.
27    pub fn new(working_dir: impl Into<PathBuf>) -> Self {
28        Self {
29            working_dir: working_dir.into(),
30            timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
31        }
32    }
33
34    /// Set the command timeout.
35    #[must_use]
36    pub fn timeout(mut self, timeout: Duration) -> Self {
37        self.timeout = timeout;
38        self
39    }
40
41    /// Truncate output if it exceeds the maximum length.
42    ///
43    /// Rounds down to a valid UTF-8 character boundary to avoid panicking
44    /// on multi-byte characters.
45    fn truncate_output(output: &str) -> String {
46        if output.len() > MAX_OUTPUT_BYTES {
47            // Walk back from MAX_OUTPUT_BYTES until we hit a char boundary.
48            let mut boundary = MAX_OUTPUT_BYTES;
49            while boundary > 0 && !output.is_char_boundary(boundary) {
50                boundary -= 1;
51            }
52            let truncated = &output[..boundary];
53            format!(
54                "{truncated}\n\n... (output truncated, {} bytes total)",
55                output.len()
56            )
57        } else {
58            output.to_string()
59        }
60    }
61}
62
63#[async_trait::async_trait]
64impl Tool for BashTool {
65    fn name(&self) -> &str {
66        "bash"
67    }
68
69    fn description(&self) -> &str {
70        "Execute a bash command and return stdout and stderr"
71    }
72
73    fn input_schema(&self) -> serde_json::Value {
74        serde_json::json!({
75            "type": "object",
76            "properties": {
77                "command": {
78                    "type": "string",
79                    "description": "The bash command to execute"
80                }
81            },
82            "required": ["command"]
83        })
84    }
85
86    async fn execute(&self, input: serde_json::Value) -> Result<String> {
87        let command = input
88            .get("command")
89            .and_then(|v| v.as_str())
90            .ok_or_else(|| SaorsaAgentError::Tool("missing 'command' field".into()))?;
91
92        debug!(command = %command, dir = %self.working_dir.display(), "Executing bash command");
93
94        let result = tokio::time::timeout(
95            self.timeout,
96            tokio::process::Command::new("bash")
97                .arg("-c")
98                .arg(command)
99                .current_dir(&self.working_dir)
100                .output(),
101        )
102        .await;
103
104        let output = match result {
105            Ok(Ok(output)) => output,
106            Ok(Err(e)) => {
107                return Err(SaorsaAgentError::Tool(format!(
108                    "failed to execute command: {e}"
109                )));
110            }
111            Err(_) => {
112                return Err(SaorsaAgentError::Tool(format!(
113                    "command timed out after {} seconds",
114                    self.timeout.as_secs()
115                )));
116            }
117        };
118
119        let stdout = String::from_utf8_lossy(&output.stdout);
120        let stderr = String::from_utf8_lossy(&output.stderr);
121        let exit_code = output.status.code().unwrap_or(-1);
122
123        let mut result_text = String::new();
124
125        if !stdout.is_empty() {
126            result_text.push_str(&stdout);
127        }
128
129        if !stderr.is_empty() {
130            if !result_text.is_empty() {
131                result_text.push('\n');
132            }
133            result_text.push_str("STDERR:\n");
134            result_text.push_str(&stderr);
135        }
136
137        if exit_code != 0 {
138            if !result_text.is_empty() {
139                result_text.push('\n');
140            }
141            result_text.push_str(&format!("Exit code: {exit_code}"));
142        }
143
144        if result_text.is_empty() {
145            result_text = "(no output)".to_string();
146        }
147
148        Ok(Self::truncate_output(&result_text))
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    fn test_tool() -> BashTool {
157        BashTool::new(std::env::temp_dir())
158    }
159
160    #[cfg(unix)]
161    #[tokio::test]
162    async fn execute_echo() {
163        let tool = test_tool();
164        let result = tool
165            .execute(serde_json::json!({"command": "echo hello"}))
166            .await;
167        assert!(result.is_ok());
168        if let Ok(output) = result {
169            assert!(output.contains("hello"));
170        }
171    }
172
173    #[tokio::test]
174    async fn execute_missing_command_field() {
175        let tool = test_tool();
176        let result = tool.execute(serde_json::json!({})).await;
177        assert!(result.is_err());
178    }
179
180    #[cfg(unix)]
181    #[tokio::test]
182    async fn execute_failing_command() {
183        let tool = test_tool();
184        let result = tool
185            .execute(serde_json::json!({"command": "exit 42"}))
186            .await;
187        assert!(result.is_ok());
188        if let Ok(output) = result {
189            assert!(output.contains("Exit code: 42"));
190        }
191    }
192
193    #[cfg(unix)]
194    #[tokio::test]
195    async fn execute_stderr() {
196        let tool = test_tool();
197        let result = tool
198            .execute(serde_json::json!({"command": "echo error >&2"}))
199            .await;
200        assert!(result.is_ok());
201        if let Ok(output) = result {
202            assert!(output.contains("STDERR:"));
203            assert!(output.contains("error"));
204        }
205    }
206
207    #[cfg(unix)]
208    #[tokio::test]
209    async fn execute_timeout() {
210        let tool = BashTool::new(std::env::temp_dir()).timeout(Duration::from_millis(100));
211        let result = tool
212            .execute(serde_json::json!({"command": "sleep 10"}))
213            .await;
214        assert!(result.is_err());
215        if let Err(e) = result {
216            assert!(e.to_string().contains("timed out"));
217        }
218    }
219
220    #[test]
221    fn tool_metadata() {
222        let tool = test_tool();
223        assert_eq!(tool.name(), "bash");
224        assert!(!tool.description().is_empty());
225        let schema = tool.input_schema();
226        assert_eq!(schema["type"], "object");
227    }
228
229    #[test]
230    fn truncate_long_output() {
231        let long = "x".repeat(MAX_OUTPUT_BYTES + 1000);
232        let truncated = BashTool::truncate_output(&long);
233        assert!(truncated.len() < long.len());
234        assert!(truncated.contains("truncated"));
235    }
236
237    #[test]
238    fn truncate_short_output() {
239        let short = "hello";
240        let result = BashTool::truncate_output(short);
241        assert_eq!(result, "hello");
242    }
243}