spec_ai_core/tools/builtin/
bash.rs1use crate::tools::{Tool, ToolResult};
2use anyhow::{anyhow, Context, Result};
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashMap;
7use std::path::Path;
8use std::time::{Duration, Instant};
9use tokio::process::Command;
10use tokio::time;
11use tracing::info;
12
13const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
14const MAX_COMMAND_LENGTH: usize = 4096;
15const MAX_OUTPUT_CHARS: usize = 16_384;
16const DENYLIST: &[&str] = &[
17 "sudo", "rm -rf", "reboot", "shutdown", ":(){", "mkfs", "dd if=", ">|",
18];
19
20#[derive(Debug, Deserialize)]
21struct BashArgs {
22 command: String,
23 timeout_ms: Option<u64>,
24 env: Option<HashMap<String, String>>,
25 working_dir: Option<String>,
26}
27
28#[derive(Debug, Serialize)]
29struct CommandOutput {
30 command: String,
31 stdout: String,
32 stderr: String,
33 exit_code: i32,
34 duration_ms: u128,
35}
36
37fn truncate_output(input: &[u8]) -> String {
38 let text = String::from_utf8_lossy(input);
39 if text.len() <= MAX_OUTPUT_CHARS {
40 text.to_string()
41 } else {
42 let mut truncated = text.chars().take(MAX_OUTPUT_CHARS).collect::<String>();
43 truncated.push_str("...<truncated>");
44 truncated
45 }
46}
47
48fn validate_command(command: &str) -> Result<()> {
49 if command.trim().is_empty() {
50 return Err(anyhow!("Command cannot be empty"));
51 }
52
53 if command.len() > MAX_COMMAND_LENGTH {
54 return Err(anyhow!(
55 "Command exceeds maximum allowed length ({})",
56 MAX_COMMAND_LENGTH
57 ));
58 }
59
60 for forbidden in DENYLIST {
61 if command.contains(forbidden) {
62 return Err(anyhow!(format!(
63 "Command contains forbidden pattern '{}'",
64 forbidden
65 )));
66 }
67 }
68
69 Ok(())
70}
71
72async fn run_bash_command(args: &BashArgs, shell_path: &Path) -> Result<CommandOutput> {
73 if !shell_path.exists() {
74 return Err(anyhow!(format!(
75 "Shell path {} does not exist",
76 shell_path.display()
77 )));
78 }
79
80 validate_command(&args.command)?;
81
82 info!(
83 target: "spec_ai::tools::bash",
84 command = %args.command,
85 shell = %shell_path.display(),
86 "Executing bash command"
87 );
88
89 let timeout = args
90 .timeout_ms
91 .map(Duration::from_millis)
92 .unwrap_or(DEFAULT_TIMEOUT);
93
94 let mut command = Command::new(shell_path);
95 command.arg("-c").arg(&args.command);
96 command.kill_on_drop(true);
97
98 if let Some(dir) = &args.working_dir {
99 command.current_dir(dir);
100 }
101
102 if let Some(env) = &args.env {
103 for (key, value) in env {
104 command.env(key, value);
105 }
106 }
107
108 let start = Instant::now();
109 let output = match time::timeout(timeout, command.output()).await {
110 Ok(result) => result.context("Failed to execute bash command")?,
111 Err(_) => {
112 return Err(anyhow!(format!(
113 "Command timed out after {} ms",
114 timeout.as_millis()
115 )));
116 }
117 };
118
119 let duration = start.elapsed().as_millis();
120 let stdout = truncate_output(&output.stdout);
121 let stderr = truncate_output(&output.stderr);
122 let exit_code = output.status.code().unwrap_or_default();
123
124 info!(
125 target: "spec_ai::tools::bash",
126 command = %args.command,
127 exit_code,
128 duration_ms = duration,
129 "Bash command finished"
130 );
131
132 Ok(CommandOutput {
133 command: args.command.clone(),
134 stdout,
135 stderr,
136 exit_code,
137 duration_ms: duration,
138 })
139}
140
141pub struct BashTool {
143 shell_path: String,
144}
145
146impl BashTool {
147 pub fn new() -> Self {
148 let shell = std::env::var("SHELL").unwrap_or_else(|_| "/bin/bash".to_string());
149 Self { shell_path: shell }
150 }
151
152 pub fn with_shell(mut self, path: impl Into<String>) -> Self {
153 self.shell_path = path.into();
154 self
155 }
156}
157
158impl Default for BashTool {
159 fn default() -> Self {
160 Self::new()
161 }
162}
163
164#[async_trait]
165impl Tool for BashTool {
166 fn name(&self) -> &str {
167 "bash"
168 }
169
170 fn description(&self) -> &str {
171 "Executes bash commands with timeout, output capture, and denylisted operations"
172 }
173
174 fn parameters(&self) -> Value {
175 serde_json::json!({
176 "type": "object",
177 "properties": {
178 "command": {
179 "type": "string",
180 "description": "Bash command to run"
181 },
182 "timeout_ms": {
183 "type": "integer",
184 "description": "Maximum execution time in milliseconds",
185 "minimum": 1000
186 },
187 "env": {
188 "type": "object",
189 "additionalProperties": {"type": "string"},
190 "description": "Environment variables for the process"
191 },
192 "working_dir": {
193 "type": "string",
194 "description": "Working directory for the command"
195 }
196 },
197 "required": ["command"]
198 })
199 }
200
201 async fn execute(&self, args: Value) -> Result<ToolResult> {
202 let args: BashArgs =
203 serde_json::from_value(args).context("Failed to parse bash arguments")?;
204 let shell_path = Path::new(&self.shell_path);
205
206 let output = run_bash_command(&args, shell_path).await?;
207
208 if output.exit_code == 0 {
209 Ok(ToolResult::success(
210 serde_json::to_string(&output).context("Failed to serialize bash output")?,
211 ))
212 } else {
213 Ok(ToolResult::failure(
214 serde_json::to_string(&output).context("Failed to serialize bash output")?,
215 ))
216 }
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[tokio::test]
225 async fn test_bash_success() {
226 let tool = BashTool::new();
227 let args = serde_json::json!({ "command": "echo test" });
228 let result = tool.execute(args).await.unwrap();
229 assert!(result.success);
230 let payload: serde_json::Value = serde_json::from_str(&result.output).unwrap();
231 assert!(payload["stdout"].as_str().unwrap().contains("test"));
232 }
233
234 #[tokio::test]
235 async fn test_bash_failure() {
236 let tool = BashTool::new();
237 let args = serde_json::json!({ "command": "exit 5" });
238 let result = tool.execute(args).await.unwrap();
239 assert!(!result.success);
240 }
241
242 #[tokio::test]
243 async fn test_bash_timeout() {
244 let tool = BashTool::new();
245 let args = serde_json::json!({
246 "command": "sleep 5",
247 "timeout_ms": 1000
248 });
249 let result = tool.execute(args).await;
250 assert!(result.is_err());
251 }
252}