Skip to main content

tirith_core/
policy.rs

1use etcetera::BaseStrategy;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6use crate::verdict::{RuleId, Severity};
7
8/// Try both `.yaml` and `.yml` extensions in a directory.
9fn find_policy_in_dir(dir: &Path) -> Option<PathBuf> {
10    let yaml = dir.join("policy.yaml");
11    if yaml.exists() {
12        return Some(yaml);
13    }
14    let yml = dir.join("policy.yml");
15    if yml.exists() {
16        return Some(yml);
17    }
18    None
19}
20
21/// Policy configuration loaded from YAML.
22#[derive(Debug, Clone, Serialize, Deserialize)]
23#[serde(default)]
24pub struct Policy {
25    /// Path this policy was loaded from.
26    #[serde(skip)]
27    pub path: Option<String>,
28
29    /// Fail mode: "open" (default) or "closed".
30    pub fail_mode: FailMode,
31
32    /// Allow TIRITH=0 bypass in interactive mode.
33    pub allow_bypass_env: bool,
34
35    /// Allow TIRITH=0 bypass in non-interactive mode.
36    pub allow_bypass_env_noninteractive: bool,
37
38    /// Paranoia tier (1-4).
39    pub paranoia: u8,
40
41    /// Severity overrides per rule.
42    #[serde(default)]
43    pub severity_overrides: HashMap<String, Severity>,
44
45    /// Additional known domains (extends built-in list).
46    #[serde(default)]
47    pub additional_known_domains: Vec<String>,
48
49    /// Allowlist: URL patterns that are always allowed.
50    #[serde(default)]
51    pub allowlist: Vec<String>,
52
53    /// Blocklist: URL patterns that are always blocked.
54    #[serde(default)]
55    pub blocklist: Vec<String>,
56
57    // --- Team features (Phase 18) ---
58    /// Approval rules: commands matching these rules require human approval.
59    #[serde(default)]
60    pub approval_rules: Vec<ApprovalRule>,
61
62    /// Network deny list: block commands targeting these hosts/CIDRs.
63    #[serde(default)]
64    pub network_deny: Vec<String>,
65
66    /// Network allow list: exempt these hosts/CIDRs from network deny.
67    #[serde(default)]
68    pub network_allow: Vec<String>,
69
70    /// Webhook endpoints to notify on findings.
71    #[serde(default)]
72    pub webhooks: Vec<WebhookConfig>,
73
74    /// Checkpoint configuration (Pro+).
75    #[serde(default)]
76    pub checkpoints: CheckpointPolicyConfig,
77
78    /// Scan configuration overrides.
79    #[serde(default)]
80    pub scan: ScanPolicyConfig,
81
82    /// Per-rule allowlist scoping (Team).
83    #[serde(default)]
84    pub allowlist_rules: Vec<AllowlistRule>,
85
86    /// Custom detection rules defined in YAML (Team).
87    #[serde(default)]
88    pub custom_rules: Vec<CustomRule>,
89
90    /// Custom DLP redaction patterns (Team). Regex patterns applied alongside
91    /// built-in patterns when redacting commands in audit logs and webhooks.
92    #[serde(default)]
93    pub dlp_custom_patterns: Vec<String>,
94
95    // --- Policy server (Phase 27, Team) ---
96    /// URL of the centralized policy server (e.g., "https://policy.example.com").
97    #[serde(default)]
98    pub policy_server_url: Option<String>,
99    /// API key for authenticating with the policy server.
100    #[serde(default)]
101    pub policy_server_api_key: Option<String>,
102    /// Fail mode for remote policy fetch: "open" (default), "closed", or "cached".
103    #[serde(default)]
104    pub policy_fetch_fail_mode: Option<String>,
105    /// Whether to enforce the fetch fail mode strictly (ignore local fallback on auth errors).
106    #[serde(default)]
107    pub enforce_fail_mode: Option<bool>,
108}
109
110/// Approval rule: when a command matches, require human approval before execution.
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct ApprovalRule {
113    /// Rule IDs that trigger approval (e.g., "pipe_to_interpreter").
114    pub rule_ids: Vec<String>,
115    /// Timeout in seconds (0 = indefinite).
116    #[serde(default)]
117    pub timeout_secs: u64,
118    /// Fallback when approval times out: "block", "warn", or "allow".
119    #[serde(default = "default_approval_fallback")]
120    pub fallback: String,
121}
122
123fn default_approval_fallback() -> String {
124    "block".to_string()
125}
126
127/// Webhook configuration for event notification.
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct WebhookConfig {
130    /// Webhook URL.
131    pub url: String,
132    /// Minimum severity to trigger webhook.
133    #[serde(default = "default_webhook_severity")]
134    pub min_severity: Severity,
135    /// Optional headers (supports env var expansion: `$ENV_VAR`).
136    #[serde(default)]
137    pub headers: HashMap<String, String>,
138    /// Payload template (supports `{{rule_id}}`, `{{command_preview}}`).
139    #[serde(default)]
140    pub payload_template: Option<String>,
141}
142
143fn default_webhook_severity() -> Severity {
144    Severity::High
145}
146
147/// Checkpoint policy configuration.
148#[derive(Debug, Clone, Serialize, Deserialize)]
149#[serde(default)]
150pub struct CheckpointPolicyConfig {
151    /// Max checkpoints to retain.
152    pub max_count: usize,
153    /// Max age in hours.
154    pub max_age_hours: u64,
155    /// Max total storage in bytes.
156    pub max_storage_bytes: u64,
157}
158
159impl Default for CheckpointPolicyConfig {
160    fn default() -> Self {
161        Self {
162            max_count: 100,
163            max_age_hours: 168,                   // 1 week
164            max_storage_bytes: 500 * 1024 * 1024, // 500 MiB
165        }
166    }
167}
168
169/// Scan policy configuration.
170#[derive(Debug, Clone, Default, Serialize, Deserialize)]
171#[serde(default)]
172pub struct ScanPolicyConfig {
173    /// Additional config file paths to scan as priority files.
174    #[serde(default)]
175    pub additional_config_files: Vec<String>,
176    /// Trusted MCP server URLs (suppress McpUntrustedServer for these).
177    #[serde(default)]
178    pub trusted_mcp_servers: Vec<String>,
179    /// Glob patterns to ignore during scan.
180    #[serde(default)]
181    pub ignore_patterns: Vec<String>,
182    /// Severity threshold for CI failure (default: "critical").
183    #[serde(default)]
184    pub fail_on: Option<String>,
185}
186
187/// Per-rule allowlist scoping.
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct AllowlistRule {
190    /// Rule ID to scope the allowlist entry to.
191    pub rule_id: String,
192    /// Patterns that suppress this specific rule.
193    pub patterns: Vec<String>,
194}
195
196/// Custom detection rule defined in policy YAML.
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct CustomRule {
199    /// Unique identifier for this custom rule.
200    pub id: String,
201    /// Regex pattern to match.
202    pub pattern: String,
203    /// Contexts this rule applies to: "exec", "paste", "file".
204    #[serde(default = "default_custom_rule_contexts")]
205    pub context: Vec<String>,
206    /// Severity level.
207    #[serde(default = "default_custom_rule_severity")]
208    pub severity: Severity,
209    /// Short title for findings.
210    pub title: String,
211    /// Description for findings.
212    #[serde(default)]
213    pub description: String,
214}
215
216fn default_custom_rule_contexts() -> Vec<String> {
217    vec!["exec".to_string(), "paste".to_string()]
218}
219
220fn default_custom_rule_severity() -> Severity {
221    Severity::High
222}
223
224#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
225#[serde(rename_all = "lowercase")]
226#[derive(Default)]
227pub enum FailMode {
228    #[default]
229    Open,
230    Closed,
231}
232
233impl Default for Policy {
234    fn default() -> Self {
235        Self {
236            path: None,
237            fail_mode: FailMode::Open,
238            allow_bypass_env: true,
239            allow_bypass_env_noninteractive: false,
240            paranoia: 1,
241            severity_overrides: HashMap::new(),
242            additional_known_domains: Vec::new(),
243            allowlist: Vec::new(),
244            blocklist: Vec::new(),
245            approval_rules: Vec::new(),
246            network_deny: Vec::new(),
247            network_allow: Vec::new(),
248            webhooks: Vec::new(),
249            checkpoints: CheckpointPolicyConfig::default(),
250            scan: ScanPolicyConfig::default(),
251            allowlist_rules: Vec::new(),
252            custom_rules: Vec::new(),
253            dlp_custom_patterns: Vec::new(),
254            policy_server_url: None,
255            policy_server_api_key: None,
256            policy_fetch_fail_mode: None,
257            enforce_fail_mode: None,
258        }
259    }
260}
261
262impl Policy {
263    /// Discover and load partial policy (just bypass + fail_mode fields).
264    /// Used in Tier 2 for fast bypass resolution.
265    pub fn discover_partial(cwd: Option<&str>) -> Self {
266        match discover_policy_path(cwd) {
267            Some(path) => match std::fs::read_to_string(&path) {
268                Ok(content) => match serde_yaml::from_str::<Policy>(&content) {
269                    Ok(mut p) => {
270                        p.path = Some(path.display().to_string());
271                        p
272                    }
273                    Err(e) => {
274                        eprintln!(
275                            "tirith: warning: failed to parse policy at {}: {e}",
276                            path.display()
277                        );
278                        // Parse error: use fail_mode default behavior
279                        Policy::default()
280                    }
281                },
282                Err(e) => {
283                    eprintln!(
284                        "tirith: warning: cannot read policy at {}: {e}",
285                        path.display()
286                    );
287                    Policy::default()
288                }
289            },
290            None => Policy::default(),
291        }
292    }
293
294    /// Discover and load full policy.
295    ///
296    /// Resolution order:
297    /// 1. Local policy (TIRITH_POLICY_ROOT, walk-up discovery, user-level)
298    /// 2. Team+ only: if `TIRITH_SERVER_URL` + `TIRITH_API_KEY` are set (or
299    ///    policy has `policy_server_url`), try remote fetch. On success the
300    ///    remote policy **replaces** the local one entirely and is cached.
301    /// 3. On remote failure, apply `policy_fetch_fail_mode`:
302    ///    - `"open"` (default): warn and use local policy
303    ///    - `"closed"`: return a fail-closed default (all actions = Block)
304    ///    - `"cached"`: try cached remote policy, else fall back to local
305    /// 4. Auth errors (401/403) always fail closed regardless of mode.
306    pub fn discover(cwd: Option<&str>) -> Self {
307        // --- Step 1: resolve local policy ---
308        let local = Self::discover_local(cwd);
309
310        // Centralized policy fetch is a Team+ feature.
311        if crate::license::current_tier() < crate::license::Tier::Team {
312            return local;
313        }
314
315        // --- Step 2: determine remote fetch parameters ---
316        let server_url = std::env::var("TIRITH_SERVER_URL")
317            .ok()
318            .filter(|s| !s.is_empty())
319            .or_else(|| local.policy_server_url.clone());
320        let api_key = std::env::var("TIRITH_API_KEY")
321            .ok()
322            .filter(|s| !s.is_empty())
323            .or_else(|| local.policy_server_api_key.clone());
324
325        let (server_url, api_key) = match (server_url, api_key) {
326            (Some(u), Some(k)) => (u, k),
327            _ => return local, // no remote configured
328        };
329
330        let fail_mode = local.policy_fetch_fail_mode.as_deref().unwrap_or("open");
331
332        // --- Step 3: attempt remote fetch ---
333        match crate::policy_client::fetch_remote_policy(&server_url, &api_key) {
334            Ok(yaml) => {
335                // Cache the fetched policy for offline use
336                let _ = cache_remote_policy(&yaml);
337                match serde_yaml::from_str::<Policy>(&yaml) {
338                    Ok(mut p) => {
339                        p.path = Some(format!("remote:{server_url}"));
340                        // Carry over server connection info so audit upload can use it
341                        if p.policy_server_url.is_none() {
342                            p.policy_server_url = Some(server_url);
343                        }
344                        if p.policy_server_api_key.is_none() {
345                            p.policy_server_api_key = Some(api_key);
346                        }
347                        p
348                    }
349                    Err(e) => match fail_mode {
350                        "closed" => {
351                            eprintln!(
352                                "tirith: error: remote policy parse error ({e}), failing closed"
353                            );
354                            Self::fail_closed_policy()
355                        }
356                        "cached" => {
357                            eprintln!(
358                                "tirith: warning: remote policy parse error ({e}), trying cache"
359                            );
360                            match load_cached_remote_policy() {
361                                Some(p) => p,
362                                None => {
363                                    eprintln!(
364                                        "tirith: warning: no cached remote policy, using local"
365                                    );
366                                    local
367                                }
368                            }
369                        }
370                        _ => {
371                            eprintln!("tirith: warning: remote policy parse error: {e}");
372                            local
373                        }
374                    },
375                }
376            }
377            Err(crate::policy_client::PolicyFetchError::AuthError(code)) => {
378                // Auth errors always fail closed
379                eprintln!("tirith: error: policy server auth failed (HTTP {code}), failing closed");
380                Self::fail_closed_policy()
381            }
382            Err(e) => {
383                // Apply fail mode
384                match fail_mode {
385                    "closed" => {
386                        eprintln!(
387                            "tirith: error: remote policy fetch failed ({e}), failing closed"
388                        );
389                        Self::fail_closed_policy()
390                    }
391                    "cached" => {
392                        eprintln!(
393                            "tirith: warning: remote policy fetch failed ({e}), trying cache"
394                        );
395                        match load_cached_remote_policy() {
396                            Some(p) => p,
397                            None => {
398                                eprintln!("tirith: warning: no cached remote policy, using local");
399                                local
400                            }
401                        }
402                    }
403                    _ => {
404                        // "open" (default): warn and use local
405                        eprintln!(
406                            "tirith: warning: remote policy fetch failed ({e}), using local policy"
407                        );
408                        local
409                    }
410                }
411            }
412        }
413    }
414
415    /// Discover local policy only (no remote fetch).
416    fn discover_local(cwd: Option<&str>) -> Self {
417        // Check env override first
418        if let Ok(root) = std::env::var("TIRITH_POLICY_ROOT") {
419            if let Some(path) = find_policy_in_dir(&PathBuf::from(&root).join(".tirith")) {
420                return Self::load_from_path(&path);
421            }
422        }
423
424        match discover_policy_path(cwd) {
425            Some(path) => Self::load_from_path(&path),
426            None => {
427                // Try user-level policy
428                if let Some(user_path) = user_policy_path() {
429                    if user_path.exists() {
430                        return Self::load_from_path(&user_path);
431                    }
432                }
433                Policy::default()
434            }
435        }
436    }
437
438    /// Return a fail-closed policy that blocks everything.
439    fn fail_closed_policy() -> Self {
440        Policy {
441            fail_mode: FailMode::Closed,
442            allow_bypass_env: false,
443            allow_bypass_env_noninteractive: false,
444            path: Some("fail-closed".into()),
445            ..Default::default()
446        }
447    }
448
449    fn load_from_path(path: &Path) -> Self {
450        match std::fs::read_to_string(path) {
451            Ok(content) => match serde_yaml::from_str::<Policy>(&content) {
452                Ok(mut p) => {
453                    p.path = Some(path.display().to_string());
454                    p
455                }
456                Err(e) => {
457                    eprintln!(
458                        "tirith: warning: failed to parse policy at {}: {e}",
459                        path.display(),
460                    );
461                    Policy::default()
462                }
463            },
464            Err(e) => {
465                eprintln!(
466                    "tirith: warning: cannot read policy at {}: {e}",
467                    path.display()
468                );
469                Policy::default()
470            }
471        }
472    }
473
474    /// Get severity override for a rule.
475    pub fn severity_override(&self, rule_id: &RuleId) -> Option<Severity> {
476        let key = serde_json::to_value(rule_id)
477            .ok()
478            .and_then(|v| v.as_str().map(String::from))?;
479        self.severity_overrides.get(&key).copied()
480    }
481
482    /// Check if a URL is in the blocklist.
483    pub fn is_blocklisted(&self, url: &str) -> bool {
484        let url_lower = url.to_lowercase();
485        self.blocklist.iter().any(|pattern| {
486            let p = pattern.to_lowercase();
487            url_lower.contains(&p)
488        })
489    }
490
491    /// Check if a URL is in the allowlist.
492    pub fn is_allowlisted(&self, url: &str) -> bool {
493        let url_lower = url.to_lowercase();
494        self.allowlist.iter().any(|pattern| {
495            let p = pattern.to_lowercase();
496            if p.is_empty() {
497                return false;
498            }
499            if is_domain_pattern(&p) {
500                if let Some(host) = extract_host_for_match(url) {
501                    return domain_matches(&host, &p);
502                }
503                return false;
504            }
505            url_lower.contains(&p)
506        })
507    }
508
509    /// Load and merge user-level lists (allowlist/blocklist flat text files).
510    pub fn load_user_lists(&mut self) {
511        if let Some(config) = crate::policy::config_dir() {
512            let allowlist_path = config.join("allowlist");
513            if let Ok(content) = std::fs::read_to_string(&allowlist_path) {
514                for line in content.lines() {
515                    let line = line.trim();
516                    if !line.is_empty() && !line.starts_with('#') {
517                        self.allowlist.push(line.to_string());
518                    }
519                }
520            }
521            let blocklist_path = config.join("blocklist");
522            if let Ok(content) = std::fs::read_to_string(&blocklist_path) {
523                for line in content.lines() {
524                    let line = line.trim();
525                    if !line.is_empty() && !line.starts_with('#') {
526                        self.blocklist.push(line.to_string());
527                    }
528                }
529            }
530        }
531    }
532
533    /// Load and merge org-level lists from a repo root's .tirith/ dir.
534    ///
535    /// **Note:** Org-level policies are committed to the repository and may be
536    /// controlled by other contributors. A diagnostic is emitted so the user
537    /// knows that repo-level policy is active.
538    pub fn load_org_lists(&mut self, cwd: Option<&str>) {
539        if let Some(repo_root) = find_repo_root(cwd) {
540            let org_dir = repo_root.join(".tirith");
541            let allowlist_path = org_dir.join("allowlist");
542            if let Ok(content) = std::fs::read_to_string(&allowlist_path) {
543                eprintln!(
544                    "tirith: loading org-level allowlist from {}",
545                    allowlist_path.display()
546                );
547                for line in content.lines() {
548                    let line = line.trim();
549                    if !line.is_empty() && !line.starts_with('#') {
550                        self.allowlist.push(line.to_string());
551                    }
552                }
553            }
554            let blocklist_path = org_dir.join("blocklist");
555            if let Ok(content) = std::fs::read_to_string(&blocklist_path) {
556                eprintln!(
557                    "tirith: loading org-level blocklist from {}",
558                    blocklist_path.display()
559                );
560                for line in content.lines() {
561                    let line = line.trim();
562                    if !line.is_empty() && !line.starts_with('#') {
563                        self.blocklist.push(line.to_string());
564                    }
565                }
566            }
567        }
568    }
569}
570
571fn is_domain_pattern(p: &str) -> bool {
572    !p.contains("://")
573        && !p.contains('/')
574        && !p.contains('?')
575        && !p.contains('#')
576        && !p.contains(':')
577}
578
579fn extract_host_for_match(url: &str) -> Option<String> {
580    if let Some(host) = crate::parse::parse_url(url).host() {
581        return Some(host.trim_end_matches('.').to_lowercase());
582    }
583    // Fallback for schemeless host/path (e.g., example.com/path)
584    let candidate = url.split('/').next().unwrap_or(url).trim();
585    if candidate.starts_with('-') || !candidate.contains('.') || candidate.contains(' ') {
586        return None;
587    }
588    let host = if let Some((h, port)) = candidate.rsplit_once(':') {
589        if port.chars().all(|c| c.is_ascii_digit()) && !port.is_empty() {
590            h
591        } else {
592            candidate
593        }
594    } else {
595        candidate
596    };
597    Some(host.trim_end_matches('.').to_lowercase())
598}
599
600fn domain_matches(host: &str, pattern: &str) -> bool {
601    let host = host.trim_end_matches('.');
602    let pattern = pattern.trim_start_matches("*.").trim_end_matches('.');
603    host == pattern || host.ends_with(&format!(".{pattern}"))
604}
605
606/// Discover policy path by walking up from cwd to .git boundary.
607fn discover_policy_path(cwd: Option<&str>) -> Option<PathBuf> {
608    let start = cwd
609        .map(PathBuf::from)
610        .or_else(|| std::env::current_dir().ok())?;
611
612    let mut current = start.as_path();
613    loop {
614        // Check for .tirith/policy.yaml or .tirith/policy.yml
615        if let Some(candidate) = find_policy_in_dir(&current.join(".tirith")) {
616            return Some(candidate);
617        }
618
619        // Check for .git boundary (directory or file for worktrees)
620        let git_dir = current.join(".git");
621        if git_dir.exists() {
622            return None; // Hit repo root without finding policy
623        }
624
625        // Go up
626        match current.parent() {
627            Some(parent) if parent != current => current = parent,
628            _ => break,
629        }
630    }
631
632    None
633}
634
635/// Find the repository root (directory containing .git).
636fn find_repo_root(cwd: Option<&str>) -> Option<PathBuf> {
637    let start = cwd
638        .map(PathBuf::from)
639        .or_else(|| std::env::current_dir().ok())?;
640    let mut current = start.as_path();
641    loop {
642        let git = current.join(".git");
643        if git.exists() {
644            return Some(current.to_path_buf());
645        }
646        match current.parent() {
647            Some(parent) if parent != current => current = parent,
648            _ => break,
649        }
650    }
651    None
652}
653
654/// Get user-level policy path.
655fn user_policy_path() -> Option<PathBuf> {
656    let base = etcetera::choose_base_strategy().ok()?;
657    find_policy_in_dir(&base.config_dir().join("tirith"))
658}
659
660/// Get tirith data directory.
661pub fn data_dir() -> Option<PathBuf> {
662    let base = etcetera::choose_base_strategy().ok()?;
663    Some(base.data_dir().join("tirith"))
664}
665
666/// Get tirith config directory.
667pub fn config_dir() -> Option<PathBuf> {
668    let base = etcetera::choose_base_strategy().ok()?;
669    Some(base.config_dir().join("tirith"))
670}
671
672/// Get tirith state directory.
673/// Must match bash-hook.bash path: ${XDG_STATE_HOME:-$HOME/.local/state}/tirith
674pub fn state_dir() -> Option<PathBuf> {
675    match std::env::var("XDG_STATE_HOME") {
676        Ok(val) if !val.trim().is_empty() => Some(PathBuf::from(val.trim()).join("tirith")),
677        _ => home::home_dir().map(|h| h.join(".local/state/tirith")),
678    }
679}
680
681/// Get the path for caching remote policy: ~/.cache/tirith/remote-policy.yaml
682fn remote_policy_cache_path() -> Option<PathBuf> {
683    let cache_dir = std::env::var("XDG_CACHE_HOME")
684        .ok()
685        .filter(|s| !s.is_empty())
686        .map(PathBuf::from)
687        .or_else(|| home::home_dir().map(|h| h.join(".cache")))?;
688    Some(cache_dir.join("tirith").join("remote-policy.yaml"))
689}
690
691/// Cache the raw YAML from a remote policy fetch.
692fn cache_remote_policy(yaml: &str) -> std::io::Result<()> {
693    if let Some(path) = remote_policy_cache_path() {
694        if let Some(parent) = path.parent() {
695            std::fs::create_dir_all(parent)?;
696        }
697        // Write with restricted permissions (owner-only)
698        let mut opts = std::fs::OpenOptions::new();
699        opts.write(true).create(true).truncate(true);
700        #[cfg(unix)]
701        {
702            use std::os::unix::fs::OpenOptionsExt;
703            opts.mode(0o600);
704        }
705        let mut f = opts.open(&path)?;
706        use std::io::Write;
707        f.write_all(yaml.as_bytes())?;
708    }
709    Ok(())
710}
711
712/// Load a previously cached remote policy.
713fn load_cached_remote_policy() -> Option<Policy> {
714    let path = remote_policy_cache_path()?;
715    let content = std::fs::read_to_string(&path).ok()?;
716    match serde_yaml::from_str::<Policy>(&content) {
717        Ok(mut p) => {
718            p.path = Some(format!("cached:{}", path.display()));
719            Some(p)
720        }
721        Err(e) => {
722            eprintln!("tirith: warning: cached remote policy parse error: {e}");
723            None
724        }
725    }
726}
727
728#[cfg(test)]
729mod tests {
730    use super::*;
731
732    #[test]
733    fn test_allowlist_domain_matches_subdomain() {
734        let p = Policy {
735            allowlist: vec!["github.com".to_string()],
736            ..Default::default()
737        };
738        assert!(p.is_allowlisted("https://api.github.com/repos"));
739        assert!(p.is_allowlisted("git@github.com:owner/repo.git"));
740        assert!(!p.is_allowlisted("https://evil-github.com"));
741    }
742
743    #[test]
744    fn test_allowlist_schemeless_host() {
745        let p = Policy {
746            allowlist: vec!["raw.githubusercontent.com".to_string()],
747            ..Default::default()
748        };
749        assert!(p.is_allowlisted("raw.githubusercontent.com/path/to/file"));
750    }
751
752    #[test]
753    fn test_allowlist_schemeless_host_with_port() {
754        let p = Policy {
755            allowlist: vec!["example.com".to_string()],
756            ..Default::default()
757        };
758        assert!(p.is_allowlisted("example.com:8080/path"));
759    }
760
761    #[test]
762    fn test_discover_skips_remote_fetch_below_team_tier() {
763        let _guard = crate::TEST_ENV_LOCK.lock().unwrap();
764
765        let dir = tempfile::tempdir().unwrap();
766        let policy_dir = dir.path().join(".tirith");
767        std::fs::create_dir_all(&policy_dir).unwrap();
768        std::fs::write(
769            policy_dir.join("policy.yaml"),
770            "fail_mode: open\npolicy_fetch_fail_mode: closed\nallow_bypass_env_noninteractive: true\n",
771        )
772        .unwrap();
773
774        // Force Community tier regardless of host machine config.
775        unsafe { std::env::set_var("TIRITH_LICENSE", "!") };
776        unsafe { std::env::set_var("TIRITH_SERVER_URL", "http://127.0.0.1") };
777        unsafe { std::env::set_var("TIRITH_API_KEY", "dummy") };
778
779        let policy = Policy::discover(Some(dir.path().to_str().unwrap()));
780        assert_ne!(policy.path.as_deref(), Some("fail-closed"));
781        assert_eq!(policy.fail_mode, FailMode::Open);
782        assert!(policy.allow_bypass_env_noninteractive);
783        assert!(policy
784            .path
785            .as_deref()
786            .unwrap_or_default()
787            .contains(".tirith"));
788
789        unsafe { std::env::remove_var("TIRITH_API_KEY") };
790        unsafe { std::env::remove_var("TIRITH_SERVER_URL") };
791        unsafe { std::env::remove_var("TIRITH_LICENSE") };
792    }
793}