Skip to main content

tirith_core/
runner.rs

1/// Safe runner — Unix only.
2/// Downloads a script, analyzes it, optionally executes it with user confirmation.
3use std::fs;
4use std::io::{self, BufRead, Write};
5use std::process::Command;
6
7use sha2::{Digest, Sha256};
8
9use crate::receipt::Receipt;
10use crate::script_analysis;
11
12pub struct RunResult {
13    pub receipt: Receipt,
14    pub executed: bool,
15    pub exit_code: Option<i32>,
16}
17
18pub struct RunOptions {
19    pub url: String,
20    pub no_exec: bool,
21    pub interactive: bool,
22    pub expected_sha256: Option<String>,
23}
24
25/// Interpreters matched by exact name only.
26const ALLOWED_EXACT: &[&str] = &[
27    "sh", "bash", "zsh", "dash", "ksh", "fish", "deno", "bun", "nodejs",
28];
29
30/// Interpreter families that may have version suffixes (python3, python3.11, ruby3.2, node18, perl5.38).
31/// Matches: exact name OR name + digits[.digits]* suffix.
32const ALLOWED_FAMILIES: &[&str] = &["python", "ruby", "perl", "node"];
33
34fn is_allowed_interpreter(interpreter: &str) -> bool {
35    let base = interpreter.rsplit('/').next().unwrap_or(interpreter);
36
37    if ALLOWED_EXACT.contains(&base) {
38        return true;
39    }
40
41    for &family in ALLOWED_FAMILIES {
42        if base == family {
43            return true;
44        }
45        if let Some(suffix) = base.strip_prefix(family) {
46            if is_valid_version_suffix(suffix) {
47                return true;
48            }
49        }
50    }
51
52    false
53}
54
55/// Check if a suffix is a valid version string: digits (.digits)*
56/// Valid: "3", "3.11", "3.2.1"
57/// Invalid: "", ".3", "3.", "3..11", "evil"
58fn is_valid_version_suffix(s: &str) -> bool {
59    if s.is_empty() {
60        return false;
61    }
62    s.split('.')
63        .all(|part| !part.is_empty() && part.chars().all(|c| c.is_ascii_digit()))
64}
65
66pub fn run(opts: RunOptions) -> Result<RunResult, String> {
67    if !opts.no_exec && !opts.interactive {
68        return Err("tirith run requires an interactive terminal or --no-exec flag".to_string());
69    }
70
71    let mut redirects: Vec<String> = Vec::new();
72    let redirect_list = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
73    let redirect_list_clone = redirect_list.clone();
74
75    let client = reqwest::blocking::Client::builder()
76        .redirect(reqwest::redirect::Policy::custom(move |attempt| {
77            if let Ok(mut list) = redirect_list_clone.lock() {
78                list.push(attempt.url().to_string());
79            }
80            if attempt.previous().len() >= 10 {
81                attempt.stop()
82            } else {
83                attempt.follow()
84            }
85        }))
86        .timeout(std::time::Duration::from_secs(30))
87        .build()
88        .map_err(|e| format!("http client: {e}"))?;
89
90    let response = client
91        .get(&opts.url)
92        .send()
93        .map_err(|e| format!("download failed: {e}"))?;
94
95    let final_url = response.url().to_string();
96    if let Ok(list) = redirect_list.lock() {
97        redirects = list.clone();
98    }
99
100    const MAX_BODY: u64 = 10 * 1024 * 1024; // 10 MiB
101
102    // Fast-reject via Content-Length before we pay to read the body.
103    if let Some(len) = response.content_length() {
104        if len > MAX_BODY {
105            return Err(format!(
106                "response too large: {len} bytes (max {} MiB)",
107                MAX_BODY / 1024 / 1024
108            ));
109        }
110    }
111
112    use std::io::Read;
113    let mut buf = Vec::new();
114    response
115        .take(MAX_BODY + 1)
116        .read_to_end(&mut buf)
117        .map_err(|e| format!("read body: {e}"))?;
118    if buf.len() as u64 > MAX_BODY {
119        return Err(format!(
120            "response body exceeds {} MiB limit",
121            MAX_BODY / 1024 / 1024
122        ));
123    }
124    let content = buf;
125
126    let mut hasher = Sha256::new();
127    hasher.update(&content);
128    let sha256 = format!("{:x}", hasher.finalize());
129
130    if let Some(ref expected) = opts.expected_sha256 {
131        let expected_lower = expected.to_lowercase();
132        if sha256 != expected_lower {
133            return Err(format!(
134                "SHA-256 mismatch: expected {expected_lower}, got {sha256}"
135            ));
136        }
137    }
138
139    let cache_dir = crate::policy::data_dir()
140        .ok_or("cannot determine data directory")?
141        .join("cache");
142    fs::create_dir_all(&cache_dir).map_err(|e| format!("create cache: {e}"))?;
143    let cached_path = cache_dir.join(&sha256);
144    {
145        use std::io::Write;
146        use tempfile::NamedTempFile;
147
148        let mut tmp = NamedTempFile::new_in(&cache_dir).map_err(|e| format!("tempfile: {e}"))?;
149        #[cfg(unix)]
150        {
151            use std::os::unix::fs::PermissionsExt;
152            tmp.as_file()
153                .set_permissions(std::fs::Permissions::from_mode(0o600))
154                .map_err(|e| format!("permissions: {e}"))?;
155        }
156        tmp.write_all(&content)
157            .map_err(|e| format!("write cache: {e}"))?;
158        tmp.persist(&cached_path)
159            .map_err(|e| format!("persist cache: {e}"))?;
160    }
161
162    let content_str = match String::from_utf8(content.clone()) {
163        Ok(s) => s,
164        Err(_) => {
165            eprintln!("tirith: warning: downloaded content contains invalid UTF-8, using lossy conversion");
166            String::from_utf8_lossy(&content).into_owned()
167        }
168    };
169
170    let interpreter = script_analysis::detect_interpreter(&content_str);
171    let analysis = script_analysis::analyze(&content_str, interpreter);
172
173    // Interpreter allowlist is only enforced when we might execute. With
174    // --no-exec the user has already committed to inspecting the script.
175    if !opts.no_exec && !is_allowed_interpreter(interpreter) {
176        return Err(format!(
177            "interpreter '{interpreter}' is not in the allowed list",
178        ));
179    }
180
181    let (git_repo, git_branch) = detect_git_info();
182
183    let receipt = Receipt {
184        url: opts.url.clone(),
185        final_url: Some(final_url),
186        redirects,
187        sha256: sha256.clone(),
188        size: content.len() as u64,
189        domains_referenced: analysis.domains_referenced,
190        paths_referenced: analysis.paths_referenced,
191        analysis_method: "static".to_string(),
192        privilege: if analysis.has_sudo {
193            "elevated".to_string()
194        } else {
195            "normal".to_string()
196        },
197        timestamp: chrono::Utc::now().to_rfc3339(),
198        cwd: std::env::current_dir()
199            .ok()
200            .map(|p| p.display().to_string()),
201        git_repo,
202        git_branch,
203    };
204
205    if opts.no_exec {
206        receipt.save().map_err(|e| format!("save receipt: {e}"))?;
207        return Ok(RunResult {
208            receipt,
209            executed: false,
210            exit_code: None,
211        });
212    }
213
214    // Show analysis summary
215    eprintln!(
216        "tirith: downloaded {} bytes (SHA256: {})",
217        content.len(),
218        crate::receipt::short_hash(&sha256)
219    );
220    eprintln!("tirith: interpreter: {interpreter}");
221    if analysis.has_sudo {
222        eprintln!("tirith: WARNING: script uses sudo");
223    }
224    if analysis.has_eval {
225        eprintln!("tirith: WARNING: script uses eval");
226    }
227    if analysis.has_base64 {
228        eprintln!("tirith: WARNING: script uses base64");
229    }
230
231    // Confirm from /dev/tty
232    let tty = fs::OpenOptions::new()
233        .read(true)
234        .write(true)
235        .open("/dev/tty")
236        .map_err(|_| "cannot open /dev/tty for confirmation")?;
237
238    let mut tty_writer = io::BufWriter::new(&tty);
239    write!(tty_writer, "Execute this script? [y/N] ").map_err(|e| format!("tty write: {e}"))?;
240    tty_writer.flush().map_err(|e| format!("tty flush: {e}"))?;
241
242    let mut reader = io::BufReader::new(&tty);
243    let mut response_line = String::new();
244    reader
245        .read_line(&mut response_line)
246        .map_err(|e| format!("tty read: {e}"))?;
247
248    if !response_line.trim().eq_ignore_ascii_case("y") {
249        eprintln!("tirith: execution cancelled");
250        receipt.save().map_err(|e| format!("save receipt: {e}"))?;
251        return Ok(RunResult {
252            receipt,
253            executed: false,
254            exit_code: None,
255        });
256    }
257
258    // Execute
259    receipt.save().map_err(|e| format!("save receipt: {e}"))?;
260
261    let status = Command::new(interpreter)
262        .arg(&cached_path)
263        .status()
264        .map_err(|e| format!("execute: {e}"))?;
265
266    Ok(RunResult {
267        receipt,
268        executed: true,
269        exit_code: status.code(),
270    })
271}
272
273/// Detect git repo remote URL and current branch.
274fn detect_git_info() -> (Option<String>, Option<String>) {
275    let repo = Command::new("git")
276        .args(["remote", "get-url", "origin"])
277        .output()
278        .ok()
279        .filter(|o| o.status.success())
280        .and_then(|o| String::from_utf8(o.stdout).ok())
281        .map(|s| s.trim().to_string());
282
283    let branch = Command::new("git")
284        .args(["rev-parse", "--abbrev-ref", "HEAD"])
285        .output()
286        .ok()
287        .filter(|o| o.status.success())
288        .and_then(|o| String::from_utf8(o.stdout).ok())
289        .map(|s| s.trim().to_string());
290
291    (repo, branch)
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn test_allowed_interpreter_sh() {
300        assert!(is_allowed_interpreter("sh"));
301    }
302
303    #[test]
304    fn test_allowed_interpreter_python3() {
305        assert!(is_allowed_interpreter("python3"));
306    }
307
308    #[test]
309    fn test_allowed_interpreter_python3_11() {
310        assert!(is_allowed_interpreter("python3.11"));
311    }
312
313    #[test]
314    fn test_allowed_interpreter_nodejs() {
315        assert!(is_allowed_interpreter("nodejs"));
316    }
317
318    #[test]
319    fn test_disallowed_interpreter_vim() {
320        assert!(!is_allowed_interpreter("vim"));
321    }
322
323    #[test]
324    fn test_disallowed_interpreter_expect() {
325        assert!(!is_allowed_interpreter("expect"));
326    }
327
328    #[test]
329    fn test_disallowed_interpreter_python_evil() {
330        assert!(!is_allowed_interpreter("python.evil"));
331    }
332
333    #[test]
334    fn test_disallowed_interpreter_node_sass() {
335        assert!(!is_allowed_interpreter("node-sass"));
336    }
337
338    #[test]
339    fn test_disallowed_interpreter_python3_trailing_dot() {
340        assert!(!is_allowed_interpreter("python3."));
341    }
342
343    #[test]
344    fn test_disallowed_interpreter_python3_double_dot() {
345        assert!(!is_allowed_interpreter("python3..11"));
346    }
347
348    #[test]
349    fn test_allowed_interpreter_strips_path() {
350        assert!(is_allowed_interpreter("/usr/bin/bash"));
351    }
352
353    #[cfg(unix)]
354    #[test]
355    fn test_cache_write_permissions_0600() {
356        use std::os::unix::fs::PermissionsExt;
357        use tempfile::NamedTempFile;
358
359        let dir = tempfile::tempdir().unwrap();
360        let cache_path = dir.path().join("test_cache");
361
362        {
363            use std::io::Write;
364
365            let mut tmp = NamedTempFile::new_in(dir.path()).unwrap();
366            tmp.as_file()
367                .set_permissions(std::fs::Permissions::from_mode(0o600))
368                .unwrap();
369            tmp.write_all(b"test content").unwrap();
370            tmp.persist(&cache_path).unwrap();
371        }
372
373        let meta = std::fs::metadata(&cache_path).unwrap();
374        assert_eq!(
375            meta.permissions().mode() & 0o777,
376            0o600,
377            "cache file should be 0600"
378        );
379    }
380
381    #[test]
382    fn test_cache_write_no_predictable_tmp() {
383        use tempfile::NamedTempFile;
384
385        let dir = tempfile::tempdir().unwrap();
386        let sha = "abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789";
387        let cached_path = dir.path().join(sha);
388
389        {
390            use std::io::Write;
391            let mut tmp = NamedTempFile::new_in(dir.path()).unwrap();
392            tmp.write_all(b"cached script").unwrap();
393            tmp.persist(&cached_path).unwrap();
394        }
395
396        // No predictable temp file should remain
397        let entries: Vec<_> = std::fs::read_dir(dir.path())
398            .unwrap()
399            .filter_map(|e| e.ok())
400            .map(|e| e.file_name().to_string_lossy().to_string())
401            .collect();
402
403        // Should only contain the final cached file
404        assert_eq!(
405            entries.len(),
406            1,
407            "only the cached file should exist, found: {entries:?}"
408        );
409        assert!(
410            cached_path.exists(),
411            "cached file should exist after persist"
412        );
413    }
414}