1use std::collections::HashMap;
2
3use anyhow::{Context, Result};
4use clap::Args;
5use console::style;
6use dialoguer::{Confirm, Input, MultiSelect};
7use reqwest::header::{self, HeaderMap, HeaderValue};
8use serde::Deserialize;
9
10use crate::config::auth;
11
12const EXAMPLE_MANIFEST: &str = r#"[org]
13name = "your-github-org"
14
15[security]
16secret_scanning = true
17secret_scanning_ai_detection = true
18push_protection = true
19dependabot_alerts = true
20dependabot_security_updates = true
21
22[templates]
23branch = "chore/ward-setup"
24reviewers = []
25commit_message_prefix = "chore: "
26
27# [[systems]]
28# id = "my-system"
29# name = "My System"
30# exclude = ["operations?", "workflows"]
31"#;
32
33const OUTPUT_PATH: &str = "ward.toml";
34
35#[derive(Args)]
36pub struct InitCommand {
37 #[arg(long)]
39 non_interactive: bool,
40}
41
42impl InitCommand {
43 pub async fn run(&self) -> Result<()> {
44 if self.non_interactive {
45 return write_default();
46 }
47 run_wizard().await
48 }
49}
50
51fn write_default() -> Result<()> {
52 if std::path::Path::new(OUTPUT_PATH).exists() {
53 println!(" {} ward.toml already exists.", style("warning").yellow());
54 return Ok(());
55 }
56
57 std::fs::write(OUTPUT_PATH, EXAMPLE_MANIFEST)?;
58 println!(
59 " {} Created ward.toml - edit it to configure your org and systems.",
60 style("ok").green()
61 );
62 Ok(())
63}
64
65struct WizardState {
68 org: String,
69 repo_count: usize,
70 security: SecuritySettings,
71 branch_protection: BranchProtectionSettings,
72 systems: Vec<SystemEntry>,
73 exclude_patterns: Vec<String>,
74 templates: TemplateSettings,
75}
76
77struct SecuritySettings {
78 secret_scanning: bool,
79 push_protection: bool,
80 dependabot_alerts: bool,
81 dependabot_security_updates: bool,
82}
83
84struct BranchProtectionSettings {
85 enabled: bool,
86 required_approvals: u32,
87 dismiss_stale_reviews: bool,
88}
89
90struct SystemEntry {
91 id: String,
92 name: String,
93 repo_count: usize,
94}
95
96struct TemplateSettings {
97 branch: String,
98 reviewers: Vec<String>,
99 commit_message_prefix: String,
100}
101
102fn print_banner() {
103 println!();
104 println!(" {}", style("+---------------------------------+").cyan());
105 println!(" {}", style("| Ward Setup Wizard |").cyan());
106 println!(" {}", style("+---------------------------------+").cyan());
107 println!();
108}
109
110fn print_step(step: u8, total: u8, title: &str) {
111 println!();
112 println!(
113 " {} {}",
114 style(format!("Step {step}/{total}:")).bold(),
115 style(title).bold()
116 );
117}
118
119async fn run_wizard() -> Result<()> {
120 if std::path::Path::new(OUTPUT_PATH).exists() {
121 println!(" {} ward.toml already exists.", style("warning").yellow());
122 return Ok(());
123 }
124
125 print_banner();
126
127 let total_steps = 6;
128
129 print_step(1, total_steps, "Authentication");
131 let token = check_auth()?;
132
133 print_step(2, total_steps, "Organization");
135 let (org, repo_count) = ask_org(&token).await?;
136
137 print_step(3, total_steps, "Security Settings");
139 let security = ask_security()?;
140
141 print_step(4, total_steps, "Branch Protection");
143 let branch_protection = ask_branch_protection()?;
144
145 print_step(5, total_steps, "Systems");
147 let (systems, exclude_patterns) = discover_systems(&token, &org).await?;
148
149 print_step(6, total_steps, "Templates");
151 let templates = ask_templates()?;
152
153 let state = WizardState {
154 org,
155 repo_count,
156 security,
157 branch_protection,
158 systems,
159 exclude_patterns,
160 templates,
161 };
162
163 write_toml(&state)?;
164 print_summary(&state);
165
166 Ok(())
167}
168
169fn check_auth() -> Result<String> {
172 match auth::resolve_token() {
173 Ok(token) => {
174 let source = if std::env::var("GH_TOKEN").is_ok() {
175 "GH_TOKEN"
176 } else if std::env::var("GITHUB_TOKEN").is_ok() {
177 "GITHUB_TOKEN"
178 } else {
179 "gh auth token"
180 };
181 println!(" {} Token found via {source}", style("ok").green());
182 Ok(token)
183 }
184 Err(e) => {
185 println!(" {} No GitHub token found.", style("error").red());
186 println!(" Set GH_TOKEN, GITHUB_TOKEN, or run `gh auth login`.");
187 Err(e)
188 }
189 }
190}
191
192async fn verify_org(token: &str, org: &str) -> Result<usize> {
195 let client = build_http_client(token)?;
196 let resp = client
197 .get(format!("https://api.github.com/orgs/{org}"))
198 .send()
199 .await
200 .context("Failed to reach GitHub API")?;
201
202 let status = resp.status();
203 if status == reqwest::StatusCode::NOT_FOUND {
204 anyhow::bail!("Organization '{org}' not found");
205 }
206 if !status.is_success() {
207 let body = resp.text().await.unwrap_or_default();
208 anyhow::bail!("Failed to verify org (HTTP {status}): {body}");
209 }
210
211 let _ = resp.bytes().await;
213 let repos = fetch_all_repos(token, org).await?;
214 Ok(repos.len())
215}
216
217async fn ask_org(token: &str) -> Result<(String, usize)> {
218 loop {
219 let org: String = Input::new()
220 .with_prompt(" GitHub organization name")
221 .interact_text()?;
222
223 match verify_org(token, &org).await {
224 Ok(count) => {
225 println!(
226 " {} Organization verified ({count} repos)",
227 style("ok").green()
228 );
229 return Ok((org, count));
230 }
231 Err(e) => {
232 println!(" {} {e}", style("error").red());
233 println!(" Please try again.");
234 }
235 }
236 }
237}
238
239fn ask_security() -> Result<SecuritySettings> {
242 let secret_scanning = Confirm::new()
243 .with_prompt(" Enable secret scanning?")
244 .default(true)
245 .interact()?;
246
247 let push_protection = Confirm::new()
248 .with_prompt(" Enable push protection?")
249 .default(true)
250 .interact()?;
251
252 let dependabot_alerts = Confirm::new()
253 .with_prompt(" Enable Dependabot alerts?")
254 .default(true)
255 .interact()?;
256
257 let dependabot_security_updates = Confirm::new()
258 .with_prompt(" Enable Dependabot security updates?")
259 .default(true)
260 .interact()?;
261
262 Ok(SecuritySettings {
263 secret_scanning,
264 push_protection,
265 dependabot_alerts,
266 dependabot_security_updates,
267 })
268}
269
270fn ask_branch_protection() -> Result<BranchProtectionSettings> {
273 let enabled = Confirm::new()
274 .with_prompt(" Enable branch protection?")
275 .default(true)
276 .interact()?;
277
278 if !enabled {
279 return Ok(BranchProtectionSettings {
280 enabled: false,
281 required_approvals: 1,
282 dismiss_stale_reviews: false,
283 });
284 }
285
286 let approvals: String = Input::new()
287 .with_prompt(" Required approvals")
288 .default("1".to_owned())
289 .interact_text()?;
290 let required_approvals: u32 = approvals.parse().unwrap_or(1);
291
292 let dismiss_stale_reviews = Confirm::new()
293 .with_prompt(" Dismiss stale reviews?")
294 .default(true)
295 .interact()?;
296
297 Ok(BranchProtectionSettings {
298 enabled,
299 required_approvals,
300 dismiss_stale_reviews,
301 })
302}
303
304#[derive(Debug, Clone)]
307struct DiscoveredPrefix {
308 prefix: String,
309 count: usize,
310}
311
312#[derive(Debug, Clone, Deserialize)]
313struct MinimalRepo {
314 name: String,
315 archived: bool,
316}
317
318async fn fetch_all_repos(token: &str, org: &str) -> Result<Vec<MinimalRepo>> {
319 let client = build_http_client(token)?;
320 let mut all = Vec::new();
321 let mut page = 1u32;
322
323 loop {
324 let resp = client
325 .get(format!(
326 "https://api.github.com/orgs/{org}/repos?per_page=100&page={page}&type=all"
327 ))
328 .send()
329 .await
330 .context("Failed to fetch repos")?;
331
332 let status = resp.status();
333 if !status.is_success() {
334 let body = resp.text().await.unwrap_or_default();
335 anyhow::bail!("Failed to list repos (HTTP {status}): {body}");
336 }
337
338 let repos: Vec<MinimalRepo> = resp.json().await.context("Failed to parse repos")?;
339 if repos.is_empty() {
340 break;
341 }
342
343 all.extend(repos);
344 page += 1;
345 }
346
347 Ok(all)
348}
349
350fn discover_prefixes(repos: &[MinimalRepo]) -> Vec<DiscoveredPrefix> {
351 let mut counts: HashMap<String, usize> = HashMap::new();
352
353 for repo in repos {
354 if repo.archived {
355 continue;
356 }
357 if let Some(prefix) = repo.name.split('-').next()
358 && !prefix.is_empty()
359 {
360 *counts.entry(prefix.to_owned()).or_default() += 1;
361 }
362 }
363
364 let mut prefixes: Vec<DiscoveredPrefix> = counts
365 .into_iter()
366 .filter(|(_, count)| *count >= 2)
367 .map(|(prefix, count)| DiscoveredPrefix { prefix, count })
368 .collect();
369
370 prefixes.sort_by(|a, b| b.count.cmp(&a.count));
371 prefixes
372}
373
374async fn discover_systems(token: &str, org: &str) -> Result<(Vec<SystemEntry>, Vec<String>)> {
375 println!(" Scanning repos...");
376 let repos = fetch_all_repos(token, org).await?;
377 let active_count = repos.iter().filter(|r| !r.archived).count();
378 println!(
379 " Found {active_count} active repos (of {} total)",
380 repos.len()
381 );
382
383 let prefixes = discover_prefixes(&repos);
384
385 if prefixes.is_empty() {
386 println!(
387 " {} No common prefixes found. You can add systems manually to ward.toml.",
388 style("info").blue()
389 );
390 let exclude = ask_exclude_patterns()?;
391 return Ok((Vec::new(), exclude));
392 }
393
394 println!();
395 println!(" Discovered prefixes:");
396 for p in &prefixes {
397 println!(" {} - {} repos", style(&p.prefix).cyan(), p.count);
398 }
399 println!();
400
401 let items: Vec<String> = prefixes
402 .iter()
403 .map(|p| format!("{} ({} repos)", p.prefix, p.count))
404 .collect();
405
406 let selected = MultiSelect::new()
407 .with_prompt(" Select systems to manage")
408 .items(&items)
409 .defaults(&vec![true; items.len()])
410 .interact()?;
411
412 let mut systems = Vec::new();
413 for idx in selected {
414 let p = &prefixes[idx];
415 let name: String = Input::new()
416 .with_prompt(format!(" Name for system {}", style(&p.prefix).cyan()))
417 .default(p.prefix.clone())
418 .interact_text()?;
419
420 systems.push(SystemEntry {
421 id: p.prefix.clone(),
422 name,
423 repo_count: p.count,
424 });
425 }
426
427 let exclude = ask_exclude_patterns()?;
428
429 Ok((systems, exclude))
430}
431
432fn ask_exclude_patterns() -> Result<Vec<String>> {
433 let raw: String = Input::new()
434 .with_prompt(" Exclude patterns (comma-separated, regex)")
435 .default("operations?,workflows".to_owned())
436 .interact_text()?;
437
438 let patterns: Vec<String> = raw
439 .split(',')
440 .map(|s| s.trim().to_owned())
441 .filter(|s| !s.is_empty())
442 .collect();
443
444 Ok(patterns)
445}
446
447fn ask_templates() -> Result<TemplateSettings> {
450 let branch: String = Input::new()
451 .with_prompt(" Branch name for PRs")
452 .default("chore/ward-setup".to_owned())
453 .interact_text()?;
454
455 let reviewers_raw: String = Input::new()
456 .with_prompt(" Reviewers (comma-separated)")
457 .default(String::new())
458 .allow_empty(true)
459 .interact_text()?;
460
461 let reviewers: Vec<String> = reviewers_raw
462 .split(',')
463 .map(|s| s.trim().to_owned())
464 .filter(|s| !s.is_empty())
465 .collect();
466
467 let commit_message_prefix: String = Input::new()
468 .with_prompt(" Commit message prefix")
469 .default("chore: ".to_owned())
470 .interact_text()?;
471
472 Ok(TemplateSettings {
473 branch,
474 reviewers,
475 commit_message_prefix,
476 })
477}
478
479fn write_toml(state: &WizardState) -> Result<()> {
482 let mut out = String::new();
483
484 out.push_str("[org]\n");
485 out.push_str(&format!("name = {:?}\n", state.org));
486
487 out.push_str("\n[security]\n");
488 out.push_str(&format!(
489 "secret_scanning = {}\n",
490 state.security.secret_scanning
491 ));
492 out.push_str(&format!(
493 "push_protection = {}\n",
494 state.security.push_protection
495 ));
496 out.push_str(&format!(
497 "dependabot_alerts = {}\n",
498 state.security.dependabot_alerts
499 ));
500 out.push_str(&format!(
501 "dependabot_security_updates = {}\n",
502 state.security.dependabot_security_updates
503 ));
504
505 out.push_str("\n[branch_protection]\n");
506 out.push_str(&format!("enabled = {}\n", state.branch_protection.enabled));
507 if state.branch_protection.enabled {
508 out.push_str(&format!(
509 "required_approvals = {}\n",
510 state.branch_protection.required_approvals
511 ));
512 out.push_str(&format!(
513 "dismiss_stale_reviews = {}\n",
514 state.branch_protection.dismiss_stale_reviews
515 ));
516 }
517
518 out.push_str("\n[templates]\n");
519 out.push_str(&format!("branch = {:?}\n", state.templates.branch));
520 let reviewers_toml: Vec<String> = state
521 .templates
522 .reviewers
523 .iter()
524 .map(|r| format!("{r:?}"))
525 .collect();
526 out.push_str(&format!("reviewers = [{}]\n", reviewers_toml.join(", ")));
527 out.push_str(&format!(
528 "commit_message_prefix = {:?}\n",
529 state.templates.commit_message_prefix
530 ));
531
532 for sys in &state.systems {
533 out.push('\n');
534 out.push_str("[[systems]]\n");
535 out.push_str(&format!("id = {:?}\n", sys.id));
536 out.push_str(&format!("name = {:?}\n", sys.name));
537 if !state.exclude_patterns.is_empty() {
538 let exclude_toml: Vec<String> = state
539 .exclude_patterns
540 .iter()
541 .map(|p| format!("{p:?}"))
542 .collect();
543 out.push_str(&format!("exclude = [{}]\n", exclude_toml.join(", ")));
544 }
545 }
546
547 std::fs::write(OUTPUT_PATH, &out).context("Failed to write ward.toml")?;
548 Ok(())
549}
550
551fn print_summary(state: &WizardState) {
552 println!();
553 println!(
554 " {} Created ward.toml with {} system(s) for {} ({} repos)",
555 style("ok").green(),
556 state.systems.len(),
557 style(&state.org).cyan(),
558 state.repo_count,
559 );
560 for sys in &state.systems {
561 println!(
562 " - {} ({}) - {} repos",
563 style(&sys.id).cyan(),
564 sys.name,
565 sys.repo_count,
566 );
567 }
568 println!();
569 println!(" Next steps:");
570 println!(" ward repos list - see matched repos");
571 println!(" ward security plan - preview security changes");
572 println!(" ward security apply - apply changes");
573 println!();
574}
575
576fn build_http_client(token: &str) -> Result<reqwest::Client> {
579 let mut headers = HeaderMap::new();
580 headers.insert(
581 header::ACCEPT,
582 HeaderValue::from_static("application/vnd.github+json"),
583 );
584 headers.insert(
585 "X-GitHub-Api-Version",
586 HeaderValue::from_static("2022-11-28"),
587 );
588 headers.insert(
589 header::AUTHORIZATION,
590 HeaderValue::from_str(&format!("Bearer {token}")).context("Invalid token characters")?,
591 );
592 headers.insert(
593 header::USER_AGENT,
594 HeaderValue::from_static("ward-cli/0.1.0"),
595 );
596
597 reqwest::Client::builder()
598 .default_headers(headers)
599 .build()
600 .context("Failed to build HTTP client")
601}
602
603#[cfg(test)]
604mod tests {
605 use super::*;
606
607 fn make_repo(name: &str, archived: bool) -> MinimalRepo {
608 MinimalRepo {
609 name: name.to_owned(),
610 archived,
611 }
612 }
613
614 #[test]
615 fn discover_prefixes_groups_by_first_segment() {
616 let repos = vec![
617 make_repo("s07411-foo", false),
618 make_repo("s07411-bar", false),
619 make_repo("s07411-baz", false),
620 make_repo("s07252-one", false),
621 make_repo("s07252-two", false),
622 ];
623
624 let prefixes = discover_prefixes(&repos);
625 assert_eq!(prefixes.len(), 2);
626 assert_eq!(prefixes[0].prefix, "s07411");
627 assert_eq!(prefixes[0].count, 3);
628 assert_eq!(prefixes[1].prefix, "s07252");
629 assert_eq!(prefixes[1].count, 2);
630 }
631
632 #[test]
633 fn discover_prefixes_ignores_archived() {
634 let repos = vec![
635 make_repo("s07411-foo", false),
636 make_repo("s07411-bar", true),
637 make_repo("s07411-baz", true),
638 ];
639
640 let prefixes = discover_prefixes(&repos);
641 assert!(prefixes.is_empty());
642 }
643
644 #[test]
645 fn discover_prefixes_filters_singletons() {
646 let repos = vec![
647 make_repo("s07411-foo", false),
648 make_repo("s07252-one", false),
649 make_repo("s07252-two", false),
650 ];
651
652 let prefixes = discover_prefixes(&repos);
653 assert_eq!(prefixes.len(), 1);
654 assert_eq!(prefixes[0].prefix, "s07252");
655 }
656
657 #[test]
658 fn discover_prefixes_sorts_by_count() {
659 let repos = vec![
660 make_repo("alpha-a", false),
661 make_repo("alpha-b", false),
662 make_repo("beta-a", false),
663 make_repo("beta-b", false),
664 make_repo("beta-c", false),
665 make_repo("gamma-a", false),
666 make_repo("gamma-b", false),
667 make_repo("gamma-c", false),
668 make_repo("gamma-d", false),
669 ];
670
671 let prefixes = discover_prefixes(&repos);
672 assert_eq!(prefixes.len(), 3);
673 assert_eq!(prefixes[0].prefix, "gamma");
674 assert_eq!(prefixes[0].count, 4);
675 assert_eq!(prefixes[1].prefix, "beta");
676 assert_eq!(prefixes[1].count, 3);
677 assert_eq!(prefixes[2].prefix, "alpha");
678 assert_eq!(prefixes[2].count, 2);
679 }
680
681 #[test]
682 fn discover_prefixes_handles_no_dash_names() {
683 let repos = vec![make_repo("standalone", false), make_repo("another", false)];
684
685 let prefixes = discover_prefixes(&repos);
686 assert!(prefixes.is_empty());
687 }
688
689 #[test]
690 fn discover_prefixes_empty_input() {
691 let prefixes = discover_prefixes(&[]);
692 assert!(prefixes.is_empty());
693 }
694}