1use crate::ai::{AiEvent, AiRequest, BackendConfig, resolve_backend};
2use crate::git::GitRepo;
3use crate::ui;
4use anyhow::{Context, Result, bail};
5use indicatif::ProgressBar;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use tokio::sync::mpsc;
9
10#[derive(Debug, clap::Args)]
11pub struct RebaseArgs {
12 #[arg(short = 'M', long)]
14 pub message: Option<String>,
15
16 #[arg(short = 'n', long)]
18 pub dry_run: bool,
19
20 #[arg(short, long)]
22 pub yes: bool,
23
24 #[arg(long)]
26 pub last: Option<usize>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ReorganizePlan {
31 pub commits: Vec<ReorganizedCommit>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct ReorganizedCommit {
36 pub original_sha: String,
38 pub action: String,
40 pub message: String,
42 pub body: Option<String>,
43 pub footer: Option<String>,
44}
45
46const REORGANIZE_SCHEMA: &str = r#"{
47 "type": "object",
48 "properties": {
49 "commits": {
50 "type": "array",
51 "items": {
52 "type": "object",
53 "properties": {
54 "original_sha": { "type": "string", "description": "Short SHA of the original commit" },
55 "action": { "type": "string", "enum": ["pick", "reword", "squash", "drop"], "description": "Rebase action" },
56 "message": { "type": "string", "description": "New commit message header (type(scope): subject)" },
57 "body": { "type": "string", "description": "New commit body (optional)" },
58 "footer": { "type": "string", "description": "New commit footer (optional)" }
59 },
60 "required": ["original_sha", "action", "message"]
61 }
62 }
63 },
64 "required": ["commits"]
65}"#;
66
67struct TmpDirGuard(std::path::PathBuf);
69
70impl Drop for TmpDirGuard {
71 fn drop(&mut self) {
72 let _ = std::fs::remove_dir_all(&self.0);
73 }
74}
75
76fn format_done_detail(count: usize, label: &str, usage: &Option<crate::ai::AiUsage>) -> String {
77 let commits = format!("{count} commit{}", if count == 1 { "" } else { "s" });
78 let extra_part = if label.is_empty() {
79 String::new()
80 } else {
81 format!(" · {label}")
82 };
83 let usage_part = match usage {
84 Some(u) => {
85 let cost = u
86 .cost_usd
87 .map(|c| format!(" · ${c:.4}"))
88 .unwrap_or_default();
89 format!(
90 " · {} in / {} out{}",
91 ui::format_tokens(u.input_tokens),
92 ui::format_tokens(u.output_tokens),
93 cost
94 )
95 }
96 None => String::new(),
97 };
98 format!("{commits}{extra_part}{usage_part}")
99}
100
101fn spawn_event_handler(
102 spinner: &ProgressBar,
103) -> (mpsc::UnboundedSender<AiEvent>, tokio::task::JoinHandle<()>) {
104 let (tx, mut rx) = mpsc::unbounded_channel::<AiEvent>();
105 let spinner_clone = spinner.clone();
106 let handle = tokio::spawn(async move {
107 while let Some(event) = rx.recv().await {
108 match event {
109 AiEvent::ToolCall { input, .. } => ui::tool_call(&spinner_clone, &input),
110 }
111 }
112 });
113 (tx, handle)
114}
115
116pub async fn run(args: &RebaseArgs, backend_config: &BackendConfig) -> Result<()> {
117 ui::header("sr rebase");
118
119 let repo = GitRepo::discover()?;
120 ui::phase_ok("Repository found", None);
121
122 if repo.has_any_changes()? {
123 bail!("cannot rebase: you have uncommitted changes. Please commit or stash them first.");
124 }
125
126 let config = sr_core::config::ReleaseConfig::find_config(repo.root().as_path())
128 .map(|(path, _)| sr_core::config::ReleaseConfig::load(&path))
129 .transpose()?
130 .unwrap_or_default();
131 let type_names: Vec<&str> = config.types.iter().map(|t| t.name.as_str()).collect();
132
133 let commit_count = match args.last {
135 Some(n) => n,
136 None => {
137 let count = repo.commits_since_last_tag()?;
139 if count == 0 {
140 bail!("no commits found to rebase");
141 }
142 count
143 }
144 };
145
146 if commit_count < 2 {
147 bail!("need at least 2 commits to rebase (found {commit_count})");
148 }
149
150 let log = repo.log_detailed(commit_count)?;
152 ui::phase_ok("Commits loaded", Some(&format!("{commit_count} commits")));
153
154 let backend = resolve_backend(backend_config).await?;
156 let backend_name = backend.name().to_string();
157 let model_name = backend_config
158 .model
159 .as_deref()
160 .unwrap_or("default")
161 .to_string();
162 ui::phase_ok(
163 "Backend resolved",
164 Some(&format!("{backend_name} ({model_name})")),
165 );
166
167 let system_prompt = build_system_prompt(&config.commit_pattern, &type_names);
169 let user_prompt = build_user_prompt(&log, args.message.as_deref())?;
170
171 let spinner = ui::spinner(&format!("Analyzing commits with {backend_name}..."));
172 let (tx, event_handler) = spawn_event_handler(&spinner);
173
174 let request = AiRequest {
175 system_prompt,
176 user_prompt,
177 json_schema: Some(REORGANIZE_SCHEMA.to_string()),
178 working_dir: repo.root().to_string_lossy().to_string(),
179 };
180
181 let response = backend.request(&request, Some(tx)).await?;
182 let _ = event_handler.await;
183
184 let plan: ReorganizePlan = serde_json::from_str(&response.text)
185 .or_else(|_| {
186 let value: serde_json::Value = serde_json::from_str(&response.text)?;
187 serde_json::from_value(value)
188 })
189 .context("failed to parse rebase plan from AI response")?;
190
191 let detail = format_done_detail(plan.commits.len(), "", &response.usage);
192 ui::spinner_done(&spinner, Some(&detail));
193
194 if plan.commits.is_empty() {
195 bail!("AI returned an empty rebase plan");
196 }
197
198 display_plan(&plan);
200
201 if args.dry_run {
202 ui::info("Dry run — no changes made");
203 println!();
204 return Ok(());
205 }
206
207 if !args.yes && !ui::confirm("Execute rebase? [y/N]")? {
208 bail!(crate::error::SrAiError::Cancelled);
209 }
210
211 execute_rebase(&repo, &plan, commit_count)?;
213
214 Ok(())
215}
216
217fn build_system_prompt(commit_pattern: &str, type_names: &[&str]) -> String {
218 let types_list = type_names.join(", ");
219 format!(
220 r#"You are an expert at organizing git history. You will be given a list of recent commits and asked to reorganize them.
221
222You can:
223- **pick**: keep the commit as-is (but you may reword the message)
224- **reword**: keep the commit but change the message
225- **squash**: fold the commit into the previous one (combine their changes)
226- **drop**: remove the commit entirely (use sparingly — only for truly empty or duplicate commits)
227
228COMMIT MESSAGE FORMAT:
229- Must match this regex: {commit_pattern}
230- Format: type(scope): subject
231- Valid types ONLY: {types_list}
232- subject: imperative mood, lowercase first letter, no period at end, max 72 chars
233
234RULES:
235- Maintain the chronological order of commits (oldest first) unless reordering improves logical grouping
236- The first commit in the list CANNOT be "squash" — squash folds into the previous commit
237- Prefer "reword" over "squash" when commits are logically distinct
238- Only squash commits that are genuinely part of the same logical change
239- Every original commit SHA must appear exactly once in your output
240- If the commits are already well-organized, return them all as "pick" with improved messages if needed"#
241 )
242}
243
244fn build_user_prompt(log: &str, extra: Option<&str>) -> Result<String> {
245 let mut prompt = format!(
246 "Analyze these recent commits and suggest how to reorganize them for a cleaner history.\n\n\
247 Commits (oldest first):\n```\n{log}\n```"
248 );
249
250 if let Some(msg) = extra {
251 prompt.push_str(&format!(
252 "\n\nAdditional instructions from the user:\n{msg}"
253 ));
254 }
255
256 Ok(prompt)
257}
258
259fn display_plan(plan: &ReorganizePlan) {
260 use crossterm::style::Stylize;
261
262 println!();
263 println!(
264 " {} {}",
265 "REBASE PLAN".bold(),
266 format!("· {} commits", plan.commits.len()).dim()
267 );
268 let rule = "─".repeat(50);
269 println!(" {}", rule.as_str().dim());
270 println!();
271
272 for commit in &plan.commits {
273 let action_styled = match commit.action.as_str() {
274 "pick" => format!("{}", "pick".green()),
275 "reword" => format!("{}", "reword".yellow()),
276 "squash" => format!("{}", "squash".magenta()),
277 "drop" => format!("{}", "drop".red()),
278 other => other.to_string(),
279 };
280
281 println!(
282 " {} {} {}",
283 action_styled,
284 commit.original_sha.as_str().dim(),
285 commit.message.as_str().bold()
286 );
287
288 if let Some(body) = &commit.body
289 && !body.is_empty()
290 {
291 for line in body.lines() {
292 println!(" {} {}", "│".dim(), line.dim());
293 }
294 }
295 }
296
297 println!();
298 println!(" {}", rule.as_str().dim());
299 println!();
300}
301
302fn execute_rebase(repo: &GitRepo, plan: &ReorganizePlan, commit_count: usize) -> Result<()> {
303 let mut todo_lines = Vec::new();
305 for commit in &plan.commits {
306 let action = match commit.action.as_str() {
307 "pick" | "reword" => "pick", "squash" => "squash",
309 "drop" => "drop",
310 other => bail!("unknown rebase action: {other}"),
311 };
312 todo_lines.push(format!("{action} {}", commit.original_sha));
313 }
314 let todo_content = todo_lines.join("\n") + "\n";
315
316 let mut rewrites: HashMap<String, String> = HashMap::new();
318 let mut squash_messages: Vec<String> = Vec::new();
320 let mut last_pick_sha: Option<String> = None;
321
322 for commit in &plan.commits {
323 let mut full_msg = commit.message.clone();
324 if let Some(body) = &commit.body
325 && !body.is_empty()
326 {
327 full_msg.push_str("\n\n");
328 full_msg.push_str(body);
329 }
330 if let Some(footer) = &commit.footer
331 && !footer.is_empty()
332 {
333 full_msg.push_str("\n\n");
334 full_msg.push_str(footer);
335 }
336
337 match commit.action.as_str() {
338 "pick" | "reword" => {
339 if !squash_messages.is_empty() {
341 if let Some(ref sha) = last_pick_sha
342 && let Some(existing) = rewrites.get_mut(sha)
343 {
344 for sq_msg in &squash_messages {
345 existing.push_str("\n\n");
346 existing.push_str(sq_msg);
347 }
348 }
349 squash_messages.clear();
350 }
351 last_pick_sha = Some(commit.original_sha.clone());
352 rewrites.insert(commit.original_sha.clone(), full_msg);
353 }
354 "squash" => {
355 squash_messages.push(full_msg);
356 }
357 _ => {}
358 }
359 }
360 if !squash_messages.is_empty()
362 && let Some(ref sha) = last_pick_sha
363 && let Some(existing) = rewrites.get_mut(sha)
364 {
365 for sq_msg in &squash_messages {
366 existing.push_str("\n\n");
367 existing.push_str(sq_msg);
368 }
369 }
370
371 let tmp_dir = std::env::temp_dir().join(format!("sr-rebase-{}", std::process::id()));
373 std::fs::create_dir_all(&tmp_dir).context("failed to create temp dir")?;
374 let _cleanup = TmpDirGuard(tmp_dir.clone());
376
377 let todo_script_path = tmp_dir.join("sequence-editor.sh");
379 {
380 let todo_file_path = tmp_dir.join("todo.txt");
381 std::fs::write(&todo_file_path, &todo_content)?;
382
383 let script = format!("#!/bin/sh\ncp '{}' \"$1\"\n", todo_file_path.display());
384 std::fs::write(&todo_script_path, &script)?;
385 #[cfg(unix)]
386 {
387 use std::os::unix::fs::PermissionsExt;
388 std::fs::set_permissions(&todo_script_path, std::fs::Permissions::from_mode(0o755))?;
389 }
390 }
391
392 let editor_script_path = tmp_dir.join("commit-editor.sh");
394 {
395 let msgs_dir = tmp_dir.join("msgs");
397 std::fs::create_dir_all(&msgs_dir)?;
398 for (sha, msg) in &rewrites {
399 std::fs::write(msgs_dir.join(sha), msg)?;
400 }
401
402 let script = format!(
406 r#"#!/bin/sh
407MSGS_DIR='{msgs_dir}'
408MSG_FILE="$1"
409
410# Try to find a matching SHA in the message file
411for sha_file in "$MSGS_DIR"/*; do
412 sha=$(basename "$sha_file")
413 if grep -q "$sha" "$MSG_FILE" 2>/dev/null; then
414 cp "$sha_file" "$MSG_FILE"
415 exit 0
416 fi
417done
418
419# For squash: the combined message won't contain a single SHA.
420# Find the first pick/reword SHA that's referenced in the todo.
421# Just use the message as-is if we can't match.
422exit 0
423"#,
424 msgs_dir = msgs_dir.display()
425 );
426 std::fs::write(&editor_script_path, &script)?;
427 #[cfg(unix)]
428 {
429 use std::os::unix::fs::PermissionsExt;
430 std::fs::set_permissions(&editor_script_path, std::fs::Permissions::from_mode(0o755))?;
431 }
432 }
433
434 let base = format!("HEAD~{commit_count}");
436
437 ui::info(&format!("Rebasing {commit_count} commits..."));
438
439 let output = std::process::Command::new("git")
440 .args(["-C", repo.root().to_str().unwrap()])
441 .args(["rebase", "-i", &base])
442 .env("GIT_SEQUENCE_EDITOR", todo_script_path.to_str().unwrap())
443 .env("GIT_EDITOR", editor_script_path.to_str().unwrap())
444 .env("EDITOR", editor_script_path.to_str().unwrap())
445 .stdout(std::process::Stdio::piped())
446 .stderr(std::process::Stdio::piped())
447 .output()
448 .context("failed to run git rebase")?;
449
450 if !output.status.success() {
451 let stderr = String::from_utf8_lossy(&output.stderr);
452 let _ = std::process::Command::new("git")
454 .args(["-C", repo.root().to_str().unwrap()])
455 .args(["rebase", "--abort"])
456 .output();
457 bail!("git rebase failed: {}", stderr.trim());
458 }
459
460 let new_log = repo.recent_commits(commit_count)?;
462 println!();
463 ui::phase_ok("Rebase complete", None);
464 println!();
465 for line in new_log.lines() {
466 println!(" {line}");
467 }
468 println!();
469
470 Ok(())
471}