spec_ai_core/tools/builtin/
shell.rs1use crate::tools::{Tool, ToolResult};
2use anyhow::{anyhow, Context, Result};
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::time::{Duration, Instant};
9use tokio::process::Command;
10use tokio::time;
11use tracing::info;
12
13const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
14const MAX_OUTPUT_CHARS: usize = 16_384;
15
16#[derive(Debug, Deserialize)]
17struct ShellArgs {
18 command: String,
19 shell: Option<String>,
20 shell_args: Option<Vec<String>>,
21 env: Option<HashMap<String, String>>,
22 working_dir: Option<String>,
23 timeout_ms: Option<u64>,
24}
25
26#[derive(Debug, Serialize)]
27struct ShellOutput {
28 command: String,
29 shell: String,
30 stdout: String,
31 stderr: String,
32 exit_code: i32,
33 duration_ms: u128,
34}
35
36fn default_shell() -> (String, Vec<String>) {
37 #[cfg(target_os = "windows")]
38 {
39 ("cmd.exe".to_string(), vec!["/C".to_string()])
40 }
41 #[cfg(not(target_os = "windows"))]
42 {
43 (
44 std::env::var("SHELL").unwrap_or_else(|_| "/bin/sh".to_string()),
45 vec!["-c".to_string()],
46 )
47 }
48}
49
50fn truncate_output(output: &[u8]) -> String {
51 let text = String::from_utf8_lossy(output);
52 if text.len() <= MAX_OUTPUT_CHARS {
53 text.to_string()
54 } else {
55 let mut truncated = text.chars().take(MAX_OUTPUT_CHARS).collect::<String>();
56 truncated.push_str("...<truncated>");
57 truncated
58 }
59}
60
61async fn execute_shell_command(args: &ShellArgs) -> Result<ShellOutput> {
62 if args.command.trim().is_empty() {
63 return Err(anyhow!("shell command cannot be empty"));
64 }
65
66 let (default_shell, default_args) = default_shell();
67 let shell_binary = args.shell.clone().unwrap_or(default_shell);
68 let mut shell_args = args.shell_args.clone().unwrap_or(default_args);
69 if shell_args.is_empty() {
70 shell_args = if cfg!(windows) {
71 vec!["/C".into()]
72 } else {
73 vec!["-c".into()]
74 };
75 }
76
77 let shell_path = PathBuf::from(&shell_binary);
78 if (shell_path.is_absolute() || shell_binary.contains(std::path::MAIN_SEPARATOR))
79 && !shell_path.exists()
80 {
81 return Err(anyhow!(
82 "Shell binary {} does not exist",
83 shell_path.display()
84 ));
85 }
86
87 let timeout = args
88 .timeout_ms
89 .map(Duration::from_millis)
90 .unwrap_or(DEFAULT_TIMEOUT);
91
92 let mut command = Command::new(&shell_binary);
93 for arg in &shell_args {
94 command.arg(arg);
95 }
96 command.arg(&args.command);
97 command.kill_on_drop(true);
98
99 if let Some(dir) = &args.working_dir {
100 command.current_dir(dir);
101 }
102
103 if let Some(ref env) = args.env {
104 for (key, value) in env {
105 command.env(key, value);
106 }
107 }
108
109 info!(
110 target: "spec_ai::tools::shell",
111 command = %args.command,
112 shell = %shell_binary,
113 "Executing shell command"
114 );
115
116 let start = Instant::now();
117 let output = match time::timeout(timeout, command.output()).await {
118 Ok(result) => result.context("Failed to execute shell command")?,
119 Err(_) => {
120 return Err(anyhow!(format!(
121 "Shell command timed out after {} ms",
122 timeout.as_millis()
123 )));
124 }
125 };
126
127 let duration = start.elapsed().as_millis();
128 let stdout = truncate_output(&output.stdout);
129 let stderr = truncate_output(&output.stderr);
130 let exit_code = output.status.code().unwrap_or_default();
131
132 info!(
133 target: "spec_ai::tools::shell",
134 command = %args.command,
135 shell = %shell_binary,
136 exit_code,
137 duration_ms = duration,
138 "Shell command finished"
139 );
140
141 Ok(ShellOutput {
142 command: args.command.clone(),
143 shell: shell_binary,
144 stdout,
145 stderr,
146 exit_code,
147 duration_ms: duration,
148 })
149}
150
151pub struct ShellTool;
153
154impl ShellTool {
155 pub fn new() -> Self {
156 Self
157 }
158}
159
160impl Default for ShellTool {
161 fn default() -> Self {
162 Self::new()
163 }
164}
165
166#[async_trait]
167impl Tool for ShellTool {
168 fn name(&self) -> &str {
169 "shell"
170 }
171
172 fn description(&self) -> &str {
173 "Executes commands using the system shell with cross-platform support"
174 }
175
176 fn parameters(&self) -> Value {
177 serde_json::json!({
178 "type": "object",
179 "properties": {
180 "command": {
181 "type": "string",
182 "description": "Command to execute"
183 },
184 "shell": {
185 "type": "string",
186 "description": "Shell binary to use (defaults to system shell)"
187 },
188 "shell_args": {
189 "type": "array",
190 "description": "Custom shell arguments (default -c or /C)",
191 "items": {"type": "string"}
192 },
193 "env": {
194 "type": "object",
195 "additionalProperties": {"type": "string"},
196 "description": "Environment variables for the shell"
197 },
198 "working_dir": {
199 "type": "string",
200 "description": "Working directory for the command"
201 },
202 "timeout_ms": {
203 "type": "integer",
204 "description": "Maximum execution time in milliseconds"
205 }
206 },
207 "required": ["command"]
208 })
209 }
210
211 async fn execute(&self, args: Value) -> Result<ToolResult> {
212 let args: ShellArgs =
213 serde_json::from_value(args).context("Failed to parse shell arguments")?;
214
215 let output = execute_shell_command(&args).await?;
216
217 if output.exit_code == 0 {
218 Ok(ToolResult::success(
219 serde_json::to_string(&output).context("Failed to serialize shell output")?,
220 ))
221 } else {
222 Ok(ToolResult::failure(
223 serde_json::to_string(&output).context("Failed to serialize shell output")?,
224 ))
225 }
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 #[tokio::test]
234 async fn test_shell_default() {
235 let tool = ShellTool::new();
236 let args = serde_json::json!({
237 "command": "echo shell-tool"
238 });
239
240 let result = tool.execute(args).await.unwrap();
241 assert!(result.success);
242 let payload: serde_json::Value = serde_json::from_str(&result.output).unwrap();
243 assert!(payload["stdout"].as_str().unwrap().contains("shell-tool"));
244 }
245
246 #[tokio::test]
247 async fn test_shell_nonzero_exit() {
248 let tool = ShellTool::new();
249 let args = serde_json::json!({ "command": "exit 42" });
250 let result = tool.execute(args).await.unwrap();
251 assert!(!result.success);
252 }
253}