Skip to main content

ward/cli/
audit.rs

1use anyhow::Result;
2use clap::Args;
3use console::style;
4use serde::Serialize;
5
6use crate::config::Manifest;
7use crate::detection::versions;
8use crate::github::Client;
9
10#[derive(Args)]
11pub struct AuditCommand {
12    /// Output format (table or json)
13    #[arg(long, default_value = "table")]
14    format: String,
15}
16
17#[derive(Debug, Serialize)]
18struct AuditReport {
19    generated_at: String,
20    organization: String,
21    repositories: Vec<RepoAudit>,
22}
23
24#[derive(Debug, Serialize)]
25struct RepoAudit {
26    name: String,
27    system_id: String,
28    project_type: String,
29    language: Option<String>,
30    description: Option<String>,
31    default_branch: String,
32    versions: VersionInfo,
33    security: SecurityAudit,
34    settings: SettingsAudit,
35}
36
37#[derive(Debug, Default, Serialize)]
38struct VersionInfo {
39    java: Option<String>,
40    node: Option<String>,
41    #[serde(skip_serializing_if = "Option::is_none")]
42    spring_boot: Option<String>,
43    #[serde(skip_serializing_if = "Option::is_none")]
44    kotlin: Option<String>,
45}
46
47#[derive(Debug, Default, Serialize)]
48struct SecurityAudit {
49    dependabot_alerts: bool,
50    dependabot_security_updates: bool,
51    secret_scanning: bool,
52    secret_scanning_ai: bool,
53    push_protection: bool,
54    has_dependabot_config: bool,
55    has_codeql: bool,
56    has_dependency_submission: bool,
57    alert_counts: AlertCounts,
58}
59
60#[derive(Debug, Default, Serialize)]
61struct AlertCounts {
62    critical: u32,
63    high: u32,
64    medium: u32,
65    low: u32,
66}
67
68#[derive(Debug, Default, Serialize)]
69struct SettingsAudit {
70    has_copilot_review_ruleset: bool,
71    has_copilot_instructions: bool,
72}
73
74impl AuditCommand {
75    pub async fn run(
76        &self,
77        client: &Client,
78        manifest: &Manifest,
79        system: Option<&str>,
80    ) -> Result<()> {
81        let sys = system.ok_or_else(|| anyhow::anyhow!("--system is required for audit"))?;
82
83        let excludes = manifest.exclude_patterns_for_system(sys);
84        let explicit = manifest.explicit_repos_for_system(sys);
85        let repos = client
86            .list_repos_for_system(sys, &excludes, &explicit)
87            .await?;
88
89        println!();
90        println!(
91            "  {} Full audit: {} repos in system {}",
92            style("🔍").bold(),
93            repos.len(),
94            style(sys).cyan()
95        );
96
97        let mut audits = Vec::new();
98
99        for repo in &repos {
100            tracing::info!("Auditing {}...", repo.name);
101            let audit = audit_repo(client, &repo.name, sys).await?;
102            audits.push(audit);
103        }
104
105        if self.format == "json" {
106            let report = AuditReport {
107                generated_at: chrono::Utc::now().to_rfc3339(),
108                organization: client.org.clone(),
109                repositories: audits,
110            };
111            println!("{}", serde_json::to_string_pretty(&report)?);
112        } else {
113            print_table(&audits);
114        }
115
116        Ok(())
117    }
118}
119
120async fn audit_repo(client: &Client, repo: &str, system_id: &str) -> Result<RepoAudit> {
121    let repo_info = client.get_repo(repo).await?;
122    let security_state = client.get_security_state(repo).await?;
123
124    // Detect project type and versions
125    let mut project_type = "unknown".to_owned();
126    let mut version_info = VersionInfo::default();
127
128    if let Some(content) = client.get_file(repo, "build.gradle.kts", None).await? {
129        project_type = "gradle".to_owned();
130        let text = Client::decode_content(&content).unwrap_or_default();
131        if let Some(v) = versions::extract_java_version(&text) {
132            version_info.java = Some(v.to_string());
133        }
134        // Try to detect Spring Boot version
135        version_info.spring_boot = extract_spring_boot_version(&text);
136        if text.contains("kotlin") {
137            version_info.kotlin = Some("detected".to_owned());
138        }
139    } else if let Some(content) = client.get_file(repo, "build.gradle", None).await? {
140        project_type = "gradle".to_owned();
141        let text = Client::decode_content(&content).unwrap_or_default();
142        if let Some(v) = versions::extract_java_version(&text) {
143            version_info.java = Some(v.to_string());
144        }
145        version_info.spring_boot = extract_spring_boot_version(&text);
146    } else if let Some(content) = client.get_file(repo, "package.json", None).await? {
147        project_type = "npm".to_owned();
148        let text = Client::decode_content(&content).unwrap_or_default();
149        version_info.node = versions::extract_node_version(&text);
150    }
151
152    // Check for config files
153    let has_dependabot_config = client
154        .get_file(repo, ".github/dependabot.yml", None)
155        .await?
156        .is_some();
157    let has_codeql = client
158        .get_file(repo, ".github/workflows/codeql.yml", None)
159        .await?
160        .is_some();
161    let has_dependency_submission = client
162        .get_file(repo, ".github/workflows/dependency-submission.yml", None)
163        .await?
164        .is_some();
165
166    // Check rulesets and instructions
167    let rulesets = client.list_rulesets(repo).await.unwrap_or_default();
168    let has_copilot_review = rulesets.iter().any(|r| r.name == "Copilot Code Review");
169    let has_copilot_instructions = client
170        .get_file(repo, ".github/copilot-instructions.md", None)
171        .await?
172        .is_some();
173
174    // Get alert counts
175    let alert_counts = get_alert_counts(client, repo).await.unwrap_or_default();
176
177    Ok(RepoAudit {
178        name: repo.to_owned(),
179        system_id: system_id.to_owned(),
180        project_type,
181        language: repo_info.language,
182        description: repo_info.description,
183        default_branch: repo_info.default_branch,
184        versions: version_info,
185        security: SecurityAudit {
186            dependabot_alerts: security_state.dependabot_alerts,
187            dependabot_security_updates: security_state.dependabot_security_updates,
188            secret_scanning: security_state.secret_scanning,
189            secret_scanning_ai: security_state.secret_scanning_ai_detection,
190            push_protection: security_state.push_protection,
191            has_dependabot_config,
192            has_codeql,
193            has_dependency_submission,
194            alert_counts,
195        },
196        settings: SettingsAudit {
197            has_copilot_review_ruleset: has_copilot_review,
198            has_copilot_instructions,
199        },
200    })
201}
202
203async fn get_alert_counts(client: &Client, repo: &str) -> Result<AlertCounts> {
204    let mut counts = AlertCounts::default();
205
206    for severity in ["critical", "high", "medium", "low"] {
207        let resp = client
208            .get(&format!(
209                "/repos/{}/{repo}/dependabot/alerts?state=open&severity={severity}&per_page=1",
210                client.org
211            ))
212            .await?;
213
214        if resp.status().is_success() {
215            // Use the array length or a header if available
216            let alerts: Vec<serde_json::Value> = resp.json().await.unwrap_or_default();
217            // This gives us a rough count (limited by per_page, but indicates presence)
218            // For exact counts, we'd need to paginate, but this is fast enough for audit
219            let count = if alerts.is_empty() { 0 } else { 1 };
220
221            match severity {
222                "critical" => counts.critical = count,
223                "high" => counts.high = count,
224                "medium" => counts.medium = count,
225                "low" => counts.low = count,
226                _ => {}
227            }
228        }
229    }
230
231    // Re-fetch with higher limit for a more accurate count
232    let resp = client
233        .get(&format!(
234            "/repos/{}/{repo}/dependabot/alerts?state=open&per_page=100",
235            client.org
236        ))
237        .await?;
238
239    if resp.status().is_success() {
240        let alerts: Vec<serde_json::Value> = resp.json().await.unwrap_or_default();
241        counts = AlertCounts::default();
242        for alert in &alerts {
243            let severity = alert
244                .get("security_vulnerability")
245                .and_then(|v| v.get("severity"))
246                .and_then(|s| s.as_str())
247                .unwrap_or("unknown");
248            match severity {
249                "critical" => counts.critical += 1,
250                "high" => counts.high += 1,
251                "medium" => counts.medium += 1,
252                "low" => counts.low += 1,
253                _ => {}
254            }
255        }
256    }
257
258    Ok(counts)
259}
260
261fn extract_spring_boot_version(content: &str) -> Option<String> {
262    // Look for Spring Boot plugin version
263    for line in content.lines() {
264        let trimmed = line.trim();
265        // id("org.springframework.boot") version "3.5.6"
266        if trimmed.contains("org.springframework.boot") && trimmed.contains("version") {
267            return extract_quoted_version(trimmed);
268        }
269        // springBootVersion = "3.5.6"
270        if trimmed.contains("springBootVersion") {
271            return extract_quoted_version(trimmed);
272        }
273    }
274    None
275}
276
277fn extract_quoted_version(s: &str) -> Option<String> {
278    let mut in_quote = false;
279    let mut version = String::new();
280    for ch in s.chars() {
281        if ch == '"' || ch == '\'' {
282            if in_quote
283                && !version.is_empty()
284                && version.contains('.')
285                && version.chars().all(|c| c.is_ascii_digit() || c == '.')
286            {
287                return Some(version);
288            }
289            in_quote = !in_quote;
290            version.clear();
291        } else if in_quote {
292            version.push(ch);
293        }
294    }
295    None
296}
297
298fn print_table(audits: &[RepoAudit]) {
299    println!();
300    println!(
301        "  {:35} {:7} {:6} {:6} {:5} {:5} {:5} {:5} {:5} {:5} {:5}",
302        style("Repository").bold().underlined(),
303        style("Type").bold().underlined(),
304        style("Java").bold().underlined(),
305        style("SBoot").bold().underlined(),
306        style("Dep.A").bold().underlined(),
307        style("SecSc").bold().underlined(),
308        style("Push").bold().underlined(),
309        style("DBot").bold().underlined(),
310        style("CQL").bold().underlined(),
311        style("CopRv").bold().underlined(),
312        style("Alert").bold().underlined(),
313    );
314
315    let mut total_alerts = 0u32;
316    let mut fully_secured = 0;
317
318    for a in audits {
319        let java = a.versions.java.as_deref().unwrap_or("-");
320        let sboot = a.versions.spring_boot.as_deref().unwrap_or("-");
321
322        let icon = |b: bool| {
323            if b {
324                format!("{}", style("✅").green())
325            } else {
326                format!("{}", style("❌").red())
327            }
328        };
329
330        let alert_total = a.security.alert_counts.critical
331            + a.security.alert_counts.high
332            + a.security.alert_counts.medium
333            + a.security.alert_counts.low;
334        total_alerts += alert_total;
335
336        let alert_str = if alert_total == 0 {
337            format!("{}", style("0").green())
338        } else if a.security.alert_counts.critical > 0 {
339            format!("{}", style(alert_total).red().bold())
340        } else if a.security.alert_counts.high > 0 {
341            format!("{}", style(alert_total).yellow().bold())
342        } else {
343            format!("{}", style(alert_total).yellow())
344        };
345
346        let all_security = a.security.dependabot_alerts
347            && a.security.secret_scanning
348            && a.security.push_protection
349            && a.security.has_dependabot_config
350            && a.security.has_codeql;
351
352        if all_security {
353            fully_secured += 1;
354        }
355
356        println!(
357            "  {:35} {:7} {:6} {:6} {:5} {:5} {:5} {:5} {:5} {:5} {:5}",
358            &a.name,
359            &a.project_type,
360            java,
361            sboot,
362            icon(a.security.dependabot_alerts),
363            icon(a.security.secret_scanning),
364            icon(a.security.push_protection),
365            icon(a.security.has_dependabot_config),
366            icon(a.security.has_codeql),
367            icon(a.settings.has_copilot_review_ruleset),
368            alert_str,
369        );
370    }
371
372    println!();
373    println!(
374        "  {} repos audited | {} fully secured | {} total open alerts",
375        style(audits.len()).bold(),
376        style(fully_secured).green().bold(),
377        if total_alerts > 0 {
378            style(total_alerts).red().bold()
379        } else {
380            style(total_alerts).green().bold()
381        }
382    );
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388
389    #[test]
390    fn test_extract_spring_boot_version() {
391        assert_eq!(
392            extract_spring_boot_version(r#"id("org.springframework.boot") version "3.5.6""#),
393            Some("3.5.6".to_owned())
394        );
395    }
396
397    #[test]
398    fn test_extract_spring_boot_variable() {
399        assert_eq!(
400            extract_spring_boot_version(r#"val springBootVersion = "3.4.1""#),
401            Some("3.4.1".to_owned())
402        );
403    }
404
405    #[test]
406    fn test_no_spring_boot() {
407        assert_eq!(
408            extract_spring_boot_version("plugins { id(\"java\") }"),
409            None
410        );
411    }
412}