Skip to main content

tirith_core/
policy.rs

1use etcetera::BaseStrategy;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6use crate::verdict::{RuleId, Severity};
7
8/// Policy configuration loaded from YAML.
9#[derive(Debug, Clone, Serialize, Deserialize)]
10#[serde(default)]
11pub struct Policy {
12    /// Path this policy was loaded from.
13    #[serde(skip)]
14    pub path: Option<String>,
15
16    /// Fail mode: "open" (default) or "closed".
17    pub fail_mode: FailMode,
18
19    /// Allow TIRITH=0 bypass in interactive mode.
20    pub allow_bypass_env: bool,
21
22    /// Allow TIRITH=0 bypass in non-interactive mode.
23    pub allow_bypass_env_noninteractive: bool,
24
25    /// Paranoia tier (1-4).
26    pub paranoia: u8,
27
28    /// Severity overrides per rule.
29    #[serde(default)]
30    pub severity_overrides: HashMap<String, Severity>,
31
32    /// Additional known domains (extends built-in list).
33    #[serde(default)]
34    pub additional_known_domains: Vec<String>,
35
36    /// Allowlist: URL patterns that are always allowed.
37    #[serde(default)]
38    pub allowlist: Vec<String>,
39
40    /// Blocklist: URL patterns that are always blocked.
41    #[serde(default)]
42    pub blocklist: Vec<String>,
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
46#[serde(rename_all = "lowercase")]
47#[derive(Default)]
48pub enum FailMode {
49    #[default]
50    Open,
51    Closed,
52}
53
54impl Default for Policy {
55    fn default() -> Self {
56        Self {
57            path: None,
58            fail_mode: FailMode::Open,
59            allow_bypass_env: true,
60            allow_bypass_env_noninteractive: false,
61            paranoia: 1,
62            severity_overrides: HashMap::new(),
63            additional_known_domains: Vec::new(),
64            allowlist: Vec::new(),
65            blocklist: Vec::new(),
66        }
67    }
68}
69
70impl Policy {
71    /// Discover and load partial policy (just bypass + fail_mode fields).
72    /// Used in Tier 2 for fast bypass resolution.
73    pub fn discover_partial(cwd: Option<&str>) -> Self {
74        match discover_policy_path(cwd) {
75            Some(path) => match std::fs::read_to_string(&path) {
76                Ok(content) => match serde_yaml::from_str::<Policy>(&content) {
77                    Ok(mut p) => {
78                        p.path = Some(path.display().to_string());
79                        p
80                    }
81                    Err(_) => {
82                        // Parse error: use fail_mode default behavior
83                        Policy::default()
84                    }
85                },
86                Err(_) => Policy::default(),
87            },
88            None => Policy::default(),
89        }
90    }
91
92    /// Discover and load full policy.
93    pub fn discover(cwd: Option<&str>) -> Self {
94        // Check env override first
95        if let Ok(root) = std::env::var("TIRITH_POLICY_ROOT") {
96            let path = PathBuf::from(&root).join("policy.yml");
97            if path.exists() {
98                return Self::load_from_path(&path);
99            }
100        }
101
102        match discover_policy_path(cwd) {
103            Some(path) => Self::load_from_path(&path),
104            None => {
105                // Try user-level policy
106                if let Some(user_path) = user_policy_path() {
107                    if user_path.exists() {
108                        return Self::load_from_path(&user_path);
109                    }
110                }
111                Policy::default()
112            }
113        }
114    }
115
116    fn load_from_path(path: &Path) -> Self {
117        match std::fs::read_to_string(path) {
118            Ok(content) => match serde_yaml::from_str::<Policy>(&content) {
119                Ok(mut p) => {
120                    p.path = Some(path.display().to_string());
121                    p
122                }
123                Err(_) => {
124                    eprintln!(
125                        "tirith: warning: failed to parse policy at {}",
126                        path.display()
127                    );
128                    Policy::default()
129                }
130            },
131            Err(_) => Policy::default(),
132        }
133    }
134
135    /// Get severity override for a rule.
136    pub fn severity_override(&self, rule_id: &RuleId) -> Option<Severity> {
137        let key = serde_json::to_value(rule_id)
138            .ok()
139            .and_then(|v| v.as_str().map(String::from))?;
140        self.severity_overrides.get(&key).copied()
141    }
142
143    /// Check if a URL is in the blocklist.
144    pub fn is_blocklisted(&self, url: &str) -> bool {
145        let url_lower = url.to_lowercase();
146        self.blocklist.iter().any(|pattern| {
147            let p = pattern.to_lowercase();
148            url_lower.contains(&p)
149        })
150    }
151
152    /// Check if a URL is in the allowlist.
153    pub fn is_allowlisted(&self, url: &str) -> bool {
154        let url_lower = url.to_lowercase();
155        self.allowlist.iter().any(|pattern| {
156            let p = pattern.to_lowercase();
157            url_lower.contains(&p)
158        })
159    }
160
161    /// Load and merge user-level lists (allowlist/blocklist flat text files).
162    pub fn load_user_lists(&mut self) {
163        if let Some(config) = crate::policy::config_dir() {
164            let allowlist_path = config.join("allowlist");
165            if let Ok(content) = std::fs::read_to_string(&allowlist_path) {
166                for line in content.lines() {
167                    let line = line.trim();
168                    if !line.is_empty() && !line.starts_with('#') {
169                        self.allowlist.push(line.to_string());
170                    }
171                }
172            }
173            let blocklist_path = config.join("blocklist");
174            if let Ok(content) = std::fs::read_to_string(&blocklist_path) {
175                for line in content.lines() {
176                    let line = line.trim();
177                    if !line.is_empty() && !line.starts_with('#') {
178                        self.blocklist.push(line.to_string());
179                    }
180                }
181            }
182        }
183    }
184
185    /// Load and merge org-level lists from a repo root's .tirith/ dir.
186    pub fn load_org_lists(&mut self, cwd: Option<&str>) {
187        if let Some(repo_root) = find_repo_root(cwd) {
188            let org_dir = repo_root.join(".tirith");
189            let allowlist_path = org_dir.join("allowlist");
190            if let Ok(content) = std::fs::read_to_string(&allowlist_path) {
191                for line in content.lines() {
192                    let line = line.trim();
193                    if !line.is_empty() && !line.starts_with('#') {
194                        self.allowlist.push(line.to_string());
195                    }
196                }
197            }
198            let blocklist_path = org_dir.join("blocklist");
199            if let Ok(content) = std::fs::read_to_string(&blocklist_path) {
200                for line in content.lines() {
201                    let line = line.trim();
202                    if !line.is_empty() && !line.starts_with('#') {
203                        self.blocklist.push(line.to_string());
204                    }
205                }
206            }
207        }
208    }
209}
210
211/// Discover policy path by walking up from cwd to .git boundary.
212fn discover_policy_path(cwd: Option<&str>) -> Option<PathBuf> {
213    let start = cwd
214        .map(PathBuf::from)
215        .or_else(|| std::env::current_dir().ok())?;
216
217    let mut current = start.as_path();
218    loop {
219        // Check for .tirith/policy.yml
220        let candidate = current.join(".tirith").join("policy.yml");
221        if candidate.exists() {
222            return Some(candidate);
223        }
224
225        // Check for .git boundary (directory or file for worktrees)
226        let git_dir = current.join(".git");
227        if git_dir.exists() {
228            return None; // Hit repo root without finding policy
229        }
230
231        // Go up
232        match current.parent() {
233            Some(parent) if parent != current => current = parent,
234            _ => break,
235        }
236    }
237
238    None
239}
240
241/// Find the repository root (directory containing .git).
242fn find_repo_root(cwd: Option<&str>) -> Option<PathBuf> {
243    let start = cwd
244        .map(PathBuf::from)
245        .or_else(|| std::env::current_dir().ok())?;
246    let mut current = start.as_path();
247    loop {
248        let git = current.join(".git");
249        if git.exists() {
250            return Some(current.to_path_buf());
251        }
252        match current.parent() {
253            Some(parent) if parent != current => current = parent,
254            _ => break,
255        }
256    }
257    None
258}
259
260/// Get user-level policy path.
261fn user_policy_path() -> Option<PathBuf> {
262    let base = etcetera::choose_base_strategy().ok()?;
263    Some(base.config_dir().join("tirith").join("policy.yml"))
264}
265
266/// Get tirith data directory.
267pub fn data_dir() -> Option<PathBuf> {
268    let base = etcetera::choose_base_strategy().ok()?;
269    Some(base.data_dir().join("tirith"))
270}
271
272/// Get tirith config directory.
273pub fn config_dir() -> Option<PathBuf> {
274    let base = etcetera::choose_base_strategy().ok()?;
275    Some(base.config_dir().join("tirith"))
276}