Skip to main content

saorsa_agent/tools/
bash.rs

1//! Bash tool for executing shell commands.
2
3use std::path::PathBuf;
4use std::time::Duration;
5
6use tracing::debug;
7
8use crate::error::{Result, SaorsaAgentError};
9use crate::tool::Tool;
10
11/// Default command timeout in seconds.
12const DEFAULT_TIMEOUT_SECS: u64 = 120;
13
14/// Maximum output length in bytes before truncation.
15const MAX_OUTPUT_BYTES: usize = 100_000;
16
17/// Tool for executing bash commands.
18pub struct BashTool {
19    /// Working directory for commands.
20    working_dir: PathBuf,
21    /// Command timeout.
22    timeout: Duration,
23}
24
25impl BashTool {
26    /// Create a new bash tool with the given working directory.
27    pub fn new(working_dir: impl Into<PathBuf>) -> Self {
28        Self {
29            working_dir: working_dir.into(),
30            timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
31        }
32    }
33
34    /// Set the command timeout.
35    #[must_use]
36    pub fn timeout(mut self, timeout: Duration) -> Self {
37        self.timeout = timeout;
38        self
39    }
40
41    /// Truncate output if it exceeds the maximum length.
42    fn truncate_output(output: &str) -> String {
43        if output.len() > MAX_OUTPUT_BYTES {
44            let truncated = &output[..MAX_OUTPUT_BYTES];
45            format!(
46                "{truncated}\n\n... (output truncated, {} bytes total)",
47                output.len()
48            )
49        } else {
50            output.to_string()
51        }
52    }
53}
54
55#[async_trait::async_trait]
56impl Tool for BashTool {
57    fn name(&self) -> &str {
58        "bash"
59    }
60
61    fn description(&self) -> &str {
62        "Execute a bash command and return stdout and stderr"
63    }
64
65    fn input_schema(&self) -> serde_json::Value {
66        serde_json::json!({
67            "type": "object",
68            "properties": {
69                "command": {
70                    "type": "string",
71                    "description": "The bash command to execute"
72                }
73            },
74            "required": ["command"]
75        })
76    }
77
78    async fn execute(&self, input: serde_json::Value) -> Result<String> {
79        let command = input
80            .get("command")
81            .and_then(|v| v.as_str())
82            .ok_or_else(|| SaorsaAgentError::Tool("missing 'command' field".into()))?;
83
84        debug!(command = %command, dir = %self.working_dir.display(), "Executing bash command");
85
86        let result = tokio::time::timeout(
87            self.timeout,
88            tokio::process::Command::new("bash")
89                .arg("-c")
90                .arg(command)
91                .current_dir(&self.working_dir)
92                .output(),
93        )
94        .await;
95
96        let output = match result {
97            Ok(Ok(output)) => output,
98            Ok(Err(e)) => {
99                return Err(SaorsaAgentError::Tool(format!(
100                    "failed to execute command: {e}"
101                )));
102            }
103            Err(_) => {
104                return Err(SaorsaAgentError::Tool(format!(
105                    "command timed out after {} seconds",
106                    self.timeout.as_secs()
107                )));
108            }
109        };
110
111        let stdout = String::from_utf8_lossy(&output.stdout);
112        let stderr = String::from_utf8_lossy(&output.stderr);
113        let exit_code = output.status.code().unwrap_or(-1);
114
115        let mut result_text = String::new();
116
117        if !stdout.is_empty() {
118            result_text.push_str(&stdout);
119        }
120
121        if !stderr.is_empty() {
122            if !result_text.is_empty() {
123                result_text.push('\n');
124            }
125            result_text.push_str("STDERR:\n");
126            result_text.push_str(&stderr);
127        }
128
129        if exit_code != 0 {
130            if !result_text.is_empty() {
131                result_text.push('\n');
132            }
133            result_text.push_str(&format!("Exit code: {exit_code}"));
134        }
135
136        if result_text.is_empty() {
137            result_text = "(no output)".to_string();
138        }
139
140        Ok(Self::truncate_output(&result_text))
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    fn test_tool() -> BashTool {
149        BashTool::new(std::env::temp_dir())
150    }
151
152    #[tokio::test]
153    async fn execute_echo() {
154        let tool = test_tool();
155        let result = tool
156            .execute(serde_json::json!({"command": "echo hello"}))
157            .await;
158        assert!(result.is_ok());
159        if let Ok(output) = result {
160            assert!(output.contains("hello"));
161        }
162    }
163
164    #[tokio::test]
165    async fn execute_missing_command_field() {
166        let tool = test_tool();
167        let result = tool.execute(serde_json::json!({})).await;
168        assert!(result.is_err());
169    }
170
171    #[tokio::test]
172    async fn execute_failing_command() {
173        let tool = test_tool();
174        let result = tool
175            .execute(serde_json::json!({"command": "exit 42"}))
176            .await;
177        assert!(result.is_ok());
178        if let Ok(output) = result {
179            assert!(output.contains("Exit code: 42"));
180        }
181    }
182
183    #[tokio::test]
184    async fn execute_stderr() {
185        let tool = test_tool();
186        let result = tool
187            .execute(serde_json::json!({"command": "echo error >&2"}))
188            .await;
189        assert!(result.is_ok());
190        if let Ok(output) = result {
191            assert!(output.contains("STDERR:"));
192            assert!(output.contains("error"));
193        }
194    }
195
196    #[tokio::test]
197    async fn execute_timeout() {
198        let tool = BashTool::new(std::env::temp_dir()).timeout(Duration::from_millis(100));
199        let result = tool
200            .execute(serde_json::json!({"command": "sleep 10"}))
201            .await;
202        assert!(result.is_err());
203        if let Err(e) = result {
204            assert!(e.to_string().contains("timed out"));
205        }
206    }
207
208    #[test]
209    fn tool_metadata() {
210        let tool = test_tool();
211        assert_eq!(tool.name(), "bash");
212        assert!(!tool.description().is_empty());
213        let schema = tool.input_schema();
214        assert_eq!(schema["type"], "object");
215    }
216
217    #[test]
218    fn truncate_long_output() {
219        let long = "x".repeat(MAX_OUTPUT_BYTES + 1000);
220        let truncated = BashTool::truncate_output(&long);
221        assert!(truncated.len() < long.len());
222        assert!(truncated.contains("truncated"));
223    }
224
225    #[test]
226    fn truncate_short_output() {
227        let short = "hello";
228        let result = BashTool::truncate_output(short);
229        assert_eq!(result, "hello");
230    }
231}