Skip to main content

ward/cli/
drift.rs

1use anyhow::Result;
2use clap::Args;
3use console::style;
4use serde::Serialize;
5
6use crate::config::Manifest;
7use crate::config::manifest::{BranchProtectionConfig, SecurityConfig};
8use crate::github::Client;
9use crate::github::branch_protection::BranchProtectionState;
10use crate::github::security::SecurityState;
11
12#[derive(Args)]
13pub struct DriftCommand {
14    #[command(subcommand)]
15    action: DriftAction,
16}
17
18#[derive(clap::Subcommand)]
19enum DriftAction {
20    /// Check for configuration drift across repos
21    Check,
22}
23
24#[derive(Debug, Serialize)]
25pub struct DriftResult {
26    pub repo: String,
27    pub security_drifts: Vec<DriftItem>,
28    pub protection_drifts: Vec<DriftItem>,
29}
30
31#[derive(Debug, Serialize)]
32pub struct DriftItem {
33    pub field: String,
34    pub expected: String,
35    pub actual: String,
36}
37
38impl DriftResult {
39    fn status(&self) -> &str {
40        if self.is_drifted() { "drifted" } else { "ok" }
41    }
42
43    fn is_drifted(&self) -> bool {
44        !self.security_drifts.is_empty() || !self.protection_drifts.is_empty()
45    }
46}
47
48pub fn compare_security(desired: &SecurityConfig, actual: &SecurityState) -> Vec<DriftItem> {
49    let mut drifts = Vec::new();
50
51    let checks: &[(&str, bool, bool)] = &[
52        (
53            "secret_scanning",
54            desired.secret_scanning,
55            actual.secret_scanning,
56        ),
57        (
58            "push_protection",
59            desired.push_protection,
60            actual.push_protection,
61        ),
62        (
63            "dependabot_alerts",
64            desired.dependabot_alerts,
65            actual.dependabot_alerts,
66        ),
67        (
68            "dependabot_security_updates",
69            desired.dependabot_security_updates,
70            actual.dependabot_security_updates,
71        ),
72        (
73            "secret_scanning_ai_detection",
74            desired.secret_scanning_ai_detection,
75            actual.secret_scanning_ai_detection,
76        ),
77    ];
78
79    for &(field, expected, actual_val) in checks {
80        if expected != actual_val {
81            drifts.push(DriftItem {
82                field: field.to_string(),
83                expected: expected.to_string(),
84                actual: actual_val.to_string(),
85            });
86        }
87    }
88
89    drifts
90}
91
92pub fn compare_protection(
93    desired: &BranchProtectionConfig,
94    actual: &BranchProtectionState,
95) -> Vec<DriftItem> {
96    let mut drifts = Vec::new();
97
98    let checks: &[(&str, bool, bool)] = &[
99        (
100            "required_approvals_enabled",
101            desired.enabled,
102            actual.required_pull_request_reviews,
103        ),
104        (
105            "dismiss_stale_reviews",
106            desired.dismiss_stale_reviews,
107            actual.dismiss_stale_reviews,
108        ),
109        (
110            "require_code_owner_reviews",
111            desired.require_code_owner_reviews,
112            actual.require_code_owner_reviews,
113        ),
114        (
115            "require_status_checks",
116            desired.require_status_checks,
117            actual.required_status_checks,
118        ),
119        (
120            "strict_status_checks",
121            desired.strict_status_checks,
122            actual.strict_status_checks,
123        ),
124        (
125            "enforce_admins",
126            desired.enforce_admins,
127            actual.enforce_admins,
128        ),
129        (
130            "required_linear_history",
131            desired.required_linear_history,
132            actual.required_linear_history,
133        ),
134        (
135            "allow_force_pushes",
136            desired.allow_force_pushes,
137            actual.allow_force_pushes,
138        ),
139        (
140            "allow_deletions",
141            desired.allow_deletions,
142            actual.allow_deletions,
143        ),
144    ];
145
146    for &(field, expected, actual_val) in checks {
147        if expected != actual_val {
148            drifts.push(DriftItem {
149                field: field.to_string(),
150                expected: expected.to_string(),
151                actual: actual_val.to_string(),
152            });
153        }
154    }
155
156    if desired.required_approvals != actual.required_approving_review_count {
157        drifts.push(DriftItem {
158            field: "required_approvals".to_string(),
159            expected: desired.required_approvals.to_string(),
160            actual: actual.required_approving_review_count.to_string(),
161        });
162    }
163
164    drifts
165}
166
167impl DriftCommand {
168    pub async fn run(
169        &self,
170        client: &Client,
171        manifest: &Manifest,
172        system: Option<&str>,
173        repo: Option<&str>,
174        json: bool,
175    ) -> Result<()> {
176        match &self.action {
177            DriftAction::Check => check(client, manifest, system, repo, json).await,
178        }
179    }
180}
181
182async fn resolve_repos(
183    client: &Client,
184    manifest: &Manifest,
185    system: Option<&str>,
186    repo: Option<&str>,
187) -> Result<Vec<(String, String)>> {
188    if let Some(repo_name) = repo {
189        let r = client.get_repo(repo_name).await?;
190        return Ok(vec![(r.name, r.default_branch)]);
191    }
192
193    let sys = system.ok_or_else(|| {
194        anyhow::anyhow!("Either --system or --repo is required for drift commands")
195    })?;
196
197    let excludes = manifest.exclude_patterns_for_system(sys);
198    let explicit = manifest.explicit_repos_for_system(sys);
199    let repos = client
200        .list_repos_for_system(sys, &excludes, &explicit)
201        .await?;
202    Ok(repos
203        .into_iter()
204        .map(|r| (r.name, r.default_branch))
205        .collect())
206}
207
208async fn check(
209    client: &Client,
210    manifest: &Manifest,
211    system: Option<&str>,
212    repo: Option<&str>,
213    json: bool,
214) -> Result<()> {
215    let repos = resolve_repos(client, manifest, system, repo).await?;
216    let sys_id = system.unwrap_or("default");
217    let desired_security = manifest.security_for_system(sys_id);
218    let desired_protection = &manifest.branch_protection;
219
220    if !json {
221        println!();
222        println!(
223            "  {} Checking drift for {} repositories...",
224            style("[..]").dim(),
225            repos.len()
226        );
227    }
228
229    let mut results = Vec::new();
230
231    for (repo_name, default_branch) in &repos {
232        let (security_result, protection_result) = tokio::join!(
233            client.get_security_state(repo_name),
234            client.get_branch_protection(repo_name, default_branch)
235        );
236
237        let security_state = security_result?;
238        let protection_state = protection_result?.unwrap_or_default();
239
240        let security_drifts = compare_security(desired_security, &security_state);
241        let protection_drifts = compare_protection(desired_protection, &protection_state);
242
243        results.push(DriftResult {
244            repo: repo_name.clone(),
245            security_drifts,
246            protection_drifts,
247        });
248    }
249
250    if json {
251        print_json(&results);
252    } else {
253        print_table(&results);
254    }
255
256    let drifted = results.iter().filter(|r| r.is_drifted()).count();
257    if drifted > 0 {
258        std::process::exit(1);
259    }
260
261    Ok(())
262}
263
264fn print_json(results: &[DriftResult]) {
265    #[derive(Serialize)]
266    struct JsonEntry<'a> {
267        repo: &'a str,
268        security_drift: &'a [DriftItem],
269        protection_drift: &'a [DriftItem],
270        status: &'a str,
271    }
272
273    let output: Vec<JsonEntry<'_>> = results
274        .iter()
275        .map(|r| JsonEntry {
276            repo: &r.repo,
277            security_drift: &r.security_drifts,
278            protection_drift: &r.protection_drifts,
279            status: r.status(),
280        })
281        .collect();
282
283    println!(
284        "{}",
285        serde_json::to_string_pretty(&output).unwrap_or_default()
286    );
287}
288
289fn print_table(results: &[DriftResult]) {
290    println!();
291    println!(
292        "  {} {} {} {}",
293        style(format!("{:<40}", "Repository")).bold().underlined(),
294        style(format!("{:<15}", "Security")).bold().underlined(),
295        style(format!("{:<15}", "Protection")).bold().underlined(),
296        style("Status").bold().underlined(),
297    );
298    println!("  {}", style("\u{2500}".repeat(80)).dim());
299
300    for result in results {
301        let sec = if result.security_drifts.is_empty() {
302            format!("{}", style(format!("{:<15}", "[ok]")).green())
303        } else {
304            format!("{}", style(format!("{:<15}", "[!!]")).red())
305        };
306        let prot = if result.protection_drifts.is_empty() {
307            format!("{}", style(format!("{:<15}", "[ok]")).green())
308        } else {
309            format!("{}", style(format!("{:<15}", "[!!]")).red())
310        };
311        let status = if result.is_drifted() {
312            format!("{}", style("DRIFTED").red().bold())
313        } else {
314            format!("{}", style("In sync").green())
315        };
316
317        println!("  {:<40} {} {} {}", result.repo, sec, prot, status);
318
319        for drift in &result.security_drifts {
320            println!(
321                "    - {}: expected {}, got {}",
322                drift.field,
323                style(&drift.expected).green(),
324                style(&drift.actual).red()
325            );
326        }
327        for drift in &result.protection_drifts {
328            println!(
329                "    - {}: expected {}, got {}",
330                drift.field,
331                style(&drift.expected).green(),
332                style(&drift.actual).red()
333            );
334        }
335    }
336
337    let total = results.len();
338    let in_sync = results.iter().filter(|r| !r.is_drifted()).count();
339    let drifted = total - in_sync;
340
341    println!();
342    println!(
343        "  Summary: {}/{} in sync, {}/{} drifted",
344        style(in_sync).green().bold(),
345        total,
346        if drifted > 0 {
347            style(drifted).red().bold()
348        } else {
349            style(drifted).green().bold()
350        },
351        total,
352    );
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    #[test]
360    fn test_security_drift_detection() {
361        let desired = SecurityConfig {
362            secret_scanning: true,
363            push_protection: true,
364            dependabot_alerts: true,
365            dependabot_security_updates: true,
366            secret_scanning_ai_detection: true,
367            codeql_advanced_setup: false,
368        };
369        let actual = SecurityState {
370            secret_scanning: false,
371            push_protection: false,
372            dependabot_alerts: true,
373            dependabot_security_updates: true,
374            secret_scanning_ai_detection: true,
375        };
376
377        let drifts = compare_security(&desired, &actual);
378        assert_eq!(drifts.len(), 2);
379        assert_eq!(drifts[0].field, "secret_scanning");
380        assert_eq!(drifts[0].expected, "true");
381        assert_eq!(drifts[0].actual, "false");
382        assert_eq!(drifts[1].field, "push_protection");
383    }
384
385    #[test]
386    fn test_protection_drift_detection() {
387        let desired = BranchProtectionConfig {
388            enabled: true,
389            required_approvals: 1,
390            dismiss_stale_reviews: false,
391            require_code_owner_reviews: false,
392            require_status_checks: false,
393            strict_status_checks: false,
394            enforce_admins: false,
395            required_linear_history: false,
396            allow_force_pushes: false,
397            allow_deletions: false,
398        };
399        let actual = BranchProtectionState {
400            required_pull_request_reviews: true,
401            required_approving_review_count: 0,
402            dismiss_stale_reviews: false,
403            require_code_owner_reviews: false,
404            required_status_checks: false,
405            strict_status_checks: false,
406            enforce_admins: false,
407            required_linear_history: false,
408            allow_force_pushes: false,
409            allow_deletions: false,
410        };
411
412        let drifts = compare_protection(&desired, &actual);
413        assert_eq!(drifts.len(), 1);
414        assert_eq!(drifts[0].field, "required_approvals");
415        assert_eq!(drifts[0].expected, "1");
416        assert_eq!(drifts[0].actual, "0");
417    }
418
419    #[test]
420    fn test_no_drift_returns_empty() {
421        let desired_sec = SecurityConfig {
422            secret_scanning: true,
423            push_protection: true,
424            dependabot_alerts: true,
425            dependabot_security_updates: true,
426            secret_scanning_ai_detection: true,
427            codeql_advanced_setup: false,
428        };
429        let actual_sec = SecurityState {
430            secret_scanning: true,
431            push_protection: true,
432            dependabot_alerts: true,
433            dependabot_security_updates: true,
434            secret_scanning_ai_detection: true,
435        };
436
437        let desired_prot = BranchProtectionConfig {
438            enabled: false,
439            required_approvals: 0,
440            dismiss_stale_reviews: false,
441            require_code_owner_reviews: false,
442            require_status_checks: false,
443            strict_status_checks: false,
444            enforce_admins: false,
445            required_linear_history: false,
446            allow_force_pushes: false,
447            allow_deletions: false,
448        };
449        let actual_prot = BranchProtectionState::default();
450
451        assert!(compare_security(&desired_sec, &actual_sec).is_empty());
452        assert!(compare_protection(&desired_prot, &actual_prot).is_empty());
453    }
454
455    #[test]
456    fn test_drift_item_formatting() {
457        let item = DriftItem {
458            field: "secret_scanning".to_string(),
459            expected: "true".to_string(),
460            actual: "false".to_string(),
461        };
462        assert_eq!(item.field, "secret_scanning");
463        assert_eq!(item.expected, "true");
464        assert_eq!(item.actual, "false");
465
466        let json = serde_json::to_string(&item).unwrap();
467        assert!(json.contains("secret_scanning"));
468        assert!(json.contains(r#""expected":"true""#));
469        assert!(json.contains(r#""actual":"false""#));
470    }
471}