spec_ai_core/tools/builtin/
bash.rs

1use crate::tools::{Tool, ToolResult};
2use anyhow::{anyhow, Context, Result};
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashMap;
7use std::path::Path;
8use std::time::{Duration, Instant};
9use tokio::process::Command;
10use tokio::time;
11use tracing::info;
12
13const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
14const MAX_COMMAND_LENGTH: usize = 4096;
15const MAX_OUTPUT_CHARS: usize = 16_384;
16const DENYLIST: &[&str] = &[
17    "sudo", "rm -rf", "reboot", "shutdown", ":(){", "mkfs", "dd if=", ">|",
18];
19
20#[derive(Debug, Deserialize)]
21struct BashArgs {
22    command: String,
23    timeout_ms: Option<u64>,
24    env: Option<HashMap<String, String>>,
25    working_dir: Option<String>,
26}
27
28#[derive(Debug, Serialize)]
29struct CommandOutput {
30    command: String,
31    stdout: String,
32    stderr: String,
33    exit_code: i32,
34    duration_ms: u128,
35}
36
37fn truncate_output(input: &[u8]) -> String {
38    let text = String::from_utf8_lossy(input);
39    if text.len() <= MAX_OUTPUT_CHARS {
40        text.to_string()
41    } else {
42        let mut truncated = text.chars().take(MAX_OUTPUT_CHARS).collect::<String>();
43        truncated.push_str("...<truncated>");
44        truncated
45    }
46}
47
48fn validate_command(command: &str) -> Result<()> {
49    if command.trim().is_empty() {
50        return Err(anyhow!("Command cannot be empty"));
51    }
52
53    if command.len() > MAX_COMMAND_LENGTH {
54        return Err(anyhow!(
55            "Command exceeds maximum allowed length ({})",
56            MAX_COMMAND_LENGTH
57        ));
58    }
59
60    for forbidden in DENYLIST {
61        if command.contains(forbidden) {
62            return Err(anyhow!(format!(
63                "Command contains forbidden pattern '{}'",
64                forbidden
65            )));
66        }
67    }
68
69    Ok(())
70}
71
72async fn run_bash_command(args: &BashArgs, shell_path: &Path) -> Result<CommandOutput> {
73    if !shell_path.exists() {
74        return Err(anyhow!(format!(
75            "Shell path {} does not exist",
76            shell_path.display()
77        )));
78    }
79
80    validate_command(&args.command)?;
81
82    info!(
83        target: "spec_ai::tools::bash",
84        command = %args.command,
85        shell = %shell_path.display(),
86        "Executing bash command"
87    );
88
89    let timeout = args
90        .timeout_ms
91        .map(Duration::from_millis)
92        .unwrap_or(DEFAULT_TIMEOUT);
93
94    let mut command = Command::new(shell_path);
95    command.arg("-c").arg(&args.command);
96    command.kill_on_drop(true);
97
98    if let Some(dir) = &args.working_dir {
99        command.current_dir(dir);
100    }
101
102    if let Some(env) = &args.env {
103        for (key, value) in env {
104            command.env(key, value);
105        }
106    }
107
108    let start = Instant::now();
109    let output = match time::timeout(timeout, command.output()).await {
110        Ok(result) => result.context("Failed to execute bash command")?,
111        Err(_) => {
112            return Err(anyhow!(format!(
113                "Command timed out after {} ms",
114                timeout.as_millis()
115            )));
116        }
117    };
118
119    let duration = start.elapsed().as_millis();
120    let stdout = truncate_output(&output.stdout);
121    let stderr = truncate_output(&output.stderr);
122    let exit_code = output.status.code().unwrap_or_default();
123
124    info!(
125        target: "spec_ai::tools::bash",
126        command = %args.command,
127        exit_code,
128        duration_ms = duration,
129        "Bash command finished"
130    );
131
132    Ok(CommandOutput {
133        command: args.command.clone(),
134        stdout,
135        stderr,
136        exit_code,
137        duration_ms: duration,
138    })
139}
140
141/// Tool that executes bash commands with safety checks
142pub struct BashTool {
143    shell_path: String,
144}
145
146impl BashTool {
147    pub fn new() -> Self {
148        let shell = std::env::var("SHELL").unwrap_or_else(|_| "/bin/bash".to_string());
149        Self { shell_path: shell }
150    }
151
152    pub fn with_shell(mut self, path: impl Into<String>) -> Self {
153        self.shell_path = path.into();
154        self
155    }
156}
157
158impl Default for BashTool {
159    fn default() -> Self {
160        Self::new()
161    }
162}
163
164#[async_trait]
165impl Tool for BashTool {
166    fn name(&self) -> &str {
167        "bash"
168    }
169
170    fn description(&self) -> &str {
171        "Executes bash commands with timeout, output capture, and denylisted operations"
172    }
173
174    fn parameters(&self) -> Value {
175        serde_json::json!({
176            "type": "object",
177            "properties": {
178                "command": {
179                    "type": "string",
180                    "description": "Bash command to run"
181                },
182                "timeout_ms": {
183                    "type": "integer",
184                    "description": "Maximum execution time in milliseconds",
185                    "minimum": 1000
186                },
187                "env": {
188                    "type": "object",
189                    "additionalProperties": {"type": "string"},
190                    "description": "Environment variables for the process"
191                },
192                "working_dir": {
193                    "type": "string",
194                    "description": "Working directory for the command"
195                }
196            },
197            "required": ["command"]
198        })
199    }
200
201    async fn execute(&self, args: Value) -> Result<ToolResult> {
202        let args: BashArgs =
203            serde_json::from_value(args).context("Failed to parse bash arguments")?;
204        let shell_path = Path::new(&self.shell_path);
205
206        let output = run_bash_command(&args, shell_path).await?;
207
208        if output.exit_code == 0 {
209            Ok(ToolResult::success(
210                serde_json::to_string(&output).context("Failed to serialize bash output")?,
211            ))
212        } else {
213            Ok(ToolResult::failure(
214                serde_json::to_string(&output).context("Failed to serialize bash output")?,
215            ))
216        }
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[tokio::test]
225    async fn test_bash_success() {
226        let tool = BashTool::new();
227        let args = serde_json::json!({ "command": "echo test" });
228        let result = tool.execute(args).await.unwrap();
229        assert!(result.success);
230        let payload: serde_json::Value = serde_json::from_str(&result.output).unwrap();
231        assert!(payload["stdout"].as_str().unwrap().contains("test"));
232    }
233
234    #[tokio::test]
235    async fn test_bash_failure() {
236        let tool = BashTool::new();
237        let args = serde_json::json!({ "command": "exit 5" });
238        let result = tool.execute(args).await.unwrap();
239        assert!(!result.success);
240    }
241
242    #[tokio::test]
243    async fn test_bash_timeout() {
244        let tool = BashTool::new();
245        let args = serde_json::json!({
246            "command": "sleep 5",
247            "timeout_ms": 1000
248        });
249        let result = tool.execute(args).await;
250        assert!(result.is_err());
251    }
252}