1use std::collections::HashSet;
4use std::hash::Hash;
5use std::marker::PhantomData;
6
7use crossterm::event::KeyCode;
8use ratatui::{
9 layout::Rect,
10 style::Style,
11 text::{Line, Span},
12 widgets::{Block, List, ListItem, ListState, ScrollbarOrientation, ScrollbarState},
13 Frame,
14};
15use tui_dispatch_core::{Component, EventKind};
16
17use crate::style::{BaseStyle, ComponentStyle, Padding, ScrollbarStyle, SelectionStyle};
18
19#[derive(Debug, Clone)]
21pub struct TreeNode<Id, T> {
22 pub id: Id,
23 pub value: T,
24 pub children: Vec<TreeNode<Id, T>>,
25}
26
27impl<Id, T> TreeNode<Id, T> {
28 pub fn new(id: Id, value: T) -> Self {
30 Self {
31 id,
32 value,
33 children: Vec::new(),
34 }
35 }
36
37 pub fn with_children(id: Id, value: T, children: Vec<TreeNode<Id, T>>) -> Self {
39 Self {
40 id,
41 value,
42 children,
43 }
44 }
45}
46
47#[derive(Debug, Clone, Copy, Default)]
49pub enum TreeBranchMode {
50 #[default]
52 Caret,
53 Branch,
55}
56
57#[derive(Debug, Clone)]
59pub struct TreeBranchStyle {
60 pub mode: TreeBranchMode,
62 pub indent_width: usize,
64 pub connector_style: Style,
66 pub caret_style: Style,
68}
69
70impl Default for TreeBranchStyle {
71 fn default() -> Self {
72 Self {
73 mode: TreeBranchMode::default(),
74 indent_width: 2,
75 connector_style: Style::default(),
76 caret_style: Style::default(),
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct TreeViewStyle {
84 pub base: BaseStyle,
86 pub selection: SelectionStyle,
88 pub scrollbar: ScrollbarStyle,
90 pub branches: TreeBranchStyle,
92}
93
94impl Default for TreeViewStyle {
95 fn default() -> Self {
96 Self {
97 base: BaseStyle {
98 fg: Some(ratatui::style::Color::Reset),
99 ..Default::default()
100 },
101 selection: SelectionStyle::default(),
102 scrollbar: ScrollbarStyle::default(),
103 branches: TreeBranchStyle::default(),
104 }
105 }
106}
107
108impl TreeViewStyle {
109 pub fn borderless() -> Self {
111 let mut style = Self::default();
112 style.base.border = None;
113 style
114 }
115
116 pub fn minimal() -> Self {
118 let mut style = Self::default();
119 style.base.border = None;
120 style.base.padding = Padding::default();
121 style
122 }
123}
124
125impl ComponentStyle for TreeViewStyle {
126 fn base(&self) -> &BaseStyle {
127 &self.base
128 }
129}
130
131#[derive(Debug, Clone)]
133pub struct TreeViewBehavior {
134 pub show_scrollbar: bool,
136 pub wrap_navigation: bool,
138 pub enter_toggles: bool,
140 pub space_toggles: bool,
142}
143
144impl Default for TreeViewBehavior {
145 fn default() -> Self {
146 Self {
147 show_scrollbar: true,
148 wrap_navigation: false,
149 enter_toggles: true,
150 space_toggles: true,
151 }
152 }
153}
154
155pub struct TreeNodeRender<'a, Id, T> {
157 pub node: &'a TreeNode<Id, T>,
158 pub depth: usize,
159 pub has_children: bool,
160 pub is_expanded: bool,
161 pub is_selected: bool,
162 pub available_width: usize,
163 pub leading_width: usize,
164 pub row_width: usize,
165 pub tree_column_width: usize,
166}
167
168pub struct TreeViewProps<'a, Id, T, A>
170where
171 Id: Clone + Eq + Hash + 'static,
172{
173 pub nodes: &'a [TreeNode<Id, T>],
175 pub selected_id: Option<&'a Id>,
177 pub expanded_ids: &'a HashSet<Id>,
179 pub is_focused: bool,
181 pub style: TreeViewStyle,
183 pub behavior: TreeViewBehavior,
185 #[allow(clippy::type_complexity)]
187 pub measure_node: Option<&'a dyn Fn(&TreeNode<Id, T>) -> usize>,
188 pub column_padding: usize,
190 pub on_select: fn(&Id) -> A,
192 pub on_toggle: fn(&Id, bool) -> A,
194 pub render_node: &'a dyn Fn(TreeNodeRender<'_, Id, T>) -> Line<'static>,
196}
197
198#[derive(Clone)]
199struct FlatNode<'a, Id, T> {
200 node: &'a TreeNode<Id, T>,
201 depth: usize,
202 parent_index: Option<usize>,
203 has_children: bool,
204 is_expanded: bool,
205 is_last: bool,
206 branch_mask: Vec<bool>,
207}
208
209pub struct TreeView<Id> {
211 scroll_offset: usize,
212 _marker: PhantomData<Id>,
213}
214
215impl<Id> Default for TreeView<Id> {
216 fn default() -> Self {
217 Self {
218 scroll_offset: 0,
219 _marker: PhantomData,
220 }
221 }
222}
223
224impl<Id> TreeView<Id> {
225 pub fn new() -> Self {
227 Self::default()
228 }
229
230 fn ensure_visible(&mut self, selected: usize, viewport_height: usize) {
231 if viewport_height == 0 {
232 return;
233 }
234
235 if selected < self.scroll_offset {
236 self.scroll_offset = selected;
237 } else if selected >= self.scroll_offset + viewport_height {
238 self.scroll_offset = selected.saturating_sub(viewport_height - 1);
239 }
240 }
241}
242
243impl<Id> TreeView<Id> {
244 fn flatten_visible<'a, T>(
245 nodes: &'a [TreeNode<Id, T>],
246 expanded: &HashSet<Id>,
247 ) -> Vec<FlatNode<'a, Id, T>>
248 where
249 Id: Clone + Eq + Hash,
250 {
251 fn walk<'a, Id, T>(
252 nodes: &'a [TreeNode<Id, T>],
253 expanded: &HashSet<Id>,
254 depth: usize,
255 parent_index: Option<usize>,
256 branch_mask: Vec<bool>,
257 out: &mut Vec<FlatNode<'a, Id, T>>,
258 ) where
259 Id: Clone + Eq + Hash,
260 {
261 for (idx, node) in nodes.iter().enumerate() {
262 let is_last = idx + 1 == nodes.len();
263 let has_children = !node.children.is_empty();
264 let is_expanded = has_children && expanded.contains(&node.id);
265 let current_index = out.len();
266
267 out.push(FlatNode {
268 node,
269 depth,
270 parent_index,
271 has_children,
272 is_expanded,
273 is_last,
274 branch_mask: branch_mask.clone(),
275 });
276
277 if has_children && is_expanded {
278 let mut next_mask = branch_mask.clone();
279 next_mask.push(!is_last);
280 walk(
281 &node.children,
282 expanded,
283 depth + 1,
284 Some(current_index),
285 next_mask,
286 out,
287 );
288 }
289 }
290 }
291
292 let mut out = Vec::new();
293 walk(nodes, expanded, 0, None, Vec::new(), &mut out);
294 out
295 }
296
297 fn marker_prefix(marker: Option<&'static str>, is_selected: bool) -> String {
298 let Some(marker) = marker else {
299 return String::new();
300 };
301 if is_selected {
302 marker.to_string()
303 } else {
304 " ".repeat(marker.chars().count())
305 }
306 }
307
308 fn caret_prefix(
309 depth: usize,
310 indent_width: usize,
311 has_children: bool,
312 is_expanded: bool,
313 ) -> (String, String) {
314 let connector = " ".repeat(depth.saturating_mul(indent_width));
315 let caret = if has_children {
316 if is_expanded {
317 "▾ "
318 } else {
319 "▸ "
320 }
321 } else {
322 " "
323 };
324 (connector, caret.to_string())
325 }
326
327 fn branch_prefix(
328 branch_mask: &[bool],
329 indent_width: usize,
330 is_last: bool,
331 has_children: bool,
332 is_expanded: bool,
333 ) -> (String, String) {
334 let width = indent_width.max(2);
335 let mut connector = String::new();
336 for has_branch in branch_mask {
337 if *has_branch {
338 connector.push('│');
339 connector.push_str(&" ".repeat(width.saturating_sub(1)));
340 } else {
341 connector.push_str(&" ".repeat(width));
342 }
343 }
344
345 connector.push(if is_last { '└' } else { '├' });
346 connector.push_str(&"─".repeat(width.saturating_sub(1)));
347
348 let caret = if has_children {
349 if is_expanded {
350 "▾ "
351 } else {
352 "▸ "
353 }
354 } else {
355 " "
356 };
357
358 (connector, caret.to_string())
359 }
360
361 fn build_prefix<T>(style: &TreeViewStyle, node: &FlatNode<'_, Id, T>) -> (String, String) {
362 match style.branches.mode {
363 TreeBranchMode::Caret => Self::caret_prefix(
364 node.depth,
365 style.branches.indent_width,
366 node.has_children,
367 node.is_expanded,
368 ),
369 TreeBranchMode::Branch => Self::branch_prefix(
370 &node.branch_mask,
371 style.branches.indent_width,
372 node.is_last,
373 node.has_children,
374 node.is_expanded,
375 ),
376 }
377 }
378
379 fn available_width(width: usize, prefix_len: usize, marker_len: usize) -> usize {
380 width.saturating_sub(prefix_len).saturating_sub(marker_len)
381 }
382}
383
384impl<Id, A> Component<A> for TreeView<Id>
385where
386 Id: Clone + Eq + Hash + 'static,
387{
388 type Props<'a> = TreeViewProps<'a, Id, String, A>;
389
390 fn handle_event(
391 &mut self,
392 event: &EventKind,
393 props: Self::Props<'_>,
394 ) -> impl IntoIterator<Item = A> {
395 if !props.is_focused {
396 return None;
397 }
398
399 let visible = Self::flatten_visible(props.nodes, props.expanded_ids);
400 if visible.is_empty() {
401 return None;
402 }
403
404 let selected_idx = props
405 .selected_id
406 .and_then(|id| visible.iter().position(|n| &n.node.id == id));
407 let has_selection = selected_idx.is_some();
408 let current_idx = selected_idx.unwrap_or(0);
409 let last_idx = visible.len().saturating_sub(1);
410
411 let move_selection = |idx: usize| Some((props.on_select)(&visible[idx].node.id));
412 let toggle_node =
413 |idx: usize, expand: bool| Some((props.on_toggle)(&visible[idx].node.id, expand));
414
415 match event {
416 EventKind::Key(key) => match key.code {
417 KeyCode::Char('j') | KeyCode::Down => {
418 if !has_selection {
419 return move_selection(0);
420 }
421 let next = if props.behavior.wrap_navigation && current_idx == last_idx {
422 0
423 } else {
424 (current_idx + 1).min(last_idx)
425 };
426 if next != current_idx {
427 move_selection(next)
428 } else {
429 None
430 }
431 }
432 KeyCode::Char('k') | KeyCode::Up => {
433 if !has_selection {
434 return move_selection(last_idx);
435 }
436 let next = if props.behavior.wrap_navigation && current_idx == 0 {
437 last_idx
438 } else {
439 current_idx.saturating_sub(1)
440 };
441 if next != current_idx {
442 move_selection(next)
443 } else {
444 None
445 }
446 }
447 KeyCode::Char('g') | KeyCode::Home => {
448 if current_idx != 0 || !has_selection {
449 move_selection(0)
450 } else {
451 None
452 }
453 }
454 KeyCode::Char('G') | KeyCode::End => {
455 if current_idx != last_idx || !has_selection {
456 move_selection(last_idx)
457 } else {
458 None
459 }
460 }
461 KeyCode::Left => {
462 let current = &visible[current_idx];
463 if current.has_children && current.is_expanded {
464 toggle_node(current_idx, false)
465 } else if let Some(parent_idx) = current.parent_index {
466 move_selection(parent_idx)
467 } else {
468 None
469 }
470 }
471 KeyCode::Right => {
472 let current = &visible[current_idx];
473 if current.has_children && !current.is_expanded {
474 toggle_node(current_idx, true)
475 } else if current.has_children && current.is_expanded {
476 let child_idx = current_idx + 1;
477 if child_idx < visible.len()
478 && visible[child_idx].parent_index == Some(current_idx)
479 {
480 move_selection(child_idx)
481 } else {
482 None
483 }
484 } else {
485 None
486 }
487 }
488 KeyCode::Enter => {
489 let current = &visible[current_idx];
490 if props.behavior.enter_toggles && current.has_children {
491 toggle_node(current_idx, !current.is_expanded)
492 } else {
493 move_selection(current_idx)
494 }
495 }
496 KeyCode::Char(' ') => {
497 let current = &visible[current_idx];
498 if props.behavior.space_toggles && current.has_children {
499 toggle_node(current_idx, !current.is_expanded)
500 } else {
501 None
502 }
503 }
504 _ => None,
505 },
506 EventKind::Scroll { delta, .. } => {
507 if *delta == 0 {
508 None
509 } else if *delta > 0 {
510 if !has_selection {
511 move_selection(last_idx)
512 } else if current_idx > 0 {
513 move_selection(current_idx - 1)
514 } else {
515 None
516 }
517 } else if !has_selection {
518 move_selection(0)
519 } else if current_idx < last_idx {
520 move_selection(current_idx + 1)
521 } else {
522 None
523 }
524 }
525 _ => None,
526 }
527 }
528
529 fn render(&mut self, frame: &mut Frame, area: Rect, props: Self::Props<'_>) {
530 let style = &props.style;
531
532 if let Some(bg) = style.base.bg {
533 for y in area.y..area.y.saturating_add(area.height) {
534 for x in area.x..area.x.saturating_add(area.width) {
535 frame.buffer_mut()[(x, y)].set_bg(bg);
536 frame.buffer_mut()[(x, y)].set_symbol(" ");
537 }
538 }
539 }
540
541 let content_area = Rect {
542 x: area.x + style.base.padding.left,
543 y: area.y + style.base.padding.top,
544 width: area.width.saturating_sub(style.base.padding.horizontal()),
545 height: area.height.saturating_sub(style.base.padding.vertical()),
546 };
547
548 let mut inner_area = content_area;
549 if let Some(border) = &style.base.border {
550 let block = Block::default()
551 .borders(border.borders)
552 .border_style(border.style_for_focus(props.is_focused));
553 inner_area = block.inner(content_area);
554 frame.render_widget(block, content_area);
555 }
556
557 let viewport_height = inner_area.height as usize;
558 let visible = Self::flatten_visible(props.nodes, props.expanded_ids);
559 let selected_idx = props
560 .selected_id
561 .and_then(|id| visible.iter().position(|n| &n.node.id == id));
562 let selected_render_idx = selected_idx.unwrap_or(0);
563
564 if let Some(selected_idx) = selected_idx {
565 if viewport_height > 0 {
566 self.ensure_visible(selected_idx, viewport_height);
567 }
568 }
569
570 if viewport_height > 0 {
571 let max_offset = visible.len().saturating_sub(viewport_height);
572 self.scroll_offset = self.scroll_offset.min(max_offset);
573 }
574
575 let show_scrollbar = props.behavior.show_scrollbar
576 && viewport_height > 0
577 && visible.len() > viewport_height
578 && inner_area.width > 1;
579 let mut list_area = inner_area;
580 let scrollbar_area = if show_scrollbar {
581 let scrollbar_area = Rect {
582 x: inner_area.x + inner_area.width.saturating_sub(1),
583 width: 1,
584 ..inner_area
585 };
586 list_area.width = list_area.width.saturating_sub(1);
587 Some(scrollbar_area)
588 } else {
589 None
590 };
591
592 let marker_len = if style.selection.disabled {
593 0
594 } else {
595 style
596 .selection
597 .marker
598 .map(|marker| marker.chars().count())
599 .unwrap_or(0)
600 };
601
602 let row_width = list_area.width as usize;
603 let max_tree_width = visible
604 .iter()
605 .map(|node| {
606 let (connector_prefix, caret_prefix) = Self::build_prefix(style, node);
607 let prefix_len = connector_prefix.chars().count() + caret_prefix.chars().count();
608 let leading_width = prefix_len + marker_len;
609 let available_width = Self::available_width(row_width, prefix_len, marker_len);
610 let content_width = if let Some(measure_node) = props.measure_node {
611 measure_node(node.node)
612 } else {
613 let line = (props.render_node)(TreeNodeRender {
614 node: node.node,
615 depth: node.depth,
616 has_children: node.has_children,
617 is_expanded: node.is_expanded,
618 is_selected: false,
619 available_width,
620 leading_width,
621 row_width,
622 tree_column_width: available_width,
623 });
624 line.width()
625 };
626 leading_width + content_width
627 })
628 .max()
629 .unwrap_or(0)
630 .saturating_add(props.column_padding)
631 .min(row_width.saturating_sub(1).max(1));
632
633 let items: Vec<ListItem> = visible
634 .iter()
635 .enumerate()
636 .map(|(idx, node)| {
637 let is_selected = selected_idx == Some(idx);
638 let (connector_prefix, caret_prefix) = Self::build_prefix(style, node);
639 let prefix_len = connector_prefix.chars().count() + caret_prefix.chars().count();
640 let available_width = Self::available_width(row_width, prefix_len, marker_len);
641 let leading_width = prefix_len + marker_len;
642 let tree_column_width = max_tree_width
643 .saturating_sub(leading_width)
644 .min(available_width);
645
646 let content_line = (props.render_node)(TreeNodeRender {
647 node: node.node,
648 depth: node.depth,
649 has_children: node.has_children,
650 is_expanded: node.is_expanded,
651 is_selected,
652 available_width,
653 leading_width,
654 row_width,
655 tree_column_width,
656 });
657
658 let mut spans = Vec::new();
659 if !style.selection.disabled {
660 let marker_prefix = Self::marker_prefix(style.selection.marker, is_selected);
661 if !marker_prefix.is_empty() {
662 spans.push(Span::raw(marker_prefix));
663 }
664 }
665 if !connector_prefix.is_empty() {
666 spans.push(Span::styled(
667 connector_prefix,
668 style.branches.connector_style,
669 ));
670 }
671 if !caret_prefix.is_empty() {
672 spans.push(Span::styled(caret_prefix, style.branches.caret_style));
673 }
674 spans.extend(content_line.spans.iter().cloned());
675 let display_line = Line::from(spans);
676
677 if style.selection.disabled {
678 ListItem::new(display_line)
679 } else {
680 let item_style = if is_selected {
681 style.selection.style.unwrap_or_default()
682 } else {
683 let mut s = Style::default();
684 if let Some(fg) = style.base.fg {
685 s = s.fg(fg);
686 }
687 s
688 };
689 ListItem::new(display_line).style(item_style)
690 }
691 })
692 .collect();
693
694 let highlight_style = if style.selection.disabled {
695 Style::default()
696 } else {
697 style.selection.style.unwrap_or_default()
698 };
699 let list = List::new(items).highlight_style(highlight_style);
700
701 let selected = if visible.is_empty() || selected_idx.is_none() {
702 None
703 } else {
704 Some(selected_render_idx)
705 };
706 let mut state = ListState::default().with_selected(selected);
707 *state.offset_mut() = self.scroll_offset;
708
709 frame.render_stateful_widget(list, list_area, &mut state);
710
711 if let Some(scrollbar_area) = scrollbar_area {
712 let scrollbar = style.scrollbar.build(ScrollbarOrientation::VerticalRight);
713 let scrollbar_len = visible
714 .len()
715 .saturating_sub(viewport_height)
716 .saturating_add(1);
717 let mut scrollbar_state = ScrollbarState::new(scrollbar_len)
718 .position(self.scroll_offset)
719 .viewport_content_length(viewport_height.max(1));
720 frame.render_stateful_widget(scrollbar, scrollbar_area, &mut scrollbar_state);
721 }
722 }
723}
724
725#[cfg(test)]
726mod tests {
727 use super::*;
728 use tui_dispatch_core::testing::key;
729
730 #[derive(Debug, Clone, PartialEq)]
731 enum TestAction {
732 Select(String),
733 Toggle(String, bool),
734 }
735
736 fn select_action(id: &str) -> TestAction {
737 TestAction::Select(id.to_owned())
738 }
739
740 fn toggle_action(id: &str, expanded: bool) -> TestAction {
741 TestAction::Toggle(id.to_owned(), expanded)
742 }
743
744 fn render_node(ctx: TreeNodeRender<'_, String, String>) -> Line<'static> {
745 Line::raw(ctx.node.value.clone())
746 }
747
748 fn sample_tree() -> Vec<TreeNode<String, String>> {
749 vec![TreeNode::with_children(
750 "root".to_string(),
751 "Root".to_string(),
752 vec![TreeNode::new("child".to_string(), "Child".to_string())],
753 )]
754 }
755
756 fn props<'a>(
757 nodes: &'a [TreeNode<String, String>],
758 selected: Option<&'a String>,
759 expanded: &'a HashSet<String>,
760 ) -> TreeViewProps<'a, String, String, TestAction> {
761 TreeViewProps {
762 nodes,
763 selected_id: selected,
764 expanded_ids: expanded,
765 is_focused: true,
766 style: TreeViewStyle::borderless(),
767 behavior: TreeViewBehavior::default(),
768 measure_node: None,
769 column_padding: 0,
770 on_select: |id| select_action(id),
771 on_toggle: |id, expanded| toggle_action(id, expanded),
772 render_node: &render_node,
773 }
774 }
775
776 #[test]
777 fn test_expand_on_right() {
778 let mut view: TreeView<String> = TreeView::new();
779 let nodes = sample_tree();
780 let expanded = HashSet::new();
781
782 let actions: Vec<_> = view
783 .handle_event(
784 &EventKind::Key(key("right")),
785 props(&nodes, None, &expanded),
786 )
787 .into_iter()
788 .collect();
789
790 assert_eq!(actions, vec![TestAction::Toggle("root".into(), true)]);
791 }
792
793 #[test]
794 fn test_collapse_on_left() {
795 let mut view: TreeView<String> = TreeView::new();
796 let nodes = sample_tree();
797 let mut expanded = HashSet::new();
798 expanded.insert("root".to_string());
799 let selected = Some(&nodes[0].id);
800
801 let actions: Vec<_> = view
802 .handle_event(
803 &EventKind::Key(key("left")),
804 props(&nodes, selected, &expanded),
805 )
806 .into_iter()
807 .collect();
808
809 assert_eq!(actions, vec![TestAction::Toggle("root".into(), false)]);
810 }
811
812 #[test]
813 fn test_select_child_with_down() {
814 let mut view: TreeView<String> = TreeView::new();
815 let nodes = sample_tree();
816 let mut expanded = HashSet::new();
817 expanded.insert("root".to_string());
818 let selected = Some(&nodes[0].id);
819
820 let actions: Vec<_> = view
821 .handle_event(
822 &EventKind::Key(key("down")),
823 props(&nodes, selected, &expanded),
824 )
825 .into_iter()
826 .collect();
827
828 assert_eq!(actions, vec![TestAction::Select("child".into())]);
829 }
830
831 #[test]
832 fn test_select_parent_with_left() {
833 let mut view: TreeView<String> = TreeView::new();
834 let nodes = sample_tree();
835 let mut expanded = HashSet::new();
836 expanded.insert("root".to_string());
837 let selected = Some(&nodes[0].children[0].id);
838
839 let actions: Vec<_> = view
840 .handle_event(
841 &EventKind::Key(key("left")),
842 props(&nodes, selected, &expanded),
843 )
844 .into_iter()
845 .collect();
846
847 assert_eq!(actions, vec![TestAction::Select("root".into())]);
848 }
849}