1use crate::ai::{AiEvent, AiRequest, BackendConfig, resolve_backend};
2use crate::cache::{CacheLookup, CacheManager};
3use crate::git::{GitRepo, SnapshotGuard};
4use crate::ui;
5use anyhow::{Context, Result, bail};
6use indicatif::ProgressBar;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use tokio::sync::mpsc;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct CommitPlan {
13 pub commits: Vec<PlannedCommit>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct PlannedCommit {
18 pub order: Option<u32>,
19 pub message: String,
20 pub body: Option<String>,
21 pub footer: Option<String>,
22 pub files: Vec<String>,
23}
24
25#[derive(Debug, clap::Args)]
26pub struct CommitArgs {
27 #[arg(short, long)]
29 pub staged: bool,
30
31 #[arg(short = 'M', long)]
33 pub message: Option<String>,
34
35 #[arg(short = 'n', long)]
37 pub dry_run: bool,
38
39 #[arg(short, long)]
41 pub yes: bool,
42
43 #[arg(long)]
45 pub no_cache: bool,
46}
47
48const COMMIT_SCHEMA: &str = r#"{
49 "type": "object",
50 "properties": {
51 "commits": {
52 "type": "array",
53 "items": {
54 "type": "object",
55 "properties": {
56 "order": { "type": "integer" },
57 "message": { "type": "string", "description": "Header: type(scope): subject — imperative, lowercase, no period, max 72 chars" },
58 "body": { "type": "string", "description": "Body: explain WHY the change was made, wrap at 72 chars" },
59 "footer": { "type": "string", "description": "Footer: BREAKING CHANGE notes, Closes/Fixes/Refs #issue, etc." },
60 "files": { "type": "array", "items": { "type": "string" } }
61 },
62 "required": ["order", "message", "body", "files"]
63 }
64 }
65 },
66 "required": ["commits"]
67}"#;
68
69fn build_system_prompt(commit_pattern: &str, type_names: &[&str]) -> String {
70 let types_list = type_names.join(", ");
71 format!(
72 r#"You are an expert at analyzing git diffs and creating atomic, well-organized commits following the Angular Conventional Commits standard.
73
74HEADER ("message" field):
75- Must match this regex: {commit_pattern}
76- Format: type(scope): subject
77- Valid types ONLY: {types_list}
78- NEVER invent types. Words like db, auth, api, etc. are scopes, not types. Use the semantically correct type for the change (e.g. feat(db): add user cache migration, fix(auth): resolve token expiry)
79- scope is optional but recommended when applicable
80- subject: imperative mood, lowercase first letter, no period at end, max 72 chars
81
82BODY ("body" field — required):
83- Explain WHY the change was made, not what changed (the diff shows that)
84- Use imperative tense ("add" not "added")
85- Wrap at 72 characters
86
87FOOTER ("footer" field — optional):
88- BREAKING CHANGE: description of what breaks and migration path
89- Closes #N, Fixes #N, Refs #N for issue references
90- Only include when relevant
91
92COMMIT ORGANIZATION:
93- Each commit must be atomic: one logical change per commit
94- Every changed file must appear in exactly one commit
95- CRITICAL: A file must NEVER appear in more than one commit. The execution engine stages entire files, not individual hunks. Splitting one file across commits will fail.
96- If one file contains multiple logical changes, place it in the most fitting commit and note the secondary changes in that commit's body.
97- Order: infrastructure/config -> core library -> features -> tests -> docs
98- File paths must be relative to the repository root and match exactly as git reports them"#
99 )
100}
101
102enum CacheStatus {
103 None,
105 Cached,
107 Incremental,
109}
110
111pub async fn run(args: &CommitArgs, backend_config: &BackendConfig) -> Result<()> {
112 ui::header("sr commit");
113
114 let repo = GitRepo::discover()?;
116 ui::phase_ok("Repository found", None);
117
118 let config = sr_core::config::ReleaseConfig::find_config(repo.root().as_path())
120 .map(|(path, _)| sr_core::config::ReleaseConfig::load(&path))
121 .transpose()?
122 .unwrap_or_default();
123 let type_names: Vec<&str> = config.types.iter().map(|t| t.name.as_str()).collect();
124 let system_prompt = build_system_prompt(&config.commit_pattern, &type_names);
125
126 let has_changes = if args.staged {
128 repo.has_staged_changes()?
129 } else {
130 repo.has_any_changes()?
131 };
132
133 if !has_changes {
134 bail!(crate::error::SrAiError::NoChanges);
135 }
136
137 let statuses = repo.file_statuses().unwrap_or_default();
138 let file_count = statuses.len();
139 ui::phase_ok(
140 "Changes detected",
141 Some(&format!(
142 "{file_count} file{}",
143 if file_count == 1 { "" } else { "s" }
144 )),
145 );
146
147 let backend = resolve_backend(backend_config).await?;
149 let backend_name = backend.name().to_string();
150 let model_name = backend_config
151 .model
152 .as_deref()
153 .unwrap_or("default")
154 .to_string();
155 ui::phase_ok(
156 "Backend resolved",
157 Some(&format!("{backend_name} ({model_name})")),
158 );
159
160 let cache = if args.no_cache {
162 None
163 } else {
164 CacheManager::new(
165 repo.root(),
166 args.staged,
167 args.message.as_deref(),
168 &backend_name,
169 &model_name,
170 )
171 };
172
173 let snapshot = SnapshotGuard::new(&repo)?;
177 ui::phase_ok("Working tree snapshot saved", None);
178
179 let (mut plan, cache_status) = match cache.as_ref().map(|c| c.lookup()) {
181 Some(CacheLookup::ExactHit(cached_plan)) => {
182 ui::phase_ok(
183 "Plan loaded",
184 Some(&format!("{} commits · cached", cached_plan.commits.len())),
185 );
186 (cached_plan, CacheStatus::Cached)
187 }
188 Some(CacheLookup::IncrementalHit {
189 previous_plan,
190 delta_summary,
191 }) => {
192 let spinner = ui::spinner(&format!(
193 "Analyzing changes with {backend_name} (incremental)..."
194 ));
195 let (tx, event_handler) = spawn_event_handler(&spinner);
196
197 let user_prompt =
198 build_incremental_prompt(args, &repo, &previous_plan, &delta_summary)?;
199
200 let request = AiRequest {
201 system_prompt: system_prompt.clone(),
202 user_prompt,
203 json_schema: Some(COMMIT_SCHEMA.to_string()),
204 working_dir: repo.root().to_string_lossy().to_string(),
205 };
206
207 let response = backend.request(&request, Some(tx)).await?;
208 let _ = event_handler.await;
209
210 let p: CommitPlan = parse_plan(&response.text)?;
211
212 let detail = format_done_detail(p.commits.len(), "incremental", &response.usage);
213 ui::spinner_done(&spinner, Some(&detail));
214
215 (p, CacheStatus::Incremental)
216 }
217 _ => {
218 let spinner = ui::spinner(&format!("Analyzing changes with {backend_name}..."));
219 let (tx, event_handler) = spawn_event_handler(&spinner);
220
221 let user_prompt = build_user_prompt(args, &repo)?;
222
223 let request = AiRequest {
224 system_prompt: system_prompt.clone(),
225 user_prompt,
226 json_schema: Some(COMMIT_SCHEMA.to_string()),
227 working_dir: repo.root().to_string_lossy().to_string(),
228 };
229
230 let response = backend.request(&request, Some(tx)).await?;
231 let _ = event_handler.await;
232
233 let p: CommitPlan = parse_plan(&response.text)?;
234
235 let detail = format_done_detail(p.commits.len(), "", &response.usage);
236 ui::spinner_done(&spinner, Some(&detail));
237
238 (p, CacheStatus::None)
239 }
240 };
241
242 if plan.commits.is_empty() {
243 bail!(crate::error::SrAiError::EmptyPlan);
244 }
245
246 let pre_validate_count = plan.commits.len();
248 plan = validate_plan(plan);
249 if plan.commits.len() < pre_validate_count {
250 ui::warn(&format!(
251 "Shared files detected — merged {} commits into 1",
252 pre_validate_count - plan.commits.len() + 1
253 ));
254 }
255
256 if let Some(cache) = &cache {
258 cache.store(&plan, &backend_name, &model_name);
259 }
260
261 let cache_label: Option<&str> = match &cache_status {
263 CacheStatus::Cached => Some("cached"),
264 CacheStatus::Incremental => Some("incremental"),
265 CacheStatus::None => None,
266 };
267 ui::display_plan(&plan, &statuses, cache_label);
268
269 if args.dry_run {
270 ui::info("Dry run — no commits created");
271 println!();
272 return Ok(());
273 }
274
275 if !args.yes && !ui::confirm("Execute plan? [y/N]")? {
277 bail!(crate::error::SrAiError::Cancelled);
278 }
279
280 execute_plan(&repo, &plan)?;
282
283 snapshot.success();
285
286 Ok(())
287}
288
289fn build_user_prompt(args: &CommitArgs, repo: &GitRepo) -> Result<String> {
290 let git_root = repo.root().to_string_lossy();
291
292 let mut prompt = if args.staged {
293 "Analyze the staged git changes and group them into atomic commits.\n\
294 Use `git diff --cached` and `git diff --cached --stat` to inspect what's staged."
295 .to_string()
296 } else {
297 "Analyze all git changes (staged, unstaged, and untracked) and group them into atomic commits.\n\
298 Use `git diff HEAD`, `git diff --cached`, `git diff`, `git status --porcelain`, and \
299 `git ls-files --others --exclude-standard` to inspect changes."
300 .to_string()
301 };
302
303 prompt.push_str(&format!("\nThe git repository root is: {git_root}"));
304
305 if let Some(msg) = &args.message {
306 prompt.push_str(&format!("\n\nAdditional context from the user:\n{msg}"));
307 }
308
309 Ok(prompt)
310}
311
312fn build_incremental_prompt(
313 args: &CommitArgs,
314 repo: &GitRepo,
315 previous_plan: &CommitPlan,
316 delta_summary: &str,
317) -> Result<String> {
318 let mut prompt = build_user_prompt(args, repo)?;
319
320 let previous_json =
321 serde_json::to_string_pretty(previous_plan).unwrap_or_else(|_| "{}".to_string());
322
323 prompt.push_str(&format!(
324 "\n\n--- INCREMENTAL HINTS ---\n\
325 A previous commit plan exists for a similar set of changes. \
326 Maintain the groupings for unchanged files where possible. \
327 Only re-analyze files that have changed.\n\n\
328 Previous plan:\n```json\n{previous_json}\n```\n\n\
329 File delta:\n{delta_summary}"
330 ));
331
332 Ok(prompt)
333}
334
335fn validate_plan(plan: CommitPlan) -> CommitPlan {
338 let mut file_counts: HashMap<String, usize> = HashMap::new();
340 for commit in &plan.commits {
341 for file in &commit.files {
342 *file_counts.entry(file.clone()).or_default() += 1;
343 }
344 }
345
346 let dupes: Vec<&String> = file_counts
347 .iter()
348 .filter(|(_, count)| **count > 1)
349 .map(|(file, _)| file)
350 .collect();
351
352 if dupes.is_empty() {
353 return plan;
354 }
355
356 let mut tainted = Vec::new();
358 let mut clean = Vec::new();
359
360 for commit in plan.commits {
361 let is_tainted = commit.files.iter().any(|f| dupes.contains(&f));
362 if is_tainted {
363 tainted.push(commit);
364 } else {
365 clean.push(commit);
366 }
367 }
368
369 let merged_message = tainted
371 .first()
372 .map(|c| c.message.clone())
373 .unwrap_or_default();
374
375 let merged_body = tainted
376 .iter()
377 .filter_map(|c| c.body.as_ref())
378 .filter(|b| !b.is_empty())
379 .cloned()
380 .collect::<Vec<_>>()
381 .join("\n\n");
382
383 let merged_footer = tainted
384 .iter()
385 .filter_map(|c| c.footer.as_ref())
386 .filter(|f| !f.is_empty())
387 .cloned()
388 .collect::<Vec<_>>()
389 .join("\n");
390
391 let mut merged_files: Vec<String> = tainted
392 .iter()
393 .flat_map(|c| c.files.iter().cloned())
394 .collect();
395 merged_files.sort();
396 merged_files.dedup();
397
398 let merged_commit = PlannedCommit {
399 order: Some(1),
400 message: merged_message,
401 body: if merged_body.is_empty() {
402 None
403 } else {
404 Some(merged_body)
405 },
406 footer: if merged_footer.is_empty() {
407 None
408 } else {
409 Some(merged_footer)
410 },
411 files: merged_files,
412 };
413
414 let mut result = vec![merged_commit];
416 for (i, mut commit) in clean.into_iter().enumerate() {
417 commit.order = Some(i as u32 + 2);
418 result.push(commit);
419 }
420
421 CommitPlan { commits: result }
422}
423
424fn parse_plan(text: &str) -> Result<CommitPlan> {
426 let value: serde_json::Value =
430 serde_json::from_str(text).context("failed to parse JSON from AI response")?;
431 serde_json::from_value(value).context("failed to parse commit plan from AI response")
432}
433
434fn spawn_event_handler(
436 spinner: &ProgressBar,
437) -> (mpsc::UnboundedSender<AiEvent>, tokio::task::JoinHandle<()>) {
438 let (tx, mut rx) = mpsc::unbounded_channel();
439 let pb = spinner.clone();
440 let handle = tokio::spawn(async move {
441 while let Some(event) = rx.recv().await {
442 match event {
443 AiEvent::ToolCall { input, .. } => ui::tool_call(&pb, &input),
444 }
445 }
446 });
447 (tx, handle)
448}
449
450fn format_done_detail(
452 commit_count: usize,
453 extra: &str,
454 usage: &Option<crate::ai::AiUsage>,
455) -> String {
456 let commits = format!(
457 "{commit_count} commit{}",
458 if commit_count == 1 { "" } else { "s" }
459 );
460 let extra_part = if extra.is_empty() {
461 String::new()
462 } else {
463 format!(" · {extra}")
464 };
465 let usage_part = match usage {
466 Some(u) => {
467 let cost = u
468 .cost_usd
469 .map(|c| format!(" · ${c:.4}"))
470 .unwrap_or_default();
471 format!(
472 " · {} in / {} out{}",
473 ui::format_tokens(u.input_tokens),
474 ui::format_tokens(u.output_tokens),
475 cost
476 )
477 }
478 None => String::new(),
479 };
480 format!("{commits}{extra_part}{usage_part}")
481}
482
483fn execute_plan(repo: &GitRepo, plan: &CommitPlan) -> Result<()> {
484 repo.reset_head()?;
486
487 let total = plan.commits.len();
488 let mut created: Vec<(String, String)> = Vec::new();
489
490 for (i, commit) in plan.commits.iter().enumerate() {
491 ui::commit_start(i + 1, total, &commit.message);
492
493 for file in &commit.files {
495 let ok = repo.stage_file(file)?;
496 ui::file_staged(file, ok);
497 }
498
499 let mut full_message = commit.message.clone();
501 if let Some(body) = &commit.body
502 && !body.is_empty()
503 {
504 full_message.push_str("\n\n");
505 full_message.push_str(body);
506 }
507 if let Some(footer) = &commit.footer
508 && !footer.is_empty()
509 {
510 full_message.push_str("\n\n");
511 full_message.push_str(footer);
512 }
513
514 if repo.has_staged_after_add()? {
516 repo.commit(&full_message)?;
517 let sha = repo.head_short().unwrap_or_else(|_| "???????".to_string());
518 ui::commit_created(&sha);
519 created.push((sha, commit.message.clone()));
520 } else {
521 ui::commit_skipped();
522 }
523 }
524
525 ui::summary(&created);
526
527 Ok(())
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533
534 #[test]
535 fn validate_plan_no_dupes() {
536 let plan = CommitPlan {
537 commits: vec![
538 PlannedCommit {
539 order: Some(1),
540 message: "feat: add foo".into(),
541 body: Some("reason".into()),
542 footer: None,
543 files: vec!["a.rs".into()],
544 },
545 PlannedCommit {
546 order: Some(2),
547 message: "fix: fix bar".into(),
548 body: Some("reason".into()),
549 footer: None,
550 files: vec!["b.rs".into()],
551 },
552 ],
553 };
554
555 let result = validate_plan(plan);
556 assert_eq!(result.commits.len(), 2);
557 }
558
559 #[test]
560 fn validate_plan_merges_dupes() {
561 let plan = CommitPlan {
562 commits: vec![
563 PlannedCommit {
564 order: Some(1),
565 message: "feat: add foo".into(),
566 body: Some("reason 1".into()),
567 footer: None,
568 files: vec!["shared.rs".into(), "a.rs".into()],
569 },
570 PlannedCommit {
571 order: Some(2),
572 message: "fix: fix bar".into(),
573 body: Some("reason 2".into()),
574 footer: None,
575 files: vec!["shared.rs".into(), "b.rs".into()],
576 },
577 PlannedCommit {
578 order: Some(3),
579 message: "docs: update readme".into(),
580 body: Some("docs".into()),
581 footer: None,
582 files: vec!["README.md".into()],
583 },
584 ],
585 };
586
587 let result = validate_plan(plan);
588 assert_eq!(result.commits.len(), 2);
590 assert_eq!(result.commits[0].message, "feat: add foo");
591 assert!(result.commits[0].files.contains(&"shared.rs".to_string()));
592 assert!(result.commits[0].files.contains(&"a.rs".to_string()));
593 assert!(result.commits[0].files.contains(&"b.rs".to_string()));
594 assert_eq!(result.commits[1].message, "docs: update readme");
595 assert_eq!(result.commits[1].order, Some(2));
596 }
597}