1use anyhow::Result;
2use clap::Args;
3use console::style;
4use dialoguer::Confirm;
5
6use crate::config::Manifest;
7use crate::engine::audit_log::AuditLog;
8use crate::github::Client;
9use crate::github::branch_protection::BranchProtectionState;
10
11#[derive(Args)]
12pub struct ProtectionCommand {
13 #[command(subcommand)]
14 action: ProtectionAction,
15}
16
17#[derive(clap::Subcommand)]
18enum ProtectionAction {
19 Plan,
21
22 Apply {
24 #[arg(long)]
26 yes: bool,
27 },
28
29 Audit,
31}
32
33impl ProtectionCommand {
34 pub async fn run(
35 &self,
36 client: &Client,
37 manifest: &Manifest,
38 system: Option<&str>,
39 repo: Option<&str>,
40 ) -> Result<()> {
41 match &self.action {
42 ProtectionAction::Plan => plan(client, manifest, system, repo).await,
43 ProtectionAction::Apply { yes } => apply(client, manifest, system, repo, *yes).await,
44 ProtectionAction::Audit => audit(client, manifest, system, repo).await,
45 }
46 }
47}
48
49async fn resolve_repos_with_branches(
50 client: &Client,
51 manifest: &Manifest,
52 system: Option<&str>,
53 repo: Option<&str>,
54) -> Result<Vec<(String, String)>> {
55 if let Some(repo_name) = repo {
56 let r = client.get_repo(repo_name).await?;
57 return Ok(vec![(r.name, r.default_branch)]);
58 }
59
60 let sys = system.ok_or_else(|| {
61 anyhow::anyhow!("Either --system or --repo is required for protection commands")
62 })?;
63
64 let excludes = manifest.exclude_patterns_for_system(sys);
65 let explicit = manifest.explicit_repos_for_system(sys);
66 let repos = client
67 .list_repos_for_system(sys, &excludes, &explicit)
68 .await?;
69 Ok(repos
70 .into_iter()
71 .map(|r| (r.name, r.default_branch))
72 .collect())
73}
74
75struct ProtectionDiff {
76 repo: String,
77 branch: String,
78 changes: Vec<ProtectionChange>,
79}
80
81struct ProtectionChange {
82 field: String,
83 current: String,
84 desired: String,
85}
86
87impl ProtectionDiff {
88 fn has_changes(&self) -> bool {
89 !self.changes.is_empty()
90 }
91}
92
93fn diff_protection(
94 repo: &str,
95 branch: &str,
96 current: &BranchProtectionState,
97 config: &crate::config::manifest::BranchProtectionConfig,
98) -> ProtectionDiff {
99 let mut changes = Vec::new();
100
101 let checks: Vec<(&str, String, String)> = vec![
102 (
103 "required_pull_request_reviews",
104 current.required_pull_request_reviews.to_string(),
105 config.enabled.to_string(),
106 ),
107 (
108 "required_approvals",
109 current.required_approving_review_count.to_string(),
110 config.required_approvals.to_string(),
111 ),
112 (
113 "dismiss_stale_reviews",
114 current.dismiss_stale_reviews.to_string(),
115 config.dismiss_stale_reviews.to_string(),
116 ),
117 (
118 "require_code_owner_reviews",
119 current.require_code_owner_reviews.to_string(),
120 config.require_code_owner_reviews.to_string(),
121 ),
122 (
123 "require_status_checks",
124 current.required_status_checks.to_string(),
125 config.require_status_checks.to_string(),
126 ),
127 (
128 "strict_status_checks",
129 current.strict_status_checks.to_string(),
130 config.strict_status_checks.to_string(),
131 ),
132 (
133 "enforce_admins",
134 current.enforce_admins.to_string(),
135 config.enforce_admins.to_string(),
136 ),
137 (
138 "required_linear_history",
139 current.required_linear_history.to_string(),
140 config.required_linear_history.to_string(),
141 ),
142 (
143 "allow_force_pushes",
144 current.allow_force_pushes.to_string(),
145 config.allow_force_pushes.to_string(),
146 ),
147 (
148 "allow_deletions",
149 current.allow_deletions.to_string(),
150 config.allow_deletions.to_string(),
151 ),
152 ];
153
154 for (field, current_val, desired_val) in checks {
155 if current_val != desired_val {
156 changes.push(ProtectionChange {
157 field: field.to_string(),
158 current: current_val,
159 desired: desired_val,
160 });
161 }
162 }
163
164 ProtectionDiff {
165 repo: repo.to_string(),
166 branch: branch.to_string(),
167 changes,
168 }
169}
170
171async fn build_diffs(
172 client: &Client,
173 manifest: &Manifest,
174 system: Option<&str>,
175 repo: Option<&str>,
176) -> Result<Vec<ProtectionDiff>> {
177 let repos = resolve_repos_with_branches(client, manifest, system, repo).await?;
178 let config = &manifest.branch_protection;
179
180 println!();
181 println!(
182 " {} Scanning {} repositories...",
183 style("🔍").bold(),
184 repos.len()
185 );
186
187 let mut diffs = Vec::new();
188 for (repo_name, default_branch) in &repos {
189 let current = client
190 .get_branch_protection(repo_name, default_branch)
191 .await?
192 .unwrap_or_default();
193
194 diffs.push(diff_protection(repo_name, default_branch, ¤t, config));
195 }
196
197 Ok(diffs)
198}
199
200async fn plan(
201 client: &Client,
202 manifest: &Manifest,
203 system: Option<&str>,
204 repo: Option<&str>,
205) -> Result<()> {
206 let diffs = build_diffs(client, manifest, system, repo).await?;
207
208 print_diff_table(&diffs);
209
210 let needs_changes = diffs.iter().filter(|d| d.has_changes()).count();
211 if needs_changes > 0 {
212 println!(
213 "\n Run {} to apply these changes.",
214 style("ward protection apply").cyan().bold()
215 );
216 }
217
218 Ok(())
219}
220
221async fn apply(
222 client: &Client,
223 manifest: &Manifest,
224 system: Option<&str>,
225 repo: Option<&str>,
226 yes: bool,
227) -> Result<()> {
228 let diffs = build_diffs(client, manifest, system, repo).await?;
229
230 let needs_changes = diffs.iter().filter(|d| d.has_changes()).count();
231 if needs_changes == 0 {
232 println!(
233 "\n {} All repositories are up to date.",
234 style("✅").green()
235 );
236 return Ok(());
237 }
238
239 print_diff_table(&diffs);
240
241 if !yes {
242 let proceed = Confirm::new()
243 .with_prompt(format!(
244 " Apply branch protection to {needs_changes} repositories?"
245 ))
246 .default(false)
247 .interact()?;
248
249 if !proceed {
250 println!(" Aborted.");
251 return Ok(());
252 }
253 }
254
255 println!();
256 println!(" {} Applying changes...", style("⚡").bold());
257
258 let audit_log = AuditLog::new()?;
259 let config = &manifest.branch_protection;
260 let mut succeeded = 0usize;
261 let mut failed: Vec<(String, String)> = Vec::new();
262
263 for diff in diffs.iter().filter(|d| d.has_changes()) {
264 match client
265 .update_branch_protection(&diff.repo, &diff.branch, config)
266 .await
267 {
268 Ok(()) => {
269 println!(
270 " {} {}/{}: ✅ done",
271 style("▶").magenta(),
272 diff.repo,
273 diff.branch
274 );
275 audit_log.log(
276 &diff.repo,
277 "update_branch_protection",
278 "success",
279 false,
280 true,
281 )?;
282 succeeded += 1;
283 }
284 Err(e) => {
285 println!(
286 " {} {}/{}: ❌ {e}",
287 style("▶").magenta(),
288 diff.repo,
289 diff.branch
290 );
291 failed.push((diff.repo.clone(), e.to_string()));
292 }
293 }
294 }
295
296 println!();
297 if failed.is_empty() {
298 println!(
299 " {} All {} repositories updated successfully.",
300 style("✅").green(),
301 succeeded
302 );
303 } else {
304 println!(
305 " {} {} succeeded, {} failed:",
306 style("⚠️").yellow(),
307 succeeded,
308 failed.len()
309 );
310 for (repo, err) in &failed {
311 println!(" {} {}: {}", style("❌").red(), repo, err);
312 }
313 }
314
315 println!(
316 "\n {} Audit log: {}",
317 style("📋").bold(),
318 audit_log.path().display()
319 );
320
321 Ok(())
322}
323
324async fn audit(
325 client: &Client,
326 manifest: &Manifest,
327 system: Option<&str>,
328 repo: Option<&str>,
329) -> Result<()> {
330 let repos = resolve_repos_with_branches(client, manifest, system, repo).await?;
331
332 println!();
333 println!(
334 " {} Auditing branch protection for {} repositories...",
335 style("🔍").bold(),
336 repos.len()
337 );
338
339 println!();
340 println!(
341 " {:40} {:8} {:10} {:10} {:10} {:10} {:10} {:10}",
342 style("Repository").bold().underlined(),
343 style("Branch").bold().underlined(),
344 style("PR Rev").bold().underlined(),
345 style("Approvals").bold().underlined(),
346 style("Stale").bold().underlined(),
347 style("Admins").bold().underlined(),
348 style("Linear").bold().underlined(),
349 style("Force").bold().underlined(),
350 );
351
352 let mut total_ok = 0;
353 let mut total_issues = 0;
354
355 for (repo_name, default_branch) in &repos {
356 let state = client
357 .get_branch_protection(repo_name, default_branch)
358 .await?
359 .unwrap_or_default();
360
361 let protected = state.required_pull_request_reviews;
362 if protected {
363 total_ok += 1;
364 } else {
365 total_issues += 1;
366 }
367
368 let icon = |v: bool| {
369 if v {
370 format!("{}", style("✅").green())
371 } else {
372 format!("{}", style("❌").red())
373 }
374 };
375
376 println!(
377 " {:40} {:8} {:10} {:10} {:10} {:10} {:10} {:10}",
378 repo_name,
379 default_branch,
380 icon(state.required_pull_request_reviews),
381 state.required_approving_review_count,
382 icon(state.dismiss_stale_reviews),
383 icon(state.enforce_admins),
384 icon(state.required_linear_history),
385 icon(state.allow_force_pushes),
386 );
387 }
388
389 println!();
390 println!(
391 " Summary: {} protected, {} unprotected",
392 style(total_ok).green().bold(),
393 if total_issues > 0 {
394 style(total_issues).red().bold()
395 } else {
396 style(total_issues).green().bold()
397 }
398 );
399
400 Ok(())
401}
402
403fn print_diff_table(diffs: &[ProtectionDiff]) {
404 println!();
405 println!(" {}", style("Branch Protection Plan").bold().cyan());
406 println!(" {}", style("─".repeat(60)).dim());
407
408 for diff in diffs {
409 if diff.has_changes() {
410 println!(
411 " {} {} ({})",
412 style("⚡").yellow(),
413 style(&diff.repo).bold(),
414 diff.branch
415 );
416 for change in &diff.changes {
417 println!(
418 " {}: {} → {}",
419 change.field,
420 style(&change.current).red(),
421 style(&change.desired).green().bold()
422 );
423 }
424 } else {
425 println!(" {} {}", style("✓").green(), style(&diff.repo).dim());
426 }
427 }
428
429 let needs_changes = diffs.iter().filter(|d| d.has_changes()).count();
430 let up_to_date = diffs.len() - needs_changes;
431
432 println!();
433 println!(
434 " Summary: {} need changes, {} up to date",
435 if needs_changes > 0 {
436 style(needs_changes).yellow().bold()
437 } else {
438 style(needs_changes).green().bold()
439 },
440 style(up_to_date).green()
441 );
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447 use crate::config::manifest::BranchProtectionConfig;
448
449 fn default_state() -> BranchProtectionState {
450 BranchProtectionState {
451 required_pull_request_reviews: false,
452 required_approving_review_count: 1,
453 dismiss_stale_reviews: false,
454 require_code_owner_reviews: false,
455 required_status_checks: false,
456 strict_status_checks: false,
457 enforce_admins: false,
458 required_linear_history: false,
459 allow_force_pushes: false,
460 allow_deletions: false,
461 }
462 }
463
464 fn default_config() -> BranchProtectionConfig {
465 BranchProtectionConfig {
466 enabled: false,
467 required_approvals: 1,
468 dismiss_stale_reviews: false,
469 require_code_owner_reviews: false,
470 require_status_checks: false,
471 strict_status_checks: false,
472 enforce_admins: false,
473 required_linear_history: false,
474 allow_force_pushes: false,
475 allow_deletions: false,
476 }
477 }
478
479 #[test]
480 fn no_changes_when_state_matches_config() {
481 let state = default_state();
482 let config = default_config();
483 let diff = diff_protection("my-repo", "main", &state, &config);
484 assert!(!diff.has_changes());
485 }
486
487 #[test]
488 fn all_fields_produce_changes_when_they_differ() {
489 let state = default_state();
490 let config = BranchProtectionConfig {
491 enabled: true,
492 required_approvals: 2,
493 dismiss_stale_reviews: true,
494 require_code_owner_reviews: true,
495 require_status_checks: true,
496 strict_status_checks: true,
497 enforce_admins: true,
498 required_linear_history: true,
499 allow_force_pushes: true,
500 allow_deletions: true,
501 };
502 let diff = diff_protection("my-repo", "main", &state, &config);
503 assert_eq!(diff.changes.len(), 10);
504 }
505
506 #[test]
507 fn partial_changes_detected() {
508 let state = default_state();
509 let mut config = default_config();
510 config.enforce_admins = true;
511 config.required_approvals = 3;
512
513 let diff = diff_protection("my-repo", "main", &state, &config);
514 assert_eq!(diff.changes.len(), 2);
515 let fields: Vec<&str> = diff.changes.iter().map(|c| c.field.as_str()).collect();
516 assert!(fields.contains(&"enforce_admins"));
517 assert!(fields.contains(&"required_approvals"));
518 }
519
520 #[test]
521 fn repo_and_branch_preserved() {
522 let state = default_state();
523 let config = default_config();
524 let diff = diff_protection("acme-service", "develop", &state, &config);
525 assert_eq!(diff.repo, "acme-service");
526 assert_eq!(diff.branch, "develop");
527 }
528}