saorsa_agent/tools/
bash.rs1use std::path::PathBuf;
4use std::time::Duration;
5
6use tracing::debug;
7
8use crate::error::{Result, SaorsaAgentError};
9use crate::tool::Tool;
10
11const DEFAULT_TIMEOUT_SECS: u64 = 120;
13
14const MAX_OUTPUT_BYTES: usize = 100_000;
16
17pub struct BashTool {
19 working_dir: PathBuf,
21 timeout: Duration,
23}
24
25impl BashTool {
26 pub fn new(working_dir: impl Into<PathBuf>) -> Self {
28 Self {
29 working_dir: working_dir.into(),
30 timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
31 }
32 }
33
34 #[must_use]
36 pub fn timeout(mut self, timeout: Duration) -> Self {
37 self.timeout = timeout;
38 self
39 }
40
41 fn truncate_output(output: &str) -> String {
43 if output.len() > MAX_OUTPUT_BYTES {
44 let truncated = &output[..MAX_OUTPUT_BYTES];
45 format!(
46 "{truncated}\n\n... (output truncated, {} bytes total)",
47 output.len()
48 )
49 } else {
50 output.to_string()
51 }
52 }
53}
54
55#[async_trait::async_trait]
56impl Tool for BashTool {
57 fn name(&self) -> &str {
58 "bash"
59 }
60
61 fn description(&self) -> &str {
62 "Execute a bash command and return stdout and stderr"
63 }
64
65 fn input_schema(&self) -> serde_json::Value {
66 serde_json::json!({
67 "type": "object",
68 "properties": {
69 "command": {
70 "type": "string",
71 "description": "The bash command to execute"
72 }
73 },
74 "required": ["command"]
75 })
76 }
77
78 async fn execute(&self, input: serde_json::Value) -> Result<String> {
79 let command = input
80 .get("command")
81 .and_then(|v| v.as_str())
82 .ok_or_else(|| SaorsaAgentError::Tool("missing 'command' field".into()))?;
83
84 debug!(command = %command, dir = %self.working_dir.display(), "Executing bash command");
85
86 let result = tokio::time::timeout(
87 self.timeout,
88 tokio::process::Command::new("bash")
89 .arg("-c")
90 .arg(command)
91 .current_dir(&self.working_dir)
92 .output(),
93 )
94 .await;
95
96 let output = match result {
97 Ok(Ok(output)) => output,
98 Ok(Err(e)) => {
99 return Err(SaorsaAgentError::Tool(format!(
100 "failed to execute command: {e}"
101 )));
102 }
103 Err(_) => {
104 return Err(SaorsaAgentError::Tool(format!(
105 "command timed out after {} seconds",
106 self.timeout.as_secs()
107 )));
108 }
109 };
110
111 let stdout = String::from_utf8_lossy(&output.stdout);
112 let stderr = String::from_utf8_lossy(&output.stderr);
113 let exit_code = output.status.code().unwrap_or(-1);
114
115 let mut result_text = String::new();
116
117 if !stdout.is_empty() {
118 result_text.push_str(&stdout);
119 }
120
121 if !stderr.is_empty() {
122 if !result_text.is_empty() {
123 result_text.push('\n');
124 }
125 result_text.push_str("STDERR:\n");
126 result_text.push_str(&stderr);
127 }
128
129 if exit_code != 0 {
130 if !result_text.is_empty() {
131 result_text.push('\n');
132 }
133 result_text.push_str(&format!("Exit code: {exit_code}"));
134 }
135
136 if result_text.is_empty() {
137 result_text = "(no output)".to_string();
138 }
139
140 Ok(Self::truncate_output(&result_text))
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147
148 fn test_tool() -> BashTool {
149 BashTool::new(std::env::temp_dir())
150 }
151
152 #[tokio::test]
153 async fn execute_echo() {
154 let tool = test_tool();
155 let result = tool
156 .execute(serde_json::json!({"command": "echo hello"}))
157 .await;
158 assert!(result.is_ok());
159 if let Ok(output) = result {
160 assert!(output.contains("hello"));
161 }
162 }
163
164 #[tokio::test]
165 async fn execute_missing_command_field() {
166 let tool = test_tool();
167 let result = tool.execute(serde_json::json!({})).await;
168 assert!(result.is_err());
169 }
170
171 #[tokio::test]
172 async fn execute_failing_command() {
173 let tool = test_tool();
174 let result = tool
175 .execute(serde_json::json!({"command": "exit 42"}))
176 .await;
177 assert!(result.is_ok());
178 if let Ok(output) = result {
179 assert!(output.contains("Exit code: 42"));
180 }
181 }
182
183 #[tokio::test]
184 async fn execute_stderr() {
185 let tool = test_tool();
186 let result = tool
187 .execute(serde_json::json!({"command": "echo error >&2"}))
188 .await;
189 assert!(result.is_ok());
190 if let Ok(output) = result {
191 assert!(output.contains("STDERR:"));
192 assert!(output.contains("error"));
193 }
194 }
195
196 #[tokio::test]
197 async fn execute_timeout() {
198 let tool = BashTool::new(std::env::temp_dir()).timeout(Duration::from_millis(100));
199 let result = tool
200 .execute(serde_json::json!({"command": "sleep 10"}))
201 .await;
202 assert!(result.is_err());
203 if let Err(e) = result {
204 assert!(e.to_string().contains("timed out"));
205 }
206 }
207
208 #[test]
209 fn tool_metadata() {
210 let tool = test_tool();
211 assert_eq!(tool.name(), "bash");
212 assert!(!tool.description().is_empty());
213 let schema = tool.input_schema();
214 assert_eq!(schema["type"], "object");
215 }
216
217 #[test]
218 fn truncate_long_output() {
219 let long = "x".repeat(MAX_OUTPUT_BYTES + 1000);
220 let truncated = BashTool::truncate_output(&long);
221 assert!(truncated.len() < long.len());
222 assert!(truncated.contains("truncated"));
223 }
224
225 #[test]
226 fn truncate_short_output() {
227 let short = "hello";
228 let result = BashTool::truncate_output(short);
229 assert_eq!(result, "hello");
230 }
231}