Skip to main content

tirith_core/
runner.rs

1/// Safe runner — Unix only.
2/// Downloads a script, analyzes it, optionally executes it with user confirmation.
3use std::fs;
4use std::io::{self, BufRead, Write};
5use std::process::Command;
6
7use sha2::{Digest, Sha256};
8
9use crate::receipt::Receipt;
10use crate::script_analysis;
11
12pub struct RunResult {
13    pub receipt: Receipt,
14    pub executed: bool,
15    pub exit_code: Option<i32>,
16}
17
18pub struct RunOptions {
19    pub url: String,
20    pub no_exec: bool,
21    pub interactive: bool,
22    pub expected_sha256: Option<String>,
23}
24
25/// Interpreters matched by exact name only.
26const ALLOWED_EXACT: &[&str] = &[
27    "sh", "bash", "zsh", "dash", "ksh", "fish", "deno", "bun", "nodejs",
28];
29
30/// Interpreter families that may have version suffixes (python3, python3.11, ruby3.2, node18, perl5.38).
31/// Matches: exact name OR name + digits[.digits]* suffix.
32const ALLOWED_FAMILIES: &[&str] = &["python", "ruby", "perl", "node"];
33
34fn is_allowed_interpreter(interpreter: &str) -> bool {
35    let base = interpreter.rsplit('/').next().unwrap_or(interpreter);
36
37    if ALLOWED_EXACT.contains(&base) {
38        return true;
39    }
40
41    for &family in ALLOWED_FAMILIES {
42        if base == family {
43            return true;
44        }
45        if let Some(suffix) = base.strip_prefix(family) {
46            if is_valid_version_suffix(suffix) {
47                return true;
48            }
49        }
50    }
51
52    false
53}
54
55/// Check if a suffix is a valid version string: digits (.digits)*
56/// Valid: "3", "3.11", "3.2.1"
57/// Invalid: "", ".3", "3.", "3..11", "evil"
58fn is_valid_version_suffix(s: &str) -> bool {
59    if s.is_empty() {
60        return false;
61    }
62    s.split('.')
63        .all(|part| !part.is_empty() && part.chars().all(|c| c.is_ascii_digit()))
64}
65
66pub fn run(opts: RunOptions) -> Result<RunResult, String> {
67    // Check TTY requirement
68    if !opts.no_exec && !opts.interactive {
69        return Err("tirith run requires an interactive terminal or --no-exec flag".to_string());
70    }
71
72    // Download with redirect chain collection
73    let mut redirects: Vec<String> = Vec::new();
74    let redirect_list = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
75    let redirect_list_clone = redirect_list.clone();
76
77    let client = reqwest::blocking::Client::builder()
78        .redirect(reqwest::redirect::Policy::custom(move |attempt| {
79            if let Ok(mut list) = redirect_list_clone.lock() {
80                list.push(attempt.url().to_string());
81            }
82            if attempt.previous().len() >= 10 {
83                attempt.stop()
84            } else {
85                attempt.follow()
86            }
87        }))
88        .timeout(std::time::Duration::from_secs(30))
89        .build()
90        .map_err(|e| format!("http client: {e}"))?;
91
92    let response = client
93        .get(&opts.url)
94        .send()
95        .map_err(|e| format!("download failed: {e}"))?;
96
97    let final_url = response.url().to_string();
98    if let Ok(list) = redirect_list.lock() {
99        redirects = list.clone();
100    }
101
102    const MAX_BODY: u64 = 10 * 1024 * 1024; // 10 MiB
103
104    // Check Content-Length hint first (fast rejection)
105    if let Some(len) = response.content_length() {
106        if len > MAX_BODY {
107            return Err(format!(
108                "response too large: {len} bytes (max {} MiB)",
109                MAX_BODY / 1024 / 1024
110            ));
111        }
112    }
113
114    // Read with cap using std::io::Read::take
115    use std::io::Read;
116    let mut buf = Vec::new();
117    response
118        .take(MAX_BODY + 1)
119        .read_to_end(&mut buf)
120        .map_err(|e| format!("read body: {e}"))?;
121    if buf.len() as u64 > MAX_BODY {
122        return Err(format!(
123            "response body exceeds {} MiB limit",
124            MAX_BODY / 1024 / 1024
125        ));
126    }
127    let content = buf;
128
129    // Compute SHA256
130    let mut hasher = Sha256::new();
131    hasher.update(&content);
132    let sha256 = format!("{:x}", hasher.finalize());
133
134    // Verify hash if pinned
135    if let Some(ref expected) = opts.expected_sha256 {
136        let expected_lower = expected.to_lowercase();
137        if sha256 != expected_lower {
138            return Err(format!(
139                "SHA-256 mismatch: expected {expected_lower}, got {sha256}"
140            ));
141        }
142    }
143
144    // Cache
145    let cache_dir = crate::policy::data_dir()
146        .ok_or("cannot determine data directory")?
147        .join("cache");
148    fs::create_dir_all(&cache_dir).map_err(|e| format!("create cache: {e}"))?;
149    let cached_path = cache_dir.join(&sha256);
150    {
151        use std::io::Write;
152        use tempfile::NamedTempFile;
153
154        let mut tmp = NamedTempFile::new_in(&cache_dir).map_err(|e| format!("tempfile: {e}"))?;
155        #[cfg(unix)]
156        {
157            use std::os::unix::fs::PermissionsExt;
158            tmp.as_file()
159                .set_permissions(std::fs::Permissions::from_mode(0o600))
160                .map_err(|e| format!("permissions: {e}"))?;
161        }
162        tmp.write_all(&content)
163            .map_err(|e| format!("write cache: {e}"))?;
164        tmp.persist(&cached_path)
165            .map_err(|e| format!("persist cache: {e}"))?;
166    }
167
168    let content_str = match String::from_utf8(content.clone()) {
169        Ok(s) => s,
170        Err(_) => {
171            eprintln!("tirith: warning: downloaded content contains invalid UTF-8, using lossy conversion");
172            String::from_utf8_lossy(&content).into_owned()
173        }
174    };
175
176    // Analyze
177    let interpreter = script_analysis::detect_interpreter(&content_str);
178    let analysis = script_analysis::analyze(&content_str, interpreter);
179
180    // Enforce interpreter policy only when we might execute.
181    if !opts.no_exec && !is_allowed_interpreter(interpreter) {
182        return Err(format!(
183            "interpreter '{interpreter}' is not in the allowed list",
184        ));
185    }
186
187    // Detect git repo and branch
188    let (git_repo, git_branch) = detect_git_info();
189
190    // Create receipt
191    let receipt = Receipt {
192        url: opts.url.clone(),
193        final_url: Some(final_url),
194        redirects,
195        sha256: sha256.clone(),
196        size: content.len() as u64,
197        domains_referenced: analysis.domains_referenced,
198        paths_referenced: analysis.paths_referenced,
199        analysis_method: "static".to_string(),
200        privilege: if analysis.has_sudo {
201            "elevated".to_string()
202        } else {
203            "normal".to_string()
204        },
205        timestamp: chrono::Utc::now().to_rfc3339(),
206        cwd: std::env::current_dir()
207            .ok()
208            .map(|p| p.display().to_string()),
209        git_repo,
210        git_branch,
211    };
212
213    if opts.no_exec {
214        receipt.save().map_err(|e| format!("save receipt: {e}"))?;
215        return Ok(RunResult {
216            receipt,
217            executed: false,
218            exit_code: None,
219        });
220    }
221
222    // Show analysis summary
223    eprintln!(
224        "tirith: downloaded {} bytes (SHA256: {})",
225        content.len(),
226        crate::receipt::short_hash(&sha256)
227    );
228    eprintln!("tirith: interpreter: {interpreter}");
229    if analysis.has_sudo {
230        eprintln!("tirith: WARNING: script uses sudo");
231    }
232    if analysis.has_eval {
233        eprintln!("tirith: WARNING: script uses eval");
234    }
235    if analysis.has_base64 {
236        eprintln!("tirith: WARNING: script uses base64");
237    }
238
239    // Confirm from /dev/tty
240    let tty = fs::OpenOptions::new()
241        .read(true)
242        .write(true)
243        .open("/dev/tty")
244        .map_err(|_| "cannot open /dev/tty for confirmation")?;
245
246    let mut tty_writer = io::BufWriter::new(&tty);
247    write!(tty_writer, "Execute this script? [y/N] ").map_err(|e| format!("tty write: {e}"))?;
248    tty_writer.flush().map_err(|e| format!("tty flush: {e}"))?;
249
250    let mut reader = io::BufReader::new(&tty);
251    let mut response_line = String::new();
252    reader
253        .read_line(&mut response_line)
254        .map_err(|e| format!("tty read: {e}"))?;
255
256    if !response_line.trim().eq_ignore_ascii_case("y") {
257        eprintln!("tirith: execution cancelled");
258        receipt.save().map_err(|e| format!("save receipt: {e}"))?;
259        return Ok(RunResult {
260            receipt,
261            executed: false,
262            exit_code: None,
263        });
264    }
265
266    // Execute
267    receipt.save().map_err(|e| format!("save receipt: {e}"))?;
268
269    let status = Command::new(interpreter)
270        .arg(&cached_path)
271        .status()
272        .map_err(|e| format!("execute: {e}"))?;
273
274    Ok(RunResult {
275        receipt,
276        executed: true,
277        exit_code: status.code(),
278    })
279}
280
281/// Detect git repo remote URL and current branch.
282fn detect_git_info() -> (Option<String>, Option<String>) {
283    let repo = Command::new("git")
284        .args(["remote", "get-url", "origin"])
285        .output()
286        .ok()
287        .filter(|o| o.status.success())
288        .and_then(|o| String::from_utf8(o.stdout).ok())
289        .map(|s| s.trim().to_string());
290
291    let branch = Command::new("git")
292        .args(["rev-parse", "--abbrev-ref", "HEAD"])
293        .output()
294        .ok()
295        .filter(|o| o.status.success())
296        .and_then(|o| String::from_utf8(o.stdout).ok())
297        .map(|s| s.trim().to_string());
298
299    (repo, branch)
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn test_allowed_interpreter_sh() {
308        assert!(is_allowed_interpreter("sh"));
309    }
310
311    #[test]
312    fn test_allowed_interpreter_python3() {
313        assert!(is_allowed_interpreter("python3"));
314    }
315
316    #[test]
317    fn test_allowed_interpreter_python3_11() {
318        assert!(is_allowed_interpreter("python3.11"));
319    }
320
321    #[test]
322    fn test_allowed_interpreter_nodejs() {
323        assert!(is_allowed_interpreter("nodejs"));
324    }
325
326    #[test]
327    fn test_disallowed_interpreter_vim() {
328        assert!(!is_allowed_interpreter("vim"));
329    }
330
331    #[test]
332    fn test_disallowed_interpreter_expect() {
333        assert!(!is_allowed_interpreter("expect"));
334    }
335
336    #[test]
337    fn test_disallowed_interpreter_python_evil() {
338        assert!(!is_allowed_interpreter("python.evil"));
339    }
340
341    #[test]
342    fn test_disallowed_interpreter_node_sass() {
343        assert!(!is_allowed_interpreter("node-sass"));
344    }
345
346    #[test]
347    fn test_disallowed_interpreter_python3_trailing_dot() {
348        assert!(!is_allowed_interpreter("python3."));
349    }
350
351    #[test]
352    fn test_disallowed_interpreter_python3_double_dot() {
353        assert!(!is_allowed_interpreter("python3..11"));
354    }
355
356    #[test]
357    fn test_allowed_interpreter_strips_path() {
358        assert!(is_allowed_interpreter("/usr/bin/bash"));
359    }
360
361    #[cfg(unix)]
362    #[test]
363    fn test_cache_write_permissions_0600() {
364        use std::os::unix::fs::PermissionsExt;
365        use tempfile::NamedTempFile;
366
367        let dir = tempfile::tempdir().unwrap();
368        let cache_path = dir.path().join("test_cache");
369
370        {
371            use std::io::Write;
372
373            let mut tmp = NamedTempFile::new_in(dir.path()).unwrap();
374            tmp.as_file()
375                .set_permissions(std::fs::Permissions::from_mode(0o600))
376                .unwrap();
377            tmp.write_all(b"test content").unwrap();
378            tmp.persist(&cache_path).unwrap();
379        }
380
381        let meta = std::fs::metadata(&cache_path).unwrap();
382        assert_eq!(
383            meta.permissions().mode() & 0o777,
384            0o600,
385            "cache file should be 0600"
386        );
387    }
388
389    #[test]
390    fn test_cache_write_no_predictable_tmp() {
391        use tempfile::NamedTempFile;
392
393        let dir = tempfile::tempdir().unwrap();
394        let sha = "abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789";
395        let cached_path = dir.path().join(sha);
396
397        {
398            use std::io::Write;
399            let mut tmp = NamedTempFile::new_in(dir.path()).unwrap();
400            tmp.write_all(b"cached script").unwrap();
401            tmp.persist(&cached_path).unwrap();
402        }
403
404        // No predictable temp file should remain
405        let entries: Vec<_> = std::fs::read_dir(dir.path())
406            .unwrap()
407            .filter_map(|e| e.ok())
408            .map(|e| e.file_name().to_string_lossy().to_string())
409            .collect();
410
411        // Should only contain the final cached file
412        assert_eq!(
413            entries.len(),
414            1,
415            "only the cached file should exist, found: {entries:?}"
416        );
417        assert!(
418            cached_path.exists(),
419            "cached file should exist after persist"
420        );
421    }
422}