spec_ai_core/tools/builtin/
shell.rs

1use crate::tools::{Tool, ToolResult};
2use anyhow::{anyhow, Context, Result};
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::time::{Duration, Instant};
9use tokio::process::Command;
10use tokio::time;
11use tracing::info;
12
13const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
14const MAX_OUTPUT_CHARS: usize = 16_384;
15
16#[derive(Debug, Deserialize)]
17struct ShellArgs {
18    command: String,
19    shell: Option<String>,
20    shell_args: Option<Vec<String>>,
21    env: Option<HashMap<String, String>>,
22    working_dir: Option<String>,
23    timeout_ms: Option<u64>,
24}
25
26#[derive(Debug, Serialize)]
27struct ShellOutput {
28    command: String,
29    shell: String,
30    stdout: String,
31    stderr: String,
32    exit_code: i32,
33    duration_ms: u128,
34}
35
36fn default_shell() -> (String, Vec<String>) {
37    #[cfg(target_os = "windows")]
38    {
39        ("cmd.exe".to_string(), vec!["/C".to_string()])
40    }
41    #[cfg(not(target_os = "windows"))]
42    {
43        (
44            std::env::var("SHELL").unwrap_or_else(|_| "/bin/sh".to_string()),
45            vec!["-c".to_string()],
46        )
47    }
48}
49
50fn truncate_output(output: &[u8]) -> String {
51    let text = String::from_utf8_lossy(output);
52    if text.len() <= MAX_OUTPUT_CHARS {
53        text.to_string()
54    } else {
55        let mut truncated = text.chars().take(MAX_OUTPUT_CHARS).collect::<String>();
56        truncated.push_str("...<truncated>");
57        truncated
58    }
59}
60
61async fn execute_shell_command(args: &ShellArgs) -> Result<ShellOutput> {
62    if args.command.trim().is_empty() {
63        return Err(anyhow!("shell command cannot be empty"));
64    }
65
66    let (default_shell, default_args) = default_shell();
67    let shell_binary = args.shell.clone().unwrap_or(default_shell);
68    let mut shell_args = args.shell_args.clone().unwrap_or(default_args);
69    if shell_args.is_empty() {
70        shell_args = if cfg!(windows) {
71            vec!["/C".into()]
72        } else {
73            vec!["-c".into()]
74        };
75    }
76
77    let shell_path = PathBuf::from(&shell_binary);
78    if (shell_path.is_absolute() || shell_binary.contains(std::path::MAIN_SEPARATOR))
79        && !shell_path.exists()
80    {
81        return Err(anyhow!(
82            "Shell binary {} does not exist",
83            shell_path.display()
84        ));
85    }
86
87    let timeout = args
88        .timeout_ms
89        .map(Duration::from_millis)
90        .unwrap_or(DEFAULT_TIMEOUT);
91
92    let mut command = Command::new(&shell_binary);
93    for arg in &shell_args {
94        command.arg(arg);
95    }
96    command.arg(&args.command);
97    command.kill_on_drop(true);
98
99    if let Some(dir) = &args.working_dir {
100        command.current_dir(dir);
101    }
102
103    if let Some(ref env) = args.env {
104        for (key, value) in env {
105            command.env(key, value);
106        }
107    }
108
109    info!(
110        target: "spec_ai::tools::shell",
111        command = %args.command,
112        shell = %shell_binary,
113        "Executing shell command"
114    );
115
116    let start = Instant::now();
117    let output = match time::timeout(timeout, command.output()).await {
118        Ok(result) => result.context("Failed to execute shell command")?,
119        Err(_) => {
120            return Err(anyhow!(format!(
121                "Shell command timed out after {} ms",
122                timeout.as_millis()
123            )));
124        }
125    };
126
127    let duration = start.elapsed().as_millis();
128    let stdout = truncate_output(&output.stdout);
129    let stderr = truncate_output(&output.stderr);
130    let exit_code = output.status.code().unwrap_or_default();
131
132    info!(
133        target: "spec_ai::tools::shell",
134        command = %args.command,
135        shell = %shell_binary,
136        exit_code,
137        duration_ms = duration,
138        "Shell command finished"
139    );
140
141    Ok(ShellOutput {
142        command: args.command.clone(),
143        shell: shell_binary,
144        stdout,
145        stderr,
146        exit_code,
147        duration_ms: duration,
148    })
149}
150
151/// Cross-platform shell execution tool
152pub struct ShellTool;
153
154impl ShellTool {
155    pub fn new() -> Self {
156        Self
157    }
158}
159
160impl Default for ShellTool {
161    fn default() -> Self {
162        Self::new()
163    }
164}
165
166#[async_trait]
167impl Tool for ShellTool {
168    fn name(&self) -> &str {
169        "shell"
170    }
171
172    fn description(&self) -> &str {
173        "Executes commands using the system shell with cross-platform support"
174    }
175
176    fn parameters(&self) -> Value {
177        serde_json::json!({
178            "type": "object",
179            "properties": {
180                "command": {
181                    "type": "string",
182                    "description": "Command to execute"
183                },
184                "shell": {
185                    "type": "string",
186                    "description": "Shell binary to use (defaults to system shell)"
187                },
188                "shell_args": {
189                    "type": "array",
190                    "description": "Custom shell arguments (default -c or /C)",
191                    "items": {"type": "string"}
192                },
193                "env": {
194                    "type": "object",
195                    "additionalProperties": {"type": "string"},
196                    "description": "Environment variables for the shell"
197                },
198                "working_dir": {
199                    "type": "string",
200                    "description": "Working directory for the command"
201                },
202                "timeout_ms": {
203                    "type": "integer",
204                    "description": "Maximum execution time in milliseconds"
205                }
206            },
207            "required": ["command"]
208        })
209    }
210
211    async fn execute(&self, args: Value) -> Result<ToolResult> {
212        let args: ShellArgs =
213            serde_json::from_value(args).context("Failed to parse shell arguments")?;
214
215        let output = execute_shell_command(&args).await?;
216
217        if output.exit_code == 0 {
218            Ok(ToolResult::success(
219                serde_json::to_string(&output).context("Failed to serialize shell output")?,
220            ))
221        } else {
222            Ok(ToolResult::failure(
223                serde_json::to_string(&output).context("Failed to serialize shell output")?,
224            ))
225        }
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[tokio::test]
234    async fn test_shell_default() {
235        let tool = ShellTool::new();
236        let args = serde_json::json!({
237            "command": "echo shell-tool"
238        });
239
240        let result = tool.execute(args).await.unwrap();
241        assert!(result.success);
242        let payload: serde_json::Value = serde_json::from_str(&result.output).unwrap();
243        assert!(payload["stdout"].as_str().unwrap().contains("shell-tool"));
244    }
245
246    #[tokio::test]
247    async fn test_shell_nonzero_exit() {
248        let tool = ShellTool::new();
249        let args = serde_json::json!({ "command": "exit 42" });
250        let result = tool.execute(args).await.unwrap();
251        assert!(!result.success);
252    }
253}