1use etcetera::BaseStrategy;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6use crate::verdict::{RuleId, Severity};
7
8fn find_policy_in_dir(dir: &Path) -> Option<PathBuf> {
10 let yaml = dir.join("policy.yaml");
11 if yaml.exists() {
12 return Some(yaml);
13 }
14 let yml = dir.join("policy.yml");
15 if yml.exists() {
16 return Some(yml);
17 }
18 None
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23#[serde(default)]
24pub struct Policy {
25 #[serde(skip)]
27 pub path: Option<String>,
28
29 pub fail_mode: FailMode,
31
32 pub allow_bypass_env: bool,
34
35 pub allow_bypass_env_noninteractive: bool,
37
38 pub paranoia: u8,
40
41 #[serde(default)]
43 pub severity_overrides: HashMap<String, Severity>,
44
45 #[serde(default)]
47 pub additional_known_domains: Vec<String>,
48
49 #[serde(default)]
51 pub allowlist: Vec<String>,
52
53 #[serde(default)]
55 pub blocklist: Vec<String>,
56
57 #[serde(default)]
60 pub approval_rules: Vec<ApprovalRule>,
61
62 #[serde(default)]
64 pub network_deny: Vec<String>,
65
66 #[serde(default)]
68 pub network_allow: Vec<String>,
69
70 #[serde(default)]
72 pub webhooks: Vec<WebhookConfig>,
73
74 #[serde(default)]
76 pub checkpoints: CheckpointPolicyConfig,
77
78 #[serde(default)]
80 pub scan: ScanPolicyConfig,
81
82 #[serde(default)]
84 pub allowlist_rules: Vec<AllowlistRule>,
85
86 #[serde(default)]
88 pub custom_rules: Vec<CustomRule>,
89
90 #[serde(default)]
93 pub dlp_custom_patterns: Vec<String>,
94
95 #[serde(default)]
98 pub policy_server_url: Option<String>,
99 #[serde(default)]
101 pub policy_server_api_key: Option<String>,
102 #[serde(default)]
104 pub policy_fetch_fail_mode: Option<String>,
105 #[serde(default)]
107 pub enforce_fail_mode: Option<bool>,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct ApprovalRule {
113 pub rule_ids: Vec<String>,
115 #[serde(default)]
117 pub timeout_secs: u64,
118 #[serde(default = "default_approval_fallback")]
120 pub fallback: String,
121}
122
123fn default_approval_fallback() -> String {
124 "block".to_string()
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct WebhookConfig {
130 pub url: String,
132 #[serde(default = "default_webhook_severity")]
134 pub min_severity: Severity,
135 #[serde(default)]
137 pub headers: HashMap<String, String>,
138 #[serde(default)]
140 pub payload_template: Option<String>,
141}
142
143fn default_webhook_severity() -> Severity {
144 Severity::High
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
149#[serde(default)]
150pub struct CheckpointPolicyConfig {
151 pub max_count: usize,
153 pub max_age_hours: u64,
155 pub max_storage_bytes: u64,
157}
158
159impl Default for CheckpointPolicyConfig {
160 fn default() -> Self {
161 Self {
162 max_count: 100,
163 max_age_hours: 168, max_storage_bytes: 500 * 1024 * 1024, }
166 }
167}
168
169#[derive(Debug, Clone, Default, Serialize, Deserialize)]
171#[serde(default)]
172pub struct ScanPolicyConfig {
173 #[serde(default)]
175 pub additional_config_files: Vec<String>,
176 #[serde(default)]
178 pub trusted_mcp_servers: Vec<String>,
179 #[serde(default)]
181 pub ignore_patterns: Vec<String>,
182 #[serde(default)]
184 pub fail_on: Option<String>,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct AllowlistRule {
190 pub rule_id: String,
192 pub patterns: Vec<String>,
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct CustomRule {
199 pub id: String,
201 pub pattern: String,
203 #[serde(default = "default_custom_rule_contexts")]
205 pub context: Vec<String>,
206 #[serde(default = "default_custom_rule_severity")]
208 pub severity: Severity,
209 pub title: String,
211 #[serde(default)]
213 pub description: String,
214}
215
216fn default_custom_rule_contexts() -> Vec<String> {
217 vec!["exec".to_string(), "paste".to_string()]
218}
219
220fn default_custom_rule_severity() -> Severity {
221 Severity::High
222}
223
224#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
225#[serde(rename_all = "lowercase")]
226#[derive(Default)]
227pub enum FailMode {
228 #[default]
229 Open,
230 Closed,
231}
232
233impl Default for Policy {
234 fn default() -> Self {
235 Self {
236 path: None,
237 fail_mode: FailMode::Open,
238 allow_bypass_env: true,
239 allow_bypass_env_noninteractive: false,
240 paranoia: 1,
241 severity_overrides: HashMap::new(),
242 additional_known_domains: Vec::new(),
243 allowlist: Vec::new(),
244 blocklist: Vec::new(),
245 approval_rules: Vec::new(),
246 network_deny: Vec::new(),
247 network_allow: Vec::new(),
248 webhooks: Vec::new(),
249 checkpoints: CheckpointPolicyConfig::default(),
250 scan: ScanPolicyConfig::default(),
251 allowlist_rules: Vec::new(),
252 custom_rules: Vec::new(),
253 dlp_custom_patterns: Vec::new(),
254 policy_server_url: None,
255 policy_server_api_key: None,
256 policy_fetch_fail_mode: None,
257 enforce_fail_mode: None,
258 }
259 }
260}
261
262impl Policy {
263 pub fn discover_partial(cwd: Option<&str>) -> Self {
266 match discover_policy_path(cwd) {
267 Some(path) => match std::fs::read_to_string(&path) {
268 Ok(content) => match serde_yaml::from_str::<Policy>(&content) {
269 Ok(mut p) => {
270 p.path = Some(path.display().to_string());
271 p
272 }
273 Err(e) => {
274 eprintln!(
275 "tirith: warning: failed to parse policy at {}: {e}",
276 path.display()
277 );
278 Policy::default()
280 }
281 },
282 Err(e) => {
283 eprintln!(
284 "tirith: warning: cannot read policy at {}: {e}",
285 path.display()
286 );
287 Policy::default()
288 }
289 },
290 None => Policy::default(),
291 }
292 }
293
294 pub fn discover(cwd: Option<&str>) -> Self {
307 let local = Self::discover_local(cwd);
309
310 if crate::license::current_tier() < crate::license::Tier::Team {
312 return local;
313 }
314
315 let server_url = std::env::var("TIRITH_SERVER_URL")
317 .ok()
318 .filter(|s| !s.is_empty())
319 .or_else(|| local.policy_server_url.clone());
320 let api_key = std::env::var("TIRITH_API_KEY")
321 .ok()
322 .filter(|s| !s.is_empty())
323 .or_else(|| local.policy_server_api_key.clone());
324
325 let (server_url, api_key) = match (server_url, api_key) {
326 (Some(u), Some(k)) => (u, k),
327 _ => return local, };
329
330 let fail_mode = local.policy_fetch_fail_mode.as_deref().unwrap_or("open");
331
332 match crate::policy_client::fetch_remote_policy(&server_url, &api_key) {
334 Ok(yaml) => {
335 let _ = cache_remote_policy(&yaml);
337 match serde_yaml::from_str::<Policy>(&yaml) {
338 Ok(mut p) => {
339 p.path = Some(format!("remote:{server_url}"));
340 if p.policy_server_url.is_none() {
342 p.policy_server_url = Some(server_url);
343 }
344 if p.policy_server_api_key.is_none() {
345 p.policy_server_api_key = Some(api_key);
346 }
347 p
348 }
349 Err(e) => match fail_mode {
350 "closed" => {
351 eprintln!(
352 "tirith: error: remote policy parse error ({e}), failing closed"
353 );
354 Self::fail_closed_policy()
355 }
356 "cached" => {
357 eprintln!(
358 "tirith: warning: remote policy parse error ({e}), trying cache"
359 );
360 match load_cached_remote_policy() {
361 Some(p) => p,
362 None => {
363 eprintln!(
364 "tirith: warning: no cached remote policy, using local"
365 );
366 local
367 }
368 }
369 }
370 _ => {
371 eprintln!("tirith: warning: remote policy parse error: {e}");
372 local
373 }
374 },
375 }
376 }
377 Err(crate::policy_client::PolicyFetchError::AuthError(code)) => {
378 eprintln!("tirith: error: policy server auth failed (HTTP {code}), failing closed");
380 Self::fail_closed_policy()
381 }
382 Err(e) => {
383 match fail_mode {
385 "closed" => {
386 eprintln!(
387 "tirith: error: remote policy fetch failed ({e}), failing closed"
388 );
389 Self::fail_closed_policy()
390 }
391 "cached" => {
392 eprintln!(
393 "tirith: warning: remote policy fetch failed ({e}), trying cache"
394 );
395 match load_cached_remote_policy() {
396 Some(p) => p,
397 None => {
398 eprintln!("tirith: warning: no cached remote policy, using local");
399 local
400 }
401 }
402 }
403 _ => {
404 eprintln!(
406 "tirith: warning: remote policy fetch failed ({e}), using local policy"
407 );
408 local
409 }
410 }
411 }
412 }
413 }
414
415 fn discover_local(cwd: Option<&str>) -> Self {
417 if let Ok(root) = std::env::var("TIRITH_POLICY_ROOT") {
419 if let Some(path) = find_policy_in_dir(&PathBuf::from(&root).join(".tirith")) {
420 return Self::load_from_path(&path);
421 }
422 }
423
424 match discover_policy_path(cwd) {
425 Some(path) => Self::load_from_path(&path),
426 None => {
427 if let Some(user_path) = user_policy_path() {
429 if user_path.exists() {
430 return Self::load_from_path(&user_path);
431 }
432 }
433 Policy::default()
434 }
435 }
436 }
437
438 fn fail_closed_policy() -> Self {
440 Policy {
441 fail_mode: FailMode::Closed,
442 allow_bypass_env: false,
443 allow_bypass_env_noninteractive: false,
444 path: Some("fail-closed".into()),
445 ..Default::default()
446 }
447 }
448
449 fn load_from_path(path: &Path) -> Self {
450 match std::fs::read_to_string(path) {
451 Ok(content) => match serde_yaml::from_str::<Policy>(&content) {
452 Ok(mut p) => {
453 p.path = Some(path.display().to_string());
454 p
455 }
456 Err(e) => {
457 eprintln!(
458 "tirith: warning: failed to parse policy at {}: {e}",
459 path.display(),
460 );
461 Policy::default()
462 }
463 },
464 Err(e) => {
465 eprintln!(
466 "tirith: warning: cannot read policy at {}: {e}",
467 path.display()
468 );
469 Policy::default()
470 }
471 }
472 }
473
474 pub fn severity_override(&self, rule_id: &RuleId) -> Option<Severity> {
476 let key = serde_json::to_value(rule_id)
477 .ok()
478 .and_then(|v| v.as_str().map(String::from))?;
479 self.severity_overrides.get(&key).copied()
480 }
481
482 pub fn is_blocklisted(&self, url: &str) -> bool {
484 let url_lower = url.to_lowercase();
485 self.blocklist.iter().any(|pattern| {
486 let p = pattern.to_lowercase();
487 url_lower.contains(&p)
488 })
489 }
490
491 pub fn is_allowlisted(&self, url: &str) -> bool {
493 let url_lower = url.to_lowercase();
494 self.allowlist.iter().any(|pattern| {
495 let p = pattern.to_lowercase();
496 if p.is_empty() {
497 return false;
498 }
499 if is_domain_pattern(&p) {
500 if let Some(host) = extract_host_for_match(url) {
501 return domain_matches(&host, &p);
502 }
503 return false;
504 }
505 url_lower.contains(&p)
506 })
507 }
508
509 pub fn load_user_lists(&mut self) {
511 if let Some(config) = crate::policy::config_dir() {
512 let allowlist_path = config.join("allowlist");
513 if let Ok(content) = std::fs::read_to_string(&allowlist_path) {
514 for line in content.lines() {
515 let line = line.trim();
516 if !line.is_empty() && !line.starts_with('#') {
517 self.allowlist.push(line.to_string());
518 }
519 }
520 }
521 let blocklist_path = config.join("blocklist");
522 if let Ok(content) = std::fs::read_to_string(&blocklist_path) {
523 for line in content.lines() {
524 let line = line.trim();
525 if !line.is_empty() && !line.starts_with('#') {
526 self.blocklist.push(line.to_string());
527 }
528 }
529 }
530 }
531 }
532
533 pub fn load_org_lists(&mut self, cwd: Option<&str>) {
539 if let Some(repo_root) = find_repo_root(cwd) {
540 let org_dir = repo_root.join(".tirith");
541 let allowlist_path = org_dir.join("allowlist");
542 if let Ok(content) = std::fs::read_to_string(&allowlist_path) {
543 eprintln!(
544 "tirith: loading org-level allowlist from {}",
545 allowlist_path.display()
546 );
547 for line in content.lines() {
548 let line = line.trim();
549 if !line.is_empty() && !line.starts_with('#') {
550 self.allowlist.push(line.to_string());
551 }
552 }
553 }
554 let blocklist_path = org_dir.join("blocklist");
555 if let Ok(content) = std::fs::read_to_string(&blocklist_path) {
556 eprintln!(
557 "tirith: loading org-level blocklist from {}",
558 blocklist_path.display()
559 );
560 for line in content.lines() {
561 let line = line.trim();
562 if !line.is_empty() && !line.starts_with('#') {
563 self.blocklist.push(line.to_string());
564 }
565 }
566 }
567 }
568 }
569}
570
571fn is_domain_pattern(p: &str) -> bool {
572 !p.contains("://")
573 && !p.contains('/')
574 && !p.contains('?')
575 && !p.contains('#')
576 && !p.contains(':')
577}
578
579fn extract_host_for_match(url: &str) -> Option<String> {
580 if let Some(host) = crate::parse::parse_url(url).host() {
581 return Some(host.trim_end_matches('.').to_lowercase());
582 }
583 let candidate = url.split('/').next().unwrap_or(url).trim();
585 if candidate.starts_with('-') || !candidate.contains('.') || candidate.contains(' ') {
586 return None;
587 }
588 let host = if let Some((h, port)) = candidate.rsplit_once(':') {
589 if port.chars().all(|c| c.is_ascii_digit()) && !port.is_empty() {
590 h
591 } else {
592 candidate
593 }
594 } else {
595 candidate
596 };
597 Some(host.trim_end_matches('.').to_lowercase())
598}
599
600fn domain_matches(host: &str, pattern: &str) -> bool {
601 let host = host.trim_end_matches('.');
602 let pattern = pattern.trim_start_matches("*.").trim_end_matches('.');
603 host == pattern || host.ends_with(&format!(".{pattern}"))
604}
605
606fn discover_policy_path(cwd: Option<&str>) -> Option<PathBuf> {
608 let start = cwd
609 .map(PathBuf::from)
610 .or_else(|| std::env::current_dir().ok())?;
611
612 let mut current = start.as_path();
613 loop {
614 if let Some(candidate) = find_policy_in_dir(¤t.join(".tirith")) {
616 return Some(candidate);
617 }
618
619 let git_dir = current.join(".git");
621 if git_dir.exists() {
622 return None; }
624
625 match current.parent() {
627 Some(parent) if parent != current => current = parent,
628 _ => break,
629 }
630 }
631
632 None
633}
634
635fn find_repo_root(cwd: Option<&str>) -> Option<PathBuf> {
637 let start = cwd
638 .map(PathBuf::from)
639 .or_else(|| std::env::current_dir().ok())?;
640 let mut current = start.as_path();
641 loop {
642 let git = current.join(".git");
643 if git.exists() {
644 return Some(current.to_path_buf());
645 }
646 match current.parent() {
647 Some(parent) if parent != current => current = parent,
648 _ => break,
649 }
650 }
651 None
652}
653
654fn user_policy_path() -> Option<PathBuf> {
656 let base = etcetera::choose_base_strategy().ok()?;
657 find_policy_in_dir(&base.config_dir().join("tirith"))
658}
659
660pub fn data_dir() -> Option<PathBuf> {
662 let base = etcetera::choose_base_strategy().ok()?;
663 Some(base.data_dir().join("tirith"))
664}
665
666pub fn config_dir() -> Option<PathBuf> {
668 let base = etcetera::choose_base_strategy().ok()?;
669 Some(base.config_dir().join("tirith"))
670}
671
672pub fn state_dir() -> Option<PathBuf> {
675 match std::env::var("XDG_STATE_HOME") {
676 Ok(val) if !val.trim().is_empty() => Some(PathBuf::from(val.trim()).join("tirith")),
677 _ => home::home_dir().map(|h| h.join(".local/state/tirith")),
678 }
679}
680
681fn remote_policy_cache_path() -> Option<PathBuf> {
683 let cache_dir = std::env::var("XDG_CACHE_HOME")
684 .ok()
685 .filter(|s| !s.is_empty())
686 .map(PathBuf::from)
687 .or_else(|| home::home_dir().map(|h| h.join(".cache")))?;
688 Some(cache_dir.join("tirith").join("remote-policy.yaml"))
689}
690
691fn cache_remote_policy(yaml: &str) -> std::io::Result<()> {
693 if let Some(path) = remote_policy_cache_path() {
694 if let Some(parent) = path.parent() {
695 std::fs::create_dir_all(parent)?;
696 }
697 let mut opts = std::fs::OpenOptions::new();
699 opts.write(true).create(true).truncate(true);
700 #[cfg(unix)]
701 {
702 use std::os::unix::fs::OpenOptionsExt;
703 opts.mode(0o600);
704 }
705 let mut f = opts.open(&path)?;
706 use std::io::Write;
707 f.write_all(yaml.as_bytes())?;
708 }
709 Ok(())
710}
711
712fn load_cached_remote_policy() -> Option<Policy> {
714 let path = remote_policy_cache_path()?;
715 let content = std::fs::read_to_string(&path).ok()?;
716 match serde_yaml::from_str::<Policy>(&content) {
717 Ok(mut p) => {
718 p.path = Some(format!("cached:{}", path.display()));
719 Some(p)
720 }
721 Err(e) => {
722 eprintln!("tirith: warning: cached remote policy parse error: {e}");
723 None
724 }
725 }
726}
727
728#[cfg(test)]
729mod tests {
730 use super::*;
731
732 #[test]
733 fn test_allowlist_domain_matches_subdomain() {
734 let p = Policy {
735 allowlist: vec!["github.com".to_string()],
736 ..Default::default()
737 };
738 assert!(p.is_allowlisted("https://api.github.com/repos"));
739 assert!(p.is_allowlisted("git@github.com:owner/repo.git"));
740 assert!(!p.is_allowlisted("https://evil-github.com"));
741 }
742
743 #[test]
744 fn test_allowlist_schemeless_host() {
745 let p = Policy {
746 allowlist: vec!["raw.githubusercontent.com".to_string()],
747 ..Default::default()
748 };
749 assert!(p.is_allowlisted("raw.githubusercontent.com/path/to/file"));
750 }
751
752 #[test]
753 fn test_allowlist_schemeless_host_with_port() {
754 let p = Policy {
755 allowlist: vec!["example.com".to_string()],
756 ..Default::default()
757 };
758 assert!(p.is_allowlisted("example.com:8080/path"));
759 }
760
761 #[test]
762 fn test_discover_skips_remote_fetch_below_team_tier() {
763 let _guard = crate::TEST_ENV_LOCK.lock().unwrap();
764
765 let dir = tempfile::tempdir().unwrap();
766 let policy_dir = dir.path().join(".tirith");
767 std::fs::create_dir_all(&policy_dir).unwrap();
768 std::fs::write(
769 policy_dir.join("policy.yaml"),
770 "fail_mode: open\npolicy_fetch_fail_mode: closed\nallow_bypass_env_noninteractive: true\n",
771 )
772 .unwrap();
773
774 unsafe { std::env::set_var("TIRITH_LICENSE", "!") };
776 unsafe { std::env::set_var("TIRITH_SERVER_URL", "http://127.0.0.1") };
777 unsafe { std::env::set_var("TIRITH_API_KEY", "dummy") };
778
779 let policy = Policy::discover(Some(dir.path().to_str().unwrap()));
780 assert_ne!(policy.path.as_deref(), Some("fail-closed"));
781 assert_eq!(policy.fail_mode, FailMode::Open);
782 assert!(policy.allow_bypass_env_noninteractive);
783 assert!(policy
784 .path
785 .as_deref()
786 .unwrap_or_default()
787 .contains(".tirith"));
788
789 unsafe { std::env::remove_var("TIRITH_API_KEY") };
790 unsafe { std::env::remove_var("TIRITH_SERVER_URL") };
791 unsafe { std::env::remove_var("TIRITH_LICENSE") };
792 }
793}