Skip to main content

statespace_tool_runtime/
executor.rs

1//! Tool execution with sandboxing and resource limits.
2
3use crate::error::Error;
4use crate::sandbox::SandboxEnv;
5use crate::security::{is_private_or_restricted_ip, validate_url_initial};
6use crate::tools::{BuiltinTool, HttpMethod};
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::time::Duration;
10use tokio::process::Command;
11use tokio::time::timeout;
12use tracing::{debug, info, instrument};
13
14#[derive(Debug, Clone)]
15pub struct ExecutionLimits {
16    pub max_output_bytes: usize,
17    pub timeout: Duration,
18}
19
20impl Default for ExecutionLimits {
21    fn default() -> Self {
22        Self {
23            max_output_bytes: 1024 * 1024,
24            timeout: Duration::from_secs(30),
25        }
26    }
27}
28
29#[derive(Debug, Clone)]
30#[non_exhaustive]
31pub enum ToolOutput {
32    Text(String),
33    Process {
34        stdout: String,
35        stderr: String,
36        exit_code: i32,
37    },
38}
39
40impl ToolOutput {
41    #[must_use]
42    pub fn to_text(&self) -> String {
43        match self {
44            Self::Text(s) => s.clone(),
45            Self::Process { stdout, stderr, .. } => {
46                let mut out = stdout.clone();
47                if !stderr.is_empty() {
48                    if !out.is_empty() {
49                        out.push('\n');
50                    }
51                    out.push_str(stderr);
52                }
53                out
54            }
55        }
56    }
57
58    #[must_use]
59    pub fn stdout(&self) -> String {
60        match self {
61            Self::Process { stdout, .. } => stdout.clone(),
62            Self::Text(s) => s.clone(),
63        }
64    }
65
66    #[must_use]
67    pub fn stderr(&self) -> &str {
68        match self {
69            Self::Process { stderr, .. } => stderr,
70            _ => "",
71        }
72    }
73
74    #[must_use]
75    pub fn exit_code(&self) -> i32 {
76        match self {
77            Self::Process { exit_code, .. } => *exit_code,
78            _ => 0,
79        }
80    }
81}
82
83#[derive(Debug)]
84pub struct ToolExecutor {
85    root: PathBuf,
86    limits: ExecutionLimits,
87    sandbox_env: SandboxEnv,
88    user_env: HashMap<String, String>,
89}
90
91impl ToolExecutor {
92    #[must_use]
93    pub fn new(root: PathBuf, limits: ExecutionLimits) -> Self {
94        Self {
95            root,
96            limits,
97            sandbox_env: SandboxEnv::default(),
98            user_env: HashMap::new(),
99        }
100    }
101
102    #[must_use]
103    pub fn with_sandbox_env(mut self, sandbox_env: SandboxEnv) -> Self {
104        self.sandbox_env = sandbox_env;
105        self
106    }
107
108    #[must_use]
109    pub fn with_user_env(mut self, env: HashMap<String, String>) -> Self {
110        self.user_env = env;
111        self
112    }
113
114    /// # Errors
115    ///
116    /// Returns errors for timeouts, invalid commands, or execution failures.
117    #[instrument(skip(self), fields(tool = ?tool))]
118    pub async fn execute(&self, tool: &BuiltinTool) -> Result<ToolOutput, Error> {
119        let execution = async {
120            match tool {
121                BuiltinTool::Curl { url, method } => self.execute_curl(url, *method).await,
122                BuiltinTool::Exec { command, args } => self.execute_exec(command, args).await,
123            }
124        };
125
126        timeout(self.limits.timeout, execution)
127            .await
128            .map_err(|_err| Error::Timeout)?
129    }
130
131    async fn execute_exec(&self, command: &str, args: &[String]) -> Result<ToolOutput, Error> {
132        debug!("Executing: {}", command);
133
134        let output = Command::new(command)
135            .args(args)
136            .current_dir(&self.root)
137            .env_clear()
138            .envs(&self.user_env)
139            // Restore the minimum env a CLI tool needs after env_clear():
140            // PATH so subprocesses can find binaries, HOME so tools can
141            // locate config files, LANG/LC_ALL for deterministic UTF-8 output.
142            .env("PATH", self.sandbox_env.path())
143            .env("HOME", self.sandbox_env.home())
144            .env("LANG", self.sandbox_env.lang())
145            .env("LC_ALL", self.sandbox_env.lc_all())
146            .output()
147            .await
148            .map_err(|e| {
149                if e.kind() == std::io::ErrorKind::NotFound {
150                    return Error::InvalidCommand(format!(
151                        "Command '{command}' not found in PATH: {}",
152                        self.sandbox_env.path()
153                    ));
154                }
155
156                if e.kind() == std::io::ErrorKind::PermissionDenied {
157                    return Error::InvalidCommand(format!(
158                        "Command '{command}' not executable in PATH: {}",
159                        self.sandbox_env.path()
160                    ));
161                }
162
163                Error::Internal(format!("Failed to execute {command}: {e}"))
164            })?;
165
166        let stdout = String::from_utf8_lossy(&output.stdout).into_owned();
167        let stderr = String::from_utf8_lossy(&output.stderr).into_owned();
168        #[cfg(unix)]
169        let exit_code = output.status.code().unwrap_or_else(|| {
170            use std::os::unix::process::ExitStatusExt;
171            output.status.signal().map_or(1, |s| 128 + s)
172        });
173        #[cfg(not(unix))]
174        let exit_code = output.status.code().unwrap_or(1);
175
176        if stdout.len() + stderr.len() > self.limits.max_output_bytes {
177            return Err(Error::OutputTooLarge {
178                size: stdout.len() + stderr.len(),
179                limit: self.limits.max_output_bytes,
180            });
181        }
182
183        Ok(ToolOutput::Process {
184            stdout,
185            stderr,
186            exit_code,
187        })
188    }
189
190    async fn execute_curl(&self, url: &str, method: HttpMethod) -> Result<ToolOutput, Error> {
191        let parsed = validate_url_initial(url)?;
192        let host = parsed
193            .host_str()
194            .ok_or_else(|| Error::InvalidCommand("URL has no host".to_string()))?;
195        let port = parsed
196            .port_or_known_default()
197            .ok_or_else(|| Error::InvalidCommand("Could not determine port".to_string()))?;
198
199        info!("Executing curl: {} {}", method, host);
200
201        let addr_str = format!("{host}:{port}");
202        let addrs = tokio::net::lookup_host(&addr_str)
203            .await
204            .map_err(|e| Error::Network(format!("DNS resolution failed: {e}")))?;
205
206        for addr in addrs {
207            if is_private_or_restricted_ip(&addr.ip()) {
208                return Err(Error::Security(format!(
209                    "Access to private IP blocked: {}",
210                    addr.ip()
211                )));
212            }
213        }
214
215        let client = reqwest::Client::builder()
216            .timeout(self.limits.timeout)
217            .user_agent("Statespace/1.0")
218            .redirect(reqwest::redirect::Policy::none())
219            .build()
220            .map_err(|e| Error::Network(format!("Client error: {e}")))?;
221
222        let http_method = reqwest::Method::from_bytes(method.as_str().as_bytes())
223            .map_err(|_e| Error::InvalidCommand(format!("Invalid HTTP method: {method}")))?;
224
225        let response = client
226            .request(http_method, parsed.as_str())
227            .send()
228            .await
229            .map_err(|e| Error::Network(format!("Request failed: {e}")))?;
230
231        let text = response
232            .text()
233            .await
234            .map_err(|e| Error::Network(format!("Read failed: {e}")))?;
235
236        if text.len() > self.limits.max_output_bytes {
237            return Err(Error::OutputTooLarge {
238                size: text.len(),
239                limit: self.limits.max_output_bytes,
240            });
241        }
242
243        Ok(ToolOutput::Text(text))
244    }
245
246    #[must_use]
247    pub const fn limits(&self) -> &ExecutionLimits {
248        &self.limits
249    }
250
251    #[must_use]
252    pub fn root(&self) -> &PathBuf {
253        &self.root
254    }
255}
256
257#[cfg(test)]
258#[allow(clippy::expect_used)]
259mod tests {
260    use super::*;
261    use crate::sandbox::SandboxEnv;
262
263    #[tokio::test]
264    async fn missing_binary_returns_clear_invalid_command_error() {
265        let dir = tempfile::tempdir().expect("tempdir");
266        let executor = ToolExecutor::new(dir.path().to_path_buf(), ExecutionLimits::default())
267            .with_sandbox_env(SandboxEnv::default());
268        let tool = BuiltinTool::Exec {
269            command: "definitely-not-a-real-binary".to_string(),
270            args: vec![],
271        };
272
273        let result = executor.execute(&tool).await;
274        assert!(matches!(
275            result,
276            Err(Error::InvalidCommand(message))
277                if message.contains("not found in PATH")
278        ));
279    }
280}