Skip to main content

tirith_core/
policy_validate.rs

1//! Policy YAML validation — syntax, schema, and conflict checks.
2//!
3//! Separate from `policy.rs` (which handles loading and runtime matching).
4//! Used by `tirith policy validate`.
5
6use crate::verdict::{RuleId, Severity};
7
8/// A single validation issue found in a policy file.
9#[derive(Debug, Clone, serde::Serialize)]
10pub struct PolicyIssue {
11    pub level: IssueLevel,
12    pub message: String,
13    #[serde(skip_serializing_if = "Option::is_none")]
14    pub field: Option<String>,
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
18#[serde(rename_all = "lowercase")]
19pub enum IssueLevel {
20    Error,
21    Warning,
22}
23
24impl std::fmt::Display for IssueLevel {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        match self {
27            IssueLevel::Error => write!(f, "error"),
28            IssueLevel::Warning => write!(f, "warning"),
29        }
30    }
31}
32
33/// Validate a policy YAML string. Returns a list of issues (empty = valid).
34pub fn validate(yaml: &str) -> Vec<PolicyIssue> {
35    let mut issues = Vec::new();
36
37    // Structural parse first — fail early if the YAML shape is wrong.
38    let policy: crate::policy::Policy = match serde_yaml::from_str(yaml) {
39        Ok(p) => p,
40        Err(e) => {
41            issues.push(PolicyIssue {
42                level: IssueLevel::Error,
43                message: format!("YAML parse error: {e}"),
44                field: None,
45            });
46            return issues;
47        }
48    };
49
50    validate_paranoia(&policy, &mut issues);
51    validate_severity_overrides(&policy, &mut issues);
52    validate_allowlist_blocklist_overlap(&policy, &mut issues);
53    validate_custom_rules(&policy, &mut issues);
54    validate_approval_rules(&policy, &mut issues);
55    validate_fail_mode_fields(&policy, &mut issues);
56    validate_scan_config(&policy, &mut issues);
57    validate_network_entries(&policy, &mut issues);
58    validate_action_overrides(&policy, &mut issues);
59    validate_escalation_rules(&policy, &mut issues);
60
61    // Typo guard: flag fields that aren't part of the Policy schema.
62    validate_unknown_fields(yaml, &mut issues);
63
64    issues
65}
66
67fn validate_paranoia(policy: &crate::policy::Policy, issues: &mut Vec<PolicyIssue>) {
68    if policy.paranoia == 0 || policy.paranoia > 4 {
69        issues.push(PolicyIssue {
70            level: IssueLevel::Error,
71            message: format!("paranoia must be 1-4, got {}", policy.paranoia),
72            field: Some("paranoia".into()),
73        });
74    }
75}
76
77fn validate_severity_overrides(policy: &crate::policy::Policy, issues: &mut Vec<PolicyIssue>) {
78    for key in policy.severity_overrides.keys() {
79        // Check if the key is a valid RuleId
80        let parsed: Result<RuleId, _> =
81            serde_json::from_value(serde_json::Value::String(key.clone()));
82        if parsed.is_err() {
83            issues.push(PolicyIssue {
84                level: IssueLevel::Error,
85                message: format!("severity_overrides: unknown rule ID '{key}'"),
86                field: Some(format!("severity_overrides.{key}")),
87            });
88        }
89    }
90}
91
92fn validate_allowlist_blocklist_overlap(
93    policy: &crate::policy::Policy,
94    issues: &mut Vec<PolicyIssue>,
95) {
96    for allow in &policy.allowlist {
97        let allow_lower = allow.to_lowercase();
98        for block in &policy.blocklist {
99            if block.to_lowercase() == allow_lower {
100                issues.push(PolicyIssue {
101                    level: IssueLevel::Warning,
102                    message: format!(
103                        "pattern '{allow}' appears in both allowlist and blocklist \
104                         (blocklist takes precedence)"
105                    ),
106                    field: Some("allowlist/blocklist".into()),
107                });
108            }
109        }
110    }
111}
112
113fn validate_custom_rules(policy: &crate::policy::Policy, issues: &mut Vec<PolicyIssue>) {
114    let mut seen_ids = std::collections::HashSet::new();
115    for rule in &policy.custom_rules {
116        if !seen_ids.insert(&rule.id) {
117            issues.push(PolicyIssue {
118                level: IssueLevel::Error,
119                message: format!("custom_rules: duplicate id '{}'", rule.id),
120                field: Some(format!("custom_rules.{}", rule.id)),
121            });
122        }
123
124        // Validate regex compiles
125        if let Err(e) = regex::Regex::new(&rule.pattern) {
126            issues.push(PolicyIssue {
127                level: IssueLevel::Error,
128                message: format!(
129                    "custom_rules.{}: invalid regex '{}': {e}",
130                    rule.id, rule.pattern
131                ),
132                field: Some(format!("custom_rules.{}.pattern", rule.id)),
133            });
134        }
135
136        // Validate contexts
137        let valid_contexts = ["exec", "paste", "file"];
138        for ctx in &rule.context {
139            if !valid_contexts.contains(&ctx.as_str()) {
140                issues.push(PolicyIssue {
141                    level: IssueLevel::Error,
142                    message: format!(
143                        "custom_rules.{}: invalid context '{}' (valid: exec, paste, file)",
144                        rule.id, ctx
145                    ),
146                    field: Some(format!("custom_rules.{}.context", rule.id)),
147                });
148            }
149        }
150    }
151}
152
153fn validate_approval_rules(policy: &crate::policy::Policy, issues: &mut Vec<PolicyIssue>) {
154    for (i, rule) in policy.approval_rules.iter().enumerate() {
155        for rule_id_str in &rule.rule_ids {
156            let parsed: Result<RuleId, _> =
157                serde_json::from_value(serde_json::Value::String(rule_id_str.clone()));
158            if parsed.is_err() {
159                issues.push(PolicyIssue {
160                    level: IssueLevel::Error,
161                    message: format!("approval_rules[{i}]: unknown rule ID '{rule_id_str}'"),
162                    field: Some(format!("approval_rules[{i}].rule_ids")),
163                });
164            }
165        }
166
167        let valid_fallbacks = ["block", "warn", "allow"];
168        if !valid_fallbacks.contains(&rule.fallback.as_str()) {
169            issues.push(PolicyIssue {
170                level: IssueLevel::Error,
171                message: format!(
172                    "approval_rules[{i}]: invalid fallback '{}' (valid: block, warn, allow)",
173                    rule.fallback
174                ),
175                field: Some(format!("approval_rules[{i}].fallback")),
176            });
177        }
178    }
179}
180
181fn validate_fail_mode_fields(policy: &crate::policy::Policy, issues: &mut Vec<PolicyIssue>) {
182    if let Some(ref mode) = policy.policy_fetch_fail_mode {
183        let valid = ["open", "closed", "cached"];
184        if !valid.contains(&mode.as_str()) {
185            issues.push(PolicyIssue {
186                level: IssueLevel::Error,
187                message: format!(
188                    "policy_fetch_fail_mode: invalid value '{mode}' (valid: open, closed, cached)"
189                ),
190                field: Some("policy_fetch_fail_mode".into()),
191            });
192        }
193    }
194}
195
196fn validate_scan_config(policy: &crate::policy::Policy, issues: &mut Vec<PolicyIssue>) {
197    if let Some(ref fail_on) = policy.scan.fail_on {
198        let parsed: Result<Severity, _> =
199            serde_json::from_value(serde_json::Value::String(fail_on.to_uppercase()));
200        if parsed.is_err() {
201            issues.push(PolicyIssue {
202                level: IssueLevel::Error,
203                message: format!(
204                    "scan.fail_on: invalid severity '{}' (valid: INFO, LOW, MEDIUM, HIGH, CRITICAL)",
205                    fail_on
206                ),
207                field: Some("scan.fail_on".into()),
208            });
209        }
210    }
211
212    // Validate DLP patterns compile
213    for (i, pattern) in policy.dlp_custom_patterns.iter().enumerate() {
214        if let Err(e) = regex::Regex::new(pattern) {
215            issues.push(PolicyIssue {
216                level: IssueLevel::Error,
217                message: format!("dlp_custom_patterns[{i}]: invalid regex '{pattern}': {e}"),
218                field: Some(format!("dlp_custom_patterns[{i}]")),
219            });
220        }
221    }
222}
223
224/// Validate CIDR/host entries in network_deny and network_allow.
225fn validate_network_entries(policy: &crate::policy::Policy, issues: &mut Vec<PolicyIssue>) {
226    for (field_name, entries) in [
227        ("network_deny", &policy.network_deny),
228        ("network_allow", &policy.network_allow),
229    ] {
230        for (i, entry) in entries.iter().enumerate() {
231            if !is_valid_cidr_or_host(entry) {
232                issues.push(PolicyIssue {
233                    level: IssueLevel::Error,
234                    message: format!(
235                        "{field_name}[{i}]: '{entry}' is not a valid hostname or CIDR"
236                    ),
237                    field: Some(format!("{field_name}[{i}]")),
238                });
239            }
240        }
241    }
242}
243
244/// Check if a string is a valid hostname, IP, or CIDR notation.
245fn is_valid_cidr_or_host(s: &str) -> bool {
246    // Allow hostnames (domain-like strings)
247    if s.chars()
248        .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '*')
249        && !s.is_empty()
250    {
251        return true;
252    }
253
254    // Allow IP/CIDR: split on '/' for CIDR prefix
255    if let Some((ip_part, prefix)) = s.split_once('/') {
256        // Validate prefix length
257        let Ok(prefix_len) = prefix.parse::<u32>() else {
258            return false;
259        };
260        // IPv4 CIDR
261        if ip_part.contains('.') {
262            return prefix_len <= 32 && parse_ipv4(ip_part);
263        }
264        // IPv6 CIDR
265        if ip_part.contains(':') {
266            return prefix_len <= 128 && parse_ipv6(ip_part);
267        }
268        return false;
269    }
270
271    // Plain IP
272    if s.contains(':') {
273        return parse_ipv6(s);
274    }
275    if s.contains('.') && s.chars().all(|c| c.is_ascii_digit() || c == '.') {
276        return parse_ipv4(s);
277    }
278
279    false
280}
281
282fn parse_ipv4(s: &str) -> bool {
283    let parts: Vec<&str> = s.split('.').collect();
284    parts.len() == 4
285        && parts.iter().all(|p| {
286            p.parse::<u8>().is_ok() || (*p == "0" || p.parse::<u16>().is_ok_and(|n| n <= 255))
287        })
288}
289
290fn parse_ipv6(s: &str) -> bool {
291    // Basic IPv6 validation: 1-8 groups of hex, with :: allowed once
292    let double_colon_count = s.matches("::").count();
293    if double_colon_count > 1 {
294        return false;
295    }
296    let groups: Vec<&str> = s.split(':').collect();
297    if double_colon_count == 0 && groups.len() != 8 {
298        return false;
299    }
300    if double_colon_count == 1 && groups.len() > 8 {
301        return false;
302    }
303    groups
304        .iter()
305        .all(|g| g.is_empty() || (g.len() <= 4 && g.chars().all(|c| c.is_ascii_hexdigit())))
306}
307
308fn validate_action_overrides(policy: &crate::policy::Policy, issues: &mut Vec<PolicyIssue>) {
309    for (key, value) in &policy.action_overrides {
310        // Validate value: only "block" is allowed
311        if value != "block" {
312            let hint = match value.as_str() {
313                "allow" | "warn" | "warn_ack" => {
314                    " (use severity_overrides to change rule severity instead)"
315                }
316                _ => "",
317            };
318            issues.push(PolicyIssue {
319                level: IssueLevel::Error,
320                message: format!(
321                    "action_overrides.{key}: invalid value '{value}' \
322                     (only 'block' is supported){hint}"
323                ),
324                field: Some(format!("action_overrides.{key}")),
325            });
326        }
327
328        // Validate key is a known RuleId
329        let parsed: Result<RuleId, _> =
330            serde_json::from_value(serde_json::Value::String(key.clone()));
331        if parsed.is_err() {
332            issues.push(PolicyIssue {
333                level: IssueLevel::Error,
334                message: format!("action_overrides: unknown rule ID '{key}'"),
335                field: Some(format!("action_overrides.{key}")),
336            });
337        }
338    }
339}
340
341fn validate_escalation_rules(policy: &crate::policy::Policy, issues: &mut Vec<PolicyIssue>) {
342    for (i, rule) in policy.escalation.iter().enumerate() {
343        match rule {
344            crate::escalation::EscalationRule::RepeatCount {
345                rule_ids,
346                threshold,
347                ..
348            } => {
349                if *threshold == 0 {
350                    issues.push(PolicyIssue {
351                        level: IssueLevel::Error,
352                        message: format!("escalation[{i}]: threshold must be > 0"),
353                        field: Some(format!("escalation[{i}].threshold")),
354                    });
355                }
356                for rule_id_str in rule_ids {
357                    if rule_id_str == "*" {
358                        continue; // wildcard is valid
359                    }
360                    let parsed: Result<RuleId, _> =
361                        serde_json::from_value(serde_json::Value::String(rule_id_str.clone()));
362                    if parsed.is_err() {
363                        issues.push(PolicyIssue {
364                            level: IssueLevel::Error,
365                            message: format!("escalation[{i}]: unknown rule ID '{rule_id_str}'"),
366                            field: Some(format!("escalation[{i}].rule_ids")),
367                        });
368                    }
369                }
370            }
371            crate::escalation::EscalationRule::MultiMedium { min_findings, .. } => {
372                if *min_findings == 0 {
373                    issues.push(PolicyIssue {
374                        level: IssueLevel::Error,
375                        message: format!("escalation[{i}]: min_findings must be > 0"),
376                        field: Some(format!("escalation[{i}].min_findings")),
377                    });
378                }
379            }
380        }
381    }
382}
383
384fn validate_unknown_fields(yaml: &str, issues: &mut Vec<PolicyIssue>) {
385    let known_top_level = [
386        "fail_mode",
387        "allow_bypass_env",
388        "allow_bypass_env_noninteractive",
389        "paranoia",
390        "severity_overrides",
391        "additional_known_domains",
392        "allowlist",
393        "blocklist",
394        "approval_rules",
395        "network_deny",
396        "network_allow",
397        "webhooks",
398        "checkpoints",
399        "scan",
400        "allowlist_rules",
401        "custom_rules",
402        "dlp_custom_patterns",
403        "strict_warn",
404        "action_overrides",
405        "escalation",
406        "policy_server_url",
407        "policy_server_api_key",
408        "policy_fetch_fail_mode",
409        "enforce_fail_mode",
410    ];
411
412    // Known fields for nested objects
413    let known_scan_fields = [
414        "additional_config_files",
415        "trusted_mcp_servers",
416        "ignore_patterns",
417        "fail_on",
418        "profiles",
419    ];
420    let known_checkpoint_fields = ["max_count", "max_age_hours", "max_storage_bytes"];
421
422    // Parse as generic YAML value to check top-level keys
423    if let Ok(serde_yaml::Value::Mapping(map)) = serde_yaml::from_str::<serde_yaml::Value>(yaml) {
424        for (key, value) in &map {
425            if let serde_yaml::Value::String(k) = key {
426                if !known_top_level.contains(&k.as_str()) {
427                    issues.push(PolicyIssue {
428                        level: IssueLevel::Warning,
429                        message: format!("unknown field '{k}'"),
430                        field: Some(k.clone()),
431                    });
432                }
433
434                // Check nested fields for known sub-objects
435                if k == "scan" {
436                    if let serde_yaml::Value::Mapping(sub_map) = value {
437                        let known_profile_fields = ["include", "exclude", "fail_on", "ignore"];
438                        for (sub_key, sub_val) in sub_map {
439                            if let serde_yaml::Value::String(sk) = sub_key {
440                                if !known_scan_fields.contains(&sk.as_str()) {
441                                    issues.push(PolicyIssue {
442                                        level: IssueLevel::Warning,
443                                        message: format!("unknown field 'scan.{sk}'"),
444                                        field: Some(format!("scan.{sk}")),
445                                    });
446                                }
447                                // Validate scan.profiles.<name>.* keys
448                                if sk == "profiles" {
449                                    if let serde_yaml::Value::Mapping(profiles) = sub_val {
450                                        for (pname, pval) in profiles {
451                                            let pname_str = match pname {
452                                                serde_yaml::Value::String(s) => s.clone(),
453                                                _ => continue,
454                                            };
455                                            if let serde_yaml::Value::Mapping(pfields) = pval {
456                                                for pkey in pfields.keys() {
457                                                    if let serde_yaml::Value::String(pk) = pkey {
458                                                        if !known_profile_fields
459                                                            .contains(&pk.as_str())
460                                                        {
461                                                            issues.push(PolicyIssue {
462                                                                level: IssueLevel::Warning,
463                                                                message: format!(
464                                                                    "unknown field 'scan.profiles.{pname_str}.{pk}'"
465                                                                ),
466                                                                field: Some(format!(
467                                                                    "scan.profiles.{pname_str}.{pk}"
468                                                                )),
469                                                            });
470                                                        }
471                                                    }
472                                                }
473                                            }
474                                        }
475                                    }
476                                }
477                            }
478                        }
479                    }
480                }
481
482                if k == "checkpoints" {
483                    if let serde_yaml::Value::Mapping(sub_map) = value {
484                        for sub_key in sub_map.keys() {
485                            if let serde_yaml::Value::String(sk) = sub_key {
486                                if !known_checkpoint_fields.contains(&sk.as_str()) {
487                                    issues.push(PolicyIssue {
488                                        level: IssueLevel::Warning,
489                                        message: format!("unknown field 'checkpoints.{sk}'"),
490                                        field: Some(format!("checkpoints.{sk}")),
491                                    });
492                                }
493                            }
494                        }
495                    }
496                }
497            }
498        }
499    }
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505
506    #[test]
507    fn test_valid_minimal_policy() {
508        let yaml = "fail_mode: open\nparanoia: 1\n";
509        let issues = validate(yaml);
510        assert!(
511            issues.is_empty(),
512            "minimal policy should be valid: {issues:?}"
513        );
514    }
515
516    #[test]
517    fn test_invalid_yaml() {
518        let yaml = "{{invalid yaml";
519        let issues = validate(yaml);
520        assert_eq!(issues.len(), 1);
521        assert_eq!(issues[0].level, IssueLevel::Error);
522        assert!(issues[0].message.contains("YAML parse error"));
523    }
524
525    #[test]
526    fn test_paranoia_out_of_range() {
527        let yaml = "paranoia: 5\n";
528        let issues = validate(yaml);
529        assert!(issues
530            .iter()
531            .any(|i| i.message.contains("paranoia must be 1-4")));
532    }
533
534    #[test]
535    fn test_invalid_severity_override() {
536        let yaml = "severity_overrides:\n  not_a_rule: HIGH\n";
537        let issues = validate(yaml);
538        assert!(issues
539            .iter()
540            .any(|i| i.message.contains("unknown rule ID 'not_a_rule'")));
541    }
542
543    #[test]
544    fn test_allowlist_blocklist_overlap() {
545        let yaml = "allowlist:\n  - example.com\nblocklist:\n  - example.com\n";
546        let issues = validate(yaml);
547        assert!(issues
548            .iter()
549            .any(|i| i.message.contains("both allowlist and blocklist")));
550    }
551
552    #[test]
553    fn test_custom_rule_bad_regex() {
554        let yaml = r#"
555custom_rules:
556  - id: test
557    pattern: "[invalid"
558    title: "Test rule"
559"#;
560        let issues = validate(yaml);
561        assert!(issues.iter().any(|i| i.message.contains("invalid regex")));
562    }
563
564    #[test]
565    fn test_unknown_field() {
566        let yaml = "not_a_real_field: true\n";
567        let issues = validate(yaml);
568        assert!(issues.iter().any(|i| i.message.contains("unknown field")));
569    }
570
571    #[test]
572    fn test_nested_scan_profile_unknown_field() {
573        let yaml = "scan:\n  profiles:\n    ci:\n      nope: true\n";
574        let issues = validate(yaml);
575        assert!(
576            issues
577                .iter()
578                .any(|i| i.message.contains("scan.profiles.ci.nope")),
579            "nested profile typo should be flagged: {issues:?}"
580        );
581    }
582}