steer_tools/tools/
bash.rs1use once_cell::sync::Lazy;
2use regex::Regex;
3use schemars::JsonSchema;
4use serde::{Deserialize, Serialize};
5use std::time::Duration;
6use steer_macros::tool;
7use tokio::io::AsyncReadExt;
8use tokio::process::Command;
9use tokio::time::timeout;
10
11use crate::result::BashResult;
12use crate::{ExecutionContext, ToolError};
13
14#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
15pub struct BashParams {
16 pub command: String,
18 #[schemars(range(min = 1, max = 3600000))]
20 pub timeout: Option<u64>,
21}
22
23tool! {
24 BashTool {
25 params: BashParams,
26 output: BashResult,
27 variant: Bash,
28 description: "Run a bash command in the terminal",
29 name: "bash",
30 require_approval: true
31 }
32
33 async fn run(
34 _tool: &BashTool,
35 params: BashParams,
36 context: &ExecutionContext,
37 ) -> Result<BashResult, ToolError> {
38 if context.is_cancelled() {
39 return Err(ToolError::Cancelled(BASH_TOOL_NAME.to_string()));
40 }
41
42 if is_banned_command(¶ms.command) {
44 return Err(ToolError::execution(
45 BASH_TOOL_NAME,
46 format!(
47 "Command '{}' is disallowed for security reasons",
48 params.command
49 ),
50 ));
51 }
52
53 let timeout_ms = params.timeout.unwrap_or(3_600_000).min(3_600_000);
54 let timeout_duration = Duration::from_millis(timeout_ms);
55
56 let result = tokio::select! {
58 _ = context.cancellation_token.cancelled() => {
59 return Err(ToolError::Cancelled(BASH_TOOL_NAME.to_string()));
60 }
61 res = timeout(timeout_duration, run_command(¶ms.command, context)) => {
62 match res {
63 Ok(output) => output,
64 Err(_) => return Err(ToolError::Timeout(BASH_TOOL_NAME.to_string())),
65 }
66 }
67 };
68
69 result
70 }
71}
72
73async fn run_command(command: &str, context: &ExecutionContext) -> Result<BashResult, ToolError> {
74 let mut cmd = Command::new("/bin/bash");
75 cmd.arg("-c")
76 .arg(command)
77 .current_dir(&context.working_directory)
78 .stdout(std::process::Stdio::piped())
79 .stderr(std::process::Stdio::piped())
80 .kill_on_drop(true); for (key, value) in &context.environment {
85 cmd.env(key, value);
86 }
87
88 let mut child = cmd
89 .spawn()
90 .map_err(|e| ToolError::io(BASH_TOOL_NAME, e.to_string()))?;
91
92 let mut stdout = child
94 .stdout
95 .take()
96 .ok_or_else(|| ToolError::io(BASH_TOOL_NAME, "Failed to capture stdout".to_string()))?;
97 let mut stderr = child
98 .stderr
99 .take()
100 .ok_or_else(|| ToolError::io(BASH_TOOL_NAME, "Failed to capture stderr".to_string()))?;
101
102 let stdout_handle = tokio::spawn(async move {
104 let mut buf = Vec::new();
105 stdout.read_to_end(&mut buf).await.map(|_| buf)
106 });
107
108 let stderr_handle = tokio::spawn(async move {
109 let mut buf = Vec::new();
110 stderr.read_to_end(&mut buf).await.map(|_| buf)
111 });
112
113 let result = tokio::select! {
115 _ = context.cancellation_token.cancelled() => {
116 let _ = child.kill().await;
119 stdout_handle.abort();
121 stderr_handle.abort();
122 return Err(ToolError::Cancelled(BASH_TOOL_NAME.to_string()));
123 }
124 status = child.wait() => {
125 match status {
126 Ok(status) => {
127 let (stdout_result, stderr_result) = tokio::try_join!(stdout_handle, stderr_handle)
129 .map_err(|e| ToolError::io(BASH_TOOL_NAME, format!("Failed to join read tasks: {e}")))?;
130
131 let stdout_bytes = stdout_result.map_err(|e|
132 ToolError::io(BASH_TOOL_NAME, format!("Failed to read stdout: {e}"))
133 )?;
134 let stderr_bytes = stderr_result.map_err(|e|
135 ToolError::io(BASH_TOOL_NAME, format!("Failed to read stderr: {e}"))
136 )?;
137
138 let stdout = String::from_utf8_lossy(&stdout_bytes).to_string();
139 let stderr = String::from_utf8_lossy(&stderr_bytes).to_string();
140 let exit_code = status.code().unwrap_or(-1);
141
142 Ok(BashResult {
143 stdout,
144 stderr,
145 exit_code,
146 command: command.to_string(),
147 })
148 }
149 Err(e) => Err(ToolError::io(BASH_TOOL_NAME, e.to_string()))
150 }
151 }
152 };
153
154 result
155}
156
157static BANNED_COMMAND_REGEXES: Lazy<Vec<Regex>> = Lazy::new(|| {
158 let banned_commands = [
159 "curl", "wget", "nc", "telnet", "ssh", "scp", "ftp", "sftp",
161 "lynx", "w3m", "links", "elinks", "httpie", "xh", "chrome", "firefox", "safari", "edge",
163 "opera", "chromium", "axel", "aria2c", "alias", "unalias", "exec", "source", ".", "history",
166 "sudo", "su", "chown", "chmod", "useradd", "userdel", "groupadd", "groupdel",
168 "vi", "vim", "nano", "pico", "emacs", "ed",
170 ];
171
172 banned_commands
173 .iter()
174 .map(|cmd| {
175 Regex::new(&format!(r"^\s*(\S*/)?{}\b", regex::escape(cmd)))
176 .expect("Failed to compile banned command regex")
177 })
178 .collect()
179});
180
181fn is_banned_command(command: &str) -> bool {
182 BANNED_COMMAND_REGEXES.iter().any(|re| re.is_match(command))
183}