Skip to main content

statespace_tool_runtime/
executor.rs

1//! Tool execution with sandboxing and resource limits.
2
3use crate::error::Error;
4use crate::security::{is_private_or_restricted_ip, validate_url_initial};
5use crate::tools::{BuiltinTool, HttpMethod};
6use std::path::PathBuf;
7use std::time::Duration;
8use tokio::process::Command;
9use tokio::time::timeout;
10use tracing::{info, instrument, warn};
11
12#[derive(Debug, Clone)]
13pub struct ExecutionLimits {
14    pub max_output_bytes: usize,
15    pub max_list_items: usize,
16    pub timeout: Duration,
17}
18
19impl Default for ExecutionLimits {
20    fn default() -> Self {
21        Self {
22            max_output_bytes: 1024 * 1024, // 1MB
23            max_list_items: 1000,
24            timeout: Duration::from_secs(30),
25        }
26    }
27}
28
29#[derive(Debug, Clone)]
30#[non_exhaustive]
31pub enum ToolOutput {
32    Text(String),
33    FileList(Vec<FileInfo>),
34}
35
36impl ToolOutput {
37    #[must_use]
38    pub fn to_text(&self) -> String {
39        match self {
40            Self::Text(s) => s.clone(),
41            Self::FileList(files) => files
42                .iter()
43                .map(|f| f.key.as_str())
44                .collect::<Vec<_>>()
45                .join("\n"),
46        }
47    }
48}
49
50#[derive(Debug, Clone)]
51pub struct FileInfo {
52    pub key: String,
53    pub size: u64,
54    pub last_modified: chrono::DateTime<chrono::Utc>,
55}
56
57#[derive(Debug)]
58pub struct ToolExecutor {
59    root: PathBuf,
60    limits: ExecutionLimits,
61}
62
63impl ToolExecutor {
64    #[must_use]
65    pub const fn new(root: PathBuf, limits: ExecutionLimits) -> Self {
66        Self { root, limits }
67    }
68
69    /// # Errors
70    ///
71    /// Returns errors for timeouts, invalid commands, or execution failures.
72    #[instrument(skip(self), fields(tool = ?tool))]
73    pub async fn execute(&self, tool: &BuiltinTool) -> Result<ToolOutput, Error> {
74        let execution = async {
75            match tool {
76                BuiltinTool::Glob { pattern } => self.execute_glob(pattern),
77                BuiltinTool::Curl { url, method } => self.execute_curl(url, *method).await,
78                BuiltinTool::Exec { command, args } => self.execute_exec(command, args).await,
79            }
80        };
81
82        timeout(self.limits.timeout, execution)
83            .await
84            .map_err(|_err| Error::Timeout)?
85    }
86
87    async fn execute_exec(&self, command: &str, args: &[String]) -> Result<ToolOutput, Error> {
88        info!("Executing: {} {:?}", command, args);
89
90        for arg in args {
91            if arg.starts_with('/') {
92                return Err(Error::Security(format!(
93                    "Absolute paths not allowed in command arguments: {arg}"
94                )));
95            }
96            if arg.contains("..") {
97                return Err(Error::Security(format!(
98                    "Path traversal not allowed in command arguments: {arg}"
99                )));
100            }
101        }
102
103        let output = Command::new(command)
104            .args(args)
105            .current_dir(&self.root)
106            .env_clear()
107            .env("PATH", "/usr/local/bin:/usr/bin:/bin")
108            .env("HOME", "/tmp")
109            .env("LANG", "C.UTF-8")
110            .env("LC_ALL", "C.UTF-8")
111            .output()
112            .await
113            .map_err(|e| Error::Internal(format!("Failed to execute {command}: {e}")))?;
114
115        let mut result = String::from_utf8_lossy(&output.stdout).into_owned();
116        if !output.stderr.is_empty() {
117            let stderr = String::from_utf8_lossy(&output.stderr);
118            if !result.is_empty() {
119                result.push('\n');
120            }
121            result.push_str(&stderr);
122        }
123
124        if result.len() > self.limits.max_output_bytes {
125            return Err(Error::OutputTooLarge {
126                size: result.len(),
127                limit: self.limits.max_output_bytes,
128            });
129        }
130
131        Ok(ToolOutput::Text(result))
132    }
133
134    fn execute_glob(&self, pattern: &str) -> Result<ToolOutput, Error> {
135        let full_pattern = self.safe_join(pattern)?;
136        info!("Executing glob: {:?}", full_pattern);
137
138        let pattern_str = full_pattern
139            .to_str()
140            .ok_or_else(|| Error::Internal("Invalid UTF-8 path".to_string()))?;
141
142        let paths = glob::glob(pattern_str)
143            .map_err(|e| Error::InvalidCommand(format!("Invalid glob: {e}")))?;
144
145        let mut files = Vec::new();
146        for entry in paths {
147            match entry {
148                Ok(path) => {
149                    let relative = path
150                        .strip_prefix(&self.root)
151                        .unwrap_or(&path)
152                        .to_string_lossy()
153                        .into_owned();
154
155                    let metadata = std::fs::metadata(&path).ok();
156                    let size = metadata.as_ref().map_or(0, std::fs::Metadata::len);
157                    let last_modified = metadata
158                        .and_then(|m| m.modified().ok())
159                        .map_or_else(chrono::Utc::now, chrono::DateTime::<chrono::Utc>::from);
160
161                    files.push(FileInfo {
162                        key: relative,
163                        size,
164                        last_modified,
165                    });
166                }
167                Err(e) => warn!("Glob error: {}", e),
168            }
169        }
170
171        if files.len() > self.limits.max_list_items {
172            files.truncate(self.limits.max_list_items);
173        }
174
175        Ok(ToolOutput::FileList(files))
176    }
177
178    async fn execute_curl(&self, url: &str, method: HttpMethod) -> Result<ToolOutput, Error> {
179        let parsed = validate_url_initial(url)?;
180        let host = parsed
181            .host_str()
182            .ok_or_else(|| Error::InvalidCommand("URL has no host".to_string()))?;
183        let port = parsed
184            .port_or_known_default()
185            .ok_or_else(|| Error::InvalidCommand("Could not determine port".to_string()))?;
186
187        info!("Executing curl: {} {}", method, host);
188
189        let addr_str = format!("{host}:{port}");
190        let addrs = tokio::net::lookup_host(&addr_str)
191            .await
192            .map_err(|e| Error::Network(format!("DNS resolution failed: {e}")))?;
193
194        for addr in addrs {
195            if is_private_or_restricted_ip(&addr.ip()) {
196                return Err(Error::Security(format!(
197                    "Access to private IP blocked: {}",
198                    addr.ip()
199                )));
200            }
201        }
202
203        let client = reqwest::Client::builder()
204            .timeout(self.limits.timeout)
205            .user_agent("Statespace/1.0")
206            .redirect(reqwest::redirect::Policy::none())
207            .build()
208            .map_err(|e| Error::Network(format!("Client error: {e}")))?;
209
210        let http_method = reqwest::Method::from_bytes(method.as_str().as_bytes())
211            .map_err(|_e| Error::InvalidCommand(format!("Invalid HTTP method: {method}")))?;
212
213        let response = client
214            .request(http_method, parsed.as_str())
215            .send()
216            .await
217            .map_err(|e| Error::Network(format!("Request failed: {e}")))?;
218
219        let text = response
220            .text()
221            .await
222            .map_err(|e| Error::Network(format!("Read failed: {e}")))?;
223
224        if text.len() > self.limits.max_output_bytes {
225            return Err(Error::OutputTooLarge {
226                size: text.len(),
227                limit: self.limits.max_output_bytes,
228            });
229        }
230
231        Ok(ToolOutput::Text(text))
232    }
233
234    fn safe_join(&self, path: &str) -> Result<PathBuf, Error> {
235        let path = path.trim_start_matches('/');
236        if path.contains("..") {
237            return Err(Error::PathTraversal {
238                attempted: path.to_string(),
239                boundary: self.root.to_string_lossy().to_string(),
240            });
241        }
242        Ok(self.root.join(path))
243    }
244
245    #[must_use]
246    pub const fn limits(&self) -> &ExecutionLimits {
247        &self.limits
248    }
249
250    #[must_use]
251    pub fn root(&self) -> &PathBuf {
252        &self.root
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    fn test_executor() -> ToolExecutor {
261        ToolExecutor::new(PathBuf::from("/tmp/test-mount"), ExecutionLimits::default())
262    }
263
264    #[tokio::test]
265    async fn exec_rejects_absolute_paths() {
266        let executor = test_executor();
267        let tool = BuiltinTool::Exec {
268            command: "grep".to_string(),
269            args: vec!["pattern".to_string(), "/etc/passwd".to_string()],
270        };
271
272        let result = executor.execute(&tool).await;
273        assert!(matches!(result, Err(Error::Security(_))));
274    }
275
276    #[tokio::test]
277    async fn exec_rejects_path_traversal() {
278        let executor = test_executor();
279        let tool = BuiltinTool::Exec {
280            command: "cat".to_string(),
281            args: vec!["../../../etc/passwd".to_string()],
282        };
283
284        let result = executor.execute(&tool).await;
285        assert!(matches!(result, Err(Error::Security(_))));
286    }
287
288    #[tokio::test]
289    async fn exec_allows_relative_paths() {
290        let executor = test_executor();
291        let tool = BuiltinTool::Exec {
292            command: "ls".to_string(),
293            args: vec!["-la".to_string(), "subdir/file.txt".to_string()],
294        };
295
296        let result = executor.execute(&tool).await;
297        assert!(!matches!(result, Err(Error::Security(_))));
298    }
299}