1use crate::ai::{AiEvent, AiRequest, BackendConfig, resolve_backend};
2use crate::cache::{CacheLookup, CacheManager};
3use crate::git::{GitRepo, SnapshotGuard};
4use crate::ui;
5use anyhow::{Context, Result, bail};
6use indicatif::ProgressBar;
7use regex::Regex;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use tokio::sync::mpsc;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct CommitPlan {
14 pub commits: Vec<PlannedCommit>,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct PlannedCommit {
19 pub order: Option<u32>,
20 pub message: String,
21 pub body: Option<String>,
22 pub footer: Option<String>,
23 pub files: Vec<String>,
24}
25
26#[derive(Debug, clap::Args)]
27pub struct CommitArgs {
28 #[arg(short, long)]
30 pub staged: bool,
31
32 #[arg(short = 'M', long)]
34 pub message: Option<String>,
35
36 #[arg(short = 'n', long)]
38 pub dry_run: bool,
39
40 #[arg(short, long)]
42 pub yes: bool,
43
44 #[arg(long)]
46 pub no_cache: bool,
47}
48
49const COMMIT_SCHEMA: &str = r#"{
50 "type": "object",
51 "properties": {
52 "commits": {
53 "type": "array",
54 "items": {
55 "type": "object",
56 "properties": {
57 "order": { "type": "integer" },
58 "message": { "type": "string", "description": "Header: type(scope): subject — imperative, lowercase, no period, max 72 chars" },
59 "body": { "type": "string", "description": "Body: explain WHY the change was made, wrap at 72 chars" },
60 "footer": { "type": "string", "description": "Footer: BREAKING CHANGE notes, Closes/Fixes/Refs #issue, etc." },
61 "files": { "type": "array", "items": { "type": "string" } }
62 },
63 "required": ["order", "message", "body", "files"]
64 }
65 }
66 },
67 "required": ["commits"]
68}"#;
69
70fn build_system_prompt(commit_pattern: &str, type_names: &[&str]) -> String {
71 let types_list = type_names.join(", ");
72 format!(
73 r#"You are an expert at analyzing git diffs and creating atomic, well-organized commits following the Angular Conventional Commits standard.
74
75HEADER ("message" field):
76- Must match this regex: {commit_pattern}
77- Format: type(scope): subject
78- Valid types ONLY: {types_list}
79- NEVER invent types. Words like db, auth, api, etc. are scopes, not types. Use the semantically correct type for the change (e.g. feat(db): add user cache migration, fix(auth): resolve token expiry)
80- scope is optional but recommended when applicable
81- subject: imperative mood, lowercase first letter, no period at end, max 72 chars
82
83BODY ("body" field — required):
84- Explain WHY the change was made, not what changed (the diff shows that)
85- Use imperative tense ("add" not "added")
86- Wrap at 72 characters
87
88FOOTER ("footer" field — optional):
89- BREAKING CHANGE: description of what breaks and migration path
90- Closes #N, Fixes #N, Refs #N for issue references
91- Only include when relevant
92
93COMMIT ORGANIZATION:
94- Each commit must be atomic: one logical change per commit
95- Every changed file must appear in exactly one commit
96- CRITICAL: A file must NEVER appear in more than one commit. The execution engine stages entire files, not individual hunks. Splitting one file across commits will fail.
97- If one file contains multiple logical changes, place it in the most fitting commit and note the secondary changes in that commit's body.
98- Order: infrastructure/config -> core library -> features -> tests -> docs
99- File paths must be relative to the repository root and match exactly as git reports them"#
100 )
101}
102
103enum CacheStatus {
104 None,
106 Cached,
108 Incremental,
110}
111
112pub async fn run(args: &CommitArgs, backend_config: &BackendConfig) -> Result<()> {
113 ui::header("sr commit");
114
115 let repo = GitRepo::discover()?;
117 ui::phase_ok("Repository found", None);
118
119 let config = sr_core::config::ReleaseConfig::find_config(repo.root().as_path())
121 .map(|(path, _)| sr_core::config::ReleaseConfig::load(&path))
122 .transpose()?
123 .unwrap_or_default();
124 let type_names: Vec<&str> = config.types.iter().map(|t| t.name.as_str()).collect();
125 let system_prompt = build_system_prompt(&config.commit_pattern, &type_names);
126
127 let has_changes = if args.staged {
129 repo.has_staged_changes()?
130 } else {
131 repo.has_any_changes()?
132 };
133
134 if !has_changes {
135 bail!(crate::error::SrAiError::NoChanges);
136 }
137
138 let statuses = repo.file_statuses().unwrap_or_default();
139 let file_count = statuses.len();
140 ui::phase_ok(
141 "Changes detected",
142 Some(&format!(
143 "{file_count} file{}",
144 if file_count == 1 { "" } else { "s" }
145 )),
146 );
147
148 let backend = resolve_backend(backend_config).await?;
150 let backend_name = backend.name().to_string();
151 let model_name = backend_config
152 .model
153 .as_deref()
154 .unwrap_or("default")
155 .to_string();
156 ui::phase_ok(
157 "Backend resolved",
158 Some(&format!("{backend_name} ({model_name})")),
159 );
160
161 let cache = if args.no_cache {
163 None
164 } else {
165 CacheManager::new(
166 repo.root(),
167 args.staged,
168 args.message.as_deref(),
169 &backend_name,
170 &model_name,
171 )
172 };
173
174 let snapshot = SnapshotGuard::new(&repo)?;
178 ui::phase_ok("Working tree snapshot saved", None);
179
180 let (mut plan, cache_status) = match cache.as_ref().map(|c| c.lookup()) {
182 Some(CacheLookup::ExactHit(cached_plan)) => {
183 ui::phase_ok(
184 "Plan loaded",
185 Some(&format!("{} commits · cached", cached_plan.commits.len())),
186 );
187 (cached_plan, CacheStatus::Cached)
188 }
189 Some(CacheLookup::IncrementalHit {
190 previous_plan,
191 delta_summary,
192 }) => {
193 let spinner = ui::spinner(&format!(
194 "Analyzing changes with {backend_name} (incremental)..."
195 ));
196 let (tx, event_handler) = spawn_event_handler(&spinner);
197
198 let user_prompt =
199 build_incremental_prompt(args, &repo, &previous_plan, &delta_summary)?;
200
201 let request = AiRequest {
202 system_prompt: system_prompt.clone(),
203 user_prompt,
204 json_schema: Some(COMMIT_SCHEMA.to_string()),
205 working_dir: repo.root().to_string_lossy().to_string(),
206 };
207
208 let response = backend.request(&request, Some(tx)).await?;
209 let _ = event_handler.await;
210
211 let p: CommitPlan = parse_plan(&response.text)?;
212
213 let detail = format_done_detail(p.commits.len(), "incremental", &response.usage);
214 ui::spinner_done(&spinner, Some(&detail));
215
216 (p, CacheStatus::Incremental)
217 }
218 _ => {
219 let spinner = ui::spinner(&format!("Analyzing changes with {backend_name}..."));
220 let (tx, event_handler) = spawn_event_handler(&spinner);
221
222 let user_prompt = build_user_prompt(args, &repo)?;
223
224 let request = AiRequest {
225 system_prompt: system_prompt.clone(),
226 user_prompt,
227 json_schema: Some(COMMIT_SCHEMA.to_string()),
228 working_dir: repo.root().to_string_lossy().to_string(),
229 };
230
231 let response = backend.request(&request, Some(tx)).await?;
232 let _ = event_handler.await;
233
234 let p: CommitPlan = parse_plan(&response.text)?;
235
236 let detail = format_done_detail(p.commits.len(), "", &response.usage);
237 ui::spinner_done(&spinner, Some(&detail));
238
239 (p, CacheStatus::None)
240 }
241 };
242
243 if plan.commits.is_empty() {
244 bail!(crate::error::SrAiError::EmptyPlan);
245 }
246
247 let pre_validate_count = plan.commits.len();
249 plan = validate_plan(plan);
250 if plan.commits.len() < pre_validate_count {
251 ui::warn(&format!(
252 "Shared files detected — merged {} commits into 1",
253 pre_validate_count - plan.commits.len() + 1
254 ));
255 }
256
257 if let Some(cache) = &cache {
259 cache.store(&plan, &backend_name, &model_name);
260 }
261
262 let cache_label: Option<&str> = match &cache_status {
264 CacheStatus::Cached => Some("cached"),
265 CacheStatus::Incremental => Some("incremental"),
266 CacheStatus::None => None,
267 };
268 ui::display_plan(&plan, &statuses, cache_label);
269
270 if args.dry_run {
271 ui::info("Dry run — no commits created");
272 println!();
273 return Ok(());
274 }
275
276 if !args.yes && !ui::confirm("Execute plan? [y/N]")? {
278 bail!(crate::error::SrAiError::Cancelled);
279 }
280
281 let invalid = validate_messages(&plan, &config.commit_pattern);
283 if !invalid.is_empty() {
284 ui::invalid_messages(&invalid);
285 if !args.yes && !ui::confirm("Continue anyway? Invalid commits will likely fail. [y/N]")? {
286 bail!(crate::error::SrAiError::Cancelled);
287 }
288 }
289
290 execute_plan(&repo, &plan)?;
292
293 snapshot.success();
295
296 Ok(())
297}
298
299fn build_user_prompt(args: &CommitArgs, repo: &GitRepo) -> Result<String> {
300 let git_root = repo.root().to_string_lossy();
301
302 let mut prompt = if args.staged {
303 "Analyze the staged git changes and group them into atomic commits.\n\
304 Use `git diff --cached` and `git diff --cached --stat` to inspect what's staged."
305 .to_string()
306 } else {
307 "Analyze all git changes (staged, unstaged, and untracked) and group them into atomic commits.\n\
308 Use `git diff HEAD`, `git diff --cached`, `git diff`, `git status --porcelain`, and \
309 `git ls-files --others --exclude-standard` to inspect changes."
310 .to_string()
311 };
312
313 prompt.push_str(&format!("\nThe git repository root is: {git_root}"));
314
315 if let Some(msg) = &args.message {
316 prompt.push_str(&format!("\n\nAdditional context from the user:\n{msg}"));
317 }
318
319 Ok(prompt)
320}
321
322fn build_incremental_prompt(
323 args: &CommitArgs,
324 repo: &GitRepo,
325 previous_plan: &CommitPlan,
326 delta_summary: &str,
327) -> Result<String> {
328 let mut prompt = build_user_prompt(args, repo)?;
329
330 let previous_json =
331 serde_json::to_string_pretty(previous_plan).unwrap_or_else(|_| "{}".to_string());
332
333 prompt.push_str(&format!(
334 "\n\n--- INCREMENTAL HINTS ---\n\
335 A previous commit plan exists for a similar set of changes. \
336 Maintain the groupings for unchanged files where possible. \
337 Only re-analyze files that have changed.\n\n\
338 Previous plan:\n```json\n{previous_json}\n```\n\n\
339 File delta:\n{delta_summary}"
340 ));
341
342 Ok(prompt)
343}
344
345fn validate_plan(plan: CommitPlan) -> CommitPlan {
348 let mut file_counts: HashMap<String, usize> = HashMap::new();
350 for commit in &plan.commits {
351 for file in &commit.files {
352 *file_counts.entry(file.clone()).or_default() += 1;
353 }
354 }
355
356 let dupes: Vec<&String> = file_counts
357 .iter()
358 .filter(|(_, count)| **count > 1)
359 .map(|(file, _)| file)
360 .collect();
361
362 if dupes.is_empty() {
363 return plan;
364 }
365
366 let mut tainted = Vec::new();
368 let mut clean = Vec::new();
369
370 for commit in plan.commits {
371 let is_tainted = commit.files.iter().any(|f| dupes.contains(&f));
372 if is_tainted {
373 tainted.push(commit);
374 } else {
375 clean.push(commit);
376 }
377 }
378
379 let merged_message = tainted
381 .first()
382 .map(|c| c.message.clone())
383 .unwrap_or_default();
384
385 let merged_body = tainted
386 .iter()
387 .filter_map(|c| c.body.as_ref())
388 .filter(|b| !b.is_empty())
389 .cloned()
390 .collect::<Vec<_>>()
391 .join("\n\n");
392
393 let merged_footer = tainted
394 .iter()
395 .filter_map(|c| c.footer.as_ref())
396 .filter(|f| !f.is_empty())
397 .cloned()
398 .collect::<Vec<_>>()
399 .join("\n");
400
401 let mut merged_files: Vec<String> = tainted
402 .iter()
403 .flat_map(|c| c.files.iter().cloned())
404 .collect();
405 merged_files.sort();
406 merged_files.dedup();
407
408 let merged_commit = PlannedCommit {
409 order: Some(1),
410 message: merged_message,
411 body: if merged_body.is_empty() {
412 None
413 } else {
414 Some(merged_body)
415 },
416 footer: if merged_footer.is_empty() {
417 None
418 } else {
419 Some(merged_footer)
420 },
421 files: merged_files,
422 };
423
424 let mut result = vec![merged_commit];
426 for (i, mut commit) in clean.into_iter().enumerate() {
427 commit.order = Some(i as u32 + 2);
428 result.push(commit);
429 }
430
431 CommitPlan { commits: result }
432}
433
434fn parse_plan(text: &str) -> Result<CommitPlan> {
436 let value: serde_json::Value =
440 serde_json::from_str(text).context("failed to parse JSON from AI response")?;
441 serde_json::from_value(value).context("failed to parse commit plan from AI response")
442}
443
444fn spawn_event_handler(
446 spinner: &ProgressBar,
447) -> (mpsc::UnboundedSender<AiEvent>, tokio::task::JoinHandle<()>) {
448 let (tx, mut rx) = mpsc::unbounded_channel();
449 let pb = spinner.clone();
450 let handle = tokio::spawn(async move {
451 while let Some(event) = rx.recv().await {
452 match event {
453 AiEvent::ToolCall { input, .. } => ui::tool_call(&pb, &input),
454 }
455 }
456 });
457 (tx, handle)
458}
459
460fn format_done_detail(
462 commit_count: usize,
463 extra: &str,
464 usage: &Option<crate::ai::AiUsage>,
465) -> String {
466 let commits = format!(
467 "{commit_count} commit{}",
468 if commit_count == 1 { "" } else { "s" }
469 );
470 let extra_part = if extra.is_empty() {
471 String::new()
472 } else {
473 format!(" · {extra}")
474 };
475 let usage_part = match usage {
476 Some(u) => {
477 let cost = u
478 .cost_usd
479 .map(|c| format!(" · ${c:.4}"))
480 .unwrap_or_default();
481 format!(
482 " · {} in / {} out{}",
483 ui::format_tokens(u.input_tokens),
484 ui::format_tokens(u.output_tokens),
485 cost
486 )
487 }
488 None => String::new(),
489 };
490 format!("{commits}{extra_part}{usage_part}")
491}
492
493fn validate_messages(plan: &CommitPlan, commit_pattern: &str) -> Vec<(usize, String, String)> {
496 let re = match Regex::new(commit_pattern) {
497 Ok(re) => re,
498 Err(e) => {
499 return plan
501 .commits
502 .iter()
503 .enumerate()
504 .map(|(i, c)| (i + 1, c.message.clone(), format!("invalid pattern: {e}")))
505 .collect();
506 }
507 };
508
509 plan.commits
510 .iter()
511 .enumerate()
512 .filter(|(_, c)| !re.is_match(&c.message))
513 .map(|(i, c)| {
514 (
515 i + 1,
516 c.message.clone(),
517 format!("does not match pattern: {commit_pattern}"),
518 )
519 })
520 .collect()
521}
522
523fn execute_plan(repo: &GitRepo, plan: &CommitPlan) -> Result<()> {
524 repo.reset_head()?;
526
527 let total = plan.commits.len();
528 let mut created: Vec<(String, String)> = Vec::new();
529 let mut failed: Vec<(usize, String, String)> = Vec::new();
530
531 for (i, commit) in plan.commits.iter().enumerate() {
532 ui::commit_start(i + 1, total, &commit.message);
533
534 for file in &commit.files {
536 let ok = repo.stage_file(file)?;
537 ui::file_staged(file, ok);
538 }
539
540 let mut full_message = commit.message.clone();
542 if let Some(body) = &commit.body
543 && !body.is_empty()
544 {
545 full_message.push_str("\n\n");
546 full_message.push_str(body);
547 }
548 if let Some(footer) = &commit.footer
549 && !footer.is_empty()
550 {
551 full_message.push_str("\n\n");
552 full_message.push_str(footer);
553 }
554
555 if repo.has_staged_after_add()? {
557 match repo.commit(&full_message) {
558 Ok(()) => {
559 let sha = repo.head_short().unwrap_or_else(|_| "???????".to_string());
560 ui::commit_created(&sha);
561 created.push((sha, commit.message.clone()));
562 }
563 Err(e) => {
564 ui::commit_failed(&format!("{e:#}"));
565 failed.push((i + 1, commit.message.clone(), format!("{e:#}")));
566 repo.reset_head()?;
568 }
569 }
570 } else {
571 ui::commit_skipped();
572 }
573 }
574
575 ui::summary(&created);
576
577 if !failed.is_empty() {
578 ui::failed_commits(&failed);
579 if created.is_empty() {
580 bail!("all {} commits failed", failed.len());
581 }
582 }
583
584 Ok(())
585}
586
587#[cfg(test)]
588mod tests {
589 use super::*;
590
591 #[test]
592 fn validate_plan_no_dupes() {
593 let plan = CommitPlan {
594 commits: vec![
595 PlannedCommit {
596 order: Some(1),
597 message: "feat: add foo".into(),
598 body: Some("reason".into()),
599 footer: None,
600 files: vec!["a.rs".into()],
601 },
602 PlannedCommit {
603 order: Some(2),
604 message: "fix: fix bar".into(),
605 body: Some("reason".into()),
606 footer: None,
607 files: vec!["b.rs".into()],
608 },
609 ],
610 };
611
612 let result = validate_plan(plan);
613 assert_eq!(result.commits.len(), 2);
614 }
615
616 #[test]
617 fn validate_plan_merges_dupes() {
618 let plan = CommitPlan {
619 commits: vec![
620 PlannedCommit {
621 order: Some(1),
622 message: "feat: add foo".into(),
623 body: Some("reason 1".into()),
624 footer: None,
625 files: vec!["shared.rs".into(), "a.rs".into()],
626 },
627 PlannedCommit {
628 order: Some(2),
629 message: "fix: fix bar".into(),
630 body: Some("reason 2".into()),
631 footer: None,
632 files: vec!["shared.rs".into(), "b.rs".into()],
633 },
634 PlannedCommit {
635 order: Some(3),
636 message: "docs: update readme".into(),
637 body: Some("docs".into()),
638 footer: None,
639 files: vec!["README.md".into()],
640 },
641 ],
642 };
643
644 let result = validate_plan(plan);
645 assert_eq!(result.commits.len(), 2);
647 assert_eq!(result.commits[0].message, "feat: add foo");
648 assert!(result.commits[0].files.contains(&"shared.rs".to_string()));
649 assert!(result.commits[0].files.contains(&"a.rs".to_string()));
650 assert!(result.commits[0].files.contains(&"b.rs".to_string()));
651 assert_eq!(result.commits[1].message, "docs: update readme");
652 assert_eq!(result.commits[1].order, Some(2));
653 }
654
655 #[test]
656 fn validate_messages_all_valid() {
657 let plan = CommitPlan {
658 commits: vec![
659 PlannedCommit {
660 order: Some(1),
661 message: "feat: add foo".into(),
662 body: None,
663 footer: None,
664 files: vec![],
665 },
666 PlannedCommit {
667 order: Some(2),
668 message: "fix(core): null check".into(),
669 body: None,
670 footer: None,
671 files: vec![],
672 },
673 ],
674 };
675
676 let pattern = sr_core::commit::DEFAULT_COMMIT_PATTERN;
677 let invalid = validate_messages(&plan, pattern);
678 assert!(invalid.is_empty());
679 }
680
681 #[test]
682 fn validate_messages_catches_invalid() {
683 let plan = CommitPlan {
684 commits: vec![
685 PlannedCommit {
686 order: Some(1),
687 message: "feat: add foo".into(),
688 body: None,
689 footer: None,
690 files: vec![],
691 },
692 PlannedCommit {
693 order: Some(2),
694 message: "not a conventional commit".into(),
695 body: None,
696 footer: None,
697 files: vec![],
698 },
699 PlannedCommit {
700 order: Some(3),
701 message: "fix: valid one".into(),
702 body: None,
703 footer: None,
704 files: vec![],
705 },
706 ],
707 };
708
709 let pattern = sr_core::commit::DEFAULT_COMMIT_PATTERN;
710 let invalid = validate_messages(&plan, pattern);
711 assert_eq!(invalid.len(), 1);
712 assert_eq!(invalid[0].0, 2); assert_eq!(invalid[0].1, "not a conventional commit");
714 }
715
716 #[test]
717 fn validate_messages_invalid_pattern() {
718 let plan = CommitPlan {
719 commits: vec![PlannedCommit {
720 order: Some(1),
721 message: "feat: add foo".into(),
722 body: None,
723 footer: None,
724 files: vec![],
725 }],
726 };
727
728 let invalid = validate_messages(&plan, "[invalid regex");
729 assert_eq!(invalid.len(), 1);
730 assert!(invalid[0].2.contains("invalid pattern"));
731 }
732}