Skip to main content

statespace_tool_runtime/
executor.rs

1//! Tool execution with sandboxing and resource limits.
2
3use crate::error::Error;
4use crate::sandbox::SandboxEnv;
5use crate::security::{is_private_or_restricted_ip, validate_url_initial};
6use crate::tools::{BuiltinTool, HttpMethod};
7use std::path::PathBuf;
8use std::time::Duration;
9use tokio::process::Command;
10use tokio::time::timeout;
11use tracing::{info, instrument, warn};
12
13#[derive(Debug, Clone)]
14pub struct ExecutionLimits {
15    pub max_output_bytes: usize,
16    pub max_list_items: usize,
17    pub timeout: Duration,
18}
19
20impl Default for ExecutionLimits {
21    fn default() -> Self {
22        Self {
23            max_output_bytes: 1024 * 1024, // 1MB
24            max_list_items: 1000,
25            timeout: Duration::from_secs(30),
26        }
27    }
28}
29
30#[derive(Debug, Clone)]
31#[non_exhaustive]
32pub enum ToolOutput {
33    Text(String),
34    FileList(Vec<FileInfo>),
35}
36
37impl ToolOutput {
38    #[must_use]
39    pub fn to_text(&self) -> String {
40        match self {
41            Self::Text(s) => s.clone(),
42            Self::FileList(files) => files
43                .iter()
44                .map(|f| f.key.as_str())
45                .collect::<Vec<_>>()
46                .join("\n"),
47        }
48    }
49}
50
51#[derive(Debug, Clone)]
52pub struct FileInfo {
53    pub key: String,
54    pub size: u64,
55    pub last_modified: chrono::DateTime<chrono::Utc>,
56}
57
58#[derive(Debug)]
59pub struct ToolExecutor {
60    root: PathBuf,
61    limits: ExecutionLimits,
62    sandbox_env: SandboxEnv,
63}
64
65impl ToolExecutor {
66    #[must_use]
67    pub fn new(root: PathBuf, limits: ExecutionLimits) -> Self {
68        Self {
69            root,
70            limits,
71            sandbox_env: SandboxEnv::default(),
72        }
73    }
74
75    #[must_use]
76    pub fn with_sandbox_env(mut self, sandbox_env: SandboxEnv) -> Self {
77        self.sandbox_env = sandbox_env;
78        self
79    }
80
81    /// # Errors
82    ///
83    /// Returns errors for timeouts, invalid commands, or execution failures.
84    #[instrument(skip(self), fields(tool = ?tool))]
85    pub async fn execute(&self, tool: &BuiltinTool) -> Result<ToolOutput, Error> {
86        let execution = async {
87            match tool {
88                BuiltinTool::Glob { pattern } => self.execute_glob(pattern),
89                BuiltinTool::Curl { url, method } => self.execute_curl(url, *method).await,
90                BuiltinTool::Exec { command, args } => self.execute_exec(command, args).await,
91            }
92        };
93
94        timeout(self.limits.timeout, execution)
95            .await
96            .map_err(|_err| Error::Timeout)?
97    }
98
99    async fn execute_exec(&self, command: &str, args: &[String]) -> Result<ToolOutput, Error> {
100        info!("Executing: {} {:?}", command, args);
101
102        for arg in args {
103            if arg.starts_with('/') {
104                return Err(Error::Security(format!(
105                    "Absolute paths not allowed in command arguments: {arg}"
106                )));
107            }
108            if arg.contains("..") {
109                return Err(Error::Security(format!(
110                    "Path traversal not allowed in command arguments: {arg}"
111                )));
112            }
113        }
114
115        let output = Command::new(command)
116            .args(args)
117            .current_dir(&self.root)
118            .env_clear()
119            .env("PATH", self.sandbox_env.path())
120            .env("HOME", self.sandbox_env.home())
121            .env("LANG", self.sandbox_env.lang())
122            .env("LC_ALL", self.sandbox_env.lc_all())
123            .output()
124            .await
125            .map_err(|e| {
126                if e.kind() == std::io::ErrorKind::NotFound {
127                    return Error::InvalidCommand(format!(
128                        "Command '{command}' not found in PATH: {}",
129                        self.sandbox_env.path()
130                    ));
131                }
132
133                if e.kind() == std::io::ErrorKind::PermissionDenied {
134                    return Error::InvalidCommand(format!(
135                        "Command '{command}' not executable in PATH: {}",
136                        self.sandbox_env.path()
137                    ));
138                }
139
140                Error::Internal(format!("Failed to execute {command}: {e}"))
141            })?;
142
143        let mut result = String::from_utf8_lossy(&output.stdout).into_owned();
144        if !output.stderr.is_empty() {
145            let stderr = String::from_utf8_lossy(&output.stderr);
146            if !result.is_empty() {
147                result.push('\n');
148            }
149            result.push_str(&stderr);
150        }
151
152        if result.len() > self.limits.max_output_bytes {
153            return Err(Error::OutputTooLarge {
154                size: result.len(),
155                limit: self.limits.max_output_bytes,
156            });
157        }
158
159        Ok(ToolOutput::Text(result))
160    }
161
162    fn execute_glob(&self, pattern: &str) -> Result<ToolOutput, Error> {
163        let full_pattern = self.safe_join(pattern)?;
164        info!("Executing glob: {:?}", full_pattern);
165
166        let pattern_str = full_pattern
167            .to_str()
168            .ok_or_else(|| Error::Internal("Invalid UTF-8 path".to_string()))?;
169
170        let paths = glob::glob(pattern_str)
171            .map_err(|e| Error::InvalidCommand(format!("Invalid glob: {e}")))?;
172
173        let mut files = Vec::new();
174        for entry in paths {
175            match entry {
176                Ok(path) => {
177                    let relative = path
178                        .strip_prefix(&self.root)
179                        .unwrap_or(&path)
180                        .to_string_lossy()
181                        .into_owned();
182
183                    let metadata = std::fs::metadata(&path).ok();
184                    let size = metadata.as_ref().map_or(0, std::fs::Metadata::len);
185                    let last_modified = metadata
186                        .and_then(|m| m.modified().ok())
187                        .map_or_else(chrono::Utc::now, chrono::DateTime::<chrono::Utc>::from);
188
189                    files.push(FileInfo {
190                        key: relative,
191                        size,
192                        last_modified,
193                    });
194                }
195                Err(e) => warn!("Glob error: {}", e),
196            }
197        }
198
199        if files.len() > self.limits.max_list_items {
200            files.truncate(self.limits.max_list_items);
201        }
202
203        Ok(ToolOutput::FileList(files))
204    }
205
206    async fn execute_curl(&self, url: &str, method: HttpMethod) -> Result<ToolOutput, Error> {
207        let parsed = validate_url_initial(url)?;
208        let host = parsed
209            .host_str()
210            .ok_or_else(|| Error::InvalidCommand("URL has no host".to_string()))?;
211        let port = parsed
212            .port_or_known_default()
213            .ok_or_else(|| Error::InvalidCommand("Could not determine port".to_string()))?;
214
215        info!("Executing curl: {} {}", method, host);
216
217        let addr_str = format!("{host}:{port}");
218        let addrs = tokio::net::lookup_host(&addr_str)
219            .await
220            .map_err(|e| Error::Network(format!("DNS resolution failed: {e}")))?;
221
222        for addr in addrs {
223            if is_private_or_restricted_ip(&addr.ip()) {
224                return Err(Error::Security(format!(
225                    "Access to private IP blocked: {}",
226                    addr.ip()
227                )));
228            }
229        }
230
231        let client = reqwest::Client::builder()
232            .timeout(self.limits.timeout)
233            .user_agent("Statespace/1.0")
234            .redirect(reqwest::redirect::Policy::none())
235            .build()
236            .map_err(|e| Error::Network(format!("Client error: {e}")))?;
237
238        let http_method = reqwest::Method::from_bytes(method.as_str().as_bytes())
239            .map_err(|_e| Error::InvalidCommand(format!("Invalid HTTP method: {method}")))?;
240
241        let response = client
242            .request(http_method, parsed.as_str())
243            .send()
244            .await
245            .map_err(|e| Error::Network(format!("Request failed: {e}")))?;
246
247        let text = response
248            .text()
249            .await
250            .map_err(|e| Error::Network(format!("Read failed: {e}")))?;
251
252        if text.len() > self.limits.max_output_bytes {
253            return Err(Error::OutputTooLarge {
254                size: text.len(),
255                limit: self.limits.max_output_bytes,
256            });
257        }
258
259        Ok(ToolOutput::Text(text))
260    }
261
262    fn safe_join(&self, path: &str) -> Result<PathBuf, Error> {
263        let path = path.trim_start_matches('/');
264        if path.contains("..") {
265            return Err(Error::PathTraversal {
266                attempted: path.to_string(),
267                boundary: self.root.to_string_lossy().to_string(),
268            });
269        }
270        Ok(self.root.join(path))
271    }
272
273    #[must_use]
274    pub const fn limits(&self) -> &ExecutionLimits {
275        &self.limits
276    }
277
278    #[must_use]
279    pub fn root(&self) -> &PathBuf {
280        &self.root
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use crate::sandbox::SandboxEnv;
288
289    fn test_executor() -> ToolExecutor {
290        ToolExecutor::new(PathBuf::from("/tmp/test-mount"), ExecutionLimits::default())
291    }
292
293    #[tokio::test]
294    async fn exec_rejects_absolute_paths() {
295        let executor = test_executor();
296        let tool = BuiltinTool::Exec {
297            command: "grep".to_string(),
298            args: vec!["pattern".to_string(), "/etc/passwd".to_string()],
299        };
300
301        let result = executor.execute(&tool).await;
302        assert!(matches!(result, Err(Error::Security(_))));
303    }
304
305    #[tokio::test]
306    async fn exec_rejects_path_traversal() {
307        let executor = test_executor();
308        let tool = BuiltinTool::Exec {
309            command: "cat".to_string(),
310            args: vec!["../../../etc/passwd".to_string()],
311        };
312
313        let result = executor.execute(&tool).await;
314        assert!(matches!(result, Err(Error::Security(_))));
315    }
316
317    #[tokio::test]
318    async fn exec_allows_relative_paths() {
319        let executor = test_executor();
320        let tool = BuiltinTool::Exec {
321            command: "ls".to_string(),
322            args: vec!["-la".to_string(), "subdir/file.txt".to_string()],
323        };
324
325        let result = executor.execute(&tool).await;
326        assert!(!matches!(result, Err(Error::Security(_))));
327    }
328
329    #[tokio::test]
330    async fn missing_binary_returns_clear_invalid_command_error() {
331        let executor =
332            ToolExecutor::new(PathBuf::from("/tmp/test-mount"), ExecutionLimits::default())
333                .with_sandbox_env(SandboxEnv::default());
334        let tool = BuiltinTool::Exec {
335            command: "definitely-not-a-real-binary".to_string(),
336            args: vec![],
337        };
338
339        let result = executor.execute(&tool).await;
340        assert!(matches!(
341            result,
342            Err(Error::InvalidCommand(message))
343                if message.contains("not found in PATH")
344        ));
345    }
346}