Skip to main content

sr_ai/commands/
commit.rs

1use crate::ai::{AiEvent, AiRequest, BackendConfig, resolve_backend};
2use crate::cache::{CacheLookup, CacheManager};
3use crate::git::{GitRepo, SnapshotGuard};
4use crate::ui;
5use anyhow::{Context, Result, bail};
6use indicatif::ProgressBar;
7use regex::Regex;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use tokio::sync::mpsc;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct CommitPlan {
14    pub commits: Vec<PlannedCommit>,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct PlannedCommit {
19    pub order: Option<u32>,
20    pub message: String,
21    pub body: Option<String>,
22    pub footer: Option<String>,
23    pub files: Vec<String>,
24}
25
26#[derive(Debug, clap::Args)]
27pub struct CommitArgs {
28    /// Only analyze staged changes
29    #[arg(short, long)]
30    pub staged: bool,
31
32    /// Additional context or instructions for commit generation
33    #[arg(short = 'M', long)]
34    pub message: Option<String>,
35
36    /// Display plan without executing
37    #[arg(short = 'n', long)]
38    pub dry_run: bool,
39
40    /// Skip confirmation prompt
41    #[arg(short, long)]
42    pub yes: bool,
43
44    /// Bypass cache (always call AI)
45    #[arg(long)]
46    pub no_cache: bool,
47}
48
49const COMMIT_SCHEMA: &str = r#"{
50    "type": "object",
51    "properties": {
52        "commits": {
53            "type": "array",
54            "items": {
55                "type": "object",
56                "properties": {
57                    "order": { "type": "integer" },
58                    "message": { "type": "string", "description": "Header: type(scope): subject — imperative, lowercase, no period, max 72 chars" },
59                    "body": { "type": "string", "description": "Body: explain WHY the change was made, wrap at 72 chars" },
60                    "footer": { "type": "string", "description": "Footer: BREAKING CHANGE notes, Closes/Fixes/Refs #issue, etc." },
61                    "files": { "type": "array", "items": { "type": "string" } }
62                },
63                "required": ["order", "message", "body", "files"]
64            }
65        }
66    },
67    "required": ["commits"]
68}"#;
69
70fn build_system_prompt(commit_pattern: &str, type_names: &[&str]) -> String {
71    let types_list = type_names.join(", ");
72    format!(
73        r#"You are an expert at analyzing git diffs and creating atomic, well-organized commits following the Angular Conventional Commits standard.
74
75HEADER ("message" field):
76- Must match this regex: {commit_pattern}
77- Format: type(scope): subject
78- Valid types ONLY: {types_list}
79- NEVER invent types. Words like db, auth, api, etc. are scopes, not types. Use the semantically correct type for the change (e.g. feat(db): add user cache migration, fix(auth): resolve token expiry)
80- scope is optional but recommended when applicable
81- subject: imperative mood, lowercase first letter, no period at end, max 72 chars
82
83BODY ("body" field — required):
84- Explain WHY the change was made, not what changed (the diff shows that)
85- Use imperative tense ("add" not "added")
86- Wrap at 72 characters
87
88FOOTER ("footer" field — optional):
89- BREAKING CHANGE: description of what breaks and migration path
90- Closes #N, Fixes #N, Refs #N for issue references
91- Only include when relevant
92
93COMMIT ORGANIZATION:
94- Each commit must be atomic: one logical change per commit
95- Every changed file must appear in exactly one commit
96- CRITICAL: A file must NEVER appear in more than one commit. The execution engine stages entire files, not individual hunks. Splitting one file across commits will fail.
97- If one file contains multiple logical changes, place it in the most fitting commit and note the secondary changes in that commit's body.
98- Order: infrastructure/config -> core library -> features -> tests -> docs
99- File paths must be relative to the repository root and match exactly as git reports them"#
100    )
101}
102
103enum CacheStatus {
104    /// No cache used (--no-cache, or cache unavailable)
105    None,
106    /// Exact cache hit
107    Cached,
108    /// Incremental hit
109    Incremental,
110}
111
112pub async fn run(args: &CommitArgs, backend_config: &BackendConfig) -> Result<()> {
113    ui::header("sr commit");
114
115    // Phase 1: Discover repository
116    let repo = GitRepo::discover()?;
117    ui::phase_ok("Repository found", None);
118
119    // Load project config for commit types and pattern
120    let config = sr_core::config::ReleaseConfig::find_config(repo.root().as_path())
121        .map(|(path, _)| sr_core::config::ReleaseConfig::load(&path))
122        .transpose()?
123        .unwrap_or_default();
124    let type_names: Vec<&str> = config.types.iter().map(|t| t.name.as_str()).collect();
125    let system_prompt = build_system_prompt(&config.commit_pattern, &type_names);
126
127    // Phase 2: Check for changes
128    let has_changes = if args.staged {
129        repo.has_staged_changes()?
130    } else {
131        repo.has_any_changes()?
132    };
133
134    if !has_changes {
135        bail!(crate::error::SrAiError::NoChanges);
136    }
137
138    let statuses = repo.file_statuses().unwrap_or_default();
139    let file_count = statuses.len();
140    ui::phase_ok(
141        "Changes detected",
142        Some(&format!(
143            "{file_count} file{}",
144            if file_count == 1 { "" } else { "s" }
145        )),
146    );
147
148    // Phase 3: Resolve AI backend
149    let backend = resolve_backend(backend_config).await?;
150    let backend_name = backend.name().to_string();
151    let model_name = backend_config
152        .model
153        .as_deref()
154        .unwrap_or("default")
155        .to_string();
156    ui::phase_ok(
157        "Backend resolved",
158        Some(&format!("{backend_name} ({model_name})")),
159    );
160
161    // Build cache manager (may be None if cache dir unavailable)
162    let cache = if args.no_cache {
163        None
164    } else {
165        CacheManager::new(
166            repo.root(),
167            args.staged,
168            args.message.as_deref(),
169            &backend_name,
170            &model_name,
171        )
172    };
173
174    // Snapshot the working tree before the agent runs.
175    // If anything goes wrong (agent failure, unexpected mutations),
176    // the guard restores the working tree from the snapshot on drop.
177    let snapshot = SnapshotGuard::new(&repo)?;
178    ui::phase_ok("Working tree snapshot saved", None);
179
180    // Phase 4: Generate plan (cache or AI)
181    let (mut plan, cache_status) = match cache.as_ref().map(|c| c.lookup()) {
182        Some(CacheLookup::ExactHit(cached_plan)) => {
183            ui::phase_ok(
184                "Plan loaded",
185                Some(&format!("{} commits · cached", cached_plan.commits.len())),
186            );
187            (cached_plan, CacheStatus::Cached)
188        }
189        Some(CacheLookup::IncrementalHit {
190            previous_plan,
191            delta_summary,
192        }) => {
193            let spinner = ui::spinner(&format!(
194                "Analyzing changes with {backend_name} (incremental)..."
195            ));
196            let (tx, event_handler) = spawn_event_handler(&spinner);
197
198            let user_prompt =
199                build_incremental_prompt(args, &repo, &previous_plan, &delta_summary)?;
200
201            let request = AiRequest {
202                system_prompt: system_prompt.clone(),
203                user_prompt,
204                json_schema: Some(COMMIT_SCHEMA.to_string()),
205                working_dir: repo.root().to_string_lossy().to_string(),
206            };
207
208            let response = backend.request(&request, Some(tx)).await?;
209            let _ = event_handler.await;
210
211            let p: CommitPlan = parse_plan(&response.text)?;
212
213            let detail = format_done_detail(p.commits.len(), "incremental", &response.usage);
214            ui::spinner_done(&spinner, Some(&detail));
215
216            (p, CacheStatus::Incremental)
217        }
218        _ => {
219            let spinner = ui::spinner(&format!("Analyzing changes with {backend_name}..."));
220            let (tx, event_handler) = spawn_event_handler(&spinner);
221
222            let user_prompt = build_user_prompt(args, &repo)?;
223
224            let request = AiRequest {
225                system_prompt: system_prompt.clone(),
226                user_prompt,
227                json_schema: Some(COMMIT_SCHEMA.to_string()),
228                working_dir: repo.root().to_string_lossy().to_string(),
229            };
230
231            let response = backend.request(&request, Some(tx)).await?;
232            let _ = event_handler.await;
233
234            let p: CommitPlan = parse_plan(&response.text)?;
235
236            let detail = format_done_detail(p.commits.len(), "", &response.usage);
237            ui::spinner_done(&spinner, Some(&detail));
238
239            (p, CacheStatus::None)
240        }
241    };
242
243    if plan.commits.is_empty() {
244        bail!(crate::error::SrAiError::EmptyPlan);
245    }
246
247    // Validate: merge commits with shared files
248    let pre_validate_count = plan.commits.len();
249    plan = validate_plan(plan);
250    if plan.commits.len() < pre_validate_count {
251        ui::warn(&format!(
252            "Shared files detected — merged {} commits into 1",
253            pre_validate_count - plan.commits.len() + 1
254        ));
255    }
256
257    // Store in cache (before display/execute so dry-runs populate cache too)
258    if let Some(cache) = &cache {
259        cache.store(&plan, &backend_name, &model_name);
260    }
261
262    // Display plan
263    let cache_label: Option<&str> = match &cache_status {
264        CacheStatus::Cached => Some("cached"),
265        CacheStatus::Incremental => Some("incremental"),
266        CacheStatus::None => None,
267    };
268    ui::display_plan(&plan, &statuses, cache_label);
269
270    if args.dry_run {
271        ui::info("Dry run — no commits created");
272        println!();
273        return Ok(());
274    }
275
276    // Confirm
277    if !args.yes && !ui::confirm("Execute plan? [y/N]")? {
278        bail!(crate::error::SrAiError::Cancelled);
279    }
280
281    // Pre-validate commit messages against the configured pattern
282    let invalid = validate_messages(&plan, &config.commit_pattern);
283    if !invalid.is_empty() {
284        ui::invalid_messages(&invalid);
285        if !args.yes && !ui::confirm("Continue anyway? Invalid commits will likely fail. [y/N]")? {
286            bail!(crate::error::SrAiError::Cancelled);
287        }
288    }
289
290    // Execute
291    execute_plan(&repo, &plan)?;
292
293    // All commits succeeded (or at least some did) — clear the snapshot
294    snapshot.success();
295
296    Ok(())
297}
298
299fn build_user_prompt(args: &CommitArgs, repo: &GitRepo) -> Result<String> {
300    let git_root = repo.root().to_string_lossy();
301
302    let mut prompt = if args.staged {
303        "Analyze the staged git changes and group them into atomic commits.\n\
304         Use `git diff --cached` and `git diff --cached --stat` to inspect what's staged."
305            .to_string()
306    } else {
307        "Analyze all git changes (staged, unstaged, and untracked) and group them into atomic commits.\n\
308         Use `git diff HEAD`, `git diff --cached`, `git diff`, `git status --porcelain`, and \
309         `git ls-files --others --exclude-standard` to inspect changes."
310            .to_string()
311    };
312
313    prompt.push_str(&format!("\nThe git repository root is: {git_root}"));
314
315    if let Some(msg) = &args.message {
316        prompt.push_str(&format!("\n\nAdditional context from the user:\n{msg}"));
317    }
318
319    Ok(prompt)
320}
321
322fn build_incremental_prompt(
323    args: &CommitArgs,
324    repo: &GitRepo,
325    previous_plan: &CommitPlan,
326    delta_summary: &str,
327) -> Result<String> {
328    let mut prompt = build_user_prompt(args, repo)?;
329
330    let previous_json =
331        serde_json::to_string_pretty(previous_plan).unwrap_or_else(|_| "{}".to_string());
332
333    prompt.push_str(&format!(
334        "\n\n--- INCREMENTAL HINTS ---\n\
335         A previous commit plan exists for a similar set of changes. \
336         Maintain the groupings for unchanged files where possible. \
337         Only re-analyze files that have changed.\n\n\
338         Previous plan:\n```json\n{previous_json}\n```\n\n\
339         File delta:\n{delta_summary}"
340    ));
341
342    Ok(prompt)
343}
344
345/// Validate that no file appears in multiple commits. If duplicates are found,
346/// merge affected commits into one.
347fn validate_plan(plan: CommitPlan) -> CommitPlan {
348    // Count file occurrences
349    let mut file_counts: HashMap<String, usize> = HashMap::new();
350    for commit in &plan.commits {
351        for file in &commit.files {
352            *file_counts.entry(file.clone()).or_default() += 1;
353        }
354    }
355
356    let dupes: Vec<&String> = file_counts
357        .iter()
358        .filter(|(_, count)| **count > 1)
359        .map(|(file, _)| file)
360        .collect();
361
362    if dupes.is_empty() {
363        return plan;
364    }
365
366    // Partition into tainted (has any dupe file) and clean
367    let mut tainted = Vec::new();
368    let mut clean = Vec::new();
369
370    for commit in plan.commits {
371        let is_tainted = commit.files.iter().any(|f| dupes.contains(&f));
372        if is_tainted {
373            tainted.push(commit);
374        } else {
375            clean.push(commit);
376        }
377    }
378
379    // Merge all tainted commits into one
380    let merged_message = tainted
381        .first()
382        .map(|c| c.message.clone())
383        .unwrap_or_default();
384
385    let merged_body = tainted
386        .iter()
387        .filter_map(|c| c.body.as_ref())
388        .filter(|b| !b.is_empty())
389        .cloned()
390        .collect::<Vec<_>>()
391        .join("\n\n");
392
393    let merged_footer = tainted
394        .iter()
395        .filter_map(|c| c.footer.as_ref())
396        .filter(|f| !f.is_empty())
397        .cloned()
398        .collect::<Vec<_>>()
399        .join("\n");
400
401    let mut merged_files: Vec<String> = tainted
402        .iter()
403        .flat_map(|c| c.files.iter().cloned())
404        .collect();
405    merged_files.sort();
406    merged_files.dedup();
407
408    let merged_commit = PlannedCommit {
409        order: Some(1),
410        message: merged_message,
411        body: if merged_body.is_empty() {
412            None
413        } else {
414            Some(merged_body)
415        },
416        footer: if merged_footer.is_empty() {
417            None
418        } else {
419            Some(merged_footer)
420        },
421        files: merged_files,
422    };
423
424    // Re-number: merged first, then clean commits
425    let mut result = vec![merged_commit];
426    for (i, mut commit) in clean.into_iter().enumerate() {
427        commit.order = Some(i as u32 + 2);
428        result.push(commit);
429    }
430
431    CommitPlan { commits: result }
432}
433
434/// Parse a commit plan from JSON text, tolerating duplicate fields.
435fn parse_plan(text: &str) -> Result<CommitPlan> {
436    // Parse to Value first — serde_json::Value keeps the last value for duplicate keys,
437    // while #[derive(Deserialize)] rejects them. This handles AI responses that
438    // occasionally produce duplicate fields when schema is embedded in the prompt.
439    let value: serde_json::Value =
440        serde_json::from_str(text).context("failed to parse JSON from AI response")?;
441    serde_json::from_value(value).context("failed to parse commit plan from AI response")
442}
443
444/// Spawn a background task that renders AI events (tool calls) above a spinner.
445fn spawn_event_handler(
446    spinner: &ProgressBar,
447) -> (mpsc::UnboundedSender<AiEvent>, tokio::task::JoinHandle<()>) {
448    let (tx, mut rx) = mpsc::unbounded_channel();
449    let pb = spinner.clone();
450    let handle = tokio::spawn(async move {
451        while let Some(event) = rx.recv().await {
452            match event {
453                AiEvent::ToolCall { input, .. } => ui::tool_call(&pb, &input),
454            }
455        }
456    });
457    (tx, handle)
458}
459
460/// Format the detail string for spinner_done, including usage if available.
461fn format_done_detail(
462    commit_count: usize,
463    extra: &str,
464    usage: &Option<crate::ai::AiUsage>,
465) -> String {
466    let commits = format!(
467        "{commit_count} commit{}",
468        if commit_count == 1 { "" } else { "s" }
469    );
470    let extra_part = if extra.is_empty() {
471        String::new()
472    } else {
473        format!(" · {extra}")
474    };
475    let usage_part = match usage {
476        Some(u) => {
477            let cost = u
478                .cost_usd
479                .map(|c| format!(" · ${c:.4}"))
480                .unwrap_or_default();
481            format!(
482                " · {} in / {} out{}",
483                ui::format_tokens(u.input_tokens),
484                ui::format_tokens(u.output_tokens),
485                cost
486            )
487        }
488        None => String::new(),
489    };
490    format!("{commits}{extra_part}{usage_part}")
491}
492
493/// Validate that all commit messages match the configured pattern.
494/// Returns a list of (index, message, error) for invalid commits.
495fn validate_messages(plan: &CommitPlan, commit_pattern: &str) -> Vec<(usize, String, String)> {
496    let re = match Regex::new(commit_pattern) {
497        Ok(re) => re,
498        Err(e) => {
499            // If the pattern itself is invalid, report all commits as invalid
500            return plan
501                .commits
502                .iter()
503                .enumerate()
504                .map(|(i, c)| (i + 1, c.message.clone(), format!("invalid pattern: {e}")))
505                .collect();
506        }
507    };
508
509    plan.commits
510        .iter()
511        .enumerate()
512        .filter(|(_, c)| !re.is_match(&c.message))
513        .map(|(i, c)| {
514            (
515                i + 1,
516                c.message.clone(),
517                format!("does not match pattern: {commit_pattern}"),
518            )
519        })
520        .collect()
521}
522
523fn execute_plan(repo: &GitRepo, plan: &CommitPlan) -> Result<()> {
524    // Unstage everything first
525    repo.reset_head()?;
526
527    let total = plan.commits.len();
528    let mut created: Vec<(String, String)> = Vec::new();
529    let mut failed: Vec<(usize, String, String)> = Vec::new();
530
531    for (i, commit) in plan.commits.iter().enumerate() {
532        ui::commit_start(i + 1, total, &commit.message);
533
534        // Stage files for this commit
535        for file in &commit.files {
536            let ok = repo.stage_file(file)?;
537            ui::file_staged(file, ok);
538        }
539
540        // Build full commit message
541        let mut full_message = commit.message.clone();
542        if let Some(body) = &commit.body
543            && !body.is_empty()
544        {
545            full_message.push_str("\n\n");
546            full_message.push_str(body);
547        }
548        if let Some(footer) = &commit.footer
549            && !footer.is_empty()
550        {
551            full_message.push_str("\n\n");
552            full_message.push_str(footer);
553        }
554
555        // Create commit (only if there are staged files)
556        if repo.has_staged_after_add()? {
557            match repo.commit(&full_message) {
558                Ok(()) => {
559                    let sha = repo.head_short().unwrap_or_else(|_| "???????".to_string());
560                    ui::commit_created(&sha);
561                    created.push((sha, commit.message.clone()));
562                }
563                Err(e) => {
564                    ui::commit_failed(&format!("{e:#}"));
565                    failed.push((i + 1, commit.message.clone(), format!("{e:#}")));
566                    // Unstage files from the failed commit so the next commit starts clean
567                    repo.reset_head()?;
568                }
569            }
570        } else {
571            ui::commit_skipped();
572        }
573    }
574
575    ui::summary(&created);
576
577    if !failed.is_empty() {
578        ui::failed_commits(&failed);
579        if created.is_empty() {
580            bail!("all {} commits failed", failed.len());
581        }
582    }
583
584    Ok(())
585}
586
587#[cfg(test)]
588mod tests {
589    use super::*;
590
591    #[test]
592    fn validate_plan_no_dupes() {
593        let plan = CommitPlan {
594            commits: vec![
595                PlannedCommit {
596                    order: Some(1),
597                    message: "feat: add foo".into(),
598                    body: Some("reason".into()),
599                    footer: None,
600                    files: vec!["a.rs".into()],
601                },
602                PlannedCommit {
603                    order: Some(2),
604                    message: "fix: fix bar".into(),
605                    body: Some("reason".into()),
606                    footer: None,
607                    files: vec!["b.rs".into()],
608                },
609            ],
610        };
611
612        let result = validate_plan(plan);
613        assert_eq!(result.commits.len(), 2);
614    }
615
616    #[test]
617    fn validate_plan_merges_dupes() {
618        let plan = CommitPlan {
619            commits: vec![
620                PlannedCommit {
621                    order: Some(1),
622                    message: "feat: add foo".into(),
623                    body: Some("reason 1".into()),
624                    footer: None,
625                    files: vec!["shared.rs".into(), "a.rs".into()],
626                },
627                PlannedCommit {
628                    order: Some(2),
629                    message: "fix: fix bar".into(),
630                    body: Some("reason 2".into()),
631                    footer: None,
632                    files: vec!["shared.rs".into(), "b.rs".into()],
633                },
634                PlannedCommit {
635                    order: Some(3),
636                    message: "docs: update readme".into(),
637                    body: Some("docs".into()),
638                    footer: None,
639                    files: vec!["README.md".into()],
640                },
641            ],
642        };
643
644        let result = validate_plan(plan);
645        // Two tainted merged into one + one clean = 2
646        assert_eq!(result.commits.len(), 2);
647        assert_eq!(result.commits[0].message, "feat: add foo");
648        assert!(result.commits[0].files.contains(&"shared.rs".to_string()));
649        assert!(result.commits[0].files.contains(&"a.rs".to_string()));
650        assert!(result.commits[0].files.contains(&"b.rs".to_string()));
651        assert_eq!(result.commits[1].message, "docs: update readme");
652        assert_eq!(result.commits[1].order, Some(2));
653    }
654
655    #[test]
656    fn validate_messages_all_valid() {
657        let plan = CommitPlan {
658            commits: vec![
659                PlannedCommit {
660                    order: Some(1),
661                    message: "feat: add foo".into(),
662                    body: None,
663                    footer: None,
664                    files: vec![],
665                },
666                PlannedCommit {
667                    order: Some(2),
668                    message: "fix(core): null check".into(),
669                    body: None,
670                    footer: None,
671                    files: vec![],
672                },
673            ],
674        };
675
676        let pattern = sr_core::commit::DEFAULT_COMMIT_PATTERN;
677        let invalid = validate_messages(&plan, pattern);
678        assert!(invalid.is_empty());
679    }
680
681    #[test]
682    fn validate_messages_catches_invalid() {
683        let plan = CommitPlan {
684            commits: vec![
685                PlannedCommit {
686                    order: Some(1),
687                    message: "feat: add foo".into(),
688                    body: None,
689                    footer: None,
690                    files: vec![],
691                },
692                PlannedCommit {
693                    order: Some(2),
694                    message: "not a conventional commit".into(),
695                    body: None,
696                    footer: None,
697                    files: vec![],
698                },
699                PlannedCommit {
700                    order: Some(3),
701                    message: "fix: valid one".into(),
702                    body: None,
703                    footer: None,
704                    files: vec![],
705                },
706            ],
707        };
708
709        let pattern = sr_core::commit::DEFAULT_COMMIT_PATTERN;
710        let invalid = validate_messages(&plan, pattern);
711        assert_eq!(invalid.len(), 1);
712        assert_eq!(invalid[0].0, 2); // 1-indexed
713        assert_eq!(invalid[0].1, "not a conventional commit");
714    }
715
716    #[test]
717    fn validate_messages_invalid_pattern() {
718        let plan = CommitPlan {
719            commits: vec![PlannedCommit {
720                order: Some(1),
721                message: "feat: add foo".into(),
722                body: None,
723                footer: None,
724                files: vec![],
725            }],
726        };
727
728        let invalid = validate_messages(&plan, "[invalid regex");
729        assert_eq!(invalid.len(), 1);
730        assert!(invalid[0].2.contains("invalid pattern"));
731    }
732}