1use anyhow::Result;
2use clap::Args;
3use console::style;
4use serde::{Deserialize, Serialize};
5
6use crate::config::Manifest;
7use crate::github::Client;
8use crate::github::branch_protection::BranchProtectionState;
9use crate::github::security::SecurityState;
10
11#[derive(Args)]
12pub struct PolicyCommand {
13 #[command(subcommand)]
14 action: PolicyAction,
15}
16
17#[derive(clap::Subcommand)]
18enum PolicyAction {
19 Check,
21
22 List,
24}
25
26#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
27pub struct PolicyRule {
28 pub name: String,
29 pub rule: String,
30 #[serde(default = "default_error")]
31 pub severity: PolicySeverity,
32}
33
34fn default_error() -> PolicySeverity {
35 PolicySeverity::Error
36}
37
38#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
39#[serde(rename_all = "lowercase")]
40pub enum PolicySeverity {
41 Error,
42 Warning,
43}
44
45impl std::fmt::Display for PolicySeverity {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 match self {
48 PolicySeverity::Error => write!(f, "error"),
49 PolicySeverity::Warning => write!(f, "warning"),
50 }
51 }
52}
53
54#[derive(Debug)]
55struct RepoContext {
56 visibility: String,
57 archived: bool,
58 security: SecurityState,
59 branch_protection: BranchProtectionState,
60}
61
62#[derive(Debug, Serialize)]
63struct Violation {
64 repo: String,
65 policy: String,
66 severity: String,
67 rule: String,
68}
69
70#[derive(Debug)]
71enum ParsedRule {
72 BoolField {
73 path: Vec<String>,
74 negated: bool,
75 },
76 Comparison {
77 path: Vec<String>,
78 op: CmpOp,
79 value: CmpValue,
80 },
81}
82
83#[derive(Debug)]
84enum CmpOp {
85 Eq,
86 Ne,
87 Ge,
88 Le,
89 Gt,
90 Lt,
91}
92
93#[derive(Debug)]
94enum CmpValue {
95 Number(f64),
96 Str(String),
97}
98
99impl PolicyCommand {
100 pub async fn run(
101 &self,
102 client: &Client,
103 manifest: &Manifest,
104 system: Option<&str>,
105 repo: Option<&str>,
106 json: bool,
107 ) -> Result<()> {
108 match &self.action {
109 PolicyAction::Check => check(client, manifest, system, repo, json).await,
110 PolicyAction::List => list(manifest, json),
111 }
112 }
113}
114
115fn list(manifest: &Manifest, json: bool) -> Result<()> {
116 if manifest.policies.is_empty() {
117 println!("\n No policies configured in ward.toml");
118 return Ok(());
119 }
120
121 if json {
122 println!(
123 "{}",
124 serde_json::to_string_pretty(&manifest.policies).unwrap_or_default()
125 );
126 return Ok(());
127 }
128
129 println!();
130 println!(" {}", style("Configured Policies").bold().cyan());
131 println!(" {}", style("-".repeat(60)).dim());
132
133 for p in &manifest.policies {
134 let sev = match p.severity {
135 PolicySeverity::Error => style("error").red().bold(),
136 PolicySeverity::Warning => style("warning").yellow(),
137 };
138 println!(
139 " {} [{}] {}",
140 style(&p.name).bold(),
141 sev,
142 style(&p.rule).dim()
143 );
144 }
145
146 Ok(())
147}
148
149async fn check(
150 client: &Client,
151 manifest: &Manifest,
152 system: Option<&str>,
153 repo: Option<&str>,
154 json: bool,
155) -> Result<()> {
156 if manifest.policies.is_empty() {
157 anyhow::bail!("No policies configured in ward.toml. Add [[policies]] entries first.");
158 }
159
160 let repos = resolve_repos(client, manifest, system, repo).await?;
161
162 if !json {
163 println!(
164 "\n {} Checking {} repos against {} policies...",
165 style("[..]").dim(),
166 repos.len(),
167 manifest.policies.len()
168 );
169 }
170
171 let mut violations = Vec::new();
172
173 for repo_info in &repos {
174 let (sec_result, prot_result) = tokio::join!(
175 client.get_security_state(&repo_info.name),
176 client.get_branch_protection(&repo_info.name, &repo_info.default_branch)
177 );
178
179 let ctx = RepoContext {
180 visibility: repo_info.visibility.clone(),
181 archived: repo_info.archived,
182 security: sec_result.unwrap_or_default(),
183 branch_protection: prot_result.unwrap_or(None).unwrap_or_default(),
184 };
185
186 for policy in &manifest.policies {
187 match parse_rule(&policy.rule) {
188 Ok(parsed) => {
189 if !evaluate_rule(&parsed, &ctx) {
190 violations.push(Violation {
191 repo: repo_info.name.clone(),
192 policy: policy.name.clone(),
193 severity: policy.severity.to_string(),
194 rule: policy.rule.clone(),
195 });
196 }
197 }
198 Err(e) => {
199 if !json {
200 println!(
201 " {} Skipping policy '{}': {}",
202 style("[!!]").yellow(),
203 policy.name,
204 e
205 );
206 }
207 }
208 }
209 }
210 }
211
212 if json {
213 println!(
214 "{}",
215 serde_json::to_string_pretty(&violations).unwrap_or_default()
216 );
217 } else {
218 print_violations(&violations);
219 }
220
221 let error_count = violations.iter().filter(|v| v.severity == "error").count();
222 if error_count > 0 {
223 std::process::exit(1);
224 }
225
226 Ok(())
227}
228
229async fn resolve_repos(
230 client: &Client,
231 manifest: &Manifest,
232 system: Option<&str>,
233 repo: Option<&str>,
234) -> Result<Vec<crate::github::repos::Repository>> {
235 if let Some(repo_name) = repo {
236 let r = client.get_repo(repo_name).await?;
237 return Ok(vec![r]);
238 }
239
240 if let Some(sys) = system {
241 let excludes = manifest.exclude_patterns_for_system(sys);
242 let explicit = manifest.explicit_repos_for_system(sys);
243 return client
244 .list_repos_for_system(sys, &excludes, &explicit)
245 .await;
246 }
247
248 client.list_repos().await
249}
250
251fn print_violations(violations: &[Violation]) {
252 if violations.is_empty() {
253 println!(
254 "\n {} All repos comply with all policies.",
255 style("[ok]").green()
256 );
257 return;
258 }
259
260 println!();
261
262 let mut current_repo = "";
263 for v in violations {
264 if v.repo != current_repo {
265 current_repo = &v.repo;
266 println!(" {}", style(&v.repo).bold());
267 }
268
269 let sev = if v.severity == "error" {
270 style(&v.severity).red()
271 } else {
272 style(&v.severity).yellow()
273 };
274
275 println!(
276 " {} [{}] {} ({})",
277 style("-").dim(),
278 sev,
279 v.policy,
280 style(&v.rule).dim()
281 );
282 }
283
284 let errors = violations.iter().filter(|v| v.severity == "error").count();
285 let warnings = violations
286 .iter()
287 .filter(|v| v.severity == "warning")
288 .count();
289
290 println!();
291 println!(
292 " Summary: {} errors, {} warnings",
293 if errors > 0 {
294 style(errors).red().bold()
295 } else {
296 style(errors).green().bold()
297 },
298 if warnings > 0 {
299 style(warnings).yellow().bold()
300 } else {
301 style(warnings).green().bold()
302 }
303 );
304}
305
306fn parse_rule(rule: &str) -> Result<ParsedRule> {
307 let rule = rule.trim();
308
309 if let Some(rest) = rule.strip_prefix('!') {
311 let path = parse_path(rest.trim())?;
312 return Ok(ParsedRule::BoolField {
313 path,
314 negated: true,
315 });
316 }
317
318 let ops = [">=", "<=", "!=", "==", ">", "<"];
320 for op_str in ops {
321 if let Some(pos) = rule.find(op_str) {
322 let lhs = rule[..pos].trim();
323 let rhs = rule[pos + op_str.len()..].trim();
324 let path = parse_path(lhs)?;
325 let op = match op_str {
326 ">=" => CmpOp::Ge,
327 "<=" => CmpOp::Le,
328 "!=" => CmpOp::Ne,
329 "==" => CmpOp::Eq,
330 ">" => CmpOp::Gt,
331 "<" => CmpOp::Lt,
332 _ => unreachable!(),
333 };
334 let value = parse_value(rhs)?;
335 return Ok(ParsedRule::Comparison { path, op, value });
336 }
337 }
338
339 let path = parse_path(rule)?;
341 Ok(ParsedRule::BoolField {
342 path,
343 negated: false,
344 })
345}
346
347fn parse_path(s: &str) -> Result<Vec<String>> {
348 let parts: Vec<String> = s.split('.').map(|p| p.trim().to_string()).collect();
349 if parts.is_empty() || parts.iter().any(|p| p.is_empty()) {
350 anyhow::bail!("Invalid field path: {s}");
351 }
352 Ok(parts)
353}
354
355fn parse_value(s: &str) -> Result<CmpValue> {
356 let s = s.trim();
357 if (s.starts_with('\'') && s.ends_with('\'')) || (s.starts_with('"') && s.ends_with('"')) {
358 return Ok(CmpValue::Str(s[1..s.len() - 1].to_string()));
359 }
360 if let Ok(n) = s.parse::<f64>() {
361 return Ok(CmpValue::Number(n));
362 }
363 anyhow::bail!("Cannot parse value: {s}")
364}
365
366fn evaluate_rule(rule: &ParsedRule, ctx: &RepoContext) -> bool {
367 match rule {
368 ParsedRule::BoolField { path, negated } => {
369 let val = resolve_bool(path, ctx);
370 if *negated { !val } else { val }
371 }
372 ParsedRule::Comparison { path, op, value } => match value {
373 CmpValue::Number(expected) => {
374 let actual = resolve_number(path, ctx);
375 match op {
376 CmpOp::Ge => actual >= *expected,
377 CmpOp::Le => actual <= *expected,
378 CmpOp::Gt => actual > *expected,
379 CmpOp::Lt => actual < *expected,
380 CmpOp::Eq => (actual - expected).abs() < f64::EPSILON,
381 CmpOp::Ne => (actual - expected).abs() >= f64::EPSILON,
382 }
383 }
384 CmpValue::Str(expected) => {
385 let actual = resolve_string(path, ctx);
386 match op {
387 CmpOp::Eq => actual == *expected,
388 CmpOp::Ne => actual != *expected,
389 _ => false,
390 }
391 }
392 },
393 }
394}
395
396fn resolve_bool(path: &[String], ctx: &RepoContext) -> bool {
397 match path.first().map(String::as_str) {
398 Some("security") => match path.get(1).map(String::as_str) {
399 Some("secret_scanning") => ctx.security.secret_scanning,
400 Some("push_protection") => ctx.security.push_protection,
401 Some("dependabot_alerts") => ctx.security.dependabot_alerts,
402 Some("dependabot_security_updates") => ctx.security.dependabot_security_updates,
403 Some("secret_scanning_ai_detection") => ctx.security.secret_scanning_ai_detection,
404 _ => false,
405 },
406 Some("branch_protection") => match path.get(1).map(String::as_str) {
407 Some("enabled") => ctx.branch_protection.required_pull_request_reviews,
408 Some("dismiss_stale_reviews") => ctx.branch_protection.dismiss_stale_reviews,
409 Some("require_code_owner_reviews") => ctx.branch_protection.require_code_owner_reviews,
410 Some("require_status_checks") => ctx.branch_protection.required_status_checks,
411 Some("strict_status_checks") => ctx.branch_protection.strict_status_checks,
412 Some("enforce_admins") => ctx.branch_protection.enforce_admins,
413 Some("required_linear_history") => ctx.branch_protection.required_linear_history,
414 Some("allow_force_pushes") => ctx.branch_protection.allow_force_pushes,
415 Some("allow_deletions") => ctx.branch_protection.allow_deletions,
416 _ => false,
417 },
418 Some("archived") => ctx.archived,
419 _ => false,
420 }
421}
422
423fn resolve_number(path: &[String], ctx: &RepoContext) -> f64 {
424 match path.first().map(String::as_str) {
425 Some("branch_protection") => match path.get(1).map(String::as_str) {
426 Some("required_approvals") => {
427 ctx.branch_protection.required_approving_review_count as f64
428 }
429 _ => 0.0,
430 },
431 _ => 0.0,
432 }
433}
434
435fn resolve_string(path: &[String], ctx: &RepoContext) -> String {
436 match path.first().map(String::as_str) {
437 Some("visibility") => ctx.visibility.clone(),
438 _ => String::new(),
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445
446 fn make_ctx() -> RepoContext {
447 RepoContext {
448 visibility: "private".to_string(),
449 archived: false,
450 security: SecurityState {
451 secret_scanning: true,
452 push_protection: false,
453 dependabot_alerts: true,
454 dependabot_security_updates: true,
455 secret_scanning_ai_detection: false,
456 },
457 branch_protection: BranchProtectionState {
458 required_pull_request_reviews: true,
459 required_approving_review_count: 2,
460 dismiss_stale_reviews: true,
461 require_code_owner_reviews: false,
462 required_status_checks: true,
463 strict_status_checks: false,
464 enforce_admins: false,
465 required_linear_history: false,
466 allow_force_pushes: true,
467 allow_deletions: false,
468 },
469 }
470 }
471
472 #[test]
473 fn test_parse_boolean_rule() {
474 let parsed = parse_rule("security.secret_scanning").unwrap();
475 match parsed {
476 ParsedRule::BoolField { path, negated } => {
477 assert_eq!(path, vec!["security", "secret_scanning"]);
478 assert!(!negated);
479 }
480 _ => panic!("expected BoolField"),
481 }
482 }
483
484 #[test]
485 fn test_parse_negated_rule() {
486 let parsed = parse_rule("!branch_protection.allow_force_pushes").unwrap();
487 match parsed {
488 ParsedRule::BoolField { path, negated } => {
489 assert_eq!(path, vec!["branch_protection", "allow_force_pushes"]);
490 assert!(negated);
491 }
492 _ => panic!("expected negated BoolField"),
493 }
494 }
495
496 #[test]
497 fn test_parse_comparison_rule() {
498 let parsed = parse_rule("branch_protection.required_approvals >= 2").unwrap();
499 match parsed {
500 ParsedRule::Comparison { path, op, value } => {
501 assert_eq!(path, vec!["branch_protection", "required_approvals"]);
502 assert!(matches!(op, CmpOp::Ge));
503 assert!(matches!(value, CmpValue::Number(n) if (n - 2.0).abs() < f64::EPSILON));
504 }
505 _ => panic!("expected Comparison"),
506 }
507 }
508
509 #[test]
510 fn test_parse_string_rule() {
511 let parsed = parse_rule("visibility != 'public'").unwrap();
512 match parsed {
513 ParsedRule::Comparison { path, op, value } => {
514 assert_eq!(path, vec!["visibility"]);
515 assert!(matches!(op, CmpOp::Ne));
516 assert!(matches!(value, CmpValue::Str(ref s) if s == "public"));
517 }
518 _ => panic!("expected Comparison"),
519 }
520 }
521
522 #[test]
523 fn test_evaluate_policy_pass() {
524 let ctx = make_ctx();
525
526 let rule = parse_rule("security.secret_scanning").unwrap();
528 assert!(evaluate_rule(&rule, &ctx));
529
530 let rule = parse_rule("visibility != 'public'").unwrap();
532 assert!(evaluate_rule(&rule, &ctx));
533
534 let rule = parse_rule("branch_protection.required_approvals >= 2").unwrap();
536 assert!(evaluate_rule(&rule, &ctx));
537 }
538
539 #[test]
540 fn test_evaluate_policy_fail() {
541 let ctx = make_ctx();
542
543 let rule = parse_rule("security.push_protection").unwrap();
545 assert!(!evaluate_rule(&rule, &ctx));
546
547 let rule = parse_rule("!branch_protection.allow_force_pushes").unwrap();
549 assert!(!evaluate_rule(&rule, &ctx));
550
551 let rule = parse_rule("branch_protection.required_approvals >= 3").unwrap();
553 assert!(!evaluate_rule(&rule, &ctx));
554 }
555
556 #[test]
557 fn test_parse_equality_string() {
558 let parsed = parse_rule("visibility == 'private'").unwrap();
559 match parsed {
560 ParsedRule::Comparison { path, op, value } => {
561 assert_eq!(path, vec!["visibility"]);
562 assert!(matches!(op, CmpOp::Eq));
563 assert!(matches!(value, CmpValue::Str(ref s) if s == "private"));
564 }
565 _ => panic!("expected Comparison"),
566 }
567 }
568
569 #[test]
570 fn test_evaluate_archived_bool() {
571 let ctx = make_ctx();
572 let rule = parse_rule("!archived").unwrap();
573 assert!(evaluate_rule(&rule, &ctx)); }
575
576 #[test]
577 fn test_policy_rule_serde() {
578 let toml_str = r#"
579 name = "no-public"
580 rule = "visibility != 'public'"
581 severity = "error"
582 "#;
583 let rule: PolicyRule = toml::from_str(toml_str).unwrap();
584 assert_eq!(rule.name, "no-public");
585 assert_eq!(rule.severity, PolicySeverity::Error);
586 }
587
588 #[test]
589 fn test_policy_severity_default() {
590 let toml_str = r#"
591 name = "test"
592 rule = "security.secret_scanning"
593 "#;
594 let rule: PolicyRule = toml::from_str(toml_str).unwrap();
595 assert_eq!(rule.severity, PolicySeverity::Error);
596 }
597}