1use etcetera::BaseStrategy;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6use crate::verdict::{RuleId, Severity};
7
8fn find_policy_in_dir(dir: &Path) -> Option<PathBuf> {
10 let yaml = dir.join("policy.yaml");
11 if yaml.exists() {
12 return Some(yaml);
13 }
14 let yml = dir.join("policy.yml");
15 if yml.exists() {
16 return Some(yml);
17 }
18 None
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23#[serde(default)]
24pub struct Policy {
25 #[serde(skip)]
27 pub path: Option<String>,
28
29 pub fail_mode: FailMode,
31
32 pub allow_bypass_env: bool,
34
35 pub allow_bypass_env_noninteractive: bool,
37
38 pub paranoia: u8,
40
41 #[serde(default)]
43 pub severity_overrides: HashMap<String, Severity>,
44
45 #[serde(default)]
47 pub additional_known_domains: Vec<String>,
48
49 #[serde(default)]
51 pub allowlist: Vec<String>,
52
53 #[serde(default)]
55 pub blocklist: Vec<String>,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
59#[serde(rename_all = "lowercase")]
60#[derive(Default)]
61pub enum FailMode {
62 #[default]
63 Open,
64 Closed,
65}
66
67impl Default for Policy {
68 fn default() -> Self {
69 Self {
70 path: None,
71 fail_mode: FailMode::Open,
72 allow_bypass_env: true,
73 allow_bypass_env_noninteractive: false,
74 paranoia: 1,
75 severity_overrides: HashMap::new(),
76 additional_known_domains: Vec::new(),
77 allowlist: Vec::new(),
78 blocklist: Vec::new(),
79 }
80 }
81}
82
83impl Policy {
84 pub fn discover_partial(cwd: Option<&str>) -> Self {
87 match discover_policy_path(cwd) {
88 Some(path) => match std::fs::read_to_string(&path) {
89 Ok(content) => match serde_yaml::from_str::<Policy>(&content) {
90 Ok(mut p) => {
91 p.path = Some(path.display().to_string());
92 p
93 }
94 Err(e) => {
95 eprintln!(
96 "tirith: warning: failed to parse policy at {}: {e}",
97 path.display()
98 );
99 Policy::default()
101 }
102 },
103 Err(_) => Policy::default(),
104 },
105 None => Policy::default(),
106 }
107 }
108
109 pub fn discover(cwd: Option<&str>) -> Self {
111 if let Ok(root) = std::env::var("TIRITH_POLICY_ROOT") {
113 if let Some(path) = find_policy_in_dir(&PathBuf::from(&root).join(".tirith")) {
114 return Self::load_from_path(&path);
115 }
116 }
117
118 match discover_policy_path(cwd) {
119 Some(path) => Self::load_from_path(&path),
120 None => {
121 if let Some(user_path) = user_policy_path() {
123 if user_path.exists() {
124 return Self::load_from_path(&user_path);
125 }
126 }
127 Policy::default()
128 }
129 }
130 }
131
132 fn load_from_path(path: &Path) -> Self {
133 match std::fs::read_to_string(path) {
134 Ok(content) => match serde_yaml::from_str::<Policy>(&content) {
135 Ok(mut p) => {
136 p.path = Some(path.display().to_string());
137 p
138 }
139 Err(e) => {
140 eprintln!(
141 "tirith: warning: failed to parse policy at {}: {e}",
142 path.display(),
143 );
144 Policy::default()
145 }
146 },
147 Err(_) => Policy::default(),
148 }
149 }
150
151 pub fn severity_override(&self, rule_id: &RuleId) -> Option<Severity> {
153 let key = serde_json::to_value(rule_id)
154 .ok()
155 .and_then(|v| v.as_str().map(String::from))?;
156 self.severity_overrides.get(&key).copied()
157 }
158
159 pub fn is_blocklisted(&self, url: &str) -> bool {
161 let url_lower = url.to_lowercase();
162 self.blocklist.iter().any(|pattern| {
163 let p = pattern.to_lowercase();
164 url_lower.contains(&p)
165 })
166 }
167
168 pub fn is_allowlisted(&self, url: &str) -> bool {
170 let url_lower = url.to_lowercase();
171 self.allowlist.iter().any(|pattern| {
172 let p = pattern.to_lowercase();
173 if p.is_empty() {
174 return false;
175 }
176 if is_domain_pattern(&p) {
177 if let Some(host) = extract_host_for_match(url) {
178 return domain_matches(&host, &p);
179 }
180 return false;
181 }
182 url_lower.contains(&p)
183 })
184 }
185
186 pub fn load_user_lists(&mut self) {
188 if let Some(config) = crate::policy::config_dir() {
189 let allowlist_path = config.join("allowlist");
190 if let Ok(content) = std::fs::read_to_string(&allowlist_path) {
191 for line in content.lines() {
192 let line = line.trim();
193 if !line.is_empty() && !line.starts_with('#') {
194 self.allowlist.push(line.to_string());
195 }
196 }
197 }
198 let blocklist_path = config.join("blocklist");
199 if let Ok(content) = std::fs::read_to_string(&blocklist_path) {
200 for line in content.lines() {
201 let line = line.trim();
202 if !line.is_empty() && !line.starts_with('#') {
203 self.blocklist.push(line.to_string());
204 }
205 }
206 }
207 }
208 }
209
210 pub fn load_org_lists(&mut self, cwd: Option<&str>) {
212 if let Some(repo_root) = find_repo_root(cwd) {
213 let org_dir = repo_root.join(".tirith");
214 let allowlist_path = org_dir.join("allowlist");
215 if let Ok(content) = std::fs::read_to_string(&allowlist_path) {
216 for line in content.lines() {
217 let line = line.trim();
218 if !line.is_empty() && !line.starts_with('#') {
219 self.allowlist.push(line.to_string());
220 }
221 }
222 }
223 let blocklist_path = org_dir.join("blocklist");
224 if let Ok(content) = std::fs::read_to_string(&blocklist_path) {
225 for line in content.lines() {
226 let line = line.trim();
227 if !line.is_empty() && !line.starts_with('#') {
228 self.blocklist.push(line.to_string());
229 }
230 }
231 }
232 }
233 }
234}
235
236fn is_domain_pattern(p: &str) -> bool {
237 !p.contains("://")
238 && !p.contains('/')
239 && !p.contains('?')
240 && !p.contains('#')
241 && !p.contains(':')
242}
243
244fn extract_host_for_match(url: &str) -> Option<String> {
245 if let Some(host) = crate::parse::parse_url(url).host() {
246 return Some(host.trim_end_matches('.').to_lowercase());
247 }
248 let candidate = url.split('/').next().unwrap_or(url).trim();
250 if candidate.starts_with('-') || !candidate.contains('.') || candidate.contains(' ') {
251 return None;
252 }
253 let host = if let Some((h, port)) = candidate.rsplit_once(':') {
254 if port.chars().all(|c| c.is_ascii_digit()) && !port.is_empty() {
255 h
256 } else {
257 candidate
258 }
259 } else {
260 candidate
261 };
262 Some(host.trim_end_matches('.').to_lowercase())
263}
264
265fn domain_matches(host: &str, pattern: &str) -> bool {
266 let host = host.trim_end_matches('.');
267 let pattern = pattern.trim_start_matches("*.").trim_end_matches('.');
268 host == pattern || host.ends_with(&format!(".{pattern}"))
269}
270
271fn discover_policy_path(cwd: Option<&str>) -> Option<PathBuf> {
273 let start = cwd
274 .map(PathBuf::from)
275 .or_else(|| std::env::current_dir().ok())?;
276
277 let mut current = start.as_path();
278 loop {
279 if let Some(candidate) = find_policy_in_dir(¤t.join(".tirith")) {
281 return Some(candidate);
282 }
283
284 let git_dir = current.join(".git");
286 if git_dir.exists() {
287 return None; }
289
290 match current.parent() {
292 Some(parent) if parent != current => current = parent,
293 _ => break,
294 }
295 }
296
297 None
298}
299
300fn find_repo_root(cwd: Option<&str>) -> Option<PathBuf> {
302 let start = cwd
303 .map(PathBuf::from)
304 .or_else(|| std::env::current_dir().ok())?;
305 let mut current = start.as_path();
306 loop {
307 let git = current.join(".git");
308 if git.exists() {
309 return Some(current.to_path_buf());
310 }
311 match current.parent() {
312 Some(parent) if parent != current => current = parent,
313 _ => break,
314 }
315 }
316 None
317}
318
319fn user_policy_path() -> Option<PathBuf> {
321 let base = etcetera::choose_base_strategy().ok()?;
322 find_policy_in_dir(&base.config_dir().join("tirith"))
323}
324
325pub fn data_dir() -> Option<PathBuf> {
327 let base = etcetera::choose_base_strategy().ok()?;
328 Some(base.data_dir().join("tirith"))
329}
330
331pub fn config_dir() -> Option<PathBuf> {
333 let base = etcetera::choose_base_strategy().ok()?;
334 Some(base.config_dir().join("tirith"))
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 #[test]
342 fn test_allowlist_domain_matches_subdomain() {
343 let p = Policy {
344 allowlist: vec!["github.com".to_string()],
345 ..Default::default()
346 };
347 assert!(p.is_allowlisted("https://api.github.com/repos"));
348 assert!(p.is_allowlisted("git@github.com:owner/repo.git"));
349 assert!(!p.is_allowlisted("https://evil-github.com"));
350 }
351
352 #[test]
353 fn test_allowlist_schemeless_host() {
354 let p = Policy {
355 allowlist: vec!["raw.githubusercontent.com".to_string()],
356 ..Default::default()
357 };
358 assert!(p.is_allowlisted("raw.githubusercontent.com/path/to/file"));
359 }
360
361 #[test]
362 fn test_allowlist_schemeless_host_with_port() {
363 let p = Policy {
364 allowlist: vec!["example.com".to_string()],
365 ..Default::default()
366 };
367 assert!(p.is_allowlisted("example.com:8080/path"));
368 }
369}