Skip to main content

ward/cli/
policy.rs

1use anyhow::Result;
2use clap::Args;
3use console::style;
4use serde::{Deserialize, Serialize};
5
6use crate::config::Manifest;
7use crate::github::Client;
8use crate::github::branch_protection::BranchProtectionState;
9use crate::github::security::SecurityState;
10
11#[derive(Args)]
12pub struct PolicyCommand {
13    #[command(subcommand)]
14    action: PolicyAction,
15}
16
17#[derive(clap::Subcommand)]
18enum PolicyAction {
19    /// Check all repos against policies
20    Check,
21
22    /// List configured policies
23    List,
24}
25
26#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
27pub struct PolicyRule {
28    pub name: String,
29    pub rule: String,
30    #[serde(default = "default_error")]
31    pub severity: PolicySeverity,
32}
33
34fn default_error() -> PolicySeverity {
35    PolicySeverity::Error
36}
37
38#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
39#[serde(rename_all = "lowercase")]
40pub enum PolicySeverity {
41    Error,
42    Warning,
43}
44
45impl std::fmt::Display for PolicySeverity {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        match self {
48            PolicySeverity::Error => write!(f, "error"),
49            PolicySeverity::Warning => write!(f, "warning"),
50        }
51    }
52}
53
54#[derive(Debug)]
55struct RepoContext {
56    visibility: String,
57    archived: bool,
58    security: SecurityState,
59    branch_protection: BranchProtectionState,
60}
61
62#[derive(Debug, Serialize)]
63struct Violation {
64    repo: String,
65    policy: String,
66    severity: String,
67    rule: String,
68}
69
70#[derive(Debug)]
71enum ParsedRule {
72    BoolField {
73        path: Vec<String>,
74        negated: bool,
75    },
76    Comparison {
77        path: Vec<String>,
78        op: CmpOp,
79        value: CmpValue,
80    },
81}
82
83#[derive(Debug)]
84enum CmpOp {
85    Eq,
86    Ne,
87    Ge,
88    Le,
89    Gt,
90    Lt,
91}
92
93#[derive(Debug)]
94enum CmpValue {
95    Number(f64),
96    Str(String),
97}
98
99impl PolicyCommand {
100    pub async fn run(
101        &self,
102        client: &Client,
103        manifest: &Manifest,
104        system: Option<&str>,
105        repo: Option<&str>,
106        json: bool,
107    ) -> Result<()> {
108        match &self.action {
109            PolicyAction::Check => check(client, manifest, system, repo, json).await,
110            PolicyAction::List => list(manifest, json),
111        }
112    }
113}
114
115fn list(manifest: &Manifest, json: bool) -> Result<()> {
116    if manifest.policies.is_empty() {
117        println!("\n  No policies configured in ward.toml");
118        return Ok(());
119    }
120
121    if json {
122        println!(
123            "{}",
124            serde_json::to_string_pretty(&manifest.policies).unwrap_or_default()
125        );
126        return Ok(());
127    }
128
129    println!();
130    println!("  {}", style("Configured Policies").bold().cyan());
131    println!("  {}", style("-".repeat(60)).dim());
132
133    for p in &manifest.policies {
134        let sev = match p.severity {
135            PolicySeverity::Error => style("error").red().bold(),
136            PolicySeverity::Warning => style("warning").yellow(),
137        };
138        println!(
139            "  {} [{}] {}",
140            style(&p.name).bold(),
141            sev,
142            style(&p.rule).dim()
143        );
144    }
145
146    Ok(())
147}
148
149async fn check(
150    client: &Client,
151    manifest: &Manifest,
152    system: Option<&str>,
153    repo: Option<&str>,
154    json: bool,
155) -> Result<()> {
156    if manifest.policies.is_empty() {
157        anyhow::bail!("No policies configured in ward.toml. Add [[policies]] entries first.");
158    }
159
160    let repos = resolve_repos(client, manifest, system, repo).await?;
161
162    if !json {
163        println!(
164            "\n  {} Checking {} repos against {} policies...",
165            style("[..]").dim(),
166            repos.len(),
167            manifest.policies.len()
168        );
169    }
170
171    let mut violations = Vec::new();
172
173    for repo_info in &repos {
174        let (sec_result, prot_result) = tokio::join!(
175            client.get_security_state(&repo_info.name),
176            client.get_branch_protection(&repo_info.name, &repo_info.default_branch)
177        );
178
179        let ctx = RepoContext {
180            visibility: repo_info.visibility.clone(),
181            archived: repo_info.archived,
182            security: sec_result.unwrap_or_default(),
183            branch_protection: prot_result.unwrap_or(None).unwrap_or_default(),
184        };
185
186        for policy in &manifest.policies {
187            match parse_rule(&policy.rule) {
188                Ok(parsed) => {
189                    if !evaluate_rule(&parsed, &ctx) {
190                        violations.push(Violation {
191                            repo: repo_info.name.clone(),
192                            policy: policy.name.clone(),
193                            severity: policy.severity.to_string(),
194                            rule: policy.rule.clone(),
195                        });
196                    }
197                }
198                Err(e) => {
199                    if !json {
200                        println!(
201                            "  {} Skipping policy '{}': {}",
202                            style("[!!]").yellow(),
203                            policy.name,
204                            e
205                        );
206                    }
207                }
208            }
209        }
210    }
211
212    if json {
213        println!(
214            "{}",
215            serde_json::to_string_pretty(&violations).unwrap_or_default()
216        );
217    } else {
218        print_violations(&violations);
219    }
220
221    let error_count = violations.iter().filter(|v| v.severity == "error").count();
222    if error_count > 0 {
223        std::process::exit(1);
224    }
225
226    Ok(())
227}
228
229async fn resolve_repos(
230    client: &Client,
231    manifest: &Manifest,
232    system: Option<&str>,
233    repo: Option<&str>,
234) -> Result<Vec<crate::github::repos::Repository>> {
235    if let Some(repo_name) = repo {
236        let r = client.get_repo(repo_name).await?;
237        return Ok(vec![r]);
238    }
239
240    if let Some(sys) = system {
241        let excludes = manifest.exclude_patterns_for_system(sys);
242        let explicit = manifest.explicit_repos_for_system(sys);
243        return client
244            .list_repos_for_system(sys, &excludes, &explicit)
245            .await;
246    }
247
248    client.list_repos().await
249}
250
251fn print_violations(violations: &[Violation]) {
252    if violations.is_empty() {
253        println!(
254            "\n  {} All repos comply with all policies.",
255            style("[ok]").green()
256        );
257        return;
258    }
259
260    println!();
261
262    let mut current_repo = "";
263    for v in violations {
264        if v.repo != current_repo {
265            current_repo = &v.repo;
266            println!("  {}", style(&v.repo).bold());
267        }
268
269        let sev = if v.severity == "error" {
270            style(&v.severity).red()
271        } else {
272            style(&v.severity).yellow()
273        };
274
275        println!(
276            "    {} [{}] {} ({})",
277            style("-").dim(),
278            sev,
279            v.policy,
280            style(&v.rule).dim()
281        );
282    }
283
284    let errors = violations.iter().filter(|v| v.severity == "error").count();
285    let warnings = violations
286        .iter()
287        .filter(|v| v.severity == "warning")
288        .count();
289
290    println!();
291    println!(
292        "  Summary: {} errors, {} warnings",
293        if errors > 0 {
294            style(errors).red().bold()
295        } else {
296            style(errors).green().bold()
297        },
298        if warnings > 0 {
299            style(warnings).yellow().bold()
300        } else {
301            style(warnings).green().bold()
302        }
303    );
304}
305
306fn parse_rule(rule: &str) -> Result<ParsedRule> {
307    let rule = rule.trim();
308
309    // Negated boolean: !field.subfield
310    if let Some(rest) = rule.strip_prefix('!') {
311        let path = parse_path(rest.trim())?;
312        return Ok(ParsedRule::BoolField {
313            path,
314            negated: true,
315        });
316    }
317
318    // Comparison operators: >=, <=, !=, ==, >, <
319    let ops = [">=", "<=", "!=", "==", ">", "<"];
320    for op_str in ops {
321        if let Some(pos) = rule.find(op_str) {
322            let lhs = rule[..pos].trim();
323            let rhs = rule[pos + op_str.len()..].trim();
324            let path = parse_path(lhs)?;
325            let op = match op_str {
326                ">=" => CmpOp::Ge,
327                "<=" => CmpOp::Le,
328                "!=" => CmpOp::Ne,
329                "==" => CmpOp::Eq,
330                ">" => CmpOp::Gt,
331                "<" => CmpOp::Lt,
332                _ => unreachable!(),
333            };
334            let value = parse_value(rhs)?;
335            return Ok(ParsedRule::Comparison { path, op, value });
336        }
337    }
338
339    // Simple boolean: field.subfield
340    let path = parse_path(rule)?;
341    Ok(ParsedRule::BoolField {
342        path,
343        negated: false,
344    })
345}
346
347fn parse_path(s: &str) -> Result<Vec<String>> {
348    let parts: Vec<String> = s.split('.').map(|p| p.trim().to_string()).collect();
349    if parts.is_empty() || parts.iter().any(|p| p.is_empty()) {
350        anyhow::bail!("Invalid field path: {s}");
351    }
352    Ok(parts)
353}
354
355fn parse_value(s: &str) -> Result<CmpValue> {
356    let s = s.trim();
357    if (s.starts_with('\'') && s.ends_with('\'')) || (s.starts_with('"') && s.ends_with('"')) {
358        return Ok(CmpValue::Str(s[1..s.len() - 1].to_string()));
359    }
360    if let Ok(n) = s.parse::<f64>() {
361        return Ok(CmpValue::Number(n));
362    }
363    anyhow::bail!("Cannot parse value: {s}")
364}
365
366fn evaluate_rule(rule: &ParsedRule, ctx: &RepoContext) -> bool {
367    match rule {
368        ParsedRule::BoolField { path, negated } => {
369            let val = resolve_bool(path, ctx);
370            if *negated { !val } else { val }
371        }
372        ParsedRule::Comparison { path, op, value } => match value {
373            CmpValue::Number(expected) => {
374                let actual = resolve_number(path, ctx);
375                match op {
376                    CmpOp::Ge => actual >= *expected,
377                    CmpOp::Le => actual <= *expected,
378                    CmpOp::Gt => actual > *expected,
379                    CmpOp::Lt => actual < *expected,
380                    CmpOp::Eq => (actual - expected).abs() < f64::EPSILON,
381                    CmpOp::Ne => (actual - expected).abs() >= f64::EPSILON,
382                }
383            }
384            CmpValue::Str(expected) => {
385                let actual = resolve_string(path, ctx);
386                match op {
387                    CmpOp::Eq => actual == *expected,
388                    CmpOp::Ne => actual != *expected,
389                    _ => false,
390                }
391            }
392        },
393    }
394}
395
396fn resolve_bool(path: &[String], ctx: &RepoContext) -> bool {
397    match path.first().map(String::as_str) {
398        Some("security") => match path.get(1).map(String::as_str) {
399            Some("secret_scanning") => ctx.security.secret_scanning,
400            Some("push_protection") => ctx.security.push_protection,
401            Some("dependabot_alerts") => ctx.security.dependabot_alerts,
402            Some("dependabot_security_updates") => ctx.security.dependabot_security_updates,
403            Some("secret_scanning_ai_detection") => ctx.security.secret_scanning_ai_detection,
404            _ => false,
405        },
406        Some("branch_protection") => match path.get(1).map(String::as_str) {
407            Some("enabled") => ctx.branch_protection.required_pull_request_reviews,
408            Some("dismiss_stale_reviews") => ctx.branch_protection.dismiss_stale_reviews,
409            Some("require_code_owner_reviews") => ctx.branch_protection.require_code_owner_reviews,
410            Some("require_status_checks") => ctx.branch_protection.required_status_checks,
411            Some("strict_status_checks") => ctx.branch_protection.strict_status_checks,
412            Some("enforce_admins") => ctx.branch_protection.enforce_admins,
413            Some("required_linear_history") => ctx.branch_protection.required_linear_history,
414            Some("allow_force_pushes") => ctx.branch_protection.allow_force_pushes,
415            Some("allow_deletions") => ctx.branch_protection.allow_deletions,
416            _ => false,
417        },
418        Some("archived") => ctx.archived,
419        _ => false,
420    }
421}
422
423fn resolve_number(path: &[String], ctx: &RepoContext) -> f64 {
424    match path.first().map(String::as_str) {
425        Some("branch_protection") => match path.get(1).map(String::as_str) {
426            Some("required_approvals") => {
427                ctx.branch_protection.required_approving_review_count as f64
428            }
429            _ => 0.0,
430        },
431        _ => 0.0,
432    }
433}
434
435fn resolve_string(path: &[String], ctx: &RepoContext) -> String {
436    match path.first().map(String::as_str) {
437        Some("visibility") => ctx.visibility.clone(),
438        _ => String::new(),
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445
446    fn make_ctx() -> RepoContext {
447        RepoContext {
448            visibility: "private".to_string(),
449            archived: false,
450            security: SecurityState {
451                secret_scanning: true,
452                push_protection: false,
453                dependabot_alerts: true,
454                dependabot_security_updates: true,
455                secret_scanning_ai_detection: false,
456            },
457            branch_protection: BranchProtectionState {
458                required_pull_request_reviews: true,
459                required_approving_review_count: 2,
460                dismiss_stale_reviews: true,
461                require_code_owner_reviews: false,
462                required_status_checks: true,
463                strict_status_checks: false,
464                enforce_admins: false,
465                required_linear_history: false,
466                allow_force_pushes: true,
467                allow_deletions: false,
468            },
469        }
470    }
471
472    #[test]
473    fn test_parse_boolean_rule() {
474        let parsed = parse_rule("security.secret_scanning").unwrap();
475        match parsed {
476            ParsedRule::BoolField { path, negated } => {
477                assert_eq!(path, vec!["security", "secret_scanning"]);
478                assert!(!negated);
479            }
480            _ => panic!("expected BoolField"),
481        }
482    }
483
484    #[test]
485    fn test_parse_negated_rule() {
486        let parsed = parse_rule("!branch_protection.allow_force_pushes").unwrap();
487        match parsed {
488            ParsedRule::BoolField { path, negated } => {
489                assert_eq!(path, vec!["branch_protection", "allow_force_pushes"]);
490                assert!(negated);
491            }
492            _ => panic!("expected negated BoolField"),
493        }
494    }
495
496    #[test]
497    fn test_parse_comparison_rule() {
498        let parsed = parse_rule("branch_protection.required_approvals >= 2").unwrap();
499        match parsed {
500            ParsedRule::Comparison { path, op, value } => {
501                assert_eq!(path, vec!["branch_protection", "required_approvals"]);
502                assert!(matches!(op, CmpOp::Ge));
503                assert!(matches!(value, CmpValue::Number(n) if (n - 2.0).abs() < f64::EPSILON));
504            }
505            _ => panic!("expected Comparison"),
506        }
507    }
508
509    #[test]
510    fn test_parse_string_rule() {
511        let parsed = parse_rule("visibility != 'public'").unwrap();
512        match parsed {
513            ParsedRule::Comparison { path, op, value } => {
514                assert_eq!(path, vec!["visibility"]);
515                assert!(matches!(op, CmpOp::Ne));
516                assert!(matches!(value, CmpValue::Str(ref s) if s == "public"));
517            }
518            _ => panic!("expected Comparison"),
519        }
520    }
521
522    #[test]
523    fn test_evaluate_policy_pass() {
524        let ctx = make_ctx();
525
526        // security.secret_scanning is true -- should pass
527        let rule = parse_rule("security.secret_scanning").unwrap();
528        assert!(evaluate_rule(&rule, &ctx));
529
530        // visibility != 'public' -- we are 'private' -- should pass
531        let rule = parse_rule("visibility != 'public'").unwrap();
532        assert!(evaluate_rule(&rule, &ctx));
533
534        // required_approvals >= 2 -- we have 2 -- should pass
535        let rule = parse_rule("branch_protection.required_approvals >= 2").unwrap();
536        assert!(evaluate_rule(&rule, &ctx));
537    }
538
539    #[test]
540    fn test_evaluate_policy_fail() {
541        let ctx = make_ctx();
542
543        // push_protection is false -- should fail
544        let rule = parse_rule("security.push_protection").unwrap();
545        assert!(!evaluate_rule(&rule, &ctx));
546
547        // !allow_force_pushes -- allow_force_pushes is true, so negation is false -- should fail
548        let rule = parse_rule("!branch_protection.allow_force_pushes").unwrap();
549        assert!(!evaluate_rule(&rule, &ctx));
550
551        // required_approvals >= 3 -- we have 2 -- should fail
552        let rule = parse_rule("branch_protection.required_approvals >= 3").unwrap();
553        assert!(!evaluate_rule(&rule, &ctx));
554    }
555
556    #[test]
557    fn test_parse_equality_string() {
558        let parsed = parse_rule("visibility == 'private'").unwrap();
559        match parsed {
560            ParsedRule::Comparison { path, op, value } => {
561                assert_eq!(path, vec!["visibility"]);
562                assert!(matches!(op, CmpOp::Eq));
563                assert!(matches!(value, CmpValue::Str(ref s) if s == "private"));
564            }
565            _ => panic!("expected Comparison"),
566        }
567    }
568
569    #[test]
570    fn test_evaluate_archived_bool() {
571        let ctx = make_ctx();
572        let rule = parse_rule("!archived").unwrap();
573        assert!(evaluate_rule(&rule, &ctx)); // archived is false, !false = true
574    }
575
576    #[test]
577    fn test_policy_rule_serde() {
578        let toml_str = r#"
579            name = "no-public"
580            rule = "visibility != 'public'"
581            severity = "error"
582        "#;
583        let rule: PolicyRule = toml::from_str(toml_str).unwrap();
584        assert_eq!(rule.name, "no-public");
585        assert_eq!(rule.severity, PolicySeverity::Error);
586    }
587
588    #[test]
589    fn test_policy_severity_default() {
590        let toml_str = r#"
591            name = "test"
592            rule = "security.secret_scanning"
593        "#;
594        let rule: PolicyRule = toml::from_str(toml_str).unwrap();
595        assert_eq!(rule.severity, PolicySeverity::Error);
596    }
597}