Skip to main content

semantic_diff/grouper/
mod.rs

1pub mod llm;
2
3use crate::diff::DiffData;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::collections::hash_map::DefaultHasher;
7use std::hash::{Hash, Hasher};
8
9/// Response envelope from LLM grouping request.
10#[derive(Debug, Clone, Deserialize)]
11pub struct GroupingResponse {
12    pub groups: Vec<SemanticGroup>,
13}
14
15/// A semantic group of related changes (hunk-level granularity).
16/// Accepts both `changes` (hunk-level) and `files` (file-level fallback) from LLM.
17#[derive(Debug, Clone, Deserialize)]
18pub struct SemanticGroup {
19    pub label: String,
20    #[serde(default)]
21    #[allow(dead_code)]
22    pub description: String,
23    /// Hunk-level changes (preferred format).
24    #[serde(default)]
25    changes: Vec<GroupedChange>,
26    /// File-level fallback: if LLM returns `"files": ["path"]` instead of `changes`.
27    #[serde(default)]
28    files: Vec<String>,
29}
30
31impl SemanticGroup {
32    /// Create a SemanticGroup from hunk-level changes.
33    pub fn new(label: String, description: String, changes: Vec<GroupedChange>) -> Self {
34        Self {
35            label,
36            description,
37            changes,
38            files: vec![],
39        }
40    }
41
42    /// Replace the changes list directly.
43    pub fn set_changes(&mut self, changes: Vec<GroupedChange>) {
44        self.changes = changes;
45        self.files.clear();
46    }
47
48    /// Get the list of changes, normalizing the `files` fallback into `changes`.
49    pub fn changes(&self) -> Vec<GroupedChange> {
50        if !self.changes.is_empty() {
51            return self.changes.clone();
52        }
53        // Fallback: convert file-level list to changes with empty hunks (= all hunks)
54        self.files
55            .iter()
56            .map(|f| GroupedChange {
57                file: f.clone(),
58                hunks: vec![],
59            })
60            .collect()
61    }
62}
63
64/// A reference to specific hunks within a file that belong to a group.
65#[derive(Debug, Clone, Deserialize)]
66pub struct GroupedChange {
67    pub file: String,
68    /// 0-based hunk indices. If empty, means all hunks in the file.
69    #[serde(default)]
70    pub hunks: Vec<usize>,
71}
72
73/// Tracks the lifecycle of an async grouping request.
74#[derive(Debug, Clone, PartialEq)]
75pub enum GroupingStatus {
76    /// No grouping attempted yet (or no LLM backend available).
77    Idle,
78    /// Waiting for LLM response.
79    Loading,
80    /// Groups received and applied.
81    Done,
82    /// LLM call failed (timeout, parse error, etc.).
83    Error(String),
84}
85
86/// Build hunk-level summaries for the LLM prompt from parsed diff data.
87///
88/// Format:
89/// ```text
90/// FILE: src/app.rs (modified, +10 -3)
91///   HUNK 0: @@ -100,6 +100,16 @@ impl App
92///     + pub fn new_method() {
93///     + ...
94///   HUNK 1: @@ -200,3 +210,5 @@ fn handle_key
95///     - old_call();
96///     + new_call();
97/// ```
98/// Max total characters for the summaries prompt to keep LLM response fast.
99const MAX_SUMMARY_CHARS: usize = 8000;
100
101pub fn hunk_summaries(diff_data: &DiffData) -> String {
102    let mut out = String::new();
103    for f in &diff_data.files {
104        let path = f.target_file.trim_start_matches("b/");
105        let status = file_status(f);
106        out.push_str(&format!(
107            "FILE: {} ({}, +{} -{})\n",
108            path, status, f.added_count, f.removed_count
109        ));
110
111        // For untracked files, use structural sampling instead of hunk-by-hunk
112        if f.is_untracked && out.len() < MAX_SUMMARY_CHARS {
113            out.push_str(&summarize_untracked_file(f));
114        } else {
115            append_hunk_samples(&mut out, f);
116        }
117
118        if out.len() >= MAX_SUMMARY_CHARS {
119            out.push_str("... (remaining files omitted for brevity)\n");
120            break;
121        }
122    }
123    out
124}
125
126/// Classify a file's change status for LLM summaries.
127fn file_status(f: &crate::diff::DiffFile) -> String {
128    if f.is_untracked {
129        "untracked/new".to_string()
130    } else if f.is_rename {
131        format!("renamed from {}", f.source_file.trim_start_matches("a/"))
132    } else if f.added_count > 0 && f.removed_count == 0 {
133        "added".to_string()
134    } else if f.removed_count > 0 && f.added_count == 0 {
135        "deleted".to_string()
136    } else {
137        "modified".to_string()
138    }
139}
140
141/// Append standard hunk-by-hunk samples (up to 4 changed lines each) to the output.
142fn append_hunk_samples(out: &mut String, f: &crate::diff::DiffFile) {
143    for (hi, hunk) in f.hunks.iter().enumerate() {
144        out.push_str(&format!("  HUNK {}: {}\n", hi, hunk.header));
145
146        if out.len() < MAX_SUMMARY_CHARS {
147            let mut shown = 0;
148            for line in &hunk.lines {
149                if shown >= 4 {
150                    out.push_str("    ...\n");
151                    break;
152                }
153                match line.line_type {
154                    crate::diff::LineType::Added => {
155                        out.push_str(&format!("    + {}\n", truncate(&line.content, 60)));
156                        shown += 1;
157                    }
158                    crate::diff::LineType::Removed => {
159                        out.push_str(&format!("    - {}\n", truncate(&line.content, 60)));
160                        shown += 1;
161                    }
162                    _ => {}
163                }
164            }
165        }
166    }
167}
168
169/// Structural sampling for untracked (new) files.
170///
171/// Instead of showing the first 4 lines (usually imports), samples from three
172/// regions of the file to give the LLM a representative view of the file's purpose:
173/// - Head: first N non-blank lines (imports, module declaration)
174/// - Mid: N lines from the middle (core logic)
175/// - Tail: last N non-blank lines (exports, closing code)
176///
177/// For short files (≤12 lines), shows all content lines.
178fn summarize_untracked_file(f: &crate::diff::DiffFile) -> String {
179    // Collect all content lines (flatten across hunks)
180    let all_lines: Vec<&str> = f
181        .hunks
182        .iter()
183        .flat_map(|h| h.lines.iter())
184        .filter(|l| l.line_type == crate::diff::LineType::Added)
185        .map(|l| l.content.as_str())
186        .collect();
187
188    let total = all_lines.len();
189    let mut out = String::new();
190
191    if total <= 12 {
192        // Short file — show everything
193        for line in &all_lines {
194            out.push_str(&format!("    + {}\n", truncate(line, 80)));
195        }
196        return out;
197    }
198
199    const SAMPLE: usize = 4;
200
201    // Head
202    out.push_str("  [head]\n");
203    for line in all_lines.iter().take(SAMPLE) {
204        out.push_str(&format!("    + {}\n", truncate(line, 80)));
205    }
206
207    // Mid
208    let mid_start = total / 2 - SAMPLE / 2;
209    out.push_str(&format!("  [mid ~line {}]\n", mid_start + 1));
210    for line in all_lines.iter().skip(mid_start).take(SAMPLE) {
211        out.push_str(&format!("    + {}\n", truncate(line, 80)));
212    }
213
214    // Tail
215    let tail_start = total.saturating_sub(SAMPLE);
216    out.push_str(&format!("  [tail ~line {}]\n", tail_start + 1));
217    for line in all_lines.iter().skip(tail_start) {
218        out.push_str(&format!("    + {}\n", truncate(line, 80)));
219    }
220
221    out
222}
223
224/// Compute a stable hash of a file's diff content (hunk headers + line types + line content).
225/// Used to detect whether a file's diff has changed between refreshes.
226pub fn compute_file_hash(file: &crate::diff::DiffFile) -> u64 {
227    let mut hasher = DefaultHasher::new();
228    for hunk in &file.hunks {
229        hunk.header.hash(&mut hasher);
230        for line in &hunk.lines {
231            // Discriminant: 0 = Added, 1 = Removed, 2 = Context
232            let discriminant: u8 = match line.line_type {
233                crate::diff::LineType::Added => 0,
234                crate::diff::LineType::Removed => 1,
235                crate::diff::LineType::Context => 2,
236            };
237            discriminant.hash(&mut hasher);
238            line.content.hash(&mut hasher);
239        }
240    }
241    hasher.finish()
242}
243
244/// Compute hashes for all files in a diff. Key = file path with `b/` prefix stripped.
245pub fn compute_all_file_hashes(diff_data: &DiffData) -> HashMap<String, u64> {
246    diff_data
247        .files
248        .iter()
249        .map(|f| {
250            let path = f.target_file.trim_start_matches("b/").to_string();
251            (path, compute_file_hash(f))
252        })
253        .collect()
254}
255
256/// Categorization of files between two diff snapshots.
257#[derive(Debug, Clone, Serialize)]
258pub struct DiffDelta {
259    /// Files that are new (not in previous grouping).
260    pub new_files: Vec<String>,
261    /// Files that were removed (in previous but not in new diff).
262    pub removed_files: Vec<String>,
263    /// Files whose diff content changed.
264    pub modified_files: Vec<String>,
265    /// Files whose diff content is identical.
266    pub unchanged_files: Vec<String>,
267}
268
269impl DiffDelta {
270    pub fn has_changes(&self) -> bool {
271        !self.new_files.is_empty()
272            || !self.removed_files.is_empty()
273            || !self.modified_files.is_empty()
274    }
275
276    pub fn is_only_removals(&self) -> bool {
277        self.new_files.is_empty()
278            && self.modified_files.is_empty()
279            && !self.removed_files.is_empty()
280    }
281}
282
283/// Compare new file hashes against previous to categorize each file.
284pub fn compute_diff_delta(
285    new_hashes: &HashMap<String, u64>,
286    previous_hashes: &HashMap<String, u64>,
287) -> DiffDelta {
288    let mut new_files = Vec::new();
289    let mut modified_files = Vec::new();
290    let mut unchanged_files = Vec::new();
291
292    for (path, &new_hash) in new_hashes {
293        match previous_hashes.get(path) {
294            None => new_files.push(path.clone()),
295            Some(&prev_hash) if prev_hash != new_hash => modified_files.push(path.clone()),
296            _ => unchanged_files.push(path.clone()),
297        }
298    }
299
300    let removed_files = previous_hashes
301        .keys()
302        .filter(|p| !new_hashes.contains_key(*p))
303        .cloned()
304        .collect();
305
306    DiffDelta {
307        new_files,
308        removed_files,
309        modified_files,
310        unchanged_files,
311    }
312}
313
314/// Build hunk summaries for ONLY new/modified files, prepended with existing group context.
315///
316/// Format:
317/// ```text
318/// EXISTING GROUPS (for context — assign new changes to these or create new groups):
319/// 1. "Auth refactor" — files: src/auth.rs, src/middleware.rs
320///
321/// NEW/MODIFIED FILES TO GROUP:
322/// FILE: src/router.rs (added, +20 -0)
323///   HUNK 0: @@ ...
324///     + pub fn new_route() {
325/// ```
326pub fn incremental_hunk_summaries(
327    diff_data: &DiffData,
328    delta: &DiffDelta,
329    existing_groups: &[SemanticGroup],
330) -> String {
331    let mut out = String::new();
332
333    // --- Existing group context ---
334    if !existing_groups.is_empty() {
335        out.push_str(
336            "EXISTING GROUPS (for context \u{2014} assign new changes to these or create new groups):\n",
337        );
338        for (i, group) in existing_groups.iter().enumerate() {
339            let changes = group.changes();
340            let file_list: Vec<&str> = changes.iter().map(|c| c.file.as_str()).collect();
341            out.push_str(&format!(
342                "{}. \"{}\" \u{2014} files: {}\n",
343                i + 1,
344                group.label,
345                file_list.join(", ")
346            ));
347        }
348        out.push('\n');
349    }
350
351    out.push_str("NEW/MODIFIED FILES TO GROUP:\n");
352
353    // Collect the set of files to include (new + modified)
354    let include: std::collections::HashSet<&str> = delta
355        .new_files
356        .iter()
357        .chain(delta.modified_files.iter())
358        .map(|s| s.as_str())
359        .collect();
360
361    for f in &diff_data.files {
362        let path = f.target_file.trim_start_matches("b/");
363        if !include.contains(path) {
364            continue;
365        }
366
367        let status = file_status(f);
368        out.push_str(&format!(
369            "FILE: {} ({}, +{} -{})\n",
370            path, status, f.added_count, f.removed_count
371        ));
372
373        if f.is_untracked && out.len() < MAX_SUMMARY_CHARS {
374            out.push_str(&summarize_untracked_file(f));
375        } else {
376            append_hunk_samples(&mut out, f);
377        }
378
379        if out.len() >= MAX_SUMMARY_CHARS {
380            out.push_str("... (remaining files omitted for brevity)\n");
381            break;
382        }
383    }
384
385    out
386}
387
388/// Post-process grouping results: fill in explicit hunk indices when `hunks` is empty
389/// and the file has multiple hunks, so the UI can filter hunks per group correctly.
390pub fn normalize_hunk_indices(groups: &mut [SemanticGroup], diff_data: &DiffData) {
391    // Build a map from file path -> hunk count
392    let hunk_counts: HashMap<String, usize> = diff_data
393        .files
394        .iter()
395        .map(|f| {
396            let path = f.target_file.trim_start_matches("b/").to_string();
397            (path, f.hunks.len())
398        })
399        .collect();
400
401    for group in groups.iter_mut() {
402        let mut updated = group.changes();
403        for change in updated.iter_mut() {
404            if change.hunks.is_empty() {
405                if let Some(&count) = hunk_counts.get(&change.file) {
406                    if count > 1 {
407                        change.hunks = (0..count).collect();
408                    }
409                }
410            }
411        }
412        group.set_changes(updated);
413    }
414}
415
416/// Remove all entries for the given file paths from existing groups.
417/// Groups that become empty after removal are dropped.
418pub fn remove_files_from_groups(groups: &mut Vec<SemanticGroup>, files_to_remove: &[String]) {
419    if files_to_remove.is_empty() {
420        return;
421    }
422    let remove_set: std::collections::HashSet<&str> =
423        files_to_remove.iter().map(|s| s.as_str()).collect();
424
425    groups.retain_mut(|group| {
426        let filtered: Vec<GroupedChange> = group
427            .changes()
428            .into_iter()
429            .filter(|c| !remove_set.contains(c.file.as_str()))
430            .collect();
431        group.set_changes(filtered);
432        !group.changes().is_empty()
433    });
434}
435
436/// Merge new LLM grouping assignments into existing groups.
437///
438/// Steps:
439/// 1. Clone existing groups.
440/// 2. Remove entries for `removed_files` and `modified_files` (stale data).
441/// 3. For each group in `new_assignments`:
442///    - If label matches an existing group (case-insensitive), merge changes into it.
443///    - Otherwise, append as a new group.
444/// 4. Remove empty groups.
445pub fn merge_groups(
446    existing: &[SemanticGroup],
447    new_assignments: &[SemanticGroup],
448    delta: &DiffDelta,
449) -> Vec<SemanticGroup> {
450    let mut merged: Vec<SemanticGroup> = existing.to_vec();
451
452    // Remove stale file entries
453    let stale: Vec<String> = delta
454        .removed_files
455        .iter()
456        .chain(delta.modified_files.iter())
457        .cloned()
458        .collect();
459    remove_files_from_groups(&mut merged, &stale);
460
461    // Integrate new assignments
462    for new_group in new_assignments {
463        let new_changes = new_group.changes();
464        if new_changes.is_empty() {
465            continue;
466        }
467
468        // Find existing group with matching label (case-insensitive)
469        let existing_pos = merged
470            .iter()
471            .position(|g| g.label.to_lowercase() == new_group.label.to_lowercase());
472
473        if let Some(pos) = existing_pos {
474            let mut combined = merged[pos].changes();
475            combined.extend(new_changes);
476            merged[pos].set_changes(combined);
477        } else {
478            merged.push(new_group.clone());
479        }
480    }
481
482    // Drop any groups that ended up empty
483    merged.retain(|g| !g.changes().is_empty());
484
485    merged
486}
487
488/// Truncate a string to at most `max` bytes, respecting UTF-8 char boundaries.
489/// Returns a string slice that is always valid UTF-8.
490fn truncate(s: &str, max: usize) -> &str {
491    if s.len() <= max {
492        s
493    } else {
494        // Find the largest char boundary <= max
495        let mut end = max;
496        while end > 0 && !s.is_char_boundary(end) {
497            end -= 1;
498        }
499        &s[..end]
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506
507    #[test]
508    fn test_truncate_ascii() {
509        assert_eq!(truncate("hello", 3), "hel");
510    }
511
512    #[test]
513    fn test_truncate_shorter_than_max() {
514        assert_eq!(truncate("hi", 10), "hi");
515    }
516
517    #[test]
518    fn test_truncate_cjk_at_boundary_no_panic() {
519        // CJK characters are 3 bytes each in UTF-8
520        let s = "\u{4e16}\u{754c}\u{4f60}\u{597d}"; // 世界你好 (12 bytes)
521        // Truncating at byte 4 should not panic -- it should back up to byte 3
522        let result = truncate(s, 4);
523        assert_eq!(result, "\u{4e16}"); // 世 (3 bytes)
524    }
525
526    #[test]
527    fn test_truncate_emoji_at_boundary_no_panic() {
528        // Emoji like 🦀 are 4 bytes in UTF-8
529        let s = "a🦀b"; // 1 + 4 + 1 = 6 bytes
530        // Truncating at byte 3 (middle of emoji) should not panic
531        let result = truncate(s, 3);
532        assert_eq!(result, "a"); // backs up to byte 1
533    }
534
535    #[test]
536    fn test_truncate_exact_boundary() {
537        assert_eq!(truncate("hello", 5), "hello");
538    }
539
540    #[test]
541    fn test_truncate_zero() {
542        assert_eq!(truncate("hello", 0), "");
543    }
544}