Skip to main content

tirith_core/
runner.rs

1/// Safe runner — Unix only.
2/// Downloads a script, analyzes it, optionally executes it with user confirmation.
3use std::fs;
4use std::io::{self, BufRead, Write};
5use std::process::Command;
6
7use sha2::{Digest, Sha256};
8
9use crate::receipt::Receipt;
10use crate::script_analysis;
11
12pub struct RunResult {
13    pub receipt: Receipt,
14    pub executed: bool,
15    pub exit_code: Option<i32>,
16}
17
18pub struct RunOptions {
19    pub url: String,
20    pub no_exec: bool,
21    pub interactive: bool,
22}
23
24pub fn run(opts: RunOptions) -> Result<RunResult, String> {
25    // Check TTY requirement
26    if !opts.no_exec && !opts.interactive {
27        return Err("tirith run requires an interactive terminal or --no-exec flag".to_string());
28    }
29
30    // Download with redirect chain collection
31    let mut redirects: Vec<String> = Vec::new();
32    let redirect_list = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
33    let redirect_list_clone = redirect_list.clone();
34
35    let client = reqwest::blocking::Client::builder()
36        .redirect(reqwest::redirect::Policy::custom(move |attempt| {
37            if let Ok(mut list) = redirect_list_clone.lock() {
38                list.push(attempt.url().to_string());
39            }
40            if attempt.previous().len() >= 10 {
41                attempt.stop()
42            } else {
43                attempt.follow()
44            }
45        }))
46        .timeout(std::time::Duration::from_secs(30))
47        .build()
48        .map_err(|e| format!("http client: {e}"))?;
49
50    let response = client
51        .get(&opts.url)
52        .send()
53        .map_err(|e| format!("download failed: {e}"))?;
54
55    let final_url = response.url().to_string();
56    if let Ok(list) = redirect_list.lock() {
57        redirects = list.clone();
58    }
59
60    const MAX_BODY: u64 = 10 * 1024 * 1024; // 10 MiB
61
62    // Check Content-Length hint first (fast rejection)
63    if let Some(len) = response.content_length() {
64        if len > MAX_BODY {
65            return Err(format!(
66                "response too large: {len} bytes (max {} MiB)",
67                MAX_BODY / 1024 / 1024
68            ));
69        }
70    }
71
72    // Read with cap using std::io::Read::take
73    use std::io::Read;
74    let mut buf = Vec::new();
75    response
76        .take(MAX_BODY + 1)
77        .read_to_end(&mut buf)
78        .map_err(|e| format!("read body: {e}"))?;
79    if buf.len() as u64 > MAX_BODY {
80        return Err(format!(
81            "response body exceeds {} MiB limit",
82            MAX_BODY / 1024 / 1024
83        ));
84    }
85    let content = buf;
86
87    // Compute SHA256
88    let mut hasher = Sha256::new();
89    hasher.update(&content);
90    let sha256 = format!("{:x}", hasher.finalize());
91
92    // Cache
93    let cache_dir = crate::policy::data_dir()
94        .ok_or("cannot determine data directory")?
95        .join("cache");
96    fs::create_dir_all(&cache_dir).map_err(|e| format!("create cache: {e}"))?;
97    let cached_path = cache_dir.join(&sha256);
98    fs::write(&cached_path, &content).map_err(|e| format!("write cache: {e}"))?;
99
100    let content_str = String::from_utf8_lossy(&content);
101
102    // Analyze
103    let interpreter = script_analysis::detect_interpreter(&content_str);
104    let analysis = script_analysis::analyze(&content_str, interpreter);
105
106    // Detect git repo and branch
107    let (git_repo, git_branch) = detect_git_info();
108
109    // Create receipt
110    let receipt = Receipt {
111        url: opts.url.clone(),
112        final_url: Some(final_url),
113        redirects,
114        sha256: sha256.clone(),
115        size: content.len() as u64,
116        domains_referenced: analysis.domains_referenced,
117        paths_referenced: analysis.paths_referenced,
118        analysis_method: "static".to_string(),
119        privilege: if analysis.has_sudo {
120            "elevated".to_string()
121        } else {
122            "normal".to_string()
123        },
124        timestamp: chrono::Utc::now().to_rfc3339(),
125        cwd: std::env::current_dir()
126            .ok()
127            .map(|p| p.display().to_string()),
128        git_repo,
129        git_branch,
130    };
131
132    if opts.no_exec {
133        receipt.save().map_err(|e| format!("save receipt: {e}"))?;
134        return Ok(RunResult {
135            receipt,
136            executed: false,
137            exit_code: None,
138        });
139    }
140
141    // Show analysis summary
142    eprintln!(
143        "tirith: downloaded {} bytes (SHA256: {})",
144        content.len(),
145        &sha256[..12]
146    );
147    eprintln!("tirith: interpreter: {interpreter}");
148    if analysis.has_sudo {
149        eprintln!("tirith: WARNING: script uses sudo");
150    }
151    if analysis.has_eval {
152        eprintln!("tirith: WARNING: script uses eval");
153    }
154    if analysis.has_base64 {
155        eprintln!("tirith: WARNING: script uses base64");
156    }
157
158    // Confirm from /dev/tty
159    let tty = fs::OpenOptions::new()
160        .read(true)
161        .write(true)
162        .open("/dev/tty")
163        .map_err(|_| "cannot open /dev/tty for confirmation")?;
164
165    let mut tty_writer = io::BufWriter::new(&tty);
166    write!(tty_writer, "Execute this script? [y/N] ").map_err(|e| format!("tty write: {e}"))?;
167    tty_writer.flush().map_err(|e| format!("tty flush: {e}"))?;
168
169    let mut reader = io::BufReader::new(&tty);
170    let mut response_line = String::new();
171    reader
172        .read_line(&mut response_line)
173        .map_err(|e| format!("tty read: {e}"))?;
174
175    if !response_line.trim().eq_ignore_ascii_case("y") {
176        eprintln!("tirith: execution cancelled");
177        receipt.save().map_err(|e| format!("save receipt: {e}"))?;
178        return Ok(RunResult {
179            receipt,
180            executed: false,
181            exit_code: None,
182        });
183    }
184
185    // Execute
186    receipt.save().map_err(|e| format!("save receipt: {e}"))?;
187
188    let status = Command::new(interpreter)
189        .arg(&cached_path)
190        .status()
191        .map_err(|e| format!("execute: {e}"))?;
192
193    Ok(RunResult {
194        receipt,
195        executed: true,
196        exit_code: status.code(),
197    })
198}
199
200/// Detect git repo remote URL and current branch.
201fn detect_git_info() -> (Option<String>, Option<String>) {
202    let repo = Command::new("git")
203        .args(["remote", "get-url", "origin"])
204        .output()
205        .ok()
206        .filter(|o| o.status.success())
207        .and_then(|o| String::from_utf8(o.stdout).ok())
208        .map(|s| s.trim().to_string());
209
210    let branch = Command::new("git")
211        .args(["rev-parse", "--abbrev-ref", "HEAD"])
212        .output()
213        .ok()
214        .filter(|o| o.status.success())
215        .and_then(|o| String::from_utf8(o.stdout).ok())
216        .map(|s| s.trim().to_string());
217
218    (repo, branch)
219}