Skip to main content

tirith_core/
policy.rs

1use etcetera::BaseStrategy;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6/// A named scan profile for reusable filter configurations.
7#[derive(Debug, Clone, Default, Serialize, Deserialize)]
8pub struct ScanProfile {
9    #[serde(default)]
10    pub include: Vec<String>,
11    #[serde(default)]
12    pub exclude: Vec<String>,
13    #[serde(default)]
14    pub fail_on: Option<String>,
15    #[serde(default)]
16    pub ignore: Vec<String>,
17}
18
19use crate::verdict::{RuleId, Severity};
20
21/// Try both `.yaml` and `.yml` extensions in a directory.
22fn find_policy_in_dir(dir: &Path) -> Option<PathBuf> {
23    let yaml = dir.join("policy.yaml");
24    if yaml.exists() {
25        return Some(yaml);
26    }
27    let yml = dir.join("policy.yml");
28    if yml.exists() {
29        return Some(yml);
30    }
31    None
32}
33
34/// Policy configuration loaded from YAML.
35#[derive(Debug, Clone, Serialize, Deserialize)]
36#[serde(default)]
37pub struct Policy {
38    /// Path this policy was loaded from.
39    #[serde(skip)]
40    pub path: Option<String>,
41
42    /// Fail mode: "open" (default) or "closed".
43    pub fail_mode: FailMode,
44
45    /// Allow TIRITH=0 bypass in interactive mode.
46    pub allow_bypass_env: bool,
47
48    /// Allow TIRITH=0 bypass in non-interactive mode.
49    pub allow_bypass_env_noninteractive: bool,
50
51    /// Paranoia tier (1-4).
52    pub paranoia: u8,
53
54    /// Severity overrides per rule.
55    #[serde(default)]
56    pub severity_overrides: HashMap<String, Severity>,
57
58    /// Additional known domains (extends built-in list).
59    #[serde(default)]
60    pub additional_known_domains: Vec<String>,
61
62    /// Allowlist: URL patterns that are always allowed.
63    #[serde(default)]
64    pub allowlist: Vec<String>,
65
66    /// Blocklist: URL patterns that are always blocked.
67    #[serde(default)]
68    pub blocklist: Vec<String>,
69
70    /// Approval rules: commands matching these rules require human approval.
71    #[serde(default)]
72    pub approval_rules: Vec<ApprovalRule>,
73
74    /// Network deny list: block commands targeting these hosts/CIDRs.
75    #[serde(default)]
76    pub network_deny: Vec<String>,
77
78    /// Network allow list: exempt these hosts/CIDRs from network deny.
79    #[serde(default)]
80    pub network_allow: Vec<String>,
81
82    /// Webhook endpoints to notify on findings.
83    #[serde(default)]
84    pub webhooks: Vec<WebhookConfig>,
85
86    /// Checkpoint configuration (Pro+).
87    #[serde(default)]
88    pub checkpoints: CheckpointPolicyConfig,
89
90    /// Scan configuration overrides.
91    #[serde(default)]
92    pub scan: ScanPolicyConfig,
93
94    /// Per-rule allowlist scoping (Team).
95    #[serde(default)]
96    pub allowlist_rules: Vec<AllowlistRule>,
97
98    /// Custom detection rules defined in YAML (Team).
99    #[serde(default)]
100    pub custom_rules: Vec<CustomRule>,
101
102    /// Custom DLP redaction patterns (Team). Regex patterns applied alongside
103    /// built-in patterns when redacting commands in audit logs and webhooks.
104    #[serde(default)]
105    pub dlp_custom_patterns: Vec<String>,
106
107    /// Require explicit acknowledgement for warn findings in interactive mode.
108    #[serde(default)]
109    pub strict_warn: bool,
110
111    /// Per-rule action overrides: force action for specific rules (upgrade only: "block").
112    #[serde(default)]
113    pub action_overrides: HashMap<String, String>,
114
115    /// Escalation rules: upgrade action based on session history or finding count.
116    #[serde(default)]
117    pub escalation: Vec<crate::escalation::EscalationRule>,
118
119    /// URL of the centralized policy server (e.g., "https://policy.example.com").
120    #[serde(default)]
121    pub policy_server_url: Option<String>,
122    /// API key for authenticating with the policy server.
123    #[serde(default)]
124    pub policy_server_api_key: Option<String>,
125    /// Fail mode for remote policy fetch: "open" (default), "closed", or "cached".
126    #[serde(default)]
127    pub policy_fetch_fail_mode: Option<String>,
128    /// Whether to enforce the fetch fail mode strictly (ignore local fallback on auth errors).
129    #[serde(default)]
130    pub enforce_fail_mode: Option<bool>,
131
132    /// Threat intelligence configuration.
133    #[serde(default)]
134    pub threat_intel: ThreatIntelConfig,
135}
136
137/// Threat intelligence configuration.
138#[derive(Debug, Clone, Serialize, Deserialize)]
139#[serde(default)]
140pub struct ThreatIntelConfig {
141    /// Auto-update interval in hours. 0 = disabled. Default: 24.
142    pub auto_update_hours: u64,
143    /// Enable real-time OSV.dev queries. Default: true.
144    pub osv_enabled: bool,
145    /// Enable real-time deps.dev queries. Default: true.
146    pub deps_dev_enabled: bool,
147    /// Optional: Google Safe Browsing API key (user gets own free key).
148    #[serde(skip_serializing)]
149    pub google_safe_browsing_key: Option<String>,
150    /// Optional: abuse.ch Auth-Key for URLhaus/ThreatFox feeds.
151    #[serde(skip_serializing)]
152    pub abusech_auth_key: Option<String>,
153    /// Optional: enable Phishing Army feed (CC BY-NC 4.0, non-commercial only).
154    pub phishing_army_enabled: bool,
155}
156
157impl Default for ThreatIntelConfig {
158    fn default() -> Self {
159        Self {
160            auto_update_hours: 24,
161            osv_enabled: true,
162            deps_dev_enabled: true,
163            google_safe_browsing_key: None,
164            abusech_auth_key: None,
165            phishing_army_enabled: false,
166        }
167    }
168}
169
170/// Approval rule: when a command matches, require human approval before execution.
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct ApprovalRule {
173    /// Rule IDs that trigger approval (e.g., "pipe_to_interpreter").
174    pub rule_ids: Vec<String>,
175    /// Timeout in seconds (0 = indefinite).
176    #[serde(default)]
177    pub timeout_secs: u64,
178    /// Fallback when approval times out: "block", "warn", or "allow".
179    #[serde(default = "default_approval_fallback")]
180    pub fallback: String,
181}
182
183fn default_approval_fallback() -> String {
184    "block".to_string()
185}
186
187/// Webhook configuration for event notification.
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct WebhookConfig {
190    /// Webhook URL.
191    pub url: String,
192    /// Minimum severity to trigger webhook.
193    #[serde(default = "default_webhook_severity")]
194    pub min_severity: Severity,
195    /// Optional headers (supports env var expansion: `$ENV_VAR`).
196    #[serde(default)]
197    pub headers: HashMap<String, String>,
198    /// Payload template (supports `{{rule_id}}`, `{{command_preview}}`).
199    #[serde(default)]
200    pub payload_template: Option<String>,
201}
202
203fn default_webhook_severity() -> Severity {
204    Severity::High
205}
206
207/// Checkpoint policy configuration.
208#[derive(Debug, Clone, Serialize, Deserialize)]
209#[serde(default)]
210pub struct CheckpointPolicyConfig {
211    /// Max checkpoints to retain.
212    pub max_count: usize,
213    /// Max age in hours.
214    pub max_age_hours: u64,
215    /// Max total storage in bytes.
216    pub max_storage_bytes: u64,
217}
218
219impl Default for CheckpointPolicyConfig {
220    fn default() -> Self {
221        Self {
222            max_count: 100,
223            max_age_hours: 168,                   // 1 week
224            max_storage_bytes: 500 * 1024 * 1024, // 500 MiB
225        }
226    }
227}
228
229/// Scan policy configuration.
230#[derive(Debug, Clone, Default, Serialize, Deserialize)]
231#[serde(default)]
232pub struct ScanPolicyConfig {
233    /// Additional config file paths to scan as priority files.
234    #[serde(default)]
235    pub additional_config_files: Vec<String>,
236    /// Trusted MCP server URLs (suppress McpUntrustedServer for these).
237    #[serde(default)]
238    pub trusted_mcp_servers: Vec<String>,
239    /// Glob patterns to ignore during scan.
240    #[serde(default)]
241    pub ignore_patterns: Vec<String>,
242    /// Severity threshold for CI failure (default: "critical").
243    #[serde(default)]
244    pub fail_on: Option<String>,
245    /// Named scan profiles with preset include/exclude/fail_on.
246    #[serde(default)]
247    pub profiles: HashMap<String, ScanProfile>,
248}
249
250/// Per-rule allowlist scoping.
251#[derive(Debug, Clone, Serialize, Deserialize)]
252pub struct AllowlistRule {
253    /// Rule ID to scope the allowlist entry to.
254    pub rule_id: String,
255    /// Patterns that suppress this specific rule.
256    pub patterns: Vec<String>,
257}
258
259/// Custom detection rule defined in policy YAML.
260#[derive(Debug, Clone, Serialize, Deserialize)]
261pub struct CustomRule {
262    /// Unique identifier for this custom rule.
263    pub id: String,
264    /// Regex pattern to match.
265    pub pattern: String,
266    /// Contexts this rule applies to: "exec", "paste", "file".
267    #[serde(default = "default_custom_rule_contexts")]
268    pub context: Vec<String>,
269    /// Severity level.
270    #[serde(default = "default_custom_rule_severity")]
271    pub severity: Severity,
272    /// Short title for findings.
273    pub title: String,
274    /// Description for findings.
275    #[serde(default)]
276    pub description: String,
277}
278
279fn default_custom_rule_contexts() -> Vec<String> {
280    vec!["exec".to_string(), "paste".to_string()]
281}
282
283fn default_custom_rule_severity() -> Severity {
284    Severity::High
285}
286
287#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
288#[serde(rename_all = "lowercase")]
289#[derive(Default)]
290pub enum FailMode {
291    #[default]
292    Open,
293    Closed,
294}
295
296impl Default for Policy {
297    fn default() -> Self {
298        Self {
299            path: None,
300            fail_mode: FailMode::Open,
301            allow_bypass_env: true,
302            allow_bypass_env_noninteractive: false,
303            paranoia: 1,
304            severity_overrides: HashMap::new(),
305            additional_known_domains: Vec::new(),
306            allowlist: Vec::new(),
307            blocklist: Vec::new(),
308            approval_rules: Vec::new(),
309            network_deny: Vec::new(),
310            network_allow: Vec::new(),
311            webhooks: Vec::new(),
312            checkpoints: CheckpointPolicyConfig::default(),
313            scan: ScanPolicyConfig::default(),
314            allowlist_rules: Vec::new(),
315            custom_rules: Vec::new(),
316            dlp_custom_patterns: Vec::new(),
317            strict_warn: false,
318            action_overrides: HashMap::new(),
319            escalation: Vec::new(),
320            policy_server_url: None,
321            policy_server_api_key: None,
322            policy_fetch_fail_mode: None,
323            enforce_fail_mode: None,
324            threat_intel: ThreatIntelConfig::default(),
325        }
326    }
327}
328
329impl Policy {
330    /// Discover and load partial policy (just bypass + fail_mode fields).
331    /// Used in Tier 2 for fast bypass resolution.
332    /// Uses the same resolution order as full discovery (TIRITH_POLICY_ROOT,
333    /// walk-up, user-level) so bypass settings are consistent.
334    pub fn discover_partial(cwd: Option<&str>) -> Self {
335        Self::discover_local(cwd)
336    }
337
338    /// Discover and load full policy.
339    ///
340    /// Resolution order:
341    /// 1. Local policy (TIRITH_POLICY_ROOT, walk-up discovery, user-level)
342    /// 2. If `TIRITH_SERVER_URL` + `TIRITH_API_KEY` are set (or policy has
343    ///    `policy_server_url`), try remote fetch. On success the
344    ///    remote policy **replaces** the local one entirely and is cached.
345    /// 3. On remote failure, apply `policy_fetch_fail_mode`:
346    ///    - `"open"` (default): warn and use local policy
347    ///    - `"closed"`: return a fail-closed default (all actions = Block)
348    ///    - `"cached"`: try cached remote policy, else fall back to local
349    /// 4. Auth errors (401/403) always fail closed regardless of mode.
350    pub fn discover(cwd: Option<&str>) -> Self {
351        let local = Self::discover_local(cwd);
352
353        let server_url = std::env::var("TIRITH_SERVER_URL")
354            .ok()
355            .filter(|s| !s.is_empty())
356            .or_else(|| local.policy_server_url.clone());
357        let api_key = std::env::var("TIRITH_API_KEY")
358            .ok()
359            .filter(|s| !s.is_empty())
360            .or_else(|| local.policy_server_api_key.clone());
361
362        let (server_url, api_key) = match (server_url, api_key) {
363            (Some(u), Some(k)) => (u, k),
364            _ => return local,
365        };
366
367        let fail_mode = local.policy_fetch_fail_mode.as_deref().unwrap_or("open");
368
369        match crate::policy_client::fetch_remote_policy(&server_url, &api_key) {
370            Ok(yaml) => {
371                let _ = cache_remote_policy(&yaml);
372                match serde_yaml::from_str::<Policy>(&yaml) {
373                    Ok(mut p) => {
374                        p.path = Some(format!("remote:{server_url}"));
375                        // Retain connection details so audit upload can reuse them.
376                        if p.policy_server_url.is_none() {
377                            p.policy_server_url = Some(server_url);
378                        }
379                        if p.policy_server_api_key.is_none() {
380                            p.policy_server_api_key = Some(api_key);
381                        }
382                        p
383                    }
384                    Err(e) => match fail_mode {
385                        "closed" => {
386                            eprintln!(
387                                "tirith: error: remote policy parse error ({e}), failing closed"
388                            );
389                            Self::fail_closed_policy()
390                        }
391                        "cached" => {
392                            eprintln!(
393                                "tirith: warning: remote policy parse error ({e}), trying cache"
394                            );
395                            match load_cached_remote_policy() {
396                                Some(p) => p,
397                                None => {
398                                    eprintln!(
399                                        "tirith: warning: no cached remote policy, using local"
400                                    );
401                                    local
402                                }
403                            }
404                        }
405                        _ => {
406                            eprintln!("tirith: warning: remote policy parse error: {e}");
407                            local
408                        }
409                    },
410                }
411            }
412            Err(crate::policy_client::PolicyFetchError::AuthError(code)) => {
413                // Auth errors always fail closed, regardless of fail_mode —
414                // the server is explicitly saying "no".
415                eprintln!("tirith: error: policy server auth failed (HTTP {code}), failing closed");
416                Self::fail_closed_policy()
417            }
418            Err(e) => match fail_mode {
419                "closed" => {
420                    eprintln!("tirith: error: remote policy fetch failed ({e}), failing closed");
421                    Self::fail_closed_policy()
422                }
423                "cached" => {
424                    eprintln!("tirith: warning: remote policy fetch failed ({e}), trying cache");
425                    match load_cached_remote_policy() {
426                        Some(p) => p,
427                        None => {
428                            eprintln!("tirith: warning: no cached remote policy, using local");
429                            local
430                        }
431                    }
432                }
433                _ => {
434                    eprintln!(
435                        "tirith: warning: remote policy fetch failed ({e}), using local policy"
436                    );
437                    local
438                }
439            },
440        }
441    }
442
443    /// Discover local policy only (no remote fetch).
444    fn discover_local(cwd: Option<&str>) -> Self {
445        if let Ok(root) = std::env::var("TIRITH_POLICY_ROOT") {
446            if let Some(path) = find_policy_in_dir(&PathBuf::from(&root).join(".tirith")) {
447                return Self::load_from_path(&path);
448            }
449        }
450
451        match discover_policy_path(cwd) {
452            Some(path) => Self::load_from_path(&path),
453            None => {
454                if let Some(user_path) = user_policy_path() {
455                    if user_path.exists() {
456                        return Self::load_from_path(&user_path);
457                    }
458                }
459                Policy::default()
460            }
461        }
462    }
463
464    /// Return a fail-closed policy that blocks everything.
465    fn fail_closed_policy() -> Self {
466        Policy {
467            fail_mode: FailMode::Closed,
468            allow_bypass_env: false,
469            allow_bypass_env_noninteractive: false,
470            path: Some("fail-closed".into()),
471            ..Default::default()
472        }
473    }
474
475    fn load_from_path(path: &Path) -> Self {
476        match std::fs::read_to_string(path) {
477            Ok(content) => match serde_yaml::from_str::<Policy>(&content) {
478                Ok(mut p) => {
479                    p.path = Some(path.display().to_string());
480                    p
481                }
482                Err(e) => {
483                    eprintln!(
484                        "tirith: warning: failed to parse policy at {}: {e}",
485                        path.display(),
486                    );
487                    Policy::default()
488                }
489            },
490            Err(e) => {
491                eprintln!(
492                    "tirith: warning: cannot read policy at {}: {e}",
493                    path.display()
494                );
495                Policy::default()
496            }
497        }
498    }
499
500    /// Get severity override for a rule.
501    pub fn severity_override(&self, rule_id: &RuleId) -> Option<Severity> {
502        let key = serde_json::to_value(rule_id)
503            .ok()
504            .and_then(|v| v.as_str().map(String::from))?;
505        self.severity_overrides.get(&key).copied()
506    }
507
508    /// Check if a URL is in the blocklist.
509    pub fn is_blocklisted(&self, url: &str) -> bool {
510        let url_lower = url.to_lowercase();
511        self.blocklist.iter().any(|pattern| {
512            let p = pattern.to_lowercase();
513            url_lower.contains(&p)
514        })
515    }
516
517    /// Check if a URL is in the allowlist.
518    pub fn is_allowlisted(&self, url: &str) -> bool {
519        self.allowlist
520            .iter()
521            .any(|pattern| allowlist_pattern_matches(pattern, url))
522    }
523
524    /// Check if a URL is allowlisted for a specific rule or custom rule ID.
525    pub fn is_allowlisted_for_rule(&self, rule_id: &str, url: &str) -> bool {
526        self.allowlist_rules.iter().any(|rule| {
527            rule.rule_id.eq_ignore_ascii_case(rule_id)
528                && rule
529                    .patterns
530                    .iter()
531                    .any(|pattern| allowlist_pattern_matches(pattern, url))
532        })
533    }
534
535    /// Load and merge user-level lists (allowlist/blocklist flat text files).
536    pub fn load_user_lists(&mut self) {
537        if let Some(config) = crate::policy::config_dir() {
538            let allowlist_path = config.join("allowlist");
539            if let Ok(content) = std::fs::read_to_string(&allowlist_path) {
540                for line in content.lines() {
541                    let line = line.trim();
542                    if !line.is_empty() && !line.starts_with('#') {
543                        self.allowlist.push(line.to_string());
544                    }
545                }
546            }
547            let blocklist_path = config.join("blocklist");
548            if let Ok(content) = std::fs::read_to_string(&blocklist_path) {
549                for line in content.lines() {
550                    let line = line.trim();
551                    if !line.is_empty() && !line.starts_with('#') {
552                        self.blocklist.push(line.to_string());
553                    }
554                }
555            }
556        }
557    }
558
559    /// Load trust entries from trust.json files and merge non-expired entries
560    /// into the policy's allowlist and allowlist_rules.
561    ///
562    /// Called on the analysis hot path — MUST stay read-only (no file mutation).
563    pub fn load_trust_entries(&mut self, cwd: Option<&str>) {
564        if let Some(config) = config_dir() {
565            let user_trust = config.join("trust.json");
566            self.merge_trust_store(&user_trust);
567        }
568        if let Some(repo_root) = find_repo_root(cwd) {
569            let repo_trust = repo_root.join(".tirith").join("trust.json");
570            self.merge_trust_store(&repo_trust);
571        }
572    }
573
574    /// Read a trust.json file and merge non-expired entries into the policy.
575    fn merge_trust_store(&mut self, path: &Path) {
576        let content = match std::fs::read_to_string(path) {
577            Ok(c) => c,
578            Err(_) => return,
579        };
580
581        let store: serde_json::Value = match serde_json::from_str(&content) {
582            Ok(v) => v,
583            Err(e) => {
584                crate::audit::audit_diagnostic(format!(
585                    "tirith: trust: corrupt trust store at {} — trust entries skipped: {e}",
586                    path.display()
587                ));
588                return;
589            }
590        };
591
592        let entries = match store.get("entries").and_then(|v| v.as_array()) {
593            Some(arr) => arr,
594            None => return,
595        };
596
597        let now = chrono::Utc::now();
598
599        for entry in entries {
600            // Unparseable or past-expiry timestamps are treated as expired.
601            if let Some(exp_str) = entry.get("ttl_expires").and_then(|v| v.as_str()) {
602                match chrono::DateTime::parse_from_rfc3339(exp_str) {
603                    Ok(expiry) if expiry < now => continue,
604                    Ok(_) => {}
605                    Err(_) => continue,
606                }
607            }
608
609            let pattern = match entry.get("pattern").and_then(|v| v.as_str()) {
610                Some(p) if !p.is_empty() => p.to_string(),
611                _ => continue,
612            };
613
614            let rule_id = entry
615                .get("rule_id")
616                .and_then(|v| v.as_str())
617                .map(String::from);
618
619            match rule_id {
620                Some(rid) => {
621                    if let Some(existing) = self
622                        .allowlist_rules
623                        .iter_mut()
624                        .find(|r| r.rule_id.eq_ignore_ascii_case(&rid))
625                    {
626                        if !existing.patterns.contains(&pattern) {
627                            existing.patterns.push(pattern);
628                        }
629                    } else {
630                        self.allowlist_rules.push(AllowlistRule {
631                            rule_id: rid,
632                            patterns: vec![pattern],
633                        });
634                    }
635                }
636                None => {
637                    if !self.allowlist.contains(&pattern) {
638                        self.allowlist.push(pattern);
639                    }
640                }
641            }
642        }
643    }
644
645    /// Load and merge org-level lists from a repo root's .tirith/ dir.
646    ///
647    /// **Note:** Org-level policies are committed to the repository and may be
648    /// controlled by other contributors. A diagnostic is emitted so the user
649    /// knows that repo-level policy is active.
650    pub fn load_org_lists(&mut self, cwd: Option<&str>) {
651        if let Some(repo_root) = find_repo_root(cwd) {
652            let org_dir = repo_root.join(".tirith");
653            let allowlist_path = org_dir.join("allowlist");
654            if let Ok(content) = std::fs::read_to_string(&allowlist_path) {
655                eprintln!(
656                    "tirith: loading org-level allowlist from {}",
657                    allowlist_path.display()
658                );
659                for line in content.lines() {
660                    let line = line.trim();
661                    if !line.is_empty() && !line.starts_with('#') {
662                        self.allowlist.push(line.to_string());
663                    }
664                }
665            }
666            let blocklist_path = org_dir.join("blocklist");
667            if let Ok(content) = std::fs::read_to_string(&blocklist_path) {
668                eprintln!(
669                    "tirith: loading org-level blocklist from {}",
670                    blocklist_path.display()
671                );
672                for line in content.lines() {
673                    let line = line.trim();
674                    if !line.is_empty() && !line.starts_with('#') {
675                        self.blocklist.push(line.to_string());
676                    }
677                }
678            }
679        }
680    }
681}
682
683fn is_domain_pattern(p: &str) -> bool {
684    !p.contains("://")
685        && !p.contains('/')
686        && !p.contains('?')
687        && !p.contains('#')
688        && !p.contains(':')
689}
690
691fn extract_host_for_match(url: &str) -> Option<String> {
692    if let Some(host) = crate::parse::parse_url(url).host() {
693        return Some(host.trim_end_matches('.').to_lowercase());
694    }
695    // Fallback for schemeless host/path (e.g., example.com/path)
696    let candidate = url.split('/').next().unwrap_or(url).trim();
697    if candidate.starts_with('-') || !candidate.contains('.') || candidate.contains(' ') {
698        return None;
699    }
700    let host = if let Some((h, port)) = candidate.rsplit_once(':') {
701        if port.chars().all(|c| c.is_ascii_digit()) && !port.is_empty() {
702            h
703        } else {
704            candidate
705        }
706    } else {
707        candidate
708    };
709    Some(host.trim_end_matches('.').to_lowercase())
710}
711
712fn domain_matches(host: &str, pattern: &str) -> bool {
713    let host = host.trim_end_matches('.');
714    let pattern = pattern.trim_start_matches("*.").trim_end_matches('.');
715    host == pattern || host.ends_with(&format!(".{pattern}"))
716}
717
718pub fn allowlist_pattern_matches(pattern: &str, url: &str) -> bool {
719    let p = pattern.to_lowercase();
720    if p.is_empty() {
721        return false;
722    }
723    if is_domain_pattern(&p) {
724        if let Some(host) = extract_host_for_match(url) {
725            return domain_matches(&host, &p);
726        }
727        return false;
728    }
729    url.to_lowercase().contains(&p)
730}
731
732/// Discover policy path by walking up from cwd to .git boundary.
733fn discover_policy_path(cwd: Option<&str>) -> Option<PathBuf> {
734    let start = cwd
735        .map(PathBuf::from)
736        .or_else(|| std::env::current_dir().ok())?;
737
738    let mut current = start.as_path();
739    loop {
740        if let Some(candidate) = find_policy_in_dir(&current.join(".tirith")) {
741            return Some(candidate);
742        }
743
744        // `.git` may be a directory or a file (worktrees), so `.exists()` handles both.
745        let git_dir = current.join(".git");
746        if git_dir.exists() {
747            return None;
748        }
749
750        match current.parent() {
751            Some(parent) if parent != current => current = parent,
752            _ => break,
753        }
754    }
755
756    None
757}
758
759/// Find the repository root (directory containing .git).
760pub fn find_repo_root(cwd: Option<&str>) -> Option<PathBuf> {
761    let start = cwd
762        .map(PathBuf::from)
763        .or_else(|| std::env::current_dir().ok())?;
764    let mut current = start.as_path();
765    loop {
766        let git = current.join(".git");
767        if git.exists() {
768            return Some(current.to_path_buf());
769        }
770        match current.parent() {
771            Some(parent) if parent != current => current = parent,
772            _ => break,
773        }
774    }
775    None
776}
777
778/// Find the nearest ancestor directory containing a `.kiro/` subdirectory.
779///
780/// Mirrors Kiro CLI's own workspace-local agent discovery. Returns the
781/// directory that CONTAINS `.kiro/` (not `.kiro/` itself), so callers can
782/// `dir.join(".kiro/agents/foo.json")`.
783///
784/// Excludes `$HOME`: `~/.kiro` is the user-scope agent root, not a project
785/// workspace. Without this guard, any project inside `$HOME` would collapse
786/// onto the user-scope dir.
787pub fn find_workspace_kiro_dir(start: &Path) -> Option<PathBuf> {
788    let home = home::home_dir();
789    let mut current = start;
790    loop {
791        let is_home = home.as_deref().map(|h| current == h).unwrap_or(false);
792        if !is_home && current.join(".kiro").is_dir() {
793            return Some(current.to_path_buf());
794        }
795        match current.parent() {
796            Some(parent) if parent != current => current = parent,
797            _ => break,
798        }
799    }
800    None
801}
802
803/// Get user-level policy path.
804fn user_policy_path() -> Option<PathBuf> {
805    let base = etcetera::choose_base_strategy().ok()?;
806    find_policy_in_dir(&base.config_dir().join("tirith"))
807}
808
809/// Get tirith data directory.
810pub fn data_dir() -> Option<PathBuf> {
811    let base = etcetera::choose_base_strategy().ok()?;
812    Some(base.data_dir().join("tirith"))
813}
814
815/// Get tirith config directory.
816pub fn config_dir() -> Option<PathBuf> {
817    let base = etcetera::choose_base_strategy().ok()?;
818    Some(base.config_dir().join("tirith"))
819}
820
821/// Get tirith state directory.
822///
823/// MUST match the path computed by bash-hook.bash:
824/// `${XDG_STATE_HOME:-$HOME/.local/state}/tirith`. Any divergence here will
825/// make the hook and the binary disagree about where session state lives.
826/// Treat an empty `XDG_STATE_HOME` as unset to mirror `${VAR:-fallback}`.
827pub fn state_dir() -> Option<PathBuf> {
828    match std::env::var("XDG_STATE_HOME") {
829        Ok(val) if !val.trim().is_empty() => Some(PathBuf::from(val.trim()).join("tirith")),
830        _ => home::home_dir().map(|h| h.join(".local/state/tirith")),
831    }
832}
833
834/// Get the path for caching remote policy: ~/.cache/tirith/remote-policy.yaml
835fn remote_policy_cache_path() -> Option<PathBuf> {
836    let cache_dir = std::env::var("XDG_CACHE_HOME")
837        .ok()
838        .filter(|s| !s.is_empty())
839        .map(PathBuf::from)
840        .or_else(|| home::home_dir().map(|h| h.join(".cache")))?;
841    Some(cache_dir.join("tirith").join("remote-policy.yaml"))
842}
843
844/// Cache the raw YAML from a remote policy fetch.
845fn cache_remote_policy(yaml: &str) -> std::io::Result<()> {
846    if let Some(path) = remote_policy_cache_path() {
847        if let Some(parent) = path.parent() {
848            std::fs::create_dir_all(parent)?;
849        }
850        let mut opts = std::fs::OpenOptions::new();
851        opts.write(true).create(true).truncate(true);
852        #[cfg(unix)]
853        {
854            use std::os::unix::fs::OpenOptionsExt;
855            opts.mode(0o600);
856        }
857        let mut f = opts.open(&path)?;
858        use std::io::Write;
859        f.write_all(yaml.as_bytes())?;
860    }
861    Ok(())
862}
863
864/// Load a previously cached remote policy.
865fn load_cached_remote_policy() -> Option<Policy> {
866    let path = remote_policy_cache_path()?;
867    let content = std::fs::read_to_string(&path).ok()?;
868    match serde_yaml::from_str::<Policy>(&content) {
869        Ok(mut p) => {
870            p.path = Some(format!("cached:{}", path.display()));
871            Some(p)
872        }
873        Err(e) => {
874            eprintln!("tirith: warning: cached remote policy parse error: {e}");
875            None
876        }
877    }
878}
879
880#[cfg(test)]
881mod tests {
882    use super::*;
883
884    #[test]
885    fn test_allowlist_domain_matches_subdomain() {
886        let p = Policy {
887            allowlist: vec!["github.com".to_string()],
888            ..Default::default()
889        };
890        assert!(p.is_allowlisted("https://api.github.com/repos"));
891        assert!(p.is_allowlisted("git@github.com:owner/repo.git"));
892        assert!(!p.is_allowlisted("https://evil-github.com"));
893    }
894
895    #[test]
896    fn test_allowlist_schemeless_host() {
897        let p = Policy {
898            allowlist: vec!["raw.githubusercontent.com".to_string()],
899            ..Default::default()
900        };
901        assert!(p.is_allowlisted("raw.githubusercontent.com/path/to/file"));
902    }
903
904    #[test]
905    fn test_allowlist_schemeless_host_with_port() {
906        let p = Policy {
907            allowlist: vec!["example.com".to_string()],
908            ..Default::default()
909        };
910        assert!(p.is_allowlisted("example.com:8080/path"));
911    }
912
913    #[test]
914    fn test_discover_applies_remote_fetch_fail_mode_when_configured() {
915        let _guard = crate::TEST_ENV_LOCK
916            .lock()
917            .unwrap_or_else(|e| e.into_inner());
918
919        let dir = tempfile::tempdir().unwrap();
920        let policy_dir = dir.path().join(".tirith");
921        std::fs::create_dir_all(&policy_dir).unwrap();
922        std::fs::write(
923            policy_dir.join("policy.yaml"),
924            "fail_mode: open\npolicy_fetch_fail_mode: closed\nallow_bypass_env_noninteractive: true\n",
925        )
926        .unwrap();
927
928        unsafe { std::env::set_var("TIRITH_SERVER_URL", "http://127.0.0.1") };
929        unsafe { std::env::set_var("TIRITH_API_KEY", "dummy") };
930
931        let policy = Policy::discover(Some(dir.path().to_str().unwrap()));
932        assert_eq!(policy.path.as_deref(), Some("fail-closed"));
933        assert_eq!(policy.fail_mode, FailMode::Closed);
934        assert!(!policy.allow_bypass_env_noninteractive);
935
936        unsafe { std::env::remove_var("TIRITH_API_KEY") };
937        unsafe { std::env::remove_var("TIRITH_SERVER_URL") };
938    }
939}