1use std::collections::HashSet;
9
10use ratatui::buffer::Buffer;
11use ratatui::layout::Rect;
12use ratatui::style::Style;
13use ratatui::widgets::{Block, Scrollbar, ScrollbarState, StatefulWidget, Widget};
14use unicode_width::UnicodeWidthStr;
15
16pub use crate::flatten::Flattened;
17pub use crate::tree_item::TreeItem;
18pub use crate::tree_state::TreeState;
19
20mod flatten;
21mod tree_item;
22mod tree_state;
23
24#[must_use]
54#[derive(Debug, Clone)]
55pub struct Tree<'a, Identifier> {
56 items: &'a [TreeItem<'a, Identifier>],
57
58 block: Option<Block<'a>>,
59 scrollbar: Option<Scrollbar<'a>>,
60 style: Style,
62
63 highlight_style: Style,
65 highlight_symbol: &'a str,
67
68 node_closed_symbol: &'a str,
70 node_open_symbol: &'a str,
72 node_no_children_symbol: &'a str,
74}
75
76impl<'a, Identifier> Tree<'a, Identifier>
77where
78 Identifier: Clone + PartialEq + Eq + core::hash::Hash,
79{
80 pub fn new(items: &'a [TreeItem<'a, Identifier>]) -> std::io::Result<Self> {
86 let identifiers = items
87 .iter()
88 .map(|item| &item.identifier)
89 .collect::<HashSet<_>>();
90 if identifiers.len() != items.len() {
91 return Err(std::io::Error::new(
92 std::io::ErrorKind::AlreadyExists,
93 "The items contain duplicate identifiers",
94 ));
95 }
96
97 Ok(Self {
98 items,
99 block: None,
100 scrollbar: None,
101 style: Style::new(),
102 highlight_style: Style::new(),
103 highlight_symbol: "",
104 node_closed_symbol: "\u{25b6} ", node_open_symbol: "\u{25bc} ", node_no_children_symbol: " ",
107 })
108 }
109
110 #[allow(clippy::missing_const_for_fn)]
111 pub fn block(mut self, block: Block<'a>) -> Self {
112 self.block = Some(block);
113 self
114 }
115
116 pub const fn experimental_scrollbar(mut self, scrollbar: Option<Scrollbar<'a>>) -> Self {
122 self.scrollbar = scrollbar;
123 self
124 }
125
126 pub const fn style(mut self, style: Style) -> Self {
127 self.style = style;
128 self
129 }
130
131 pub const fn highlight_style(mut self, style: Style) -> Self {
132 self.highlight_style = style;
133 self
134 }
135
136 pub const fn highlight_symbol(mut self, highlight_symbol: &'a str) -> Self {
137 self.highlight_symbol = highlight_symbol;
138 self
139 }
140
141 pub const fn node_closed_symbol(mut self, symbol: &'a str) -> Self {
142 self.node_closed_symbol = symbol;
143 self
144 }
145
146 pub const fn node_open_symbol(mut self, symbol: &'a str) -> Self {
147 self.node_open_symbol = symbol;
148 self
149 }
150
151 pub const fn node_no_children_symbol(mut self, symbol: &'a str) -> Self {
152 self.node_no_children_symbol = symbol;
153 self
154 }
155}
156
157#[test]
158#[should_panic = "duplicate identifiers"]
159fn tree_new_errors_with_duplicate_identifiers() {
160 let item = TreeItem::new_leaf("same", "text");
161 let another = item.clone();
162 let items = [item, another];
163 let _: Tree<_> = Tree::new(&items).unwrap();
164}
165
166impl<Identifier> StatefulWidget for Tree<'_, Identifier>
167where
168 Identifier: Clone + PartialEq + Eq + core::hash::Hash,
169{
170 type State = TreeState<Identifier>;
171
172 #[allow(clippy::too_many_lines)]
173 fn render(self, full_area: Rect, buf: &mut Buffer, state: &mut Self::State) {
174 buf.set_style(full_area, self.style);
175
176 let area = self.block.map_or(full_area, |block| {
178 let inner_area = block.inner(full_area);
179 block.render(full_area, buf);
180 inner_area
181 });
182
183 state.last_area = area;
184 state.last_rendered_identifiers.clear();
185 if area.width < 1 || area.height < 1 {
186 return;
187 }
188
189 let visible = state.flatten(self.items);
190 state.last_biggest_index = visible.len().saturating_sub(1);
191 if visible.is_empty() {
192 return;
193 }
194 let available_height = area.height as usize;
195
196 let ensure_index_in_view =
197 if state.ensure_selected_in_view_on_next_render && !state.selected.is_empty() {
198 visible
199 .iter()
200 .position(|flattened| flattened.identifier == state.selected)
201 } else {
202 None
203 };
204
205 let mut start = state.offset.min(state.last_biggest_index);
207
208 if let Some(ensure_index_in_view) = ensure_index_in_view {
209 start = start.min(ensure_index_in_view);
210 }
211
212 let mut end = start;
213 let mut height = 0;
214 for item_height in visible
215 .iter()
216 .skip(start)
217 .map(|flattened| flattened.item.height())
218 {
219 if height + item_height > available_height {
220 break;
221 }
222 height += item_height;
223 end += 1;
224 }
225
226 if let Some(ensure_index_in_view) = ensure_index_in_view {
227 while ensure_index_in_view >= end {
228 height += visible[end].item.height();
229 end += 1;
230 while height > available_height {
231 height = height.saturating_sub(visible[start].item.height());
232 start += 1;
233 }
234 }
235 }
236
237 state.offset = start;
238 state.ensure_selected_in_view_on_next_render = false;
239
240 if let Some(scrollbar) = self.scrollbar {
241 let mut scrollbar_state = ScrollbarState::new(visible.len().saturating_sub(height))
242 .position(start)
243 .viewport_content_length(height);
244 let scrollbar_area = Rect {
245 y: area.y,
247 height: area.height,
248 x: full_area.x,
250 width: full_area.width,
251 };
252 scrollbar.render(scrollbar_area, buf, &mut scrollbar_state);
253 }
254
255 let blank_symbol = " ".repeat(self.highlight_symbol.width());
256
257 let mut current_height = 0;
258 let has_selection = !state.selected.is_empty();
259 #[allow(clippy::cast_possible_truncation)]
260 for flattened in visible.iter().skip(state.offset).take(end - start) {
261 let Flattened { identifier, item } = flattened;
262
263 let x = area.x;
264 let y = area.y + current_height;
265 let height = item.height() as u16;
266 current_height += height;
267
268 let area = Rect {
269 x,
270 y,
271 width: area.width,
272 height,
273 };
274
275 let text = &item.text;
276 let item_style = text.style;
277
278 let is_selected = state.selected == *identifier;
279 let after_highlight_symbol_x = if has_selection {
280 let symbol = if is_selected {
281 self.highlight_symbol
282 } else {
283 &blank_symbol
284 };
285 let (x, _) = buf.set_stringn(x, y, symbol, area.width as usize, item_style);
286 x
287 } else {
288 x
289 };
290
291 let after_depth_x = {
292 let indent_width = flattened.depth() * 2;
293 let (after_indent_x, _) = buf.set_stringn(
294 after_highlight_symbol_x,
295 y,
296 " ".repeat(indent_width),
297 indent_width,
298 item_style,
299 );
300 let symbol = if item.children.is_empty() {
301 self.node_no_children_symbol
302 } else if state.opened.contains(identifier) {
303 self.node_open_symbol
304 } else {
305 self.node_closed_symbol
306 };
307 let max_width = area.width.saturating_sub(after_indent_x - x);
308 let (x, _) =
309 buf.set_stringn(after_indent_x, y, symbol, max_width as usize, item_style);
310 x
311 };
312
313 let text_area = Rect {
314 x: after_depth_x,
315 width: area.width.saturating_sub(after_depth_x - x),
316 ..area
317 };
318 text.render(text_area, buf);
319
320 if is_selected {
321 buf.set_style(area, self.highlight_style);
322 }
323
324 state
325 .last_rendered_identifiers
326 .push((area.y, identifier.clone()));
327 }
328 state.last_identifiers = visible
329 .into_iter()
330 .map(|flattened| flattened.identifier)
331 .collect();
332 }
333}
334
335impl<Identifier> Widget for Tree<'_, Identifier>
336where
337 Identifier: Clone + Default + Eq + core::hash::Hash,
338{
339 fn render(self, area: Rect, buf: &mut Buffer) {
340 let mut state = TreeState::default();
341 StatefulWidget::render(self, area, buf, &mut state);
342 }
343}
344
345#[cfg(test)]
346mod render_tests {
347 use super::*;
348
349 #[must_use]
350 #[track_caller]
351 fn render(width: u16, height: u16, state: &mut TreeState<&'static str>) -> Buffer {
352 let items = TreeItem::example();
353 let tree = Tree::new(&items).unwrap();
354 let area = Rect::new(0, 0, width, height);
355 let mut buffer = Buffer::empty(area);
356 StatefulWidget::render(tree, area, &mut buffer, state);
357 buffer
358 }
359
360 #[test]
361 fn does_not_panic() {
362 _ = render(0, 0, &mut TreeState::default());
363 _ = render(10, 0, &mut TreeState::default());
364 _ = render(0, 10, &mut TreeState::default());
365 _ = render(10, 10, &mut TreeState::default());
366 }
367
368 #[test]
369 fn nothing_open() {
370 let buffer = render(10, 4, &mut TreeState::default());
371 #[rustfmt::skip]
372 let expected = Buffer::with_lines([
373 " Alfa ",
374 "▶ Bravo ",
375 " Hotel ",
376 " ",
377 ]);
378 assert_eq!(buffer, expected);
379 }
380
381 #[test]
382 fn depth_one() {
383 let mut state = TreeState::default();
384 state.open(vec!["b"]);
385 let buffer = render(13, 7, &mut state);
386 let expected = Buffer::with_lines([
387 " Alfa ",
388 "▼ Bravo ",
389 " Charlie ",
390 " ▶ Delta ",
391 " Golf ",
392 " Hotel ",
393 " ",
394 ]);
395 assert_eq!(buffer, expected);
396 }
397
398 #[test]
399 fn depth_two() {
400 let mut state = TreeState::default();
401 state.open(vec!["b"]);
402 state.open(vec!["b", "d"]);
403 let buffer = render(15, 9, &mut state);
404 let expected = Buffer::with_lines([
405 " Alfa ",
406 "▼ Bravo ",
407 " Charlie ",
408 " ▼ Delta ",
409 " Echo ",
410 " Foxtrot ",
411 " Golf ",
412 " Hotel ",
413 " ",
414 ]);
415 assert_eq!(buffer, expected);
416 }
417}