1use anyhow::Result;
2use clap::Args;
3use console::style;
4use serde::Serialize;
5
6use crate::config::Manifest;
7use crate::config::manifest::{BranchProtectionConfig, SecurityConfig};
8use crate::github::Client;
9use crate::github::branch_protection::BranchProtectionState;
10use crate::github::security::SecurityState;
11
12#[derive(Args)]
13pub struct DriftCommand {
14 #[command(subcommand)]
15 action: DriftAction,
16}
17
18#[derive(clap::Subcommand)]
19enum DriftAction {
20 Check,
22}
23
24#[derive(Debug, Serialize)]
25pub struct DriftResult {
26 pub repo: String,
27 pub security_drifts: Vec<DriftItem>,
28 pub protection_drifts: Vec<DriftItem>,
29}
30
31#[derive(Debug, Serialize)]
32pub struct DriftItem {
33 pub field: String,
34 pub expected: String,
35 pub actual: String,
36}
37
38impl DriftResult {
39 fn status(&self) -> &str {
40 if self.is_drifted() { "drifted" } else { "ok" }
41 }
42
43 fn is_drifted(&self) -> bool {
44 !self.security_drifts.is_empty() || !self.protection_drifts.is_empty()
45 }
46}
47
48pub fn compare_security(desired: &SecurityConfig, actual: &SecurityState) -> Vec<DriftItem> {
49 let mut drifts = Vec::new();
50
51 let checks: &[(&str, bool, bool)] = &[
52 (
53 "secret_scanning",
54 desired.secret_scanning,
55 actual.secret_scanning,
56 ),
57 (
58 "push_protection",
59 desired.push_protection,
60 actual.push_protection,
61 ),
62 (
63 "dependabot_alerts",
64 desired.dependabot_alerts,
65 actual.dependabot_alerts,
66 ),
67 (
68 "dependabot_security_updates",
69 desired.dependabot_security_updates,
70 actual.dependabot_security_updates,
71 ),
72 (
73 "secret_scanning_ai_detection",
74 desired.secret_scanning_ai_detection,
75 actual.secret_scanning_ai_detection,
76 ),
77 ];
78
79 for &(field, expected, actual_val) in checks {
80 if expected != actual_val {
81 drifts.push(DriftItem {
82 field: field.to_string(),
83 expected: expected.to_string(),
84 actual: actual_val.to_string(),
85 });
86 }
87 }
88
89 drifts
90}
91
92pub fn compare_protection(
93 desired: &BranchProtectionConfig,
94 actual: &BranchProtectionState,
95) -> Vec<DriftItem> {
96 let mut drifts = Vec::new();
97
98 let checks: &[(&str, bool, bool)] = &[
99 (
100 "required_approvals_enabled",
101 desired.enabled,
102 actual.required_pull_request_reviews,
103 ),
104 (
105 "dismiss_stale_reviews",
106 desired.dismiss_stale_reviews,
107 actual.dismiss_stale_reviews,
108 ),
109 (
110 "require_code_owner_reviews",
111 desired.require_code_owner_reviews,
112 actual.require_code_owner_reviews,
113 ),
114 (
115 "require_status_checks",
116 desired.require_status_checks,
117 actual.required_status_checks,
118 ),
119 (
120 "strict_status_checks",
121 desired.strict_status_checks,
122 actual.strict_status_checks,
123 ),
124 (
125 "enforce_admins",
126 desired.enforce_admins,
127 actual.enforce_admins,
128 ),
129 (
130 "required_linear_history",
131 desired.required_linear_history,
132 actual.required_linear_history,
133 ),
134 (
135 "allow_force_pushes",
136 desired.allow_force_pushes,
137 actual.allow_force_pushes,
138 ),
139 (
140 "allow_deletions",
141 desired.allow_deletions,
142 actual.allow_deletions,
143 ),
144 ];
145
146 for &(field, expected, actual_val) in checks {
147 if expected != actual_val {
148 drifts.push(DriftItem {
149 field: field.to_string(),
150 expected: expected.to_string(),
151 actual: actual_val.to_string(),
152 });
153 }
154 }
155
156 if desired.required_approvals != actual.required_approving_review_count {
157 drifts.push(DriftItem {
158 field: "required_approvals".to_string(),
159 expected: desired.required_approvals.to_string(),
160 actual: actual.required_approving_review_count.to_string(),
161 });
162 }
163
164 drifts
165}
166
167impl DriftCommand {
168 pub async fn run(
169 &self,
170 client: &Client,
171 manifest: &Manifest,
172 system: Option<&str>,
173 repo: Option<&str>,
174 json: bool,
175 ) -> Result<()> {
176 match &self.action {
177 DriftAction::Check => check(client, manifest, system, repo, json).await,
178 }
179 }
180}
181
182async fn resolve_repos(
183 client: &Client,
184 manifest: &Manifest,
185 system: Option<&str>,
186 repo: Option<&str>,
187) -> Result<Vec<(String, String)>> {
188 if let Some(repo_name) = repo {
189 let r = client.get_repo(repo_name).await?;
190 return Ok(vec![(r.name, r.default_branch)]);
191 }
192
193 let sys = system.ok_or_else(|| {
194 anyhow::anyhow!("Either --system or --repo is required for drift commands")
195 })?;
196
197 let excludes = manifest.exclude_patterns_for_system(sys);
198 let explicit = manifest.explicit_repos_for_system(sys);
199 let repos = client
200 .list_repos_for_system(sys, &excludes, &explicit)
201 .await?;
202 Ok(repos
203 .into_iter()
204 .map(|r| (r.name, r.default_branch))
205 .collect())
206}
207
208async fn check(
209 client: &Client,
210 manifest: &Manifest,
211 system: Option<&str>,
212 repo: Option<&str>,
213 json: bool,
214) -> Result<()> {
215 let repos = resolve_repos(client, manifest, system, repo).await?;
216 let sys_id = system.unwrap_or("default");
217 let desired_security = manifest.security_for_system(sys_id);
218 let desired_protection = &manifest.branch_protection;
219
220 if !json {
221 println!();
222 println!(
223 " {} Checking drift for {} repositories...",
224 style("[..]").dim(),
225 repos.len()
226 );
227 }
228
229 let mut results = Vec::new();
230
231 for (repo_name, default_branch) in &repos {
232 let (security_result, protection_result) = tokio::join!(
233 client.get_security_state(repo_name),
234 client.get_branch_protection(repo_name, default_branch)
235 );
236
237 let security_state = security_result?;
238 let protection_state = protection_result?.unwrap_or_default();
239
240 let security_drifts = compare_security(desired_security, &security_state);
241 let protection_drifts = compare_protection(desired_protection, &protection_state);
242
243 results.push(DriftResult {
244 repo: repo_name.clone(),
245 security_drifts,
246 protection_drifts,
247 });
248 }
249
250 if json {
251 print_json(&results);
252 } else {
253 print_table(&results);
254 }
255
256 let drifted = results.iter().filter(|r| r.is_drifted()).count();
257 if drifted > 0 {
258 std::process::exit(1);
259 }
260
261 Ok(())
262}
263
264fn print_json(results: &[DriftResult]) {
265 #[derive(Serialize)]
266 struct JsonEntry<'a> {
267 repo: &'a str,
268 security_drift: &'a [DriftItem],
269 protection_drift: &'a [DriftItem],
270 status: &'a str,
271 }
272
273 let output: Vec<JsonEntry<'_>> = results
274 .iter()
275 .map(|r| JsonEntry {
276 repo: &r.repo,
277 security_drift: &r.security_drifts,
278 protection_drift: &r.protection_drifts,
279 status: r.status(),
280 })
281 .collect();
282
283 println!(
284 "{}",
285 serde_json::to_string_pretty(&output).unwrap_or_default()
286 );
287}
288
289fn print_table(results: &[DriftResult]) {
290 println!();
291 println!(
292 " {} {} {} {}",
293 style(format!("{:<40}", "Repository")).bold().underlined(),
294 style(format!("{:<15}", "Security")).bold().underlined(),
295 style(format!("{:<15}", "Protection")).bold().underlined(),
296 style("Status").bold().underlined(),
297 );
298 println!(" {}", style("\u{2500}".repeat(80)).dim());
299
300 for result in results {
301 let sec = if result.security_drifts.is_empty() {
302 format!("{}", style(format!("{:<15}", "[ok]")).green())
303 } else {
304 format!("{}", style(format!("{:<15}", "[!!]")).red())
305 };
306 let prot = if result.protection_drifts.is_empty() {
307 format!("{}", style(format!("{:<15}", "[ok]")).green())
308 } else {
309 format!("{}", style(format!("{:<15}", "[!!]")).red())
310 };
311 let status = if result.is_drifted() {
312 format!("{}", style("DRIFTED").red().bold())
313 } else {
314 format!("{}", style("In sync").green())
315 };
316
317 println!(" {:<40} {} {} {}", result.repo, sec, prot, status);
318
319 for drift in &result.security_drifts {
320 println!(
321 " - {}: expected {}, got {}",
322 drift.field,
323 style(&drift.expected).green(),
324 style(&drift.actual).red()
325 );
326 }
327 for drift in &result.protection_drifts {
328 println!(
329 " - {}: expected {}, got {}",
330 drift.field,
331 style(&drift.expected).green(),
332 style(&drift.actual).red()
333 );
334 }
335 }
336
337 let total = results.len();
338 let in_sync = results.iter().filter(|r| !r.is_drifted()).count();
339 let drifted = total - in_sync;
340
341 println!();
342 println!(
343 " Summary: {}/{} in sync, {}/{} drifted",
344 style(in_sync).green().bold(),
345 total,
346 if drifted > 0 {
347 style(drifted).red().bold()
348 } else {
349 style(drifted).green().bold()
350 },
351 total,
352 );
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 #[test]
360 fn test_security_drift_detection() {
361 let desired = SecurityConfig {
362 secret_scanning: true,
363 push_protection: true,
364 dependabot_alerts: true,
365 dependabot_security_updates: true,
366 secret_scanning_ai_detection: true,
367 codeql_advanced_setup: false,
368 };
369 let actual = SecurityState {
370 secret_scanning: false,
371 push_protection: false,
372 dependabot_alerts: true,
373 dependabot_security_updates: true,
374 secret_scanning_ai_detection: true,
375 };
376
377 let drifts = compare_security(&desired, &actual);
378 assert_eq!(drifts.len(), 2);
379 assert_eq!(drifts[0].field, "secret_scanning");
380 assert_eq!(drifts[0].expected, "true");
381 assert_eq!(drifts[0].actual, "false");
382 assert_eq!(drifts[1].field, "push_protection");
383 }
384
385 #[test]
386 fn test_protection_drift_detection() {
387 let desired = BranchProtectionConfig {
388 enabled: true,
389 required_approvals: 1,
390 dismiss_stale_reviews: false,
391 require_code_owner_reviews: false,
392 require_status_checks: false,
393 strict_status_checks: false,
394 enforce_admins: false,
395 required_linear_history: false,
396 allow_force_pushes: false,
397 allow_deletions: false,
398 };
399 let actual = BranchProtectionState {
400 required_pull_request_reviews: true,
401 required_approving_review_count: 0,
402 dismiss_stale_reviews: false,
403 require_code_owner_reviews: false,
404 required_status_checks: false,
405 strict_status_checks: false,
406 enforce_admins: false,
407 required_linear_history: false,
408 allow_force_pushes: false,
409 allow_deletions: false,
410 };
411
412 let drifts = compare_protection(&desired, &actual);
413 assert_eq!(drifts.len(), 1);
414 assert_eq!(drifts[0].field, "required_approvals");
415 assert_eq!(drifts[0].expected, "1");
416 assert_eq!(drifts[0].actual, "0");
417 }
418
419 #[test]
420 fn test_no_drift_returns_empty() {
421 let desired_sec = SecurityConfig {
422 secret_scanning: true,
423 push_protection: true,
424 dependabot_alerts: true,
425 dependabot_security_updates: true,
426 secret_scanning_ai_detection: true,
427 codeql_advanced_setup: false,
428 };
429 let actual_sec = SecurityState {
430 secret_scanning: true,
431 push_protection: true,
432 dependabot_alerts: true,
433 dependabot_security_updates: true,
434 secret_scanning_ai_detection: true,
435 };
436
437 let desired_prot = BranchProtectionConfig {
438 enabled: false,
439 required_approvals: 0,
440 dismiss_stale_reviews: false,
441 require_code_owner_reviews: false,
442 require_status_checks: false,
443 strict_status_checks: false,
444 enforce_admins: false,
445 required_linear_history: false,
446 allow_force_pushes: false,
447 allow_deletions: false,
448 };
449 let actual_prot = BranchProtectionState::default();
450
451 assert!(compare_security(&desired_sec, &actual_sec).is_empty());
452 assert!(compare_protection(&desired_prot, &actual_prot).is_empty());
453 }
454
455 #[test]
456 fn test_drift_item_formatting() {
457 let item = DriftItem {
458 field: "secret_scanning".to_string(),
459 expected: "true".to_string(),
460 actual: "false".to_string(),
461 };
462 assert_eq!(item.field, "secret_scanning");
463 assert_eq!(item.expected, "true");
464 assert_eq!(item.actual, "false");
465
466 let json = serde_json::to_string(&item).unwrap();
467 assert!(json.contains("secret_scanning"));
468 assert!(json.contains(r#""expected":"true""#));
469 assert!(json.contains(r#""actual":"false""#));
470 }
471}