Skip to main content

sr_ai/commands/
rebase.rs

1use crate::ai::{AiEvent, AiRequest, BackendConfig, resolve_backend};
2use crate::git::GitRepo;
3use crate::ui;
4use anyhow::{Context, Result, bail};
5use indicatif::ProgressBar;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use tokio::sync::mpsc;
9
10#[derive(Debug, clap::Args)]
11pub struct RebaseArgs {
12    /// Additional context or instructions for reorganization
13    #[arg(short = 'M', long)]
14    pub message: Option<String>,
15
16    /// Display plan without executing
17    #[arg(short = 'n', long)]
18    pub dry_run: bool,
19
20    /// Skip confirmation prompt
21    #[arg(short, long)]
22    pub yes: bool,
23
24    /// Number of recent commits to reorganize (default: auto-detect since last tag)
25    #[arg(long)]
26    pub last: Option<usize>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ReorganizePlan {
31    pub commits: Vec<ReorganizedCommit>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct ReorganizedCommit {
36    /// Original SHA (short) — use "squash" to fold into the previous commit
37    pub original_sha: String,
38    /// Action: "pick", "reword", "squash", "drop"
39    pub action: String,
40    /// New commit message (required for pick/reword/squash)
41    pub message: String,
42    pub body: Option<String>,
43    pub footer: Option<String>,
44}
45
46const REORGANIZE_SCHEMA: &str = r#"{
47    "type": "object",
48    "properties": {
49        "commits": {
50            "type": "array",
51            "items": {
52                "type": "object",
53                "properties": {
54                    "original_sha": { "type": "string", "description": "Short SHA of the original commit" },
55                    "action": { "type": "string", "enum": ["pick", "reword", "squash", "drop"], "description": "Rebase action" },
56                    "message": { "type": "string", "description": "New commit message header (type(scope): subject)" },
57                    "body": { "type": "string", "description": "New commit body (optional)" },
58                    "footer": { "type": "string", "description": "New commit footer (optional)" }
59                },
60                "required": ["original_sha", "action", "message"]
61            }
62        }
63    },
64    "required": ["commits"]
65}"#;
66
67/// Guard that removes a temp directory on drop.
68struct TmpDirGuard(std::path::PathBuf);
69
70impl Drop for TmpDirGuard {
71    fn drop(&mut self) {
72        let _ = std::fs::remove_dir_all(&self.0);
73    }
74}
75
76fn format_done_detail(count: usize, label: &str, usage: &Option<crate::ai::AiUsage>) -> String {
77    let commits = format!("{count} commit{}", if count == 1 { "" } else { "s" });
78    let extra_part = if label.is_empty() {
79        String::new()
80    } else {
81        format!(" · {label}")
82    };
83    let usage_part = match usage {
84        Some(u) => {
85            let cost = u
86                .cost_usd
87                .map(|c| format!(" · ${c:.4}"))
88                .unwrap_or_default();
89            format!(
90                " · {} in / {} out{}",
91                ui::format_tokens(u.input_tokens),
92                ui::format_tokens(u.output_tokens),
93                cost
94            )
95        }
96        None => String::new(),
97    };
98    format!("{commits}{extra_part}{usage_part}")
99}
100
101fn spawn_event_handler(
102    spinner: &ProgressBar,
103) -> (mpsc::UnboundedSender<AiEvent>, tokio::task::JoinHandle<()>) {
104    let (tx, mut rx) = mpsc::unbounded_channel::<AiEvent>();
105    let spinner_clone = spinner.clone();
106    let handle = tokio::spawn(async move {
107        while let Some(event) = rx.recv().await {
108            match event {
109                AiEvent::ToolCall { input, .. } => ui::tool_call(&spinner_clone, &input),
110            }
111        }
112    });
113    (tx, handle)
114}
115
116pub async fn run(args: &RebaseArgs, backend_config: &BackendConfig) -> Result<()> {
117    ui::header("sr rebase");
118
119    let repo = GitRepo::discover()?;
120    ui::phase_ok("Repository found", None);
121
122    if repo.has_any_changes()? {
123        bail!("cannot rebase: you have uncommitted changes. Please commit or stash them first.");
124    }
125
126    // Load config for commit pattern/types
127    let config = sr_core::config::ReleaseConfig::find_config(repo.root().as_path())
128        .map(|(path, _)| sr_core::config::ReleaseConfig::load(&path))
129        .transpose()?
130        .unwrap_or_default();
131    let type_names: Vec<&str> = config.types.iter().map(|t| t.name.as_str()).collect();
132
133    // Determine how many commits to reorganize
134    let commit_count = match args.last {
135        Some(n) => n,
136        None => {
137            // Auto-detect: count commits since last tag
138            let count = repo.commits_since_last_tag()?;
139            if count == 0 {
140                bail!("no commits found to rebase");
141            }
142            count
143        }
144    };
145
146    if commit_count < 2 {
147        bail!("need at least 2 commits to rebase (found {commit_count})");
148    }
149
150    // Get commit details
151    let log = repo.log_detailed(commit_count)?;
152    ui::phase_ok("Commits loaded", Some(&format!("{commit_count} commits")));
153
154    // Resolve AI backend
155    let backend = resolve_backend(backend_config).await?;
156    let backend_name = backend.name().to_string();
157    let model_name = backend_config
158        .model
159        .as_deref()
160        .unwrap_or("default")
161        .to_string();
162    ui::phase_ok(
163        "Backend resolved",
164        Some(&format!("{backend_name} ({model_name})")),
165    );
166
167    // Build prompt
168    let system_prompt = build_system_prompt(&config.commit_pattern, &type_names);
169    let user_prompt = build_user_prompt(&log, args.message.as_deref())?;
170
171    let spinner = ui::spinner(&format!("Analyzing commits with {backend_name}..."));
172    let (tx, event_handler) = spawn_event_handler(&spinner);
173
174    let request = AiRequest {
175        system_prompt,
176        user_prompt,
177        json_schema: Some(REORGANIZE_SCHEMA.to_string()),
178        working_dir: repo.root().to_string_lossy().to_string(),
179    };
180
181    let response = backend.request(&request, Some(tx)).await?;
182    let _ = event_handler.await;
183
184    let plan: ReorganizePlan = serde_json::from_str(&response.text)
185        .or_else(|_| {
186            let value: serde_json::Value = serde_json::from_str(&response.text)?;
187            serde_json::from_value(value)
188        })
189        .context("failed to parse rebase plan from AI response")?;
190
191    let detail = format_done_detail(plan.commits.len(), "", &response.usage);
192    ui::spinner_done(&spinner, Some(&detail));
193
194    if plan.commits.is_empty() {
195        bail!("AI returned an empty rebase plan");
196    }
197
198    // Display the plan
199    display_plan(&plan);
200
201    if args.dry_run {
202        ui::info("Dry run — no changes made");
203        println!();
204        return Ok(());
205    }
206
207    if !args.yes && !ui::confirm("Execute rebase? [y/N]")? {
208        bail!(crate::error::SrAiError::Cancelled);
209    }
210
211    // Execute via git rebase
212    execute_rebase(&repo, &plan, commit_count)?;
213
214    Ok(())
215}
216
217fn build_system_prompt(commit_pattern: &str, type_names: &[&str]) -> String {
218    let types_list = type_names.join(", ");
219    format!(
220        r#"You are an expert at organizing git history. You will be given a list of recent commits and asked to reorganize them.
221
222You can:
223- **pick**: keep the commit as-is (but you may reword the message)
224- **reword**: keep the commit but change the message
225- **squash**: fold the commit into the previous one (combine their changes)
226- **drop**: remove the commit entirely (use sparingly — only for truly empty or duplicate commits)
227
228COMMIT MESSAGE FORMAT:
229- Must match this regex: {commit_pattern}
230- Format: type(scope): subject
231- Valid types ONLY: {types_list}
232- subject: imperative mood, lowercase first letter, no period at end, max 72 chars
233
234RULES:
235- Maintain the chronological order of commits (oldest first) unless reordering improves logical grouping
236- The first commit in the list CANNOT be "squash" — squash folds into the previous commit
237- Prefer "reword" over "squash" when commits are logically distinct
238- Only squash commits that are genuinely part of the same logical change
239- Every original commit SHA must appear exactly once in your output
240- If the commits are already well-organized, return them all as "pick" with improved messages if needed"#
241    )
242}
243
244fn build_user_prompt(log: &str, extra: Option<&str>) -> Result<String> {
245    let mut prompt = format!(
246        "Analyze these recent commits and suggest how to reorganize them for a cleaner history.\n\n\
247         Commits (oldest first):\n```\n{log}\n```"
248    );
249
250    if let Some(msg) = extra {
251        prompt.push_str(&format!(
252            "\n\nAdditional instructions from the user:\n{msg}"
253        ));
254    }
255
256    Ok(prompt)
257}
258
259fn display_plan(plan: &ReorganizePlan) {
260    use crossterm::style::Stylize;
261
262    println!();
263    println!(
264        "  {} {}",
265        "REBASE PLAN".bold(),
266        format!("· {} commits", plan.commits.len()).dim()
267    );
268    let rule = "─".repeat(50);
269    println!("  {}", rule.as_str().dim());
270    println!();
271
272    for commit in &plan.commits {
273        let action_styled = match commit.action.as_str() {
274            "pick" => format!("{}", "pick".green()),
275            "reword" => format!("{}", "reword".yellow()),
276            "squash" => format!("{}", "squash".magenta()),
277            "drop" => format!("{}", "drop".red()),
278            other => other.to_string(),
279        };
280
281        println!(
282            "  {} {} {}",
283            action_styled,
284            commit.original_sha.as_str().dim(),
285            commit.message.as_str().bold()
286        );
287
288        if let Some(body) = &commit.body
289            && !body.is_empty()
290        {
291            for line in body.lines() {
292                println!("   {}  {}", "│".dim(), line.dim());
293            }
294        }
295    }
296
297    println!();
298    println!("  {}", rule.as_str().dim());
299    println!();
300}
301
302fn execute_rebase(repo: &GitRepo, plan: &ReorganizePlan, commit_count: usize) -> Result<()> {
303    // Build the rebase todo script
304    let mut todo_lines = Vec::new();
305    for commit in &plan.commits {
306        let action = match commit.action.as_str() {
307            "pick" | "reword" => "pick", // we'll force-reword via GIT_SEQUENCE_EDITOR
308            "squash" => "squash",
309            "drop" => "drop",
310            other => bail!("unknown rebase action: {other}"),
311        };
312        todo_lines.push(format!("{action} {}", commit.original_sha));
313    }
314    let todo_content = todo_lines.join("\n") + "\n";
315
316    // Build commit message rewrites: map SHA -> new full message
317    let mut rewrites: HashMap<String, String> = HashMap::new();
318    // Also track squash messages to combine
319    let mut squash_messages: Vec<String> = Vec::new();
320    let mut last_pick_sha: Option<String> = None;
321
322    for commit in &plan.commits {
323        let mut full_msg = commit.message.clone();
324        if let Some(body) = &commit.body
325            && !body.is_empty()
326        {
327            full_msg.push_str("\n\n");
328            full_msg.push_str(body);
329        }
330        if let Some(footer) = &commit.footer
331            && !footer.is_empty()
332        {
333            full_msg.push_str("\n\n");
334            full_msg.push_str(footer);
335        }
336
337        match commit.action.as_str() {
338            "pick" | "reword" => {
339                // Flush any pending squash messages into the last pick
340                if !squash_messages.is_empty() {
341                    if let Some(ref sha) = last_pick_sha
342                        && let Some(existing) = rewrites.get_mut(sha)
343                    {
344                        for sq_msg in &squash_messages {
345                            existing.push_str("\n\n");
346                            existing.push_str(sq_msg);
347                        }
348                    }
349                    squash_messages.clear();
350                }
351                last_pick_sha = Some(commit.original_sha.clone());
352                rewrites.insert(commit.original_sha.clone(), full_msg);
353            }
354            "squash" => {
355                squash_messages.push(full_msg);
356            }
357            _ => {}
358        }
359    }
360    // Flush remaining squash messages
361    if !squash_messages.is_empty()
362        && let Some(ref sha) = last_pick_sha
363        && let Some(existing) = rewrites.get_mut(sha)
364    {
365        for sq_msg in &squash_messages {
366            existing.push_str("\n\n");
367            existing.push_str(sq_msg);
368        }
369    }
370
371    // Create a temporary directory for our editor scripts
372    let tmp_dir = std::env::temp_dir().join(format!("sr-rebase-{}", std::process::id()));
373    std::fs::create_dir_all(&tmp_dir).context("failed to create temp dir")?;
374    // Ensure cleanup on exit
375    let _cleanup = TmpDirGuard(tmp_dir.clone());
376
377    // Write the todo script (used as GIT_SEQUENCE_EDITOR)
378    let todo_script_path = tmp_dir.join("sequence-editor.sh");
379    {
380        let todo_file_path = tmp_dir.join("todo.txt");
381        std::fs::write(&todo_file_path, &todo_content)?;
382
383        let script = format!("#!/bin/sh\ncp '{}' \"$1\"\n", todo_file_path.display());
384        std::fs::write(&todo_script_path, &script)?;
385        #[cfg(unix)]
386        {
387            use std::os::unix::fs::PermissionsExt;
388            std::fs::set_permissions(&todo_script_path, std::fs::Permissions::from_mode(0o755))?;
389        }
390    }
391
392    // Write the commit message editor script (used as GIT_EDITOR / EDITOR)
393    let editor_script_path = tmp_dir.join("commit-editor.sh");
394    {
395        // Write each rewrite message to a file named by SHA
396        let msgs_dir = tmp_dir.join("msgs");
397        std::fs::create_dir_all(&msgs_dir)?;
398        for (sha, msg) in &rewrites {
399            std::fs::write(msgs_dir.join(sha), msg)?;
400        }
401
402        // The editor script: given a commit message file, find the matching SHA
403        // and replace with our rewritten message. For squash commits, git presents
404        // a combined message — we replace it entirely with the pick commit's message.
405        let script = format!(
406            r#"#!/bin/sh
407MSGS_DIR='{msgs_dir}'
408MSG_FILE="$1"
409
410# Try to find a matching SHA in the message file
411for sha_file in "$MSGS_DIR"/*; do
412    sha=$(basename "$sha_file")
413    if grep -q "$sha" "$MSG_FILE" 2>/dev/null; then
414        cp "$sha_file" "$MSG_FILE"
415        exit 0
416    fi
417done
418
419# For squash: the combined message won't contain a single SHA.
420# Find the first pick/reword SHA that's referenced in the todo.
421# Just use the message as-is if we can't match.
422exit 0
423"#,
424            msgs_dir = msgs_dir.display()
425        );
426        std::fs::write(&editor_script_path, &script)?;
427        #[cfg(unix)]
428        {
429            use std::os::unix::fs::PermissionsExt;
430            std::fs::set_permissions(&editor_script_path, std::fs::Permissions::from_mode(0o755))?;
431        }
432    }
433
434    // Run git rebase -i with our custom editors
435    let base = format!("HEAD~{commit_count}");
436
437    ui::info(&format!("Rebasing {commit_count} commits..."));
438
439    let output = std::process::Command::new("git")
440        .args(["-C", repo.root().to_str().unwrap()])
441        .args(["rebase", "-i", &base])
442        .env("GIT_SEQUENCE_EDITOR", todo_script_path.to_str().unwrap())
443        .env("GIT_EDITOR", editor_script_path.to_str().unwrap())
444        .env("EDITOR", editor_script_path.to_str().unwrap())
445        .stdout(std::process::Stdio::piped())
446        .stderr(std::process::Stdio::piped())
447        .output()
448        .context("failed to run git rebase")?;
449
450    if !output.status.success() {
451        let stderr = String::from_utf8_lossy(&output.stderr);
452        // Abort the rebase if it failed
453        let _ = std::process::Command::new("git")
454            .args(["-C", repo.root().to_str().unwrap()])
455            .args(["rebase", "--abort"])
456            .output();
457        bail!("git rebase failed: {}", stderr.trim());
458    }
459
460    // Show the new history
461    let new_log = repo.recent_commits(commit_count)?;
462    println!();
463    ui::phase_ok("Rebase complete", None);
464    println!();
465    for line in new_log.lines() {
466        println!("    {line}");
467    }
468    println!();
469
470    Ok(())
471}