1use etcetera::BaseStrategy;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6use crate::verdict::{RuleId, Severity};
7
8fn find_policy_in_dir(dir: &Path) -> Option<PathBuf> {
10 let yaml = dir.join("policy.yaml");
11 if yaml.exists() {
12 return Some(yaml);
13 }
14 let yml = dir.join("policy.yml");
15 if yml.exists() {
16 return Some(yml);
17 }
18 None
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23#[serde(default)]
24pub struct Policy {
25 #[serde(skip)]
27 pub path: Option<String>,
28
29 pub fail_mode: FailMode,
31
32 pub allow_bypass_env: bool,
34
35 pub allow_bypass_env_noninteractive: bool,
37
38 pub paranoia: u8,
40
41 #[serde(default)]
43 pub severity_overrides: HashMap<String, Severity>,
44
45 #[serde(default)]
47 pub additional_known_domains: Vec<String>,
48
49 #[serde(default)]
51 pub allowlist: Vec<String>,
52
53 #[serde(default)]
55 pub blocklist: Vec<String>,
56
57 #[serde(default)]
60 pub approval_rules: Vec<ApprovalRule>,
61
62 #[serde(default)]
64 pub network_deny: Vec<String>,
65
66 #[serde(default)]
68 pub network_allow: Vec<String>,
69
70 #[serde(default)]
72 pub webhooks: Vec<WebhookConfig>,
73
74 #[serde(default)]
76 pub checkpoints: CheckpointPolicyConfig,
77
78 #[serde(default)]
80 pub scan: ScanPolicyConfig,
81
82 #[serde(default)]
84 pub allowlist_rules: Vec<AllowlistRule>,
85
86 #[serde(default)]
88 pub custom_rules: Vec<CustomRule>,
89
90 #[serde(default)]
93 pub dlp_custom_patterns: Vec<String>,
94
95 #[serde(default)]
98 pub policy_server_url: Option<String>,
99 #[serde(default)]
101 pub policy_server_api_key: Option<String>,
102 #[serde(default)]
104 pub policy_fetch_fail_mode: Option<String>,
105 #[serde(default)]
107 pub enforce_fail_mode: Option<bool>,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct ApprovalRule {
113 pub rule_ids: Vec<String>,
115 #[serde(default)]
117 pub timeout_secs: u64,
118 #[serde(default = "default_approval_fallback")]
120 pub fallback: String,
121}
122
123fn default_approval_fallback() -> String {
124 "block".to_string()
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct WebhookConfig {
130 pub url: String,
132 #[serde(default = "default_webhook_severity")]
134 pub min_severity: Severity,
135 #[serde(default)]
137 pub headers: HashMap<String, String>,
138 #[serde(default)]
140 pub payload_template: Option<String>,
141}
142
143fn default_webhook_severity() -> Severity {
144 Severity::High
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
149#[serde(default)]
150pub struct CheckpointPolicyConfig {
151 pub max_count: usize,
153 pub max_age_hours: u64,
155 pub max_storage_bytes: u64,
157}
158
159impl Default for CheckpointPolicyConfig {
160 fn default() -> Self {
161 Self {
162 max_count: 100,
163 max_age_hours: 168, max_storage_bytes: 500 * 1024 * 1024, }
166 }
167}
168
169#[derive(Debug, Clone, Default, Serialize, Deserialize)]
171#[serde(default)]
172pub struct ScanPolicyConfig {
173 #[serde(default)]
175 pub additional_config_files: Vec<String>,
176 #[serde(default)]
178 pub trusted_mcp_servers: Vec<String>,
179 #[serde(default)]
181 pub ignore_patterns: Vec<String>,
182 #[serde(default)]
184 pub fail_on: Option<String>,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct AllowlistRule {
190 pub rule_id: String,
192 pub patterns: Vec<String>,
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct CustomRule {
199 pub id: String,
201 pub pattern: String,
203 #[serde(default = "default_custom_rule_contexts")]
205 pub context: Vec<String>,
206 #[serde(default = "default_custom_rule_severity")]
208 pub severity: Severity,
209 pub title: String,
211 #[serde(default)]
213 pub description: String,
214}
215
216fn default_custom_rule_contexts() -> Vec<String> {
217 vec!["exec".to_string(), "paste".to_string()]
218}
219
220fn default_custom_rule_severity() -> Severity {
221 Severity::High
222}
223
224#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
225#[serde(rename_all = "lowercase")]
226#[derive(Default)]
227pub enum FailMode {
228 #[default]
229 Open,
230 Closed,
231}
232
233impl Default for Policy {
234 fn default() -> Self {
235 Self {
236 path: None,
237 fail_mode: FailMode::Open,
238 allow_bypass_env: true,
239 allow_bypass_env_noninteractive: false,
240 paranoia: 1,
241 severity_overrides: HashMap::new(),
242 additional_known_domains: Vec::new(),
243 allowlist: Vec::new(),
244 blocklist: Vec::new(),
245 approval_rules: Vec::new(),
246 network_deny: Vec::new(),
247 network_allow: Vec::new(),
248 webhooks: Vec::new(),
249 checkpoints: CheckpointPolicyConfig::default(),
250 scan: ScanPolicyConfig::default(),
251 allowlist_rules: Vec::new(),
252 custom_rules: Vec::new(),
253 dlp_custom_patterns: Vec::new(),
254 policy_server_url: None,
255 policy_server_api_key: None,
256 policy_fetch_fail_mode: None,
257 enforce_fail_mode: None,
258 }
259 }
260}
261
262impl Policy {
263 pub fn discover_partial(cwd: Option<&str>) -> Self {
266 match discover_policy_path(cwd) {
267 Some(path) => match std::fs::read_to_string(&path) {
268 Ok(content) => match serde_yaml::from_str::<Policy>(&content) {
269 Ok(mut p) => {
270 p.path = Some(path.display().to_string());
271 p
272 }
273 Err(e) => {
274 eprintln!(
275 "tirith: warning: failed to parse policy at {}: {e}",
276 path.display()
277 );
278 Policy::default()
280 }
281 },
282 Err(e) => {
283 eprintln!(
284 "tirith: warning: cannot read policy at {}: {e}",
285 path.display()
286 );
287 Policy::default()
288 }
289 },
290 None => Policy::default(),
291 }
292 }
293
294 pub fn discover(cwd: Option<&str>) -> Self {
307 let local = Self::discover_local(cwd);
309
310 let server_url = std::env::var("TIRITH_SERVER_URL")
312 .ok()
313 .filter(|s| !s.is_empty())
314 .or_else(|| local.policy_server_url.clone());
315 let api_key = std::env::var("TIRITH_API_KEY")
316 .ok()
317 .filter(|s| !s.is_empty())
318 .or_else(|| local.policy_server_api_key.clone());
319
320 let (server_url, api_key) = match (server_url, api_key) {
321 (Some(u), Some(k)) => (u, k),
322 _ => return local, };
324
325 let fail_mode = local.policy_fetch_fail_mode.as_deref().unwrap_or("open");
326
327 match crate::policy_client::fetch_remote_policy(&server_url, &api_key) {
329 Ok(yaml) => {
330 let _ = cache_remote_policy(&yaml);
332 match serde_yaml::from_str::<Policy>(&yaml) {
333 Ok(mut p) => {
334 p.path = Some(format!("remote:{server_url}"));
335 if p.policy_server_url.is_none() {
337 p.policy_server_url = Some(server_url);
338 }
339 if p.policy_server_api_key.is_none() {
340 p.policy_server_api_key = Some(api_key);
341 }
342 p
343 }
344 Err(e) => match fail_mode {
345 "closed" => {
346 eprintln!(
347 "tirith: error: remote policy parse error ({e}), failing closed"
348 );
349 Self::fail_closed_policy()
350 }
351 "cached" => {
352 eprintln!(
353 "tirith: warning: remote policy parse error ({e}), trying cache"
354 );
355 match load_cached_remote_policy() {
356 Some(p) => p,
357 None => {
358 eprintln!(
359 "tirith: warning: no cached remote policy, using local"
360 );
361 local
362 }
363 }
364 }
365 _ => {
366 eprintln!("tirith: warning: remote policy parse error: {e}");
367 local
368 }
369 },
370 }
371 }
372 Err(crate::policy_client::PolicyFetchError::AuthError(code)) => {
373 eprintln!("tirith: error: policy server auth failed (HTTP {code}), failing closed");
375 Self::fail_closed_policy()
376 }
377 Err(e) => {
378 match fail_mode {
380 "closed" => {
381 eprintln!(
382 "tirith: error: remote policy fetch failed ({e}), failing closed"
383 );
384 Self::fail_closed_policy()
385 }
386 "cached" => {
387 eprintln!(
388 "tirith: warning: remote policy fetch failed ({e}), trying cache"
389 );
390 match load_cached_remote_policy() {
391 Some(p) => p,
392 None => {
393 eprintln!("tirith: warning: no cached remote policy, using local");
394 local
395 }
396 }
397 }
398 _ => {
399 eprintln!(
401 "tirith: warning: remote policy fetch failed ({e}), using local policy"
402 );
403 local
404 }
405 }
406 }
407 }
408 }
409
410 fn discover_local(cwd: Option<&str>) -> Self {
412 if let Ok(root) = std::env::var("TIRITH_POLICY_ROOT") {
414 if let Some(path) = find_policy_in_dir(&PathBuf::from(&root).join(".tirith")) {
415 return Self::load_from_path(&path);
416 }
417 }
418
419 match discover_policy_path(cwd) {
420 Some(path) => Self::load_from_path(&path),
421 None => {
422 if let Some(user_path) = user_policy_path() {
424 if user_path.exists() {
425 return Self::load_from_path(&user_path);
426 }
427 }
428 Policy::default()
429 }
430 }
431 }
432
433 fn fail_closed_policy() -> Self {
435 Policy {
436 fail_mode: FailMode::Closed,
437 allow_bypass_env: false,
438 allow_bypass_env_noninteractive: false,
439 path: Some("fail-closed".into()),
440 ..Default::default()
441 }
442 }
443
444 fn load_from_path(path: &Path) -> Self {
445 match std::fs::read_to_string(path) {
446 Ok(content) => match serde_yaml::from_str::<Policy>(&content) {
447 Ok(mut p) => {
448 p.path = Some(path.display().to_string());
449 p
450 }
451 Err(e) => {
452 eprintln!(
453 "tirith: warning: failed to parse policy at {}: {e}",
454 path.display(),
455 );
456 Policy::default()
457 }
458 },
459 Err(e) => {
460 eprintln!(
461 "tirith: warning: cannot read policy at {}: {e}",
462 path.display()
463 );
464 Policy::default()
465 }
466 }
467 }
468
469 pub fn severity_override(&self, rule_id: &RuleId) -> Option<Severity> {
471 let key = serde_json::to_value(rule_id)
472 .ok()
473 .and_then(|v| v.as_str().map(String::from))?;
474 self.severity_overrides.get(&key).copied()
475 }
476
477 pub fn is_blocklisted(&self, url: &str) -> bool {
479 let url_lower = url.to_lowercase();
480 self.blocklist.iter().any(|pattern| {
481 let p = pattern.to_lowercase();
482 url_lower.contains(&p)
483 })
484 }
485
486 pub fn is_allowlisted(&self, url: &str) -> bool {
488 self.allowlist
489 .iter()
490 .any(|pattern| allowlist_pattern_matches(pattern, url))
491 }
492
493 pub fn is_allowlisted_for_rule(&self, rule_id: &str, url: &str) -> bool {
495 self.allowlist_rules.iter().any(|rule| {
496 rule.rule_id.eq_ignore_ascii_case(rule_id)
497 && rule
498 .patterns
499 .iter()
500 .any(|pattern| allowlist_pattern_matches(pattern, url))
501 })
502 }
503
504 pub fn load_user_lists(&mut self) {
506 if let Some(config) = crate::policy::config_dir() {
507 let allowlist_path = config.join("allowlist");
508 if let Ok(content) = std::fs::read_to_string(&allowlist_path) {
509 for line in content.lines() {
510 let line = line.trim();
511 if !line.is_empty() && !line.starts_with('#') {
512 self.allowlist.push(line.to_string());
513 }
514 }
515 }
516 let blocklist_path = config.join("blocklist");
517 if let Ok(content) = std::fs::read_to_string(&blocklist_path) {
518 for line in content.lines() {
519 let line = line.trim();
520 if !line.is_empty() && !line.starts_with('#') {
521 self.blocklist.push(line.to_string());
522 }
523 }
524 }
525 }
526 }
527
528 pub fn load_org_lists(&mut self, cwd: Option<&str>) {
534 if let Some(repo_root) = find_repo_root(cwd) {
535 let org_dir = repo_root.join(".tirith");
536 let allowlist_path = org_dir.join("allowlist");
537 if let Ok(content) = std::fs::read_to_string(&allowlist_path) {
538 eprintln!(
539 "tirith: loading org-level allowlist from {}",
540 allowlist_path.display()
541 );
542 for line in content.lines() {
543 let line = line.trim();
544 if !line.is_empty() && !line.starts_with('#') {
545 self.allowlist.push(line.to_string());
546 }
547 }
548 }
549 let blocklist_path = org_dir.join("blocklist");
550 if let Ok(content) = std::fs::read_to_string(&blocklist_path) {
551 eprintln!(
552 "tirith: loading org-level blocklist from {}",
553 blocklist_path.display()
554 );
555 for line in content.lines() {
556 let line = line.trim();
557 if !line.is_empty() && !line.starts_with('#') {
558 self.blocklist.push(line.to_string());
559 }
560 }
561 }
562 }
563 }
564}
565
566fn is_domain_pattern(p: &str) -> bool {
567 !p.contains("://")
568 && !p.contains('/')
569 && !p.contains('?')
570 && !p.contains('#')
571 && !p.contains(':')
572}
573
574fn extract_host_for_match(url: &str) -> Option<String> {
575 if let Some(host) = crate::parse::parse_url(url).host() {
576 return Some(host.trim_end_matches('.').to_lowercase());
577 }
578 let candidate = url.split('/').next().unwrap_or(url).trim();
580 if candidate.starts_with('-') || !candidate.contains('.') || candidate.contains(' ') {
581 return None;
582 }
583 let host = if let Some((h, port)) = candidate.rsplit_once(':') {
584 if port.chars().all(|c| c.is_ascii_digit()) && !port.is_empty() {
585 h
586 } else {
587 candidate
588 }
589 } else {
590 candidate
591 };
592 Some(host.trim_end_matches('.').to_lowercase())
593}
594
595fn domain_matches(host: &str, pattern: &str) -> bool {
596 let host = host.trim_end_matches('.');
597 let pattern = pattern.trim_start_matches("*.").trim_end_matches('.');
598 host == pattern || host.ends_with(&format!(".{pattern}"))
599}
600
601fn allowlist_pattern_matches(pattern: &str, url: &str) -> bool {
602 let p = pattern.to_lowercase();
603 if p.is_empty() {
604 return false;
605 }
606 if is_domain_pattern(&p) {
607 if let Some(host) = extract_host_for_match(url) {
608 return domain_matches(&host, &p);
609 }
610 return false;
611 }
612 url.to_lowercase().contains(&p)
613}
614
615fn discover_policy_path(cwd: Option<&str>) -> Option<PathBuf> {
617 let start = cwd
618 .map(PathBuf::from)
619 .or_else(|| std::env::current_dir().ok())?;
620
621 let mut current = start.as_path();
622 loop {
623 if let Some(candidate) = find_policy_in_dir(¤t.join(".tirith")) {
625 return Some(candidate);
626 }
627
628 let git_dir = current.join(".git");
630 if git_dir.exists() {
631 return None; }
633
634 match current.parent() {
636 Some(parent) if parent != current => current = parent,
637 _ => break,
638 }
639 }
640
641 None
642}
643
644fn find_repo_root(cwd: Option<&str>) -> Option<PathBuf> {
646 let start = cwd
647 .map(PathBuf::from)
648 .or_else(|| std::env::current_dir().ok())?;
649 let mut current = start.as_path();
650 loop {
651 let git = current.join(".git");
652 if git.exists() {
653 return Some(current.to_path_buf());
654 }
655 match current.parent() {
656 Some(parent) if parent != current => current = parent,
657 _ => break,
658 }
659 }
660 None
661}
662
663fn user_policy_path() -> Option<PathBuf> {
665 let base = etcetera::choose_base_strategy().ok()?;
666 find_policy_in_dir(&base.config_dir().join("tirith"))
667}
668
669pub fn data_dir() -> Option<PathBuf> {
671 let base = etcetera::choose_base_strategy().ok()?;
672 Some(base.data_dir().join("tirith"))
673}
674
675pub fn config_dir() -> Option<PathBuf> {
677 let base = etcetera::choose_base_strategy().ok()?;
678 Some(base.config_dir().join("tirith"))
679}
680
681pub fn state_dir() -> Option<PathBuf> {
684 match std::env::var("XDG_STATE_HOME") {
685 Ok(val) if !val.trim().is_empty() => Some(PathBuf::from(val.trim()).join("tirith")),
686 _ => home::home_dir().map(|h| h.join(".local/state/tirith")),
687 }
688}
689
690fn remote_policy_cache_path() -> Option<PathBuf> {
692 let cache_dir = std::env::var("XDG_CACHE_HOME")
693 .ok()
694 .filter(|s| !s.is_empty())
695 .map(PathBuf::from)
696 .or_else(|| home::home_dir().map(|h| h.join(".cache")))?;
697 Some(cache_dir.join("tirith").join("remote-policy.yaml"))
698}
699
700fn cache_remote_policy(yaml: &str) -> std::io::Result<()> {
702 if let Some(path) = remote_policy_cache_path() {
703 if let Some(parent) = path.parent() {
704 std::fs::create_dir_all(parent)?;
705 }
706 let mut opts = std::fs::OpenOptions::new();
708 opts.write(true).create(true).truncate(true);
709 #[cfg(unix)]
710 {
711 use std::os::unix::fs::OpenOptionsExt;
712 opts.mode(0o600);
713 }
714 let mut f = opts.open(&path)?;
715 use std::io::Write;
716 f.write_all(yaml.as_bytes())?;
717 }
718 Ok(())
719}
720
721fn load_cached_remote_policy() -> Option<Policy> {
723 let path = remote_policy_cache_path()?;
724 let content = std::fs::read_to_string(&path).ok()?;
725 match serde_yaml::from_str::<Policy>(&content) {
726 Ok(mut p) => {
727 p.path = Some(format!("cached:{}", path.display()));
728 Some(p)
729 }
730 Err(e) => {
731 eprintln!("tirith: warning: cached remote policy parse error: {e}");
732 None
733 }
734 }
735}
736
737#[cfg(test)]
738mod tests {
739 use super::*;
740
741 #[test]
742 fn test_allowlist_domain_matches_subdomain() {
743 let p = Policy {
744 allowlist: vec!["github.com".to_string()],
745 ..Default::default()
746 };
747 assert!(p.is_allowlisted("https://api.github.com/repos"));
748 assert!(p.is_allowlisted("git@github.com:owner/repo.git"));
749 assert!(!p.is_allowlisted("https://evil-github.com"));
750 }
751
752 #[test]
753 fn test_allowlist_schemeless_host() {
754 let p = Policy {
755 allowlist: vec!["raw.githubusercontent.com".to_string()],
756 ..Default::default()
757 };
758 assert!(p.is_allowlisted("raw.githubusercontent.com/path/to/file"));
759 }
760
761 #[test]
762 fn test_allowlist_schemeless_host_with_port() {
763 let p = Policy {
764 allowlist: vec!["example.com".to_string()],
765 ..Default::default()
766 };
767 assert!(p.is_allowlisted("example.com:8080/path"));
768 }
769
770 #[test]
771 fn test_discover_applies_remote_fetch_fail_mode_when_configured() {
772 let _guard = crate::TEST_ENV_LOCK
773 .lock()
774 .unwrap_or_else(|e| e.into_inner());
775
776 let dir = tempfile::tempdir().unwrap();
777 let policy_dir = dir.path().join(".tirith");
778 std::fs::create_dir_all(&policy_dir).unwrap();
779 std::fs::write(
780 policy_dir.join("policy.yaml"),
781 "fail_mode: open\npolicy_fetch_fail_mode: closed\nallow_bypass_env_noninteractive: true\n",
782 )
783 .unwrap();
784
785 unsafe { std::env::set_var("TIRITH_SERVER_URL", "http://127.0.0.1") };
786 unsafe { std::env::set_var("TIRITH_API_KEY", "dummy") };
787
788 let policy = Policy::discover(Some(dir.path().to_str().unwrap()));
789 assert_eq!(policy.path.as_deref(), Some("fail-closed"));
790 assert_eq!(policy.fail_mode, FailMode::Closed);
791 assert!(!policy.allow_bypass_env_noninteractive);
792
793 unsafe { std::env::remove_var("TIRITH_API_KEY") };
794 unsafe { std::env::remove_var("TIRITH_SERVER_URL") };
795 }
796}