steer_tools/tools/
bash.rs

1use once_cell::sync::Lazy;
2use regex::Regex;
3use schemars::JsonSchema;
4use serde::{Deserialize, Serialize};
5use std::time::Duration;
6use steer_macros::tool;
7use tokio::io::AsyncReadExt;
8use tokio::process::Command;
9use tokio::time::timeout;
10
11use crate::result::BashResult;
12use crate::{ExecutionContext, ToolError};
13
14#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
15pub struct BashParams {
16    /// The command to execute
17    pub command: String,
18    /// Optional timeout in milliseconds (default 3600000, max 3600000)
19    #[schemars(range(min = 1, max = 3600000))]
20    pub timeout: Option<u64>,
21}
22
23tool! {
24    BashTool {
25        params: BashParams,
26        output: BashResult,
27        variant: Bash,
28        description: "Run a bash command in the terminal",
29        name: "bash",
30        require_approval: true
31    }
32
33    async fn run(
34        _tool: &BashTool,
35        params: BashParams,
36        context: &ExecutionContext,
37    ) -> Result<BashResult, ToolError> {
38        if context.is_cancelled() {
39            return Err(ToolError::Cancelled(BASH_TOOL_NAME.to_string()));
40        }
41
42        // Basic security check
43        if is_banned_command(&params.command) {
44            return Err(ToolError::execution(
45                BASH_TOOL_NAME,
46                format!(
47                    "Command '{}' is disallowed for security reasons",
48                    params.command
49                ),
50            ));
51        }
52
53        let timeout_ms = params.timeout.unwrap_or(3_600_000).min(3_600_000);
54        let timeout_duration = Duration::from_millis(timeout_ms);
55
56        // Execute the command with cancellation support
57        let result = tokio::select! {
58            _ = context.cancellation_token.cancelled() => {
59                return Err(ToolError::Cancelled(BASH_TOOL_NAME.to_string()));
60            }
61            res = timeout(timeout_duration, run_command(&params.command, context)) => {
62                match res {
63                    Ok(output) => output,
64                    Err(_) => return Err(ToolError::Timeout(BASH_TOOL_NAME.to_string())),
65                }
66            }
67        };
68
69        result
70    }
71}
72
73async fn run_command(command: &str, context: &ExecutionContext) -> Result<BashResult, ToolError> {
74    let mut cmd = Command::new("/bin/bash");
75    cmd.arg("-c")
76        .arg(command)
77        .current_dir(&context.working_directory)
78        .stdout(std::process::Stdio::piped())
79        .stderr(std::process::Stdio::piped())
80        .kill_on_drop(true); // This ensures the child is killed when dropped
81
82    // Set environment variables
83    // TODO: Do we want to do this?
84    for (key, value) in &context.environment {
85        cmd.env(key, value);
86    }
87
88    let mut child = cmd
89        .spawn()
90        .map_err(|e| ToolError::io(BASH_TOOL_NAME, e.to_string()))?;
91
92    // Take the stdout and stderr handles
93    let mut stdout = child
94        .stdout
95        .take()
96        .ok_or_else(|| ToolError::io(BASH_TOOL_NAME, "Failed to capture stdout".to_string()))?;
97    let mut stderr = child
98        .stderr
99        .take()
100        .ok_or_else(|| ToolError::io(BASH_TOOL_NAME, "Failed to capture stderr".to_string()))?;
101
102    // Read stdout and stderr concurrently with the process execution
103    let stdout_handle = tokio::spawn(async move {
104        let mut buf = Vec::new();
105        stdout.read_to_end(&mut buf).await.map(|_| buf)
106    });
107
108    let stderr_handle = tokio::spawn(async move {
109        let mut buf = Vec::new();
110        stderr.read_to_end(&mut buf).await.map(|_| buf)
111    });
112
113    // Wait for the process to complete, with cancellation support
114    let result = tokio::select! {
115        _ = context.cancellation_token.cancelled() => {
116            // The child will be killed automatically when dropped due to kill_on_drop(true)
117            // But we can also explicitly kill it to be sure
118            let _ = child.kill().await;
119            // Also abort the read tasks
120            stdout_handle.abort();
121            stderr_handle.abort();
122            return Err(ToolError::Cancelled(BASH_TOOL_NAME.to_string()));
123        }
124        status = child.wait() => {
125            match status {
126                Ok(status) => {
127                    // Now collect the output that was already being read concurrently
128                    let (stdout_result, stderr_result) = tokio::try_join!(stdout_handle, stderr_handle)
129                        .map_err(|e| ToolError::io(BASH_TOOL_NAME, format!("Failed to join read tasks: {e}")))?;
130
131                    let stdout_bytes = stdout_result.map_err(|e|
132                        ToolError::io(BASH_TOOL_NAME, format!("Failed to read stdout: {e}"))
133                    )?;
134                    let stderr_bytes = stderr_result.map_err(|e|
135                        ToolError::io(BASH_TOOL_NAME, format!("Failed to read stderr: {e}"))
136                    )?;
137
138                    let stdout = String::from_utf8_lossy(&stdout_bytes).to_string();
139                    let stderr = String::from_utf8_lossy(&stderr_bytes).to_string();
140                    let exit_code = status.code().unwrap_or(-1);
141
142                    Ok(BashResult {
143                        stdout,
144                        stderr,
145                        exit_code,
146                        command: command.to_string(),
147                    })
148                }
149                Err(e) => Err(ToolError::io(BASH_TOOL_NAME, e.to_string()))
150            }
151        }
152    };
153
154    result
155}
156
157static BANNED_COMMAND_REGEXES: Lazy<Vec<Regex>> = Lazy::new(|| {
158    let banned_commands = [
159        // Network tools
160        "curl", "wget", "nc", "telnet", "ssh", "scp", "ftp", "sftp",
161        // Web browsers/clients
162        "lynx", "w3m", "links", "elinks", "httpie", "xh", "chrome", "firefox", "safari", "edge",
163        "opera", "chromium", // Download managers
164        "axel", "aria2c", // Shell utilities that might be risky
165        "alias", "unalias", "exec", "source", ".", "history",
166        // Potentially dangerous system modification tools
167        "sudo", "su", "chown", "chmod", "useradd", "userdel", "groupadd", "groupdel",
168        // File editors
169        "vi", "vim", "nano", "pico", "emacs", "ed",
170    ];
171
172    banned_commands
173        .iter()
174        .map(|cmd| {
175            Regex::new(&format!(r"^\s*(\S*/)?{}\b", regex::escape(cmd)))
176                .expect("Failed to compile banned command regex")
177        })
178        .collect()
179});
180
181fn is_banned_command(command: &str) -> bool {
182    BANNED_COMMAND_REGEXES.iter().any(|re| re.is_match(command))
183}