Skip to main content

tirith_core/
policy.rs

1use etcetera::BaseStrategy;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6use crate::verdict::{RuleId, Severity};
7
8/// Try both `.yaml` and `.yml` extensions in a directory.
9fn find_policy_in_dir(dir: &Path) -> Option<PathBuf> {
10    let yaml = dir.join("policy.yaml");
11    if yaml.exists() {
12        return Some(yaml);
13    }
14    let yml = dir.join("policy.yml");
15    if yml.exists() {
16        return Some(yml);
17    }
18    None
19}
20
21/// Policy configuration loaded from YAML.
22#[derive(Debug, Clone, Serialize, Deserialize)]
23#[serde(default)]
24pub struct Policy {
25    /// Path this policy was loaded from.
26    #[serde(skip)]
27    pub path: Option<String>,
28
29    /// Fail mode: "open" (default) or "closed".
30    pub fail_mode: FailMode,
31
32    /// Allow TIRITH=0 bypass in interactive mode.
33    pub allow_bypass_env: bool,
34
35    /// Allow TIRITH=0 bypass in non-interactive mode.
36    pub allow_bypass_env_noninteractive: bool,
37
38    /// Paranoia tier (1-4).
39    pub paranoia: u8,
40
41    /// Severity overrides per rule.
42    #[serde(default)]
43    pub severity_overrides: HashMap<String, Severity>,
44
45    /// Additional known domains (extends built-in list).
46    #[serde(default)]
47    pub additional_known_domains: Vec<String>,
48
49    /// Allowlist: URL patterns that are always allowed.
50    #[serde(default)]
51    pub allowlist: Vec<String>,
52
53    /// Blocklist: URL patterns that are always blocked.
54    #[serde(default)]
55    pub blocklist: Vec<String>,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
59#[serde(rename_all = "lowercase")]
60#[derive(Default)]
61pub enum FailMode {
62    #[default]
63    Open,
64    Closed,
65}
66
67impl Default for Policy {
68    fn default() -> Self {
69        Self {
70            path: None,
71            fail_mode: FailMode::Open,
72            allow_bypass_env: true,
73            allow_bypass_env_noninteractive: false,
74            paranoia: 1,
75            severity_overrides: HashMap::new(),
76            additional_known_domains: Vec::new(),
77            allowlist: Vec::new(),
78            blocklist: Vec::new(),
79        }
80    }
81}
82
83impl Policy {
84    /// Discover and load partial policy (just bypass + fail_mode fields).
85    /// Used in Tier 2 for fast bypass resolution.
86    pub fn discover_partial(cwd: Option<&str>) -> Self {
87        match discover_policy_path(cwd) {
88            Some(path) => match std::fs::read_to_string(&path) {
89                Ok(content) => match serde_yaml::from_str::<Policy>(&content) {
90                    Ok(mut p) => {
91                        p.path = Some(path.display().to_string());
92                        p
93                    }
94                    Err(e) => {
95                        eprintln!(
96                            "tirith: warning: failed to parse policy at {}: {e}",
97                            path.display()
98                        );
99                        // Parse error: use fail_mode default behavior
100                        Policy::default()
101                    }
102                },
103                Err(_) => Policy::default(),
104            },
105            None => Policy::default(),
106        }
107    }
108
109    /// Discover and load full policy.
110    pub fn discover(cwd: Option<&str>) -> Self {
111        // Check env override first
112        if let Ok(root) = std::env::var("TIRITH_POLICY_ROOT") {
113            if let Some(path) = find_policy_in_dir(&PathBuf::from(&root).join(".tirith")) {
114                return Self::load_from_path(&path);
115            }
116        }
117
118        match discover_policy_path(cwd) {
119            Some(path) => Self::load_from_path(&path),
120            None => {
121                // Try user-level policy
122                if let Some(user_path) = user_policy_path() {
123                    if user_path.exists() {
124                        return Self::load_from_path(&user_path);
125                    }
126                }
127                Policy::default()
128            }
129        }
130    }
131
132    fn load_from_path(path: &Path) -> Self {
133        match std::fs::read_to_string(path) {
134            Ok(content) => match serde_yaml::from_str::<Policy>(&content) {
135                Ok(mut p) => {
136                    p.path = Some(path.display().to_string());
137                    p
138                }
139                Err(e) => {
140                    eprintln!(
141                        "tirith: warning: failed to parse policy at {}: {e}",
142                        path.display(),
143                    );
144                    Policy::default()
145                }
146            },
147            Err(_) => Policy::default(),
148        }
149    }
150
151    /// Get severity override for a rule.
152    pub fn severity_override(&self, rule_id: &RuleId) -> Option<Severity> {
153        let key = serde_json::to_value(rule_id)
154            .ok()
155            .and_then(|v| v.as_str().map(String::from))?;
156        self.severity_overrides.get(&key).copied()
157    }
158
159    /// Check if a URL is in the blocklist.
160    pub fn is_blocklisted(&self, url: &str) -> bool {
161        let url_lower = url.to_lowercase();
162        self.blocklist.iter().any(|pattern| {
163            let p = pattern.to_lowercase();
164            url_lower.contains(&p)
165        })
166    }
167
168    /// Check if a URL is in the allowlist.
169    pub fn is_allowlisted(&self, url: &str) -> bool {
170        let url_lower = url.to_lowercase();
171        self.allowlist.iter().any(|pattern| {
172            let p = pattern.to_lowercase();
173            if p.is_empty() {
174                return false;
175            }
176            if is_domain_pattern(&p) {
177                if let Some(host) = extract_host_for_match(url) {
178                    return domain_matches(&host, &p);
179                }
180                return false;
181            }
182            url_lower.contains(&p)
183        })
184    }
185
186    /// Load and merge user-level lists (allowlist/blocklist flat text files).
187    pub fn load_user_lists(&mut self) {
188        if let Some(config) = crate::policy::config_dir() {
189            let allowlist_path = config.join("allowlist");
190            if let Ok(content) = std::fs::read_to_string(&allowlist_path) {
191                for line in content.lines() {
192                    let line = line.trim();
193                    if !line.is_empty() && !line.starts_with('#') {
194                        self.allowlist.push(line.to_string());
195                    }
196                }
197            }
198            let blocklist_path = config.join("blocklist");
199            if let Ok(content) = std::fs::read_to_string(&blocklist_path) {
200                for line in content.lines() {
201                    let line = line.trim();
202                    if !line.is_empty() && !line.starts_with('#') {
203                        self.blocklist.push(line.to_string());
204                    }
205                }
206            }
207        }
208    }
209
210    /// Load and merge org-level lists from a repo root's .tirith/ dir.
211    pub fn load_org_lists(&mut self, cwd: Option<&str>) {
212        if let Some(repo_root) = find_repo_root(cwd) {
213            let org_dir = repo_root.join(".tirith");
214            let allowlist_path = org_dir.join("allowlist");
215            if let Ok(content) = std::fs::read_to_string(&allowlist_path) {
216                for line in content.lines() {
217                    let line = line.trim();
218                    if !line.is_empty() && !line.starts_with('#') {
219                        self.allowlist.push(line.to_string());
220                    }
221                }
222            }
223            let blocklist_path = org_dir.join("blocklist");
224            if let Ok(content) = std::fs::read_to_string(&blocklist_path) {
225                for line in content.lines() {
226                    let line = line.trim();
227                    if !line.is_empty() && !line.starts_with('#') {
228                        self.blocklist.push(line.to_string());
229                    }
230                }
231            }
232        }
233    }
234}
235
236fn is_domain_pattern(p: &str) -> bool {
237    !p.contains("://")
238        && !p.contains('/')
239        && !p.contains('?')
240        && !p.contains('#')
241        && !p.contains(':')
242}
243
244fn extract_host_for_match(url: &str) -> Option<String> {
245    if let Some(host) = crate::parse::parse_url(url).host() {
246        return Some(host.trim_end_matches('.').to_lowercase());
247    }
248    // Fallback for schemeless host/path (e.g., example.com/path)
249    let candidate = url.split('/').next().unwrap_or(url).trim();
250    if candidate.starts_with('-') || !candidate.contains('.') || candidate.contains(' ') {
251        return None;
252    }
253    let host = if let Some((h, port)) = candidate.rsplit_once(':') {
254        if port.chars().all(|c| c.is_ascii_digit()) && !port.is_empty() {
255            h
256        } else {
257            candidate
258        }
259    } else {
260        candidate
261    };
262    Some(host.trim_end_matches('.').to_lowercase())
263}
264
265fn domain_matches(host: &str, pattern: &str) -> bool {
266    let host = host.trim_end_matches('.');
267    let pattern = pattern.trim_start_matches("*.").trim_end_matches('.');
268    host == pattern || host.ends_with(&format!(".{pattern}"))
269}
270
271/// Discover policy path by walking up from cwd to .git boundary.
272fn discover_policy_path(cwd: Option<&str>) -> Option<PathBuf> {
273    let start = cwd
274        .map(PathBuf::from)
275        .or_else(|| std::env::current_dir().ok())?;
276
277    let mut current = start.as_path();
278    loop {
279        // Check for .tirith/policy.yaml or .tirith/policy.yml
280        if let Some(candidate) = find_policy_in_dir(&current.join(".tirith")) {
281            return Some(candidate);
282        }
283
284        // Check for .git boundary (directory or file for worktrees)
285        let git_dir = current.join(".git");
286        if git_dir.exists() {
287            return None; // Hit repo root without finding policy
288        }
289
290        // Go up
291        match current.parent() {
292            Some(parent) if parent != current => current = parent,
293            _ => break,
294        }
295    }
296
297    None
298}
299
300/// Find the repository root (directory containing .git).
301fn find_repo_root(cwd: Option<&str>) -> Option<PathBuf> {
302    let start = cwd
303        .map(PathBuf::from)
304        .or_else(|| std::env::current_dir().ok())?;
305    let mut current = start.as_path();
306    loop {
307        let git = current.join(".git");
308        if git.exists() {
309            return Some(current.to_path_buf());
310        }
311        match current.parent() {
312            Some(parent) if parent != current => current = parent,
313            _ => break,
314        }
315    }
316    None
317}
318
319/// Get user-level policy path.
320fn user_policy_path() -> Option<PathBuf> {
321    let base = etcetera::choose_base_strategy().ok()?;
322    find_policy_in_dir(&base.config_dir().join("tirith"))
323}
324
325/// Get tirith data directory.
326pub fn data_dir() -> Option<PathBuf> {
327    let base = etcetera::choose_base_strategy().ok()?;
328    Some(base.data_dir().join("tirith"))
329}
330
331/// Get tirith config directory.
332pub fn config_dir() -> Option<PathBuf> {
333    let base = etcetera::choose_base_strategy().ok()?;
334    Some(base.config_dir().join("tirith"))
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    #[test]
342    fn test_allowlist_domain_matches_subdomain() {
343        let p = Policy {
344            allowlist: vec!["github.com".to_string()],
345            ..Default::default()
346        };
347        assert!(p.is_allowlisted("https://api.github.com/repos"));
348        assert!(p.is_allowlisted("git@github.com:owner/repo.git"));
349        assert!(!p.is_allowlisted("https://evil-github.com"));
350    }
351
352    #[test]
353    fn test_allowlist_schemeless_host() {
354        let p = Policy {
355            allowlist: vec!["raw.githubusercontent.com".to_string()],
356            ..Default::default()
357        };
358        assert!(p.is_allowlisted("raw.githubusercontent.com/path/to/file"));
359    }
360
361    #[test]
362    fn test_allowlist_schemeless_host_with_port() {
363        let p = Policy {
364            allowlist: vec!["example.com".to_string()],
365            ..Default::default()
366        };
367        assert!(p.is_allowlisted("example.com:8080/path"));
368    }
369}