1use std::collections::HashSet;
9
10use ratatui::buffer::Buffer;
11use ratatui::layout::{Constraint, Rect};
12use ratatui::style::Style;
13use ratatui::widgets::{Block, Row, Scrollbar, ScrollbarState, StatefulWidget, Table, TableState, Widget};
14use unicode_width::UnicodeWidthStr as _;
15
16pub use crate::flatten::Flattened;
17pub use crate::tree_item::TreeItem;
18pub use crate::tree_state::TreeState;
19
20mod flatten;
21mod tree_item;
22mod tree_state;
23
24#[must_use]
54#[derive(Debug, Clone)]
55pub struct Tree<'a, Identifier> {
56 items: &'a [TreeItem<'a, Identifier>],
57
58 table_header: Option<Row<'a>>,
59 table_widths: &'a [Constraint],
60
61 block: Option<Block<'a>>,
62 scrollbar: Option<Scrollbar<'a>>,
63 style: Style,
65
66 highlight_style: Style,
68 highlight_symbol: &'a str,
70
71 node_closed_symbol: &'a str,
73 node_open_symbol: &'a str,
75 node_no_children_symbol: &'a str,
77}
78
79impl<'a, Identifier> Tree<'a, Identifier>
80where
81 Identifier: Clone + PartialEq + Eq + core::hash::Hash,
82{
83 pub fn new(items: &'a [TreeItem<'a, Identifier>]) -> std::io::Result<Self> {
89 let identifiers = items
90 .iter()
91 .map(|item| &item.identifier)
92 .collect::<HashSet<_>>();
93 if identifiers.len() != items.len() {
94 return Err(std::io::Error::new(
95 std::io::ErrorKind::AlreadyExists,
96 "The items contain duplicate identifiers",
97 ));
98 }
99
100 Ok(Self {
101 items,
102 table_header: None,
103 table_widths: &[],
104 block: None,
105 scrollbar: None,
106 style: Style::new(),
107 highlight_style: Style::new(),
108 highlight_symbol: "",
109 node_closed_symbol: "\u{25b6} ", node_open_symbol: "\u{25bc} ", node_no_children_symbol: " ",
112 })
113 }
114
115 #[allow(clippy::missing_const_for_fn)]
116 pub fn block(mut self, block: Block<'a>) -> Self {
117 self.block = Some(block);
118 self
119 }
120
121 pub const fn experimental_scrollbar(mut self, scrollbar: Option<Scrollbar<'a>>) -> Self {
127 self.scrollbar = scrollbar;
128 self
129 }
130
131 pub const fn style(mut self, style: Style) -> Self {
132 self.style = style;
133 self
134 }
135
136 pub const fn highlight_style(mut self, style: Style) -> Self {
137 self.highlight_style = style;
138 self
139 }
140
141 pub const fn highlight_symbol(mut self, highlight_symbol: &'a str) -> Self {
142 self.highlight_symbol = highlight_symbol;
143 self
144 }
145
146 pub const fn node_closed_symbol(mut self, symbol: &'a str) -> Self {
147 self.node_closed_symbol = symbol;
148 self
149 }
150
151 pub const fn node_open_symbol(mut self, symbol: &'a str) -> Self {
152 self.node_open_symbol = symbol;
153 self
154 }
155
156 pub const fn node_no_children_symbol(mut self, symbol: &'a str) -> Self {
157 self.node_no_children_symbol = symbol;
158 self
159 }
160
161 #[must_use]
162 pub fn table_header(mut self, headers: Option<Row<'a>>) -> Self {
163 self.table_header = headers;
164 self
165 }
166
167 #[must_use]
168 pub fn table_widths(mut self, widths: &'a [Constraint]) -> Self {
169 self.table_widths = widths;
170 self
171 }
172}
173
174#[test]
175#[should_panic = "duplicate identifiers"]
176fn tree_new_errors_with_duplicate_identifiers() {
177 let item = TreeItem::new_leaf("same", "text");
178 let another = item.clone();
179 let items = [item, another];
180 let _: Tree<_> = Tree::new(&items).unwrap();
181}
182
183impl<Identifier> StatefulWidget for Tree<'_, Identifier>
184where
185 Identifier: Clone + PartialEq + Eq + core::hash::Hash,
186{
187 type State = TreeState<Identifier>;
188
189 #[allow(clippy::too_many_lines)]
190 fn render(self, full_area: Rect, buf: &mut Buffer, state: &mut Self::State) {
191 buf.set_style(full_area, self.style);
192
193 let area = self.block.map_or(full_area, |block| {
195 let inner_area = block.inner(full_area);
196 block.render(full_area, buf);
197 inner_area
198 });
199
200 let (area, table_area) = if area.width > 24 {
202 let table_area = Rect { width: 24, ..area };
203 let mut area = Rect {
204 x: area.x + 24,
205 width: area.width - 24,
206 ..area
207 };
208 if self.table_header.is_some() {
213 area.y += 1;
214 area.height -= 1;
215 }
216 (area, Some(table_area))
217 } else {
218 (area, None)
219 };
220
221 state.last_area = area;
222 state.last_rendered_identifiers.clear();
223 if area.width < 1 || area.height < 1 {
224 return;
225 }
226
227 let visible = state.flatten(self.items);
228 state.last_biggest_index = visible.len().saturating_sub(1);
229 if visible.is_empty() {
230 return;
231 }
232 let available_height = area.height as usize;
233
234 let ensure_index_in_view =
235 if state.ensure_selected_in_view_on_next_render && !state.selected.is_empty() {
236 visible
237 .iter()
238 .position(|flattened| flattened.identifier == state.selected)
239 } else {
240 None
241 };
242
243 let mut start = state.offset.min(state.last_biggest_index);
245
246 if let Some(ensure_index_in_view) = ensure_index_in_view {
247 start = start.min(ensure_index_in_view);
248 }
249
250 let mut end = start;
251 let mut height = 0;
252 for item_height in visible
253 .iter()
254 .skip(start)
255 .map(|flattened| flattened.item.height())
256 {
257 if height + item_height > available_height {
258 break;
259 }
260 height += item_height;
261 end += 1;
262 }
263
264 if let Some(ensure_index_in_view) = ensure_index_in_view {
265 while ensure_index_in_view >= end {
266 height += visible[end].item.height();
267 end += 1;
268 while height > available_height {
269 height = height.saturating_sub(visible[start].item.height());
270 start += 1;
271 }
272 }
273 }
274
275 state.offset = start;
276 state.ensure_selected_in_view_on_next_render = false;
277
278 if let Some(scrollbar) = self.scrollbar {
279 let mut scrollbar_state = ScrollbarState::new(visible.len().saturating_sub(height))
280 .position(start)
281 .viewport_content_length(height);
282 let scrollbar_area = Rect {
283 y: area.y,
285 height: area.height,
286 x: full_area.x,
288 width: full_area.width,
289 };
290 scrollbar.render(scrollbar_area, buf, &mut scrollbar_state);
291 }
292
293 let blank_symbol = " ".repeat(self.highlight_symbol.width());
294
295 let mut current_height = 0;
296 let has_selection = !state.selected.is_empty();
297
298 if let Some(table_area) = table_area {
299 let mut selection = None;
300
301 let data_rows: Vec<_> = visible
302 .iter()
303 .skip(state.offset)
304 .take(end - start)
305 .enumerate()
306 .map(|(index, item)| {
307 if state.selected == item.identifier {
308 selection = Some(index);
309 }
310 item.item.data.clone()
311 })
312 .collect();
313
314 let mut table = Table::new(data_rows, self.table_widths)
315 .row_highlight_style(self.highlight_style);
316
317 if let Some(headers) = self.table_header {
318 table = table.header(headers);
319 }
320
321 StatefulWidget::render(
322 table,
323 table_area,
324 buf,
325 &mut TableState::default().with_selected(selection),
326 );
327 }
328
329
330 #[allow(clippy::cast_possible_truncation)]
331 for flattened in visible.iter().skip(state.offset).take(end - start) {
332 let Flattened { identifier, item } = flattened;
333
334 let x = area.x;
335 let y = area.y + current_height;
336 let height = item.height() as u16;
337 current_height += height;
338
339 let area = Rect {
340 x,
341 y,
342 width: area.width,
343 height,
344 };
345
346 let text = &item.text;
347 let item_style = text.style;
348
349 let is_selected = state.selected == *identifier;
350 let after_highlight_symbol_x = if has_selection {
351 let symbol = if is_selected {
352 self.highlight_symbol
353 } else {
354 &blank_symbol
355 };
356 let (x, _) = buf.set_stringn(x, y, symbol, area.width as usize, item_style);
357 x
358 } else {
359 x
360 };
361
362 let after_depth_x = {
363 let indent_width = flattened.depth() * 2;
364 let (after_indent_x, _) = buf.set_stringn(
365 after_highlight_symbol_x,
366 y,
367 " ".repeat(indent_width),
368 indent_width,
369 item_style,
370 );
371 let symbol = if item.children.is_empty() {
372 self.node_no_children_symbol
373 } else if state.opened.contains(identifier) {
374 self.node_open_symbol
375 } else {
376 self.node_closed_symbol
377 };
378 let max_width = area.width.saturating_sub(after_indent_x - x);
379 let (x, _) =
380 buf.set_stringn(after_indent_x, y, symbol, max_width as usize, item_style);
381 x
382 };
383
384 let text_area = Rect {
385 x: after_depth_x,
386 width: area.width.saturating_sub(after_depth_x - x),
387 ..area
388 };
389 text.render(text_area, buf);
390
391 if is_selected {
392 buf.set_style(area, self.highlight_style);
393 }
394
395 state
396 .last_rendered_identifiers
397 .push((area.y, identifier.clone()));
398 }
399 state.last_identifiers = visible
400 .into_iter()
401 .map(|flattened| flattened.identifier)
402 .collect();
403 }
404}
405
406impl<Identifier> Widget for Tree<'_, Identifier>
407where
408 Identifier: Clone + Default + Eq + core::hash::Hash,
409{
410 fn render(self, area: Rect, buf: &mut Buffer) {
411 let mut state = TreeState::default();
412 StatefulWidget::render(self, area, buf, &mut state);
413 }
414}
415
416#[cfg(test)]
417mod render_tests {
418 use super::*;
419
420 #[must_use]
421 #[track_caller]
422 fn render(width: u16, height: u16, state: &mut TreeState<&'static str>) -> Buffer {
423 let items = TreeItem::example();
424 let tree = Tree::new(&items).unwrap();
425 let area = Rect::new(0, 0, width, height);
426 let mut buffer = Buffer::empty(area);
427 StatefulWidget::render(tree, area, &mut buffer, state);
428 buffer
429 }
430
431 #[test]
432 fn does_not_panic() {
433 _ = render(0, 0, &mut TreeState::default());
434 _ = render(10, 0, &mut TreeState::default());
435 _ = render(0, 10, &mut TreeState::default());
436 _ = render(10, 10, &mut TreeState::default());
437 }
438
439 #[test]
440 fn nothing_open() {
441 let buffer = render(10, 4, &mut TreeState::default());
442 #[rustfmt::skip]
443 let expected = Buffer::with_lines([
444 " Alfa ",
445 "▶ Bravo ",
446 " Hotel ",
447 " ",
448 ]);
449 assert_eq!(buffer, expected);
450 }
451
452 #[test]
453 fn depth_one() {
454 let mut state = TreeState::default();
455 state.open(vec!["b"]);
456 let buffer = render(13, 7, &mut state);
457 let expected = Buffer::with_lines([
458 " Alfa ",
459 "▼ Bravo ",
460 " Charlie ",
461 " ▶ Delta ",
462 " Golf ",
463 " Hotel ",
464 " ",
465 ]);
466 assert_eq!(buffer, expected);
467 }
468
469 #[test]
470 fn depth_two() {
471 let mut state = TreeState::default();
472 state.open(vec!["b"]);
473 state.open(vec!["b", "d"]);
474 let buffer = render(15, 9, &mut state);
475 let expected = Buffer::with_lines([
476 " Alfa ",
477 "▼ Bravo ",
478 " Charlie ",
479 " ▼ Delta ",
480 " Echo ",
481 " Foxtrot ",
482 " Golf ",
483 " Hotel ",
484 " ",
485 ]);
486 assert_eq!(buffer, expected);
487 }
488}