1use crate::registry::Tool;
4use async_trait::async_trait;
5use rustant_core::error::ToolError;
6use rustant_core::types::{ProgressUpdate, RiskLevel, ToolOutput};
7use std::path::PathBuf;
8use std::time::Duration;
9use tokio::io::{AsyncBufReadExt, BufReader};
10use tokio::sync::mpsc;
11use tracing::{debug, warn};
12
13pub struct ShellExecTool {
17 workspace: PathBuf,
18 progress_tx: Option<mpsc::UnboundedSender<ProgressUpdate>>,
20}
21
22impl ShellExecTool {
23 pub fn new(workspace: PathBuf) -> Self {
24 Self {
25 workspace,
26 progress_tx: None,
27 }
28 }
29
30 pub fn with_progress(workspace: PathBuf, tx: mpsc::UnboundedSender<ProgressUpdate>) -> Self {
32 Self {
33 workspace,
34 progress_tx: Some(tx),
35 }
36 }
37}
38
39#[async_trait]
40impl Tool for ShellExecTool {
41 fn name(&self) -> &str {
42 "shell_exec"
43 }
44
45 fn description(&self) -> &str {
46 "Execute a shell command in the workspace directory. Returns stdout, stderr, and exit code."
47 }
48
49 fn parameters_schema(&self) -> serde_json::Value {
50 serde_json::json!({
51 "type": "object",
52 "properties": {
53 "command": {
54 "type": "string",
55 "description": "The shell command to execute"
56 },
57 "working_dir": {
58 "type": "string",
59 "description": "Working directory (relative to workspace). Defaults to workspace root."
60 }
61 },
62 "required": ["command"]
63 })
64 }
65
66 async fn execute(&self, args: serde_json::Value) -> Result<ToolOutput, ToolError> {
67 let command = args["command"]
68 .as_str()
69 .ok_or_else(|| ToolError::InvalidArguments {
70 name: "shell_exec".into(),
71 reason: "'command' parameter is required".into(),
72 })?;
73
74 let working_dir = if let Some(dir) = args["working_dir"].as_str() {
75 self.workspace.join(dir)
76 } else {
77 self.workspace.clone()
78 };
79
80 debug!(command = command, cwd = %working_dir.display(), "Executing shell command");
81
82 if let Some(ref tx) = self.progress_tx {
84 self.execute_streaming(command, &working_dir, tx).await
85 } else {
86 self.execute_buffered(command, &working_dir).await
87 }
88 }
89
90 fn risk_level(&self) -> RiskLevel {
91 RiskLevel::Execute
92 }
93
94 fn timeout(&self) -> Duration {
95 Duration::from_secs(120)
96 }
97}
98
99impl ShellExecTool {
100 async fn execute_streaming(
102 &self,
103 command: &str,
104 working_dir: &PathBuf,
105 tx: &mpsc::UnboundedSender<ProgressUpdate>,
106 ) -> Result<ToolOutput, ToolError> {
107 use tokio::process::Command;
108
109 let mut child = Command::new("sh")
110 .arg("-c")
111 .arg(command)
112 .current_dir(working_dir)
113 .stdout(std::process::Stdio::piped())
114 .stderr(std::process::Stdio::piped())
115 .spawn()
116 .map_err(|e| ToolError::ExecutionFailed {
117 name: "shell_exec".into(),
118 message: format!("Failed to execute command: {}", e),
119 })?;
120
121 let _ = tx.send(ProgressUpdate::ToolProgress {
123 tool: "shell_exec".into(),
124 stage: format!("running: {}", truncate_cmd(command, 50)),
125 percent: None,
126 });
127
128 let stdout_pipe = child.stdout.take();
129 let stderr_pipe = child.stderr.take();
130
131 let mut stdout_lines = Vec::new();
132 let mut stderr_lines = Vec::new();
133
134 let tx_stdout = tx.clone();
135 let tx_stderr = tx.clone();
136
137 let stdout_task = tokio::spawn(async move {
139 let mut lines = Vec::new();
140 if let Some(pipe) = stdout_pipe {
141 let reader = BufReader::new(pipe);
142 let mut line_stream = reader.lines();
143 while let Ok(Some(line)) = line_stream.next_line().await {
144 let _ = tx_stdout.send(ProgressUpdate::ShellOutput {
145 line: line.clone(),
146 is_stderr: false,
147 });
148 lines.push(line);
149 }
150 }
151 lines
152 });
153
154 let stderr_task = tokio::spawn(async move {
155 let mut lines = Vec::new();
156 if let Some(pipe) = stderr_pipe {
157 let reader = BufReader::new(pipe);
158 let mut line_stream = reader.lines();
159 while let Ok(Some(line)) = line_stream.next_line().await {
160 let _ = tx_stderr.send(ProgressUpdate::ShellOutput {
161 line: line.clone(),
162 is_stderr: true,
163 });
164 lines.push(line);
165 }
166 }
167 lines
168 });
169
170 let status = child.wait().await.map_err(|e| ToolError::ExecutionFailed {
172 name: "shell_exec".into(),
173 message: format!("Failed to wait for command: {}", e),
174 })?;
175
176 if let Ok(lines) = stdout_task.await {
178 stdout_lines = lines;
179 }
180 if let Ok(lines) = stderr_task.await {
181 stderr_lines = lines;
182 }
183
184 let exit_code = status.code().unwrap_or(-1);
185 let stdout = stdout_lines.join("\n");
186 let stderr = stderr_lines.join("\n");
187
188 let result = format!(
189 "Exit code: {}\n\n--- stdout ---\n{}\n--- stderr ---\n{}",
190 exit_code,
191 if stdout.is_empty() {
192 "(empty)"
193 } else {
194 &stdout
195 },
196 if stderr.is_empty() {
197 "(empty)"
198 } else {
199 &stderr
200 }
201 );
202
203 if exit_code != 0 {
204 warn!(
205 command = command,
206 exit_code, "Command exited with non-zero status"
207 );
208 }
209
210 Ok(ToolOutput::text(result))
211 }
212
213 async fn execute_buffered(
215 &self,
216 command: &str,
217 working_dir: &PathBuf,
218 ) -> Result<ToolOutput, ToolError> {
219 let output = tokio::process::Command::new("sh")
220 .arg("-c")
221 .arg(command)
222 .current_dir(working_dir)
223 .output()
224 .await
225 .map_err(|e| ToolError::ExecutionFailed {
226 name: "shell_exec".into(),
227 message: format!("Failed to execute command: {}", e),
228 })?;
229
230 let stdout = String::from_utf8_lossy(&output.stdout);
231 let stderr = String::from_utf8_lossy(&output.stderr);
232 let exit_code = output.status.code().unwrap_or(-1);
233
234 let result = format!(
235 "Exit code: {}\n\n--- stdout ---\n{}\n--- stderr ---\n{}",
236 exit_code,
237 if stdout.is_empty() {
238 "(empty)"
239 } else {
240 &stdout
241 },
242 if stderr.is_empty() {
243 "(empty)"
244 } else {
245 &stderr
246 }
247 );
248
249 if exit_code != 0 {
250 warn!(
251 command = command,
252 exit_code, "Command exited with non-zero status"
253 );
254 }
255
256 Ok(ToolOutput::text(result))
257 }
258}
259
260fn truncate_cmd(cmd: &str, max: usize) -> String {
262 if cmd.len() <= max {
263 cmd.to_string()
264 } else {
265 format!("{}..", &cmd[..max.saturating_sub(2)])
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use tempfile::TempDir;
273
274 fn setup_workspace() -> TempDir {
275 let dir = TempDir::new().unwrap();
276 std::fs::write(dir.path().join("test.txt"), "hello world").unwrap();
277 dir
278 }
279
280 #[tokio::test]
281 async fn test_shell_exec_basic() {
282 let dir = setup_workspace();
283 let tool = ShellExecTool::new(dir.path().to_path_buf());
284
285 let result = tool
286 .execute(serde_json::json!({"command": "echo hello"}))
287 .await
288 .unwrap();
289
290 assert!(result.content.contains("hello"));
291 assert!(result.content.contains("Exit code: 0"));
292 }
293
294 #[tokio::test]
295 async fn test_shell_exec_with_cwd() {
296 let dir = setup_workspace();
297 std::fs::create_dir_all(dir.path().join("subdir")).unwrap();
298 std::fs::write(dir.path().join("subdir/file.txt"), "sub content").unwrap();
299
300 let tool = ShellExecTool::new(dir.path().to_path_buf());
301
302 let result = tool
303 .execute(serde_json::json!({
304 "command": "cat file.txt",
305 "working_dir": "subdir"
306 }))
307 .await
308 .unwrap();
309
310 assert!(result.content.contains("sub content"));
311 }
312
313 #[tokio::test]
314 async fn test_shell_exec_nonzero_exit() {
315 let dir = setup_workspace();
316 let tool = ShellExecTool::new(dir.path().to_path_buf());
317
318 let result = tool
319 .execute(serde_json::json!({"command": "exit 42"}))
320 .await
321 .unwrap();
322
323 assert!(result.content.contains("Exit code: 42"));
324 }
325
326 #[tokio::test]
327 async fn test_shell_exec_stderr() {
328 let dir = setup_workspace();
329 let tool = ShellExecTool::new(dir.path().to_path_buf());
330
331 let result = tool
332 .execute(serde_json::json!({"command": "echo error >&2"}))
333 .await
334 .unwrap();
335
336 assert!(result.content.contains("error"));
337 assert!(result.content.contains("stderr"));
338 }
339
340 #[tokio::test]
341 async fn test_shell_exec_missing_command() {
342 let dir = setup_workspace();
343 let tool = ShellExecTool::new(dir.path().to_path_buf());
344
345 let result = tool.execute(serde_json::json!({})).await;
346 assert!(result.is_err());
347 match result.unwrap_err() {
348 ToolError::InvalidArguments { name, .. } => assert_eq!(name, "shell_exec"),
349 e => panic!("Expected InvalidArguments, got: {:?}", e),
350 }
351 }
352
353 #[test]
354 fn test_shell_exec_properties() {
355 let tool = ShellExecTool::new(PathBuf::from("/tmp"));
356 assert_eq!(tool.name(), "shell_exec");
357 assert_eq!(tool.risk_level(), RiskLevel::Execute);
358 assert_eq!(tool.timeout(), Duration::from_secs(120));
359 }
360
361 #[tokio::test]
362 async fn test_shell_exec_streaming() {
363 let dir = setup_workspace();
364 let (tx, mut rx) = mpsc::unbounded_channel();
365 let tool = ShellExecTool::with_progress(dir.path().to_path_buf(), tx);
366
367 let result = tool
368 .execute(serde_json::json!({"command": "echo line1 && echo line2"}))
369 .await
370 .unwrap();
371
372 assert!(result.content.contains("line1"));
373 assert!(result.content.contains("line2"));
374 assert!(result.content.contains("Exit code: 0"));
375
376 let mut progress_count = 0;
378 while let Ok(update) = rx.try_recv() {
379 progress_count += 1;
380 match update {
381 ProgressUpdate::ToolProgress { tool, .. } => {
382 assert_eq!(tool, "shell_exec");
383 }
384 ProgressUpdate::ShellOutput { is_stderr, .. } => {
385 assert!(!is_stderr);
386 }
387 _ => {}
388 }
389 }
390 assert!(
392 progress_count >= 3,
393 "Expected at least 3 progress updates, got {}",
394 progress_count
395 );
396 }
397
398 #[tokio::test]
399 async fn test_shell_exec_streaming_stderr() {
400 let dir = setup_workspace();
401 let (tx, mut rx) = mpsc::unbounded_channel();
402 let tool = ShellExecTool::with_progress(dir.path().to_path_buf(), tx);
403
404 let result = tool
405 .execute(serde_json::json!({"command": "echo err >&2"}))
406 .await
407 .unwrap();
408
409 assert!(result.content.contains("err"));
410
411 let mut has_stderr = false;
412 while let Ok(update) = rx.try_recv() {
413 if let ProgressUpdate::ShellOutput { is_stderr, .. } = update
414 && is_stderr
415 {
416 has_stderr = true;
417 }
418 }
419 assert!(has_stderr, "Expected at least one stderr progress update");
420 }
421
422 #[test]
423 fn test_truncate_cmd() {
424 assert_eq!(truncate_cmd("echo hello", 20), "echo hello");
425 assert_eq!(
426 truncate_cmd("a very long command that should be truncated", 20),
427 "a very long comman.."
428 );
429 }
430
431 #[test]
432 fn test_shell_exec_schema() {
433 let tool = ShellExecTool::new(PathBuf::from("/tmp"));
434 let schema = tool.parameters_schema();
435 assert!(schema["properties"]["command"].is_object());
436 assert!(schema["properties"]["working_dir"].is_object());
437 let required = schema["required"].as_array().unwrap();
438 assert!(required.contains(&serde_json::json!("command")));
439 assert!(!required.contains(&serde_json::json!("working_dir")));
440 }
441
442 #[tokio::test]
443 async fn test_shell_exec_empty_command() {
444 let dir = setup_workspace();
445 let tool = ShellExecTool::new(dir.path().to_path_buf());
446
447 let result = tool
449 .execute(serde_json::json!({"command": ""}))
450 .await
451 .unwrap();
452 assert!(result.content.contains("Exit code: 0"));
453 }
454
455 #[tokio::test]
456 async fn test_shell_exec_multiline_output() {
457 let dir = setup_workspace();
458 let tool = ShellExecTool::new(dir.path().to_path_buf());
459
460 let result = tool
461 .execute(serde_json::json!({"command": "echo line1; echo line2; echo line3"}))
462 .await
463 .unwrap();
464
465 assert!(result.content.contains("line1"));
466 assert!(result.content.contains("line2"));
467 assert!(result.content.contains("line3"));
468 }
469
470 #[tokio::test]
471 async fn test_shell_exec_special_chars() {
472 let dir = setup_workspace();
473 let tool = ShellExecTool::new(dir.path().to_path_buf());
474
475 let result = tool
476 .execute(serde_json::json!({"command": "echo 'hello world' \"with quotes\""}))
477 .await
478 .unwrap();
479
480 assert!(result.content.contains("hello world"));
481 assert!(result.content.contains("with quotes"));
482 }
483
484 #[tokio::test]
485 async fn test_shell_exec_reads_workspace_file() {
486 let dir = setup_workspace();
487 let tool = ShellExecTool::new(dir.path().to_path_buf());
488
489 let result = tool
490 .execute(serde_json::json!({"command": "cat test.txt"}))
491 .await
492 .unwrap();
493
494 assert!(result.content.contains("hello world"));
495 }
496
497 #[test]
498 fn test_truncate_cmd_exact_length() {
499 assert_eq!(truncate_cmd("12345", 5), "12345");
500 }
501
502 #[test]
503 fn test_truncate_cmd_empty() {
504 assert_eq!(truncate_cmd("", 10), "");
505 }
506}