Skip to main content

sr_ai/commands/
commit.rs

1use crate::ai::{AiEvent, AiRequest, BackendConfig, resolve_backend};
2use crate::cache::{CacheLookup, CacheManager};
3use crate::git::{GitRepo, SnapshotGuard};
4use crate::ui;
5use anyhow::{Context, Result, bail};
6use indicatif::ProgressBar;
7use regex::Regex;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use tokio::sync::mpsc;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct CommitPlan {
14    pub commits: Vec<PlannedCommit>,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct PlannedCommit {
19    pub order: Option<u32>,
20    pub message: String,
21    pub body: Option<String>,
22    pub footer: Option<String>,
23    pub files: Vec<String>,
24}
25
26#[derive(Debug, clap::Args)]
27pub struct CommitArgs {
28    /// Only analyze staged changes
29    #[arg(short, long)]
30    pub staged: bool,
31
32    /// Additional context or instructions for commit generation
33    #[arg(short = 'M', long)]
34    pub message: Option<String>,
35
36    /// Display plan without executing
37    #[arg(short = 'n', long)]
38    pub dry_run: bool,
39
40    /// Skip confirmation prompt
41    #[arg(short, long)]
42    pub yes: bool,
43
44    /// Bypass cache (always call AI)
45    #[arg(long)]
46    pub no_cache: bool,
47}
48
49const COMMIT_SCHEMA: &str = r#"{
50    "type": "object",
51    "properties": {
52        "commits": {
53            "type": "array",
54            "items": {
55                "type": "object",
56                "properties": {
57                    "order": { "type": "integer" },
58                    "message": { "type": "string", "description": "Header: type(scope): subject — imperative, lowercase, no period, max 72 chars" },
59                    "body": { "type": "string", "description": "Body: explain WHY the change was made, wrap at 72 chars" },
60                    "footer": { "type": "string", "description": "Footer: BREAKING CHANGE notes, Closes/Fixes/Refs #issue, etc." },
61                    "files": { "type": "array", "items": { "type": "string" } }
62                },
63                "required": ["order", "message", "body", "files"]
64            }
65        }
66    },
67    "required": ["commits"]
68}"#;
69
70fn build_system_prompt(commit_pattern: &str, type_names: &[&str]) -> String {
71    let types_list = type_names.join(", ");
72    format!(
73        r#"You are an expert at analyzing git diffs and creating atomic, well-organized commits following the Angular Conventional Commits standard.
74
75HEADER ("message" field):
76- Must match this regex: {commit_pattern}
77- Format: type(scope): subject
78- Valid types ONLY: {types_list}
79- NEVER invent types. Words like db, auth, api, etc. are scopes, not types. Use the semantically correct type for the change (e.g. feat(db): add user cache migration, fix(auth): resolve token expiry)
80- scope is optional but recommended when applicable
81- subject: imperative mood, lowercase first letter, no period at end, max 72 chars
82
83BODY ("body" field — required):
84- Explain WHY the change was made, not what changed (the diff shows that)
85- Use imperative tense ("add" not "added")
86- Wrap at 72 characters
87
88FOOTER ("footer" field — optional):
89- BREAKING CHANGE: description of what breaks and migration path
90- Closes #N, Fixes #N, Refs #N for issue references
91- Only include when relevant
92
93COMMIT ORGANIZATION:
94- Each commit must be atomic: one logical change per commit
95- Every changed file must appear in exactly one commit
96- CRITICAL: A file must NEVER appear in more than one commit. The execution engine stages entire files, not individual hunks. Splitting one file across commits will fail.
97- If one file contains multiple logical changes, place it in the most fitting commit and note the secondary changes in that commit's body.
98- Order: infrastructure/config -> core library -> features -> tests -> docs
99- File paths must be relative to the repository root and match exactly as git reports them"#
100    )
101}
102
103enum CacheStatus {
104    /// No cache used (--no-cache, or cache unavailable)
105    None,
106    /// Exact cache hit
107    Cached,
108    /// Incremental hit
109    Incremental,
110}
111
112pub async fn run(args: &CommitArgs, backend_config: &BackendConfig) -> Result<()> {
113    ui::header("sr commit");
114
115    // Phase 1: Discover repository
116    let repo = GitRepo::discover()?;
117    ui::phase_ok("Repository found", None);
118
119    // Load project config for commit types and pattern
120    let config = sr_core::config::ReleaseConfig::find_config(repo.root().as_path())
121        .map(|(path, _)| sr_core::config::ReleaseConfig::load(&path))
122        .transpose()?
123        .unwrap_or_default();
124    let type_names: Vec<&str> = config.types.iter().map(|t| t.name.as_str()).collect();
125    let system_prompt = build_system_prompt(&config.commit_pattern, &type_names);
126
127    // Phase 2: Check for changes
128    let has_changes = if args.staged {
129        repo.has_staged_changes()?
130    } else {
131        repo.has_any_changes()?
132    };
133
134    if !has_changes {
135        bail!(crate::error::SrAiError::NoChanges);
136    }
137
138    let statuses = repo.file_statuses().unwrap_or_default();
139    let file_count = statuses.len();
140    ui::phase_ok(
141        "Changes detected",
142        Some(&format!(
143            "{file_count} file{}",
144            if file_count == 1 { "" } else { "s" }
145        )),
146    );
147
148    // Phase 3: Resolve AI backend
149    let backend = resolve_backend(backend_config).await?;
150    let backend_name = backend.name().to_string();
151    let model_name = backend_config
152        .model
153        .as_deref()
154        .unwrap_or("default")
155        .to_string();
156    ui::phase_ok(
157        "Backend resolved",
158        Some(&format!("{backend_name} ({model_name})")),
159    );
160
161    // Build cache manager (may be None if cache dir unavailable)
162    let cache = if args.no_cache {
163        None
164    } else {
165        CacheManager::new(
166            repo.root(),
167            args.staged,
168            args.message.as_deref(),
169            &backend_name,
170            &model_name,
171        )
172    };
173
174    // Snapshot the working tree before the agent runs.
175    // If anything goes wrong (agent failure, unexpected mutations),
176    // the guard restores the working tree from the snapshot on drop.
177    let snapshot = SnapshotGuard::new(&repo)?;
178    ui::phase_ok("Working tree snapshot saved", None);
179
180    // Phase 4: Generate plan (cache or AI)
181    let (mut plan, cache_status) = match cache.as_ref().map(|c| c.lookup()) {
182        Some(CacheLookup::ExactHit(cached_plan)) => {
183            ui::phase_ok(
184                "Plan loaded",
185                Some(&format!("{} commits · cached", cached_plan.commits.len())),
186            );
187            (cached_plan, CacheStatus::Cached)
188        }
189        Some(CacheLookup::IncrementalHit {
190            previous_plan,
191            delta_summary,
192        }) => {
193            let spinner = ui::spinner(&format!(
194                "Analyzing changes with {backend_name} (incremental)..."
195            ));
196            let (tx, event_handler) = spawn_event_handler(&spinner);
197
198            let user_prompt =
199                build_incremental_prompt(args, &repo, &previous_plan, &delta_summary)?;
200
201            let request = AiRequest {
202                system_prompt: system_prompt.clone(),
203                user_prompt,
204                json_schema: Some(COMMIT_SCHEMA.to_string()),
205                working_dir: repo.root().to_string_lossy().to_string(),
206            };
207
208            let response = backend.request(&request, Some(tx)).await?;
209            let _ = event_handler.await;
210
211            let p: CommitPlan = parse_plan(&response.text)?;
212
213            let detail = format_done_detail(p.commits.len(), "incremental", &response.usage);
214            ui::spinner_done(&spinner, Some(&detail));
215
216            (p, CacheStatus::Incremental)
217        }
218        _ => {
219            let spinner = ui::spinner(&format!("Analyzing changes with {backend_name}..."));
220            let (tx, event_handler) = spawn_event_handler(&spinner);
221
222            let user_prompt = build_user_prompt(args, &repo)?;
223
224            let request = AiRequest {
225                system_prompt: system_prompt.clone(),
226                user_prompt,
227                json_schema: Some(COMMIT_SCHEMA.to_string()),
228                working_dir: repo.root().to_string_lossy().to_string(),
229            };
230
231            let response = backend.request(&request, Some(tx)).await?;
232            let _ = event_handler.await;
233
234            let p: CommitPlan = parse_plan(&response.text)?;
235
236            let detail = format_done_detail(p.commits.len(), "", &response.usage);
237            ui::spinner_done(&spinner, Some(&detail));
238
239            (p, CacheStatus::None)
240        }
241    };
242
243    if plan.commits.is_empty() {
244        bail!(crate::error::SrAiError::EmptyPlan);
245    }
246
247    // Validate: merge commits with shared files
248    let pre_validate_count = plan.commits.len();
249    plan = validate_plan(plan);
250    if plan.commits.len() < pre_validate_count {
251        ui::warn(&format!(
252            "Shared files detected — merged {} commits into 1",
253            pre_validate_count - plan.commits.len() + 1
254        ));
255    }
256
257    // Store in cache (before display/execute so dry-runs populate cache too)
258    if let Some(cache) = &cache {
259        cache.store(&plan, &backend_name, &model_name);
260    }
261
262    // Display plan
263    let cache_label: Option<&str> = match &cache_status {
264        CacheStatus::Cached => Some("cached"),
265        CacheStatus::Incremental => Some("incremental"),
266        CacheStatus::None => None,
267    };
268    ui::display_plan(&plan, &statuses, cache_label);
269
270    if args.dry_run {
271        ui::info("Dry run — no commits created");
272        println!();
273        snapshot.success();
274        return Ok(());
275    }
276
277    // Confirm
278    if !args.yes && !ui::confirm("Execute plan? [y/N]")? {
279        bail!(crate::error::SrAiError::Cancelled);
280    }
281
282    // Pre-validate commit messages against the configured pattern
283    let invalid = validate_messages(&plan, &config.commit_pattern);
284    if !invalid.is_empty() {
285        ui::invalid_messages(&invalid);
286        if !args.yes && !ui::confirm("Continue anyway? Invalid commits will likely fail. [y/N]")? {
287            bail!(crate::error::SrAiError::Cancelled);
288        }
289    }
290
291    // Execute
292    execute_plan(&repo, &plan)?;
293
294    // All commits succeeded (or at least some did) — clear the snapshot
295    snapshot.success();
296
297    Ok(())
298}
299
300fn build_user_prompt(args: &CommitArgs, repo: &GitRepo) -> Result<String> {
301    let git_root = repo.root().to_string_lossy();
302
303    let mut prompt = if args.staged {
304        "Analyze the staged git changes and group them into atomic commits.\n\
305         Use `git diff --cached` and `git diff --cached --stat` to inspect what's staged."
306            .to_string()
307    } else {
308        "Analyze all git changes (staged, unstaged, and untracked) and group them into atomic commits.\n\
309         Use `git diff HEAD`, `git diff --cached`, `git diff`, `git status --porcelain`, and \
310         `git ls-files --others --exclude-standard` to inspect changes."
311            .to_string()
312    };
313
314    prompt.push_str(&format!("\nThe git repository root is: {git_root}"));
315
316    if let Some(msg) = &args.message {
317        prompt.push_str(&format!("\n\nAdditional context from the user:\n{msg}"));
318    }
319
320    Ok(prompt)
321}
322
323fn build_incremental_prompt(
324    args: &CommitArgs,
325    repo: &GitRepo,
326    previous_plan: &CommitPlan,
327    delta_summary: &str,
328) -> Result<String> {
329    let mut prompt = build_user_prompt(args, repo)?;
330
331    let previous_json =
332        serde_json::to_string_pretty(previous_plan).unwrap_or_else(|_| "{}".to_string());
333
334    prompt.push_str(&format!(
335        "\n\n--- INCREMENTAL HINTS ---\n\
336         A previous commit plan exists for a similar set of changes. \
337         Maintain the groupings for unchanged files where possible. \
338         Only re-analyze files that have changed.\n\n\
339         Previous plan:\n```json\n{previous_json}\n```\n\n\
340         File delta:\n{delta_summary}"
341    ));
342
343    Ok(prompt)
344}
345
346/// Validate that no file appears in multiple commits. If duplicates are found,
347/// merge affected commits into one.
348fn validate_plan(plan: CommitPlan) -> CommitPlan {
349    // Count file occurrences
350    let mut file_counts: HashMap<String, usize> = HashMap::new();
351    for commit in &plan.commits {
352        for file in &commit.files {
353            *file_counts.entry(file.clone()).or_default() += 1;
354        }
355    }
356
357    let dupes: Vec<&String> = file_counts
358        .iter()
359        .filter(|(_, count)| **count > 1)
360        .map(|(file, _)| file)
361        .collect();
362
363    if dupes.is_empty() {
364        return plan;
365    }
366
367    // Partition into tainted (has any dupe file) and clean
368    let mut tainted = Vec::new();
369    let mut clean = Vec::new();
370
371    for commit in plan.commits {
372        let is_tainted = commit.files.iter().any(|f| dupes.contains(&f));
373        if is_tainted {
374            tainted.push(commit);
375        } else {
376            clean.push(commit);
377        }
378    }
379
380    // Merge all tainted commits into one
381    let merged_message = tainted
382        .first()
383        .map(|c| c.message.clone())
384        .unwrap_or_default();
385
386    let merged_body = tainted
387        .iter()
388        .filter_map(|c| c.body.as_ref())
389        .filter(|b| !b.is_empty())
390        .cloned()
391        .collect::<Vec<_>>()
392        .join("\n\n");
393
394    let merged_footer = tainted
395        .iter()
396        .filter_map(|c| c.footer.as_ref())
397        .filter(|f| !f.is_empty())
398        .cloned()
399        .collect::<Vec<_>>()
400        .join("\n");
401
402    let mut merged_files: Vec<String> = tainted
403        .iter()
404        .flat_map(|c| c.files.iter().cloned())
405        .collect();
406    merged_files.sort();
407    merged_files.dedup();
408
409    let merged_commit = PlannedCommit {
410        order: Some(1),
411        message: merged_message,
412        body: if merged_body.is_empty() {
413            None
414        } else {
415            Some(merged_body)
416        },
417        footer: if merged_footer.is_empty() {
418            None
419        } else {
420            Some(merged_footer)
421        },
422        files: merged_files,
423    };
424
425    // Re-number: merged first, then clean commits
426    let mut result = vec![merged_commit];
427    for (i, mut commit) in clean.into_iter().enumerate() {
428        commit.order = Some(i as u32 + 2);
429        result.push(commit);
430    }
431
432    CommitPlan { commits: result }
433}
434
435/// Parse a commit plan from JSON text, tolerating duplicate fields.
436fn parse_plan(text: &str) -> Result<CommitPlan> {
437    // Parse to Value first — serde_json::Value keeps the last value for duplicate keys,
438    // while #[derive(Deserialize)] rejects them. This handles AI responses that
439    // occasionally produce duplicate fields when schema is embedded in the prompt.
440    let value: serde_json::Value =
441        serde_json::from_str(text).context("failed to parse JSON from AI response")?;
442    serde_json::from_value(value).context("failed to parse commit plan from AI response")
443}
444
445/// Spawn a background task that renders AI events (tool calls) above a spinner.
446fn spawn_event_handler(
447    spinner: &ProgressBar,
448) -> (mpsc::UnboundedSender<AiEvent>, tokio::task::JoinHandle<()>) {
449    let (tx, mut rx) = mpsc::unbounded_channel();
450    let pb = spinner.clone();
451    let handle = tokio::spawn(async move {
452        while let Some(event) = rx.recv().await {
453            match event {
454                AiEvent::ToolCall { input, .. } => ui::tool_call(&pb, &input),
455            }
456        }
457    });
458    (tx, handle)
459}
460
461/// Format the detail string for spinner_done, including usage if available.
462fn format_done_detail(
463    commit_count: usize,
464    extra: &str,
465    usage: &Option<crate::ai::AiUsage>,
466) -> String {
467    let commits = format!(
468        "{commit_count} commit{}",
469        if commit_count == 1 { "" } else { "s" }
470    );
471    let extra_part = if extra.is_empty() {
472        String::new()
473    } else {
474        format!(" · {extra}")
475    };
476    let usage_part = match usage {
477        Some(u) => {
478            let cost = u
479                .cost_usd
480                .map(|c| format!(" · ${c:.4}"))
481                .unwrap_or_default();
482            format!(
483                " · {} in / {} out{}",
484                ui::format_tokens(u.input_tokens),
485                ui::format_tokens(u.output_tokens),
486                cost
487            )
488        }
489        None => String::new(),
490    };
491    format!("{commits}{extra_part}{usage_part}")
492}
493
494/// Validate that all commit messages match the configured pattern.
495/// Returns a list of (index, message, error) for invalid commits.
496fn validate_messages(plan: &CommitPlan, commit_pattern: &str) -> Vec<(usize, String, String)> {
497    let re = match Regex::new(commit_pattern) {
498        Ok(re) => re,
499        Err(e) => {
500            // If the pattern itself is invalid, report all commits as invalid
501            return plan
502                .commits
503                .iter()
504                .enumerate()
505                .map(|(i, c)| (i + 1, c.message.clone(), format!("invalid pattern: {e}")))
506                .collect();
507        }
508    };
509
510    plan.commits
511        .iter()
512        .enumerate()
513        .filter(|(_, c)| !re.is_match(&c.message))
514        .map(|(i, c)| {
515            (
516                i + 1,
517                c.message.clone(),
518                format!("does not match pattern: {commit_pattern}"),
519            )
520        })
521        .collect()
522}
523
524fn execute_plan(repo: &GitRepo, plan: &CommitPlan) -> Result<()> {
525    // Unstage everything first
526    repo.reset_head()?;
527
528    let total = plan.commits.len();
529    let mut created: Vec<(String, String)> = Vec::new();
530    let mut failed: Vec<(usize, String, String)> = Vec::new();
531
532    for (i, commit) in plan.commits.iter().enumerate() {
533        ui::commit_start(i + 1, total, &commit.message);
534
535        // Stage files for this commit
536        for file in &commit.files {
537            let ok = repo.stage_file(file)?;
538            ui::file_staged(file, ok);
539        }
540
541        // Build full commit message
542        let mut full_message = commit.message.clone();
543        if let Some(body) = &commit.body
544            && !body.is_empty()
545        {
546            full_message.push_str("\n\n");
547            full_message.push_str(body);
548        }
549        if let Some(footer) = &commit.footer
550            && !footer.is_empty()
551        {
552            full_message.push_str("\n\n");
553            full_message.push_str(footer);
554        }
555
556        // Create commit (only if there are staged files)
557        if repo.has_staged_after_add()? {
558            match repo.commit(&full_message) {
559                Ok(()) => {
560                    let sha = repo.head_short().unwrap_or_else(|_| "???????".to_string());
561                    ui::commit_created(&sha);
562                    created.push((sha, commit.message.clone()));
563                }
564                Err(e) => {
565                    ui::commit_failed(&format!("{e:#}"));
566                    failed.push((i + 1, commit.message.clone(), format!("{e:#}")));
567                    // Unstage files from the failed commit so the next commit starts clean
568                    repo.reset_head()?;
569                }
570            }
571        } else {
572            ui::commit_skipped();
573        }
574    }
575
576    ui::summary(&created);
577
578    if !failed.is_empty() {
579        ui::failed_commits(&failed);
580        if created.is_empty() {
581            bail!("all {} commits failed", failed.len());
582        }
583    }
584
585    Ok(())
586}
587
588#[cfg(test)]
589mod tests {
590    use super::*;
591
592    #[test]
593    fn validate_plan_no_dupes() {
594        let plan = CommitPlan {
595            commits: vec![
596                PlannedCommit {
597                    order: Some(1),
598                    message: "feat: add foo".into(),
599                    body: Some("reason".into()),
600                    footer: None,
601                    files: vec!["a.rs".into()],
602                },
603                PlannedCommit {
604                    order: Some(2),
605                    message: "fix: fix bar".into(),
606                    body: Some("reason".into()),
607                    footer: None,
608                    files: vec!["b.rs".into()],
609                },
610            ],
611        };
612
613        let result = validate_plan(plan);
614        assert_eq!(result.commits.len(), 2);
615    }
616
617    #[test]
618    fn validate_plan_merges_dupes() {
619        let plan = CommitPlan {
620            commits: vec![
621                PlannedCommit {
622                    order: Some(1),
623                    message: "feat: add foo".into(),
624                    body: Some("reason 1".into()),
625                    footer: None,
626                    files: vec!["shared.rs".into(), "a.rs".into()],
627                },
628                PlannedCommit {
629                    order: Some(2),
630                    message: "fix: fix bar".into(),
631                    body: Some("reason 2".into()),
632                    footer: None,
633                    files: vec!["shared.rs".into(), "b.rs".into()],
634                },
635                PlannedCommit {
636                    order: Some(3),
637                    message: "docs: update readme".into(),
638                    body: Some("docs".into()),
639                    footer: None,
640                    files: vec!["README.md".into()],
641                },
642            ],
643        };
644
645        let result = validate_plan(plan);
646        // Two tainted merged into one + one clean = 2
647        assert_eq!(result.commits.len(), 2);
648        assert_eq!(result.commits[0].message, "feat: add foo");
649        assert!(result.commits[0].files.contains(&"shared.rs".to_string()));
650        assert!(result.commits[0].files.contains(&"a.rs".to_string()));
651        assert!(result.commits[0].files.contains(&"b.rs".to_string()));
652        assert_eq!(result.commits[1].message, "docs: update readme");
653        assert_eq!(result.commits[1].order, Some(2));
654    }
655
656    #[test]
657    fn validate_messages_all_valid() {
658        let plan = CommitPlan {
659            commits: vec![
660                PlannedCommit {
661                    order: Some(1),
662                    message: "feat: add foo".into(),
663                    body: None,
664                    footer: None,
665                    files: vec![],
666                },
667                PlannedCommit {
668                    order: Some(2),
669                    message: "fix(core): null check".into(),
670                    body: None,
671                    footer: None,
672                    files: vec![],
673                },
674            ],
675        };
676
677        let pattern = sr_core::commit::DEFAULT_COMMIT_PATTERN;
678        let invalid = validate_messages(&plan, pattern);
679        assert!(invalid.is_empty());
680    }
681
682    #[test]
683    fn validate_messages_catches_invalid() {
684        let plan = CommitPlan {
685            commits: vec![
686                PlannedCommit {
687                    order: Some(1),
688                    message: "feat: add foo".into(),
689                    body: None,
690                    footer: None,
691                    files: vec![],
692                },
693                PlannedCommit {
694                    order: Some(2),
695                    message: "not a conventional commit".into(),
696                    body: None,
697                    footer: None,
698                    files: vec![],
699                },
700                PlannedCommit {
701                    order: Some(3),
702                    message: "fix: valid one".into(),
703                    body: None,
704                    footer: None,
705                    files: vec![],
706                },
707            ],
708        };
709
710        let pattern = sr_core::commit::DEFAULT_COMMIT_PATTERN;
711        let invalid = validate_messages(&plan, pattern);
712        assert_eq!(invalid.len(), 1);
713        assert_eq!(invalid[0].0, 2); // 1-indexed
714        assert_eq!(invalid[0].1, "not a conventional commit");
715    }
716
717    #[test]
718    fn validate_messages_invalid_pattern() {
719        let plan = CommitPlan {
720            commits: vec![PlannedCommit {
721                order: Some(1),
722                message: "feat: add foo".into(),
723                body: None,
724                footer: None,
725                files: vec![],
726            }],
727        };
728
729        let invalid = validate_messages(&plan, "[invalid regex");
730        assert_eq!(invalid.len(), 1);
731        assert!(invalid[0].2.contains("invalid pattern"));
732    }
733}