Skip to main content

steer_core/tools/static_tools/
bash.rs

1use async_trait::async_trait;
2use regex::Regex;
3use std::time::Duration;
4use tokio::io::AsyncReadExt;
5use tokio::process::Command;
6use tokio::time::timeout;
7
8use crate::tools::capability::Capabilities;
9use crate::tools::static_tool::{StaticTool, StaticToolContext, StaticToolError};
10use steer_tools::result::BashResult;
11use steer_tools::tools::bash::{BashError, BashParams, BashToolSpec};
12
13pub struct BashTool;
14
15#[async_trait]
16impl StaticTool for BashTool {
17    type Params = BashParams;
18    type Output = BashResult;
19    type Spec = BashToolSpec;
20
21    const DESCRIPTION: &'static str = "Run a bash command in the terminal";
22    const REQUIRES_APPROVAL: bool = true;
23    const REQUIRED_CAPABILITIES: Capabilities = Capabilities::WORKSPACE;
24
25    async fn execute(
26        &self,
27        params: Self::Params,
28        ctx: &StaticToolContext,
29    ) -> Result<Self::Output, StaticToolError<BashError>> {
30        if ctx.is_cancelled() {
31            return Err(StaticToolError::Cancelled);
32        }
33
34        if is_banned_command(&params.command) {
35            return Err(StaticToolError::execution(BashError::DisallowedCommand {
36                command: params.command,
37            }));
38        }
39
40        let timeout_ms = params.timeout.unwrap_or(3_600_000).min(3_600_000);
41        let timeout_duration = Duration::from_millis(timeout_ms);
42        let working_directory = ctx.services.workspace.working_directory().to_path_buf();
43
44        tokio::select! {
45            () = ctx.cancellation_token.cancelled() => Err(StaticToolError::Cancelled),
46            res = timeout(timeout_duration, run_command(&params.command, &working_directory, ctx.cancellation_token.clone())) => {
47                match res {
48                    Ok(output) => output,
49                    Err(_) => Err(StaticToolError::Timeout),
50                }
51            }
52        }
53    }
54}
55
56async fn run_command(
57    command: &str,
58    working_directory: &std::path::Path,
59    cancellation_token: tokio_util::sync::CancellationToken,
60) -> Result<BashResult, StaticToolError<BashError>> {
61    let mut cmd = Command::new("/bin/bash");
62    cmd.arg("-c")
63        .arg(command)
64        .current_dir(working_directory)
65        .stdout(std::process::Stdio::piped())
66        .stderr(std::process::Stdio::piped())
67        .kill_on_drop(true);
68
69    let mut child = cmd.spawn().map_err(|e| {
70        StaticToolError::execution(BashError::Io {
71            message: e.to_string(),
72        })
73    })?;
74
75    let mut stdout = child.stdout.take().ok_or_else(|| {
76        StaticToolError::execution(BashError::Io {
77            message: "Failed to capture stdout".to_string(),
78        })
79    })?;
80    let mut stderr = child.stderr.take().ok_or_else(|| {
81        StaticToolError::execution(BashError::Io {
82            message: "Failed to capture stderr".to_string(),
83        })
84    })?;
85
86    let stdout_handle = tokio::spawn(async move {
87        let mut buf = Vec::new();
88        stdout.read_to_end(&mut buf).await.map(|_| buf)
89    });
90
91    let stderr_handle = tokio::spawn(async move {
92        let mut buf = Vec::new();
93        stderr.read_to_end(&mut buf).await.map(|_| buf)
94    });
95
96    let result = tokio::select! {
97        () = cancellation_token.cancelled() => {
98            let _ = child.kill().await;
99            stdout_handle.abort();
100            stderr_handle.abort();
101            Err(StaticToolError::Cancelled)
102        }
103        status = child.wait() => {
104            match status {
105                Ok(status) => {
106                    let (stdout_result, stderr_result) = tokio::try_join!(stdout_handle, stderr_handle)
107                        .map_err(|e| {
108                            StaticToolError::execution(BashError::Io {
109                                message: format!("Failed to join read tasks: {e}"),
110                            })
111                        })?;
112
113                    let stdout_bytes = stdout_result
114                        .map_err(|e| {
115                            StaticToolError::execution(BashError::Io {
116                                message: format!("Failed to read stdout: {e}"),
117                            })
118                        })?;
119                    let stderr_bytes = stderr_result
120                        .map_err(|e| {
121                            StaticToolError::execution(BashError::Io {
122                                message: format!("Failed to read stderr: {e}"),
123                            })
124                        })?;
125
126                    let stdout = String::from_utf8_lossy(&stdout_bytes).to_string();
127                    let stderr = String::from_utf8_lossy(&stderr_bytes).to_string();
128                    let exit_code = status.code().unwrap_or(-1);
129
130                    Ok(BashResult {
131                        stdout,
132                        stderr,
133                        exit_code,
134                        command: command.to_string(),
135                    })
136                }
137                Err(e) => Err(StaticToolError::execution(BashError::Io {
138                    message: e.to_string(),
139                })),
140            }
141        }
142    };
143
144    result
145}
146
147static BANNED_COMMAND_REGEXES: std::sync::LazyLock<Vec<Regex>> = std::sync::LazyLock::new(|| {
148    let banned_commands = [
149        "curl", "wget", "nc", "telnet", "ssh", "scp", "ftp", "sftp", "lynx", "w3m", "links",
150        "elinks", "httpie", "xh", "chrome", "firefox", "safari", "edge", "opera", "chromium",
151        "axel", "aria2c", "alias", "unalias", "exec", "source", ".", "history", "sudo", "su",
152        "chown", "chmod", "useradd", "userdel", "groupadd", "groupdel", "vi", "vim", "nano",
153        "pico", "emacs", "ed",
154    ];
155
156    banned_commands
157        .iter()
158        .filter_map(|cmd| {
159            let pattern = format!(r"^\\s*(\\S*/)?{}\\b", regex::escape(cmd));
160            match Regex::new(&pattern) {
161                Ok(regex) => Some(regex),
162                Err(err) => {
163                    tracing::error!(
164                        target: "tools::bash",
165                        command = %cmd,
166                        error = %err,
167                        "Failed to compile banned command regex"
168                    );
169                    None
170                }
171            }
172        })
173        .collect()
174});
175
176fn is_banned_command(command: &str) -> bool {
177    BANNED_COMMAND_REGEXES.iter().any(|re| re.is_match(command))
178}