Skip to main content

steer_tui/tui/widgets/
diff.rs

1use crate::tui::theme::{Component, Theme};
2use ratatui::{
3    style::{Modifier, Style},
4    text::{Line, Span},
5};
6use similar::{Algorithm, ChangeTag, TextDiff};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum DiffMode {
10    Unified,
11    Split,
12}
13
14pub struct DiffWidget<'a> {
15    old: &'a str,
16    new: &'a str,
17    mode: DiffMode,
18    wrap_width: usize,
19    theme: &'a Theme,
20    context_radius: usize,
21    max_lines: Option<usize>,
22}
23
24impl<'a> DiffWidget<'a> {
25    pub fn new(old: &'a str, new: &'a str, theme: &'a Theme) -> Self {
26        Self {
27            old,
28            new,
29            mode: DiffMode::Unified,
30            wrap_width: 80,
31            theme,
32            context_radius: 3,
33            max_lines: None,
34        }
35    }
36
37    pub fn with_mode(mut self, mode: DiffMode) -> Self {
38        self.mode = mode;
39        self
40    }
41
42    pub fn with_wrap_width(mut self, width: usize) -> Self {
43        self.wrap_width = width;
44        self
45    }
46
47    pub fn with_context_radius(mut self, radius: usize) -> Self {
48        self.context_radius = radius;
49        self
50    }
51
52    pub fn with_max_lines(mut self, max: usize) -> Self {
53        self.max_lines = Some(max);
54        self
55    }
56
57    pub fn lines(&self) -> Vec<Line<'static>> {
58        match self.mode {
59            DiffMode::Unified => self.unified_diff(),
60            DiffMode::Split => {
61                // Fall back to unified if terminal is too narrow for split view
62                // We need at least ~40 chars for split view to be useful (20 per side)
63                if self.wrap_width < 40 {
64                    self.unified_diff()
65                } else {
66                    self.split_diff()
67                }
68            }
69        }
70    }
71
72    fn unified_diff(&self) -> Vec<Line<'static>> {
73        let diff = TextDiff::configure()
74            .algorithm(Algorithm::Myers)
75            .diff_lines(self.old, self.new);
76
77        let changes: Vec<_> = diff.iter_all_changes().collect();
78
79        // First pass: determine which lines to show
80        let mut show_line = vec![false; changes.len()];
81        for (idx, change) in changes.iter().enumerate() {
82            if change.tag() != ChangeTag::Equal {
83                // Always show non-equal lines
84                show_line[idx] = true;
85
86                // Show context before
87                let start = idx.saturating_sub(self.context_radius);
88                for line in show_line.iter_mut().take(idx).skip(start) {
89                    *line = true;
90                }
91
92                // Show context after
93                let end = (idx + 1 + self.context_radius).min(changes.len());
94                for line in show_line.iter_mut().take(end).skip(idx + 1) {
95                    *line = true;
96                }
97            }
98        }
99
100        // Second pass: render lines with ellipsis for gaps
101        let mut lines = Vec::new();
102        let mut last_shown: Option<usize> = None;
103
104        for (idx, (change, &should_show)) in changes.iter().zip(&show_line).enumerate() {
105            if !should_show {
106                continue;
107            }
108
109            // Add ellipsis if there's a gap
110            match last_shown {
111                None if idx > 0 => {
112                    // Gap at beginning
113                    lines.push(self.separator_line());
114                }
115                Some(last) if idx > last + 1 => {
116                    // Gap in middle
117                    lines.push(self.separator_line());
118                }
119                _ => {}
120            }
121
122            let (prefix, style) = match change.tag() {
123                ChangeTag::Delete => ("-", self.theme.style(Component::CodeDeletion)),
124                ChangeTag::Insert => ("+", self.theme.style(Component::CodeAddition)),
125                ChangeTag::Equal => (" ", self.theme.style(Component::DimText)),
126            };
127
128            let content = change.value().trim_end();
129            lines.extend(self.format_line(prefix, content, style));
130
131            last_shown = Some(idx);
132
133            // Check if we've hit the line limit
134            if let Some(max) = self.max_lines
135                && lines.len() >= max
136            {
137                let remaining = changes.len() - idx - 1;
138                if remaining > 0 {
139                    lines.push(Line::from(Span::styled(
140                        format!("... ({remaining} more lines)"),
141                        self.theme
142                            .style(Component::DimText)
143                            .add_modifier(Modifier::ITALIC),
144                    )));
145                }
146                break;
147            }
148        }
149
150        lines
151    }
152
153    fn split_diff(&self) -> Vec<Line<'static>> {
154        let mut lines = Vec::new();
155
156        // Calculate split width (account for prefix and divider)
157        let half_width = (self.wrap_width.saturating_sub(5)) / 2;
158
159        let diff = TextDiff::configure()
160            .algorithm(Algorithm::Myers)
161            .diff_lines(self.old, self.new);
162
163        // Group changes to properly handle replacements
164        let changes: Vec<_> = diff.iter_all_changes().collect();
165        let mut i = 0;
166
167        while i < changes.len() {
168            let change = changes[i];
169
170            match change.tag() {
171                ChangeTag::Equal => {
172                    // Equal lines - show on both sides
173                    let content = change.value().trim_end();
174                    let left = Self::truncate_or_pad(content, half_width);
175                    let right = Self::truncate_or_pad(content, half_width);
176
177                    lines.push(Line::from(vec![
178                        Span::styled(" ", self.theme.style(Component::DimText)),
179                        Span::styled(left, self.theme.style(Component::DimText)),
180                        Span::styled(" │ ", self.theme.style(Component::DimText)),
181                        Span::styled(" ", self.theme.style(Component::DimText)),
182                        Span::styled(right, self.theme.style(Component::DimText)),
183                    ]));
184                    i += 1;
185                }
186                ChangeTag::Delete => {
187                    // Check if this is part of a replacement (delete followed by insert)
188                    let mut deletes = vec![change];
189                    let mut j = i + 1;
190
191                    // Collect consecutive deletes
192                    while j < changes.len() && changes[j].tag() == ChangeTag::Delete {
193                        deletes.push(changes[j]);
194                        j += 1;
195                    }
196
197                    // Check if followed by inserts
198                    let mut inserts = Vec::new();
199                    while j < changes.len() && changes[j].tag() == ChangeTag::Insert {
200                        inserts.push(changes[j]);
201                        j += 1;
202                    }
203
204                    if inserts.is_empty() {
205                        // Pure deletion - show on left only
206                        for del in deletes {
207                            let content = del.value().trim_end();
208                            let left = Self::truncate_or_pad(content, half_width);
209                            let right = " ".repeat(half_width);
210
211                            lines.push(Line::from(vec![
212                                Span::styled("-", self.theme.style(Component::CodeDeletion)),
213                                Span::styled(left, self.theme.style(Component::CodeDeletion)),
214                                Span::styled(" │ ", self.theme.style(Component::DimText)),
215                                Span::styled(" ", self.theme.style(Component::DimText)),
216                                Span::styled(right, self.theme.style(Component::DimText)),
217                            ]));
218                        }
219                        i = j;
220                    } else {
221                        // This is a replacement - show side by side
222                        let max_len = deletes.len().max(inserts.len());
223
224                        for idx in 0..max_len {
225                            let left_content = if idx < deletes.len() {
226                                deletes[idx].value().trim_end()
227                            } else {
228                                ""
229                            };
230                            let right_content = if idx < inserts.len() {
231                                inserts[idx].value().trim_end()
232                            } else {
233                                ""
234                            };
235
236                            let left = Self::truncate_or_pad(left_content, half_width);
237                            let right = Self::truncate_or_pad(right_content, half_width);
238
239                            // Determine prefixes based on whether there's content
240                            let left_prefix = if left_content.is_empty() { " " } else { "-" };
241                            let right_prefix = if right_content.is_empty() { " " } else { "+" };
242
243                            lines.push(Line::from(vec![
244                                Span::styled(
245                                    left_prefix,
246                                    self.theme.style(Component::CodeDeletion),
247                                ),
248                                Span::styled(
249                                    left,
250                                    if left_content.is_empty() {
251                                        self.theme.style(Component::DimText)
252                                    } else {
253                                        self.theme.style(Component::CodeDeletion)
254                                    },
255                                ),
256                                Span::styled(" │ ", self.theme.style(Component::DimText)),
257                                Span::styled(
258                                    right_prefix,
259                                    self.theme.style(Component::CodeAddition),
260                                ),
261                                Span::styled(
262                                    right,
263                                    if right_content.is_empty() {
264                                        self.theme.style(Component::DimText)
265                                    } else {
266                                        self.theme.style(Component::CodeAddition)
267                                    },
268                                ),
269                            ]));
270                        }
271
272                        i = j;
273                    }
274                }
275                ChangeTag::Insert => {
276                    // Pure insertion (not part of replacement) - show on right only
277                    let content = change.value().trim_end();
278                    let left = " ".repeat(half_width);
279                    let right = Self::truncate_or_pad(content, half_width);
280
281                    lines.push(Line::from(vec![
282                        Span::styled(" ", self.theme.style(Component::DimText)),
283                        Span::styled(left, self.theme.style(Component::DimText)),
284                        Span::styled(" │ ", self.theme.style(Component::DimText)),
285                        Span::styled("+", self.theme.style(Component::CodeAddition)),
286                        Span::styled(right, self.theme.style(Component::CodeAddition)),
287                    ]));
288                    i += 1;
289                }
290            }
291
292            // Check line limit
293            if let Some(max) = self.max_lines
294                && lines.len() >= max
295            {
296                break;
297            }
298        }
299
300        lines
301    }
302
303    fn format_line(&self, prefix: &str, content: &str, style: Style) -> Vec<Line<'static>> {
304        let mut lines = Vec::new();
305
306        // Wrap long lines
307        let wrapped = textwrap::wrap(content, self.wrap_width.saturating_sub(2));
308
309        if wrapped.is_empty() {
310            lines.push(Line::from(vec![
311                Span::styled(prefix.to_string(), style),
312                Span::styled(" ", style),
313            ]));
314        } else {
315            for (i, wrapped_line) in wrapped.iter().enumerate() {
316                if i == 0 {
317                    lines.push(Line::from(vec![
318                        Span::styled(prefix.to_string(), style),
319                        Span::styled(format!(" {wrapped_line}"), style),
320                    ]));
321                } else {
322                    // Continuation lines
323                    lines.push(Line::from(vec![
324                        Span::styled("  ", style),
325                        Span::styled(wrapped_line.to_string(), style),
326                    ]));
327                }
328            }
329        }
330
331        lines
332    }
333
334    fn separator_line(&self) -> Line<'static> {
335        Line::from(Span::styled(
336            "···",
337            self.theme
338                .style(Component::DimText)
339                .add_modifier(Modifier::DIM),
340        ))
341    }
342
343    fn truncate_or_pad(s: &str, width: usize) -> String {
344        // Use Unicode-aware truncation
345        let char_count = s.chars().count();
346        if char_count > width {
347            // Collect chars and truncate at character boundary
348            let truncated: String = s.chars().take(width.saturating_sub(1)).collect();
349            format!("{truncated}…")
350        } else {
351            // Pad with spaces to reach desired width
352            format!("{s:width$}")
353        }
354    }
355}
356
357// Helper function for preview/summary use cases
358pub fn diff_summary(old: &str, new: &str, max_len: usize) -> (String, String) {
359    let old_preview = if old.is_empty() {
360        String::new()
361    } else {
362        let trimmed = old.trim();
363        if trimmed.len() <= max_len {
364            trimmed.to_string()
365        } else {
366            format!("{}...", &trimmed[..max_len.saturating_sub(3)])
367        }
368    };
369
370    let new_preview = if new.is_empty() {
371        String::new()
372    } else {
373        let trimmed = new.trim();
374        if trimmed.len() <= max_len {
375            trimmed.to_string()
376        } else {
377            format!("{}...", &trimmed[..max_len.saturating_sub(3)])
378        }
379    };
380
381    (old_preview, new_preview)
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387    use crate::tui::theme::Theme;
388
389    fn extract_text_from_line(line: &Line) -> String {
390        line.spans
391            .iter()
392            .map(|span| span.content.as_ref())
393            .collect()
394    }
395
396    #[test]
397    fn test_unified_diff_basic() {
398        let theme = Theme::default();
399        let widget = DiffWidget::new("hello\nworld", "hello\nthere", &theme);
400        let lines = widget
401            .lines()
402            .iter()
403            .map(extract_text_from_line)
404            .collect::<Vec<_>>();
405
406        let expected = vec!["  hello", "- world", "+ there"];
407
408        assert_eq!(lines, expected);
409    }
410
411    #[test]
412    fn test_split_diff_equal_lines() {
413        let theme = Theme::default();
414        let old = "line1\nline2\nline3";
415        let new = "line1\nmodified2\nline3";
416
417        let widget = DiffWidget::new(old, new, &theme)
418            .with_mode(DiffMode::Split)
419            .with_wrap_width(80);
420
421        let lines = widget
422            .lines()
423            .iter()
424            .map(extract_text_from_line)
425            .collect::<Vec<_>>();
426        let expected = vec![
427            " line1                                 │  line1                                ",
428            "-line2                                 │ +modified2                            ",
429            " line3                                 │  line3                                ",
430        ];
431
432        assert_eq!(lines.len(), expected.len());
433        assert_eq!(lines, expected);
434    }
435
436    #[test]
437    fn test_split_diff_more_deletes_than_inserts() {
438        let theme = Theme::default();
439        let old = "line1\nline2\nline3\nline4\nline5";
440        let new = "line1\nreplacement";
441
442        let widget = DiffWidget::new(old, new, &theme)
443            .with_mode(DiffMode::Split)
444            .with_wrap_width(80);
445
446        let lines = widget
447            .lines()
448            .iter()
449            .map(extract_text_from_line)
450            .collect::<Vec<_>>();
451        let expected = vec![
452            " line1                                 │  line1                                ",
453            "-line2                                 │ +replacement                          ",
454            "-line3                                 │                                       ",
455            "-line4                                 │                                       ",
456            "-line5                                 │                                       ",
457        ];
458
459        assert_eq!(lines, expected);
460    }
461
462    #[test]
463    fn test_split_diff_more_inserts_than_deletes() {
464        let theme = Theme::default();
465        let old = "line1\nold";
466        let new = "line1\nnew1\nnew2\nnew3";
467
468        let widget = DiffWidget::new(old, new, &theme)
469            .with_mode(DiffMode::Split)
470            .with_wrap_width(80);
471
472        let lines = widget
473            .lines()
474            .iter()
475            .map(extract_text_from_line)
476            .collect::<Vec<_>>();
477        let expected = vec![
478            " line1                                 │  line1                                ",
479            "-old                                   │ +new1                                 ",
480            "                                       │ +new2                                 ",
481            "                                       │ +new3                                 ",
482        ];
483
484        assert_eq!(lines, expected);
485    }
486
487    #[test]
488    fn test_unicode_truncation() {
489        let theme = Theme::default();
490        let old = "Short";
491        let new = "This is a line with unicode: → ← ↑ ↓ — and more symbols";
492
493        let widget = DiffWidget::new(old, new, &theme)
494            .with_mode(DiffMode::Split)
495            .with_wrap_width(40); // Force truncation
496
497        let lines = widget
498            .lines()
499            .iter()
500            .map(extract_text_from_line)
501            .collect::<Vec<_>>();
502        // half_width = (40 - 5) / 2 = 17
503        let expected = vec!["-Short             │ +This is a line w…"];
504        assert_eq!(lines, expected);
505    }
506
507    #[test]
508    fn test_narrow_terminal_fallback() {
509        let theme = Theme::default();
510        let widget = DiffWidget::new("old", "new", &theme)
511            .with_mode(DiffMode::Split)
512            .with_wrap_width(30); // Too narrow for split view
513
514        let lines = widget
515            .lines()
516            .iter()
517            .map(extract_text_from_line)
518            .collect::<Vec<_>>();
519        let expected = vec!["- old", "+ new"];
520
521        assert_eq!(lines, expected);
522    }
523
524    #[test]
525    fn test_context_radius() {
526        let theme = Theme::default();
527        let old = "a\nb\nc\nd\ne\nf\ng";
528        let new = "a\nb\nc\nX\ne\nf\ng";
529
530        let widget = DiffWidget::new(old, new, &theme).with_context_radius(1); // Only 1 line of context
531
532        let lines = widget
533            .lines()
534            .iter()
535            .map(extract_text_from_line)
536            .collect::<Vec<_>>();
537        // With context_radius 1, skips a,b then shows c,d->X,e
538        let expected = vec!["···", "  c", "- d", "+ X", "  e"];
539
540        assert_eq!(lines, expected);
541    }
542
543    #[test]
544    fn test_max_lines_limit() {
545        let theme = Theme::default();
546        let old = "line 0\nline 1\nline 2\nline 3\nline 4\nline 5";
547        let new = "modified 0\nmodified 1\nmodified 2\nmodified 3\nmodified 4\nmodified 5";
548
549        let widget = DiffWidget::new(old, new, &theme).with_max_lines(10);
550
551        let lines = widget
552            .lines()
553            .iter()
554            .map(extract_text_from_line)
555            .collect::<Vec<_>>();
556        let expected = vec![
557            "- line 0",
558            "- line 1",
559            "- line 2",
560            "- line 3",
561            "- line 4",
562            "- line 5",
563            "+ modified 0",
564            "+ modified 1",
565            "+ modified 2",
566            "+ modified 3",
567            "... (2 more lines)",
568        ];
569
570        assert_eq!(lines, expected);
571    }
572
573    #[test]
574    fn test_diff_summary() {
575        let (old_preview, new_preview) = diff_summary(
576            "This is a very long line that should be truncated",
577            "Short",
578            20,
579        );
580
581        assert_eq!(old_preview, "This is a very lo...");
582        assert_eq!(new_preview, "Short");
583    }
584
585    #[test]
586    fn test_empty_strings() {
587        let theme = Theme::default();
588
589        // Test empty old string (pure addition)
590        let widget = DiffWidget::new("", "new content", &theme);
591        let lines = widget
592            .lines()
593            .iter()
594            .map(extract_text_from_line)
595            .collect::<Vec<_>>();
596        let expected = vec!["+ new content"];
597        assert_eq!(lines, expected);
598
599        // Test empty new string (pure deletion)
600        let widget = DiffWidget::new("old content", "", &theme);
601        let lines = widget
602            .lines()
603            .iter()
604            .map(extract_text_from_line)
605            .collect::<Vec<_>>();
606        let expected = vec!["- old content"];
607        assert_eq!(lines, expected);
608
609        // Test both empty
610        let widget = DiffWidget::new("", "", &theme);
611        let lines = widget
612            .lines()
613            .iter()
614            .map(extract_text_from_line)
615            .collect::<Vec<_>>();
616        assert!(lines.is_empty());
617    }
618
619    #[test]
620    fn test_line_wrapping() {
621        let theme = Theme::default();
622        let old = "short";
623        let new = "This is a very long line that should be wrapped when displayed in the diff widget because it exceeds the wrap width";
624
625        let widget = DiffWidget::new(old, new, &theme).with_wrap_width(30);
626
627        let lines = widget
628            .lines()
629            .iter()
630            .map(extract_text_from_line)
631            .collect::<Vec<_>>();
632        // wrap_width=30, so 28 chars per line
633        let expected = vec![
634            "- short",
635            "+ This is a very long line",
636            "  that should be wrapped when",
637            "  displayed in the diff widget",
638            "  because it exceeds the wrap",
639            "  width",
640        ];
641
642        assert_eq!(lines, expected);
643    }
644
645    #[test]
646    fn test_unified_diff_exact_output() {
647        let theme = Theme::default();
648        let widget = DiffWidget::new("line1\nline2\nline3", "line1\nmodified\nline3", &theme)
649            .with_context_radius(1);
650
651        let lines = widget
652            .lines()
653            .iter()
654            .map(extract_text_from_line)
655            .collect::<Vec<_>>();
656        // With context_radius 1, we show line1 (context), line2->modified (change), line3 (context)
657        // No lines are skipped, so no separator
658        let expected = vec!["  line1", "- line2", "+ modified", "  line3"];
659
660        assert_eq!(lines, expected);
661    }
662
663    #[test]
664    fn test_split_diff_exact_output() {
665        let theme = Theme::default();
666        let widget = DiffWidget::new("same\nold\nsame", "same\nnew\nsame", &theme)
667            .with_mode(DiffMode::Split)
668            .with_wrap_width(80); // Wide enough to not trigger fallback
669
670        let lines = widget
671            .lines()
672            .iter()
673            .map(extract_text_from_line)
674            .collect::<Vec<_>>();
675        let expected = vec![
676            " same                                  │  same                                 ",
677            "-old                                   │ +new                                  ",
678            " same                                  │  same                                 ",
679        ];
680
681        assert_eq!(lines, expected);
682    }
683
684    #[test]
685    fn test_split_diff_uneven_replacement_exact() {
686        let theme = Theme::default();
687        let widget = DiffWidget::new("a\nb\nc\nd\ne", "a\nX\ne", &theme)
688            .with_mode(DiffMode::Split)
689            .with_wrap_width(80);
690
691        let lines = widget
692            .lines()
693            .iter()
694            .map(extract_text_from_line)
695            .collect::<Vec<_>>();
696        let expected = vec![
697            " a                                     │  a                                    ",
698            "-b                                     │ +X                                    ",
699            "-c                                     │                                       ",
700            "-d                                     │                                       ",
701            " e                                     │  e                                    ",
702        ];
703        assert_eq!(lines, expected);
704    }
705
706    #[test]
707    fn test_context_radius_exact() {
708        let theme = Theme::default();
709        let widget = DiffWidget::new(
710            "1\n2\n3\n4\n5\n6\n7\n8\n9",
711            "1\n2\n3\n4\nX\n6\n7\n8\n9",
712            &theme,
713        )
714        .with_context_radius(2);
715
716        let lines = widget
717            .lines()
718            .iter()
719            .map(extract_text_from_line)
720            .collect::<Vec<_>>();
721        // With context radius 2, skip 1,2, show 3,4,5->X,6,7, skip 8,9
722        let expected = vec!["···", "  3", "  4", "- 5", "+ X", "  6", "  7"];
723
724        assert_eq!(lines, expected);
725    }
726
727    #[test]
728    fn test_line_wrapping_exact() {
729        let theme = Theme::default();
730        let widget =
731            DiffWidget::new("short", "This is a long line that wraps", &theme).with_wrap_width(20); // Force wrapping
732
733        let lines = widget
734            .lines()
735            .iter()
736            .map(extract_text_from_line)
737            .collect::<Vec<_>>();
738        let expected = vec!["- short", "+ This is a long", "  line that wraps"];
739
740        assert_eq!(lines, expected);
741    }
742}