Skip to main content

ward/cli/
init.rs

1use std::collections::HashMap;
2
3use anyhow::{Context, Result};
4use clap::Args;
5use console::style;
6use dialoguer::{Confirm, Input, MultiSelect};
7use reqwest::header::{self, HeaderMap, HeaderValue};
8use serde::Deserialize;
9
10use crate::config::auth;
11
12const EXAMPLE_MANIFEST: &str = r#"[org]
13name = "your-github-org"
14
15[security]
16secret_scanning = true
17secret_scanning_ai_detection = true
18push_protection = true
19dependabot_alerts = true
20dependabot_security_updates = true
21
22[templates]
23branch = "chore/ward-setup"
24reviewers = []
25commit_message_prefix = "chore: "
26
27# [[systems]]
28# id = "my-system"
29# name = "My System"
30# exclude = ["operations?", "workflows"]
31"#;
32
33const OUTPUT_PATH: &str = "ward.toml";
34
35#[derive(Args)]
36pub struct InitCommand {
37    /// Skip interactive wizard, write default ward.toml
38    #[arg(long)]
39    non_interactive: bool,
40}
41
42impl InitCommand {
43    pub async fn run(&self) -> Result<()> {
44        if self.non_interactive {
45            return write_default();
46        }
47        run_wizard().await
48    }
49}
50
51fn write_default() -> Result<()> {
52    if std::path::Path::new(OUTPUT_PATH).exists() {
53        println!("  {} ward.toml already exists.", style("warning").yellow());
54        return Ok(());
55    }
56
57    std::fs::write(OUTPUT_PATH, EXAMPLE_MANIFEST)?;
58    println!(
59        "  {} Created ward.toml - edit it to configure your org and systems.",
60        style("ok").green()
61    );
62    Ok(())
63}
64
65// --- Wizard ---
66
67struct WizardState {
68    org: String,
69    repo_count: usize,
70    security: SecuritySettings,
71    branch_protection: BranchProtectionSettings,
72    systems: Vec<SystemEntry>,
73    exclude_patterns: Vec<String>,
74    templates: TemplateSettings,
75}
76
77struct SecuritySettings {
78    secret_scanning: bool,
79    push_protection: bool,
80    dependabot_alerts: bool,
81    dependabot_security_updates: bool,
82}
83
84struct BranchProtectionSettings {
85    enabled: bool,
86    required_approvals: u32,
87    dismiss_stale_reviews: bool,
88}
89
90struct SystemEntry {
91    id: String,
92    name: String,
93    repo_count: usize,
94}
95
96struct TemplateSettings {
97    branch: String,
98    reviewers: Vec<String>,
99    commit_message_prefix: String,
100}
101
102fn print_banner() {
103    println!();
104    println!("  {}", style("+---------------------------------+").cyan());
105    println!("  {}", style("|       Ward Setup Wizard         |").cyan());
106    println!("  {}", style("+---------------------------------+").cyan());
107    println!();
108}
109
110fn print_step(step: u8, total: u8, title: &str) {
111    println!();
112    println!(
113        "  {} {}",
114        style(format!("Step {step}/{total}:")).bold(),
115        style(title).bold()
116    );
117}
118
119async fn run_wizard() -> Result<()> {
120    if std::path::Path::new(OUTPUT_PATH).exists() {
121        println!("  {} ward.toml already exists.", style("warning").yellow());
122        return Ok(());
123    }
124
125    print_banner();
126
127    let total_steps = 6;
128
129    // Step 1: Auth
130    print_step(1, total_steps, "Authentication");
131    let token = check_auth()?;
132
133    // Step 2: Org
134    print_step(2, total_steps, "Organization");
135    let (org, repo_count) = ask_org(&token).await?;
136
137    // Step 3: Security
138    print_step(3, total_steps, "Security Settings");
139    let security = ask_security()?;
140
141    // Step 4: Branch protection
142    print_step(4, total_steps, "Branch Protection");
143    let branch_protection = ask_branch_protection()?;
144
145    // Step 5: Systems
146    print_step(5, total_steps, "Systems");
147    let (systems, exclude_patterns) = discover_systems(&token, &org).await?;
148
149    // Step 6: Templates
150    print_step(6, total_steps, "Templates");
151    let templates = ask_templates()?;
152
153    let state = WizardState {
154        org,
155        repo_count,
156        security,
157        branch_protection,
158        systems,
159        exclude_patterns,
160        templates,
161    };
162
163    write_toml(&state)?;
164    print_summary(&state);
165
166    Ok(())
167}
168
169// Step 1: Check authentication
170
171fn check_auth() -> Result<String> {
172    match auth::resolve_token() {
173        Ok(token) => {
174            let source = if std::env::var("GH_TOKEN").is_ok() {
175                "GH_TOKEN"
176            } else if std::env::var("GITHUB_TOKEN").is_ok() {
177                "GITHUB_TOKEN"
178            } else {
179                "gh auth token"
180            };
181            println!("  {} Token found via {source}", style("ok").green());
182            Ok(token)
183        }
184        Err(e) => {
185            println!("  {} No GitHub token found.", style("error").red());
186            println!("  Set GH_TOKEN, GITHUB_TOKEN, or run `gh auth login`.");
187            Err(e)
188        }
189    }
190}
191
192// Step 2: Ask for org and verify
193
194async fn verify_org(token: &str, org: &str) -> Result<usize> {
195    let client = build_http_client(token)?;
196    let resp = client
197        .get(format!("https://api.github.com/orgs/{org}"))
198        .send()
199        .await
200        .context("Failed to reach GitHub API")?;
201
202    let status = resp.status();
203    if status == reqwest::StatusCode::NOT_FOUND {
204        anyhow::bail!("Organization '{org}' not found");
205    }
206    if !status.is_success() {
207        let body = resp.text().await.unwrap_or_default();
208        anyhow::bail!("Failed to verify org (HTTP {status}): {body}");
209    }
210
211    // Org exists - consume the response body, then count repos via pagination
212    let _ = resp.bytes().await;
213    let repos = fetch_all_repos(token, org).await?;
214    Ok(repos.len())
215}
216
217async fn ask_org(token: &str) -> Result<(String, usize)> {
218    loop {
219        let org: String = Input::new()
220            .with_prompt("  GitHub organization name")
221            .interact_text()?;
222
223        match verify_org(token, &org).await {
224            Ok(count) => {
225                println!(
226                    "  {} Organization verified ({count} repos)",
227                    style("ok").green()
228                );
229                return Ok((org, count));
230            }
231            Err(e) => {
232                println!("  {} {e}", style("error").red());
233                println!("  Please try again.");
234            }
235        }
236    }
237}
238
239// Step 3: Security settings
240
241fn ask_security() -> Result<SecuritySettings> {
242    let secret_scanning = Confirm::new()
243        .with_prompt("  Enable secret scanning?")
244        .default(true)
245        .interact()?;
246
247    let push_protection = Confirm::new()
248        .with_prompt("  Enable push protection?")
249        .default(true)
250        .interact()?;
251
252    let dependabot_alerts = Confirm::new()
253        .with_prompt("  Enable Dependabot alerts?")
254        .default(true)
255        .interact()?;
256
257    let dependabot_security_updates = Confirm::new()
258        .with_prompt("  Enable Dependabot security updates?")
259        .default(true)
260        .interact()?;
261
262    Ok(SecuritySettings {
263        secret_scanning,
264        push_protection,
265        dependabot_alerts,
266        dependabot_security_updates,
267    })
268}
269
270// Step 4: Branch protection
271
272fn ask_branch_protection() -> Result<BranchProtectionSettings> {
273    let enabled = Confirm::new()
274        .with_prompt("  Enable branch protection?")
275        .default(true)
276        .interact()?;
277
278    if !enabled {
279        return Ok(BranchProtectionSettings {
280            enabled: false,
281            required_approvals: 1,
282            dismiss_stale_reviews: false,
283        });
284    }
285
286    let approvals: String = Input::new()
287        .with_prompt("  Required approvals")
288        .default("1".to_owned())
289        .interact_text()?;
290    let required_approvals: u32 = approvals.parse().unwrap_or(1);
291
292    let dismiss_stale_reviews = Confirm::new()
293        .with_prompt("  Dismiss stale reviews?")
294        .default(true)
295        .interact()?;
296
297    Ok(BranchProtectionSettings {
298        enabled,
299        required_approvals,
300        dismiss_stale_reviews,
301    })
302}
303
304// Step 5: Discover systems
305
306#[derive(Debug, Clone)]
307struct DiscoveredPrefix {
308    prefix: String,
309    count: usize,
310}
311
312#[derive(Debug, Clone, Deserialize)]
313struct MinimalRepo {
314    name: String,
315    archived: bool,
316}
317
318async fn fetch_all_repos(token: &str, org: &str) -> Result<Vec<MinimalRepo>> {
319    let client = build_http_client(token)?;
320    let mut all = Vec::new();
321    let mut page = 1u32;
322
323    loop {
324        let resp = client
325            .get(format!(
326                "https://api.github.com/orgs/{org}/repos?per_page=100&page={page}&type=all"
327            ))
328            .send()
329            .await
330            .context("Failed to fetch repos")?;
331
332        let status = resp.status();
333        if !status.is_success() {
334            let body = resp.text().await.unwrap_or_default();
335            anyhow::bail!("Failed to list repos (HTTP {status}): {body}");
336        }
337
338        let repos: Vec<MinimalRepo> = resp.json().await.context("Failed to parse repos")?;
339        if repos.is_empty() {
340            break;
341        }
342
343        all.extend(repos);
344        page += 1;
345    }
346
347    Ok(all)
348}
349
350fn discover_prefixes(repos: &[MinimalRepo]) -> Vec<DiscoveredPrefix> {
351    let mut counts: HashMap<String, usize> = HashMap::new();
352
353    for repo in repos {
354        if repo.archived {
355            continue;
356        }
357        if let Some(prefix) = repo.name.split('-').next()
358            && !prefix.is_empty()
359        {
360            *counts.entry(prefix.to_owned()).or_default() += 1;
361        }
362    }
363
364    let mut prefixes: Vec<DiscoveredPrefix> = counts
365        .into_iter()
366        .filter(|(_, count)| *count >= 2)
367        .map(|(prefix, count)| DiscoveredPrefix { prefix, count })
368        .collect();
369
370    prefixes.sort_by(|a, b| b.count.cmp(&a.count));
371    prefixes
372}
373
374async fn discover_systems(token: &str, org: &str) -> Result<(Vec<SystemEntry>, Vec<String>)> {
375    println!("  Scanning repos...");
376    let repos = fetch_all_repos(token, org).await?;
377    let active_count = repos.iter().filter(|r| !r.archived).count();
378    println!(
379        "  Found {active_count} active repos (of {} total)",
380        repos.len()
381    );
382
383    let prefixes = discover_prefixes(&repos);
384
385    if prefixes.is_empty() {
386        println!(
387            "  {} No common prefixes found. You can add systems manually to ward.toml.",
388            style("info").blue()
389        );
390        let exclude = ask_exclude_patterns()?;
391        return Ok((Vec::new(), exclude));
392    }
393
394    println!();
395    println!("  Discovered prefixes:");
396    for p in &prefixes {
397        println!("    {} - {} repos", style(&p.prefix).cyan(), p.count);
398    }
399    println!();
400
401    let items: Vec<String> = prefixes
402        .iter()
403        .map(|p| format!("{} ({} repos)", p.prefix, p.count))
404        .collect();
405
406    let selected = MultiSelect::new()
407        .with_prompt("  Select systems to manage")
408        .items(&items)
409        .defaults(&vec![true; items.len()])
410        .interact()?;
411
412    let mut systems = Vec::new();
413    for idx in selected {
414        let p = &prefixes[idx];
415        let name: String = Input::new()
416            .with_prompt(format!("  Name for system {}", style(&p.prefix).cyan()))
417            .default(p.prefix.clone())
418            .interact_text()?;
419
420        systems.push(SystemEntry {
421            id: p.prefix.clone(),
422            name,
423            repo_count: p.count,
424        });
425    }
426
427    let exclude = ask_exclude_patterns()?;
428
429    Ok((systems, exclude))
430}
431
432fn ask_exclude_patterns() -> Result<Vec<String>> {
433    let raw: String = Input::new()
434        .with_prompt("  Exclude patterns (comma-separated, regex)")
435        .default("operations?,workflows".to_owned())
436        .interact_text()?;
437
438    let patterns: Vec<String> = raw
439        .split(',')
440        .map(|s| s.trim().to_owned())
441        .filter(|s| !s.is_empty())
442        .collect();
443
444    Ok(patterns)
445}
446
447// Step 6: Templates
448
449fn ask_templates() -> Result<TemplateSettings> {
450    let branch: String = Input::new()
451        .with_prompt("  Branch name for PRs")
452        .default("chore/ward-setup".to_owned())
453        .interact_text()?;
454
455    let reviewers_raw: String = Input::new()
456        .with_prompt("  Reviewers (comma-separated)")
457        .default(String::new())
458        .allow_empty(true)
459        .interact_text()?;
460
461    let reviewers: Vec<String> = reviewers_raw
462        .split(',')
463        .map(|s| s.trim().to_owned())
464        .filter(|s| !s.is_empty())
465        .collect();
466
467    let commit_message_prefix: String = Input::new()
468        .with_prompt("  Commit message prefix")
469        .default("chore: ".to_owned())
470        .interact_text()?;
471
472    Ok(TemplateSettings {
473        branch,
474        reviewers,
475        commit_message_prefix,
476    })
477}
478
479// TOML generation
480
481fn write_toml(state: &WizardState) -> Result<()> {
482    let mut out = String::new();
483
484    out.push_str("[org]\n");
485    out.push_str(&format!("name = {:?}\n", state.org));
486
487    out.push_str("\n[security]\n");
488    out.push_str(&format!(
489        "secret_scanning = {}\n",
490        state.security.secret_scanning
491    ));
492    out.push_str(&format!(
493        "push_protection = {}\n",
494        state.security.push_protection
495    ));
496    out.push_str(&format!(
497        "dependabot_alerts = {}\n",
498        state.security.dependabot_alerts
499    ));
500    out.push_str(&format!(
501        "dependabot_security_updates = {}\n",
502        state.security.dependabot_security_updates
503    ));
504
505    out.push_str("\n[branch_protection]\n");
506    out.push_str(&format!("enabled = {}\n", state.branch_protection.enabled));
507    if state.branch_protection.enabled {
508        out.push_str(&format!(
509            "required_approvals = {}\n",
510            state.branch_protection.required_approvals
511        ));
512        out.push_str(&format!(
513            "dismiss_stale_reviews = {}\n",
514            state.branch_protection.dismiss_stale_reviews
515        ));
516    }
517
518    out.push_str("\n[templates]\n");
519    out.push_str(&format!("branch = {:?}\n", state.templates.branch));
520    let reviewers_toml: Vec<String> = state
521        .templates
522        .reviewers
523        .iter()
524        .map(|r| format!("{r:?}"))
525        .collect();
526    out.push_str(&format!("reviewers = [{}]\n", reviewers_toml.join(", ")));
527    out.push_str(&format!(
528        "commit_message_prefix = {:?}\n",
529        state.templates.commit_message_prefix
530    ));
531
532    for sys in &state.systems {
533        out.push('\n');
534        out.push_str("[[systems]]\n");
535        out.push_str(&format!("id = {:?}\n", sys.id));
536        out.push_str(&format!("name = {:?}\n", sys.name));
537        if !state.exclude_patterns.is_empty() {
538            let exclude_toml: Vec<String> = state
539                .exclude_patterns
540                .iter()
541                .map(|p| format!("{p:?}"))
542                .collect();
543            out.push_str(&format!("exclude = [{}]\n", exclude_toml.join(", ")));
544        }
545    }
546
547    std::fs::write(OUTPUT_PATH, &out).context("Failed to write ward.toml")?;
548    Ok(())
549}
550
551fn print_summary(state: &WizardState) {
552    println!();
553    println!(
554        "  {} Created ward.toml with {} system(s) for {} ({} repos)",
555        style("ok").green(),
556        state.systems.len(),
557        style(&state.org).cyan(),
558        state.repo_count,
559    );
560    for sys in &state.systems {
561        println!(
562            "    - {} ({}) - {} repos",
563            style(&sys.id).cyan(),
564            sys.name,
565            sys.repo_count,
566        );
567    }
568    println!();
569    println!("  Next steps:");
570    println!("    ward repos list              - see matched repos");
571    println!("    ward security plan            - preview security changes");
572    println!("    ward security apply           - apply changes");
573    println!();
574}
575
576// HTTP helpers
577
578fn build_http_client(token: &str) -> Result<reqwest::Client> {
579    let mut headers = HeaderMap::new();
580    headers.insert(
581        header::ACCEPT,
582        HeaderValue::from_static("application/vnd.github+json"),
583    );
584    headers.insert(
585        "X-GitHub-Api-Version",
586        HeaderValue::from_static("2022-11-28"),
587    );
588    headers.insert(
589        header::AUTHORIZATION,
590        HeaderValue::from_str(&format!("Bearer {token}")).context("Invalid token characters")?,
591    );
592    headers.insert(
593        header::USER_AGENT,
594        HeaderValue::from_static("ward-cli/0.1.0"),
595    );
596
597    reqwest::Client::builder()
598        .default_headers(headers)
599        .build()
600        .context("Failed to build HTTP client")
601}
602
603#[cfg(test)]
604mod tests {
605    use super::*;
606
607    fn make_repo(name: &str, archived: bool) -> MinimalRepo {
608        MinimalRepo {
609            name: name.to_owned(),
610            archived,
611        }
612    }
613
614    #[test]
615    fn discover_prefixes_groups_by_first_segment() {
616        let repos = vec![
617            make_repo("s07411-foo", false),
618            make_repo("s07411-bar", false),
619            make_repo("s07411-baz", false),
620            make_repo("s07252-one", false),
621            make_repo("s07252-two", false),
622        ];
623
624        let prefixes = discover_prefixes(&repos);
625        assert_eq!(prefixes.len(), 2);
626        assert_eq!(prefixes[0].prefix, "s07411");
627        assert_eq!(prefixes[0].count, 3);
628        assert_eq!(prefixes[1].prefix, "s07252");
629        assert_eq!(prefixes[1].count, 2);
630    }
631
632    #[test]
633    fn discover_prefixes_ignores_archived() {
634        let repos = vec![
635            make_repo("s07411-foo", false),
636            make_repo("s07411-bar", true),
637            make_repo("s07411-baz", true),
638        ];
639
640        let prefixes = discover_prefixes(&repos);
641        assert!(prefixes.is_empty());
642    }
643
644    #[test]
645    fn discover_prefixes_filters_singletons() {
646        let repos = vec![
647            make_repo("s07411-foo", false),
648            make_repo("s07252-one", false),
649            make_repo("s07252-two", false),
650        ];
651
652        let prefixes = discover_prefixes(&repos);
653        assert_eq!(prefixes.len(), 1);
654        assert_eq!(prefixes[0].prefix, "s07252");
655    }
656
657    #[test]
658    fn discover_prefixes_sorts_by_count() {
659        let repos = vec![
660            make_repo("alpha-a", false),
661            make_repo("alpha-b", false),
662            make_repo("beta-a", false),
663            make_repo("beta-b", false),
664            make_repo("beta-c", false),
665            make_repo("gamma-a", false),
666            make_repo("gamma-b", false),
667            make_repo("gamma-c", false),
668            make_repo("gamma-d", false),
669        ];
670
671        let prefixes = discover_prefixes(&repos);
672        assert_eq!(prefixes.len(), 3);
673        assert_eq!(prefixes[0].prefix, "gamma");
674        assert_eq!(prefixes[0].count, 4);
675        assert_eq!(prefixes[1].prefix, "beta");
676        assert_eq!(prefixes[1].count, 3);
677        assert_eq!(prefixes[2].prefix, "alpha");
678        assert_eq!(prefixes[2].count, 2);
679    }
680
681    #[test]
682    fn discover_prefixes_handles_no_dash_names() {
683        let repos = vec![make_repo("standalone", false), make_repo("another", false)];
684
685        let prefixes = discover_prefixes(&repos);
686        assert!(prefixes.is_empty());
687    }
688
689    #[test]
690    fn discover_prefixes_empty_input() {
691        let prefixes = discover_prefixes(&[]);
692        assert!(prefixes.is_empty());
693    }
694}