Skip to main content

swink_agent/tools/
bash.rs

1//! Built-in tool for executing shell commands.
2
3use std::time::Duration;
4
5use schemars::JsonSchema;
6use serde::Deserialize;
7use serde_json::Value;
8use tokio::process::Command;
9use tokio_util::sync::CancellationToken;
10
11use super::MAX_OUTPUT_BYTES;
12use crate::tool::{AgentTool, AgentToolResult, ToolFuture, validated_schema_for};
13
14/// Default timeout in milliseconds.
15const DEFAULT_TIMEOUT_MS: u64 = 30_000;
16
17/// Built-in tool that executes a shell command.
18///
19/// On Unix-like targets the command is passed to `sh -c`. On Windows the
20/// command is passed to `cmd /C`, matching the platform's native shell.
21///
22/// # Security
23///
24/// This tool executes arbitrary shell commands via the platform shell. It
25/// should only be used with trusted input. It is NOT suitable for production
26/// agents exposed to untrusted users.
27pub struct BashTool {
28    schema: Value,
29}
30
31impl BashTool {
32    /// Create a new `BashTool`.
33    #[must_use]
34    pub fn new() -> Self {
35        Self {
36            schema: validated_schema_for::<Params>(),
37        }
38    }
39}
40
41impl Default for BashTool {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47#[derive(Deserialize, JsonSchema)]
48#[schemars(deny_unknown_fields)]
49struct Params {
50    /// Shell command to execute.
51    command: String,
52    /// Timeout in milliseconds (default 30000).
53    timeout_ms: Option<u64>,
54}
55
56#[allow(clippy::unnecessary_literal_bound)]
57impl AgentTool for BashTool {
58    fn name(&self) -> &str {
59        "bash"
60    }
61
62    fn label(&self) -> &str {
63        "Bash"
64    }
65
66    fn description(&self) -> &str {
67        "Execute a shell command."
68    }
69
70    fn parameters_schema(&self) -> &Value {
71        &self.schema
72    }
73
74    fn requires_approval(&self) -> bool {
75        true
76    }
77
78    fn execute(
79        &self,
80        _tool_call_id: &str,
81        params: Value,
82        cancellation_token: CancellationToken,
83        _on_update: Option<Box<dyn Fn(AgentToolResult) + Send + Sync>>,
84        _state: std::sync::Arc<std::sync::RwLock<crate::SessionState>>,
85        _credential: Option<crate::credential::ResolvedCredential>,
86    ) -> ToolFuture<'_> {
87        Box::pin(async move {
88            let parsed: Params = match serde_json::from_value(params) {
89                Ok(p) => p,
90                Err(e) => return AgentToolResult::error(format!("invalid parameters: {e}")),
91            };
92
93            if cancellation_token.is_cancelled() {
94                return AgentToolResult::error("cancelled");
95            }
96
97            let timeout = Duration::from_millis(parsed.timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS));
98
99            let mut child = match shell_command(&parsed.command)
100                .stdout(std::process::Stdio::piped())
101                .stderr(std::process::Stdio::piped())
102                .spawn()
103            {
104                Ok(c) => c,
105                Err(e) => {
106                    return AgentToolResult::error(format!("failed to spawn command: {e}"));
107                }
108            };
109
110            // Spawn concurrent readers for stdout/stderr to prevent deadlocks
111            // when OS pipe buffers fill up on large output.
112            let stdout_task = tokio::spawn(read_stream(child.stdout.take()));
113            let stderr_task = tokio::spawn(read_stream(child.stderr.take()));
114
115            tokio::select! {
116                result = child.wait() => {
117                    match result {
118                        Ok(status) => {
119                            let stdout = stdout_task.await.unwrap_or_default();
120                            let stderr = stderr_task.await.unwrap_or_default();
121                            format_output(status.code(), &stdout, &stderr)
122                        }
123                        Err(e) => AgentToolResult::error(format!("failed to execute command: {e}")),
124                    }
125                }
126                () = cancellation_token.cancelled() => {
127                    let _ = child.kill().await;
128                    stdout_task.abort();
129                    stderr_task.abort();
130                    AgentToolResult::error("cancelled")
131                }
132                () = tokio::time::sleep(timeout) => {
133                    let _ = child.kill().await;
134                    stdout_task.abort();
135                    stderr_task.abort();
136                    AgentToolResult::error(format!(
137                        "failed to complete command: timed out after {}ms",
138                        timeout.as_millis()
139                    ))
140                }
141            }
142        })
143    }
144}
145
146/// Build a platform-appropriate shell `Command` that executes `command`.
147///
148/// Unix: `sh -c <command>`. Windows: `cmd /C <command>`.
149fn shell_command(command: &str) -> Command {
150    #[cfg(windows)]
151    {
152        let mut cmd = Command::new("cmd");
153        cmd.arg("/C").arg(command);
154        cmd
155    }
156    #[cfg(not(windows))]
157    {
158        let mut cmd = Command::new("sh");
159        cmd.arg("-c").arg(command);
160        cmd
161    }
162}
163
164async fn read_stream<R: tokio::io::AsyncRead + Unpin + Send + 'static>(pipe: Option<R>) -> Vec<u8> {
165    use tokio::io::AsyncReadExt;
166    if let Some(mut p) = pipe {
167        let mut buf = Vec::new();
168        let _ = p.read_to_end(&mut buf).await;
169        buf
170    } else {
171        Vec::new()
172    }
173}
174
175fn format_output(exit_code: Option<i32>, stdout: &[u8], stderr: &[u8]) -> AgentToolResult {
176    let code_str = exit_code.map_or_else(|| "unknown".to_owned(), |c| c.to_string());
177
178    let mut stdout_text = String::from_utf8_lossy(stdout).into_owned();
179    let mut stderr_text = String::from_utf8_lossy(stderr).into_owned();
180
181    let combined_len = stdout_text.len() + stderr_text.len();
182    if combined_len > MAX_OUTPUT_BYTES {
183        // Truncate proportionally, favouring stdout.
184        let stdout_budget = MAX_OUTPUT_BYTES * stdout_text.len() / combined_len.max(1);
185        let stderr_budget = MAX_OUTPUT_BYTES.saturating_sub(stdout_budget);
186
187        if stdout_text.len() > stdout_budget {
188            stdout_text.truncate(stdout_budget);
189            stdout_text.push_str("\n[truncated]");
190        }
191        if stderr_text.len() > stderr_budget {
192            stderr_text.truncate(stderr_budget);
193            stderr_text.push_str("\n[truncated]");
194        }
195    }
196
197    let mut text = format!("Exit code: {code_str}");
198
199    if !stdout_text.is_empty() {
200        use std::fmt::Write;
201        let _ = write!(text, "\n\nStdout:\n{stdout_text}");
202    }
203    if !stderr_text.is_empty() {
204        use std::fmt::Write;
205        let _ = write!(text, "\n\nStderr:\n{stderr_text}");
206    }
207
208    AgentToolResult::text(text)
209}