1use std::collections::HashSet;
7
8use ratatui_core::buffer::Buffer;
9use ratatui_core::layout::Rect;
10use ratatui_core::style::Style;
11use ratatui_core::widgets::{StatefulWidget, Widget};
12pub use ratatui_widgets::block::Block;
13pub use ratatui_widgets::scrollbar::{Scrollbar, ScrollbarState};
14use unicode_width::UnicodeWidthStr as _;
15
16pub use crate::flatten::Flattened;
17pub use crate::tree_item::TreeItem;
18pub use crate::tree_state::TreeState;
19
20mod flatten;
21mod tree_item;
22mod tree_state;
23
24#[must_use]
54#[derive(Debug, Clone)]
55pub struct Tree<'a, Identifier> {
56 items: &'a [TreeItem<'a, Identifier>],
57
58 block: Option<Block<'a>>,
59 scrollbar: Option<Scrollbar<'a>>,
60 style: Style,
62
63 highlight_style: Style,
65 highlight_symbol: &'a str,
67
68 node_closed_symbol: &'a str,
70 node_open_symbol: &'a str,
72 node_no_children_symbol: &'a str,
74}
75
76impl<'a, Identifier> Tree<'a, Identifier>
77where
78 Identifier: Clone + PartialEq + Eq + core::hash::Hash,
79{
80 pub fn new(items: &'a [TreeItem<'a, Identifier>]) -> std::io::Result<Self> {
86 let identifiers = items
87 .iter()
88 .map(|item| &item.identifier)
89 .collect::<HashSet<_>>();
90 if identifiers.len() != items.len() {
91 return Err(std::io::Error::new(
92 std::io::ErrorKind::AlreadyExists,
93 "The items contain duplicate identifiers",
94 ));
95 }
96
97 Ok(Self {
98 items,
99 block: None,
100 scrollbar: None,
101 style: Style::new(),
102 highlight_style: Style::new(),
103 highlight_symbol: "",
104 node_closed_symbol: "\u{25b6} ", node_open_symbol: "\u{25bc} ", node_no_children_symbol: " ",
107 })
108 }
109
110 pub fn block(mut self, block: Block<'a>) -> Self {
111 self.block = Some(block);
112 self
113 }
114
115 pub const fn experimental_scrollbar(mut self, scrollbar: Option<Scrollbar<'a>>) -> Self {
121 self.scrollbar = scrollbar;
122 self
123 }
124
125 pub const fn style(mut self, style: Style) -> Self {
126 self.style = style;
127 self
128 }
129
130 pub const fn highlight_style(mut self, style: Style) -> Self {
131 self.highlight_style = style;
132 self
133 }
134
135 pub const fn highlight_symbol(mut self, highlight_symbol: &'a str) -> Self {
136 self.highlight_symbol = highlight_symbol;
137 self
138 }
139
140 pub const fn node_closed_symbol(mut self, symbol: &'a str) -> Self {
141 self.node_closed_symbol = symbol;
142 self
143 }
144
145 pub const fn node_open_symbol(mut self, symbol: &'a str) -> Self {
146 self.node_open_symbol = symbol;
147 self
148 }
149
150 pub const fn node_no_children_symbol(mut self, symbol: &'a str) -> Self {
151 self.node_no_children_symbol = symbol;
152 self
153 }
154}
155
156#[test]
157#[should_panic = "duplicate identifiers"]
158fn tree_new_errors_with_duplicate_identifiers() {
159 let item = TreeItem::new_leaf("same", "text");
160 let another = item.clone();
161 let items = [item, another];
162 let _: Tree<_> = Tree::new(&items).unwrap();
163}
164
165impl<Identifier> StatefulWidget for Tree<'_, Identifier>
166where
167 Identifier: Clone + PartialEq + Eq + core::hash::Hash,
168{
169 type State = TreeState<Identifier>;
170
171 #[expect(clippy::too_many_lines)]
172 fn render(self, full_area: Rect, buf: &mut Buffer, state: &mut Self::State) {
173 buf.set_style(full_area, self.style);
174
175 let area = self.block.map_or(full_area, |block| {
177 let inner_area = block.inner(full_area);
178 block.render(full_area, buf);
179 inner_area
180 });
181
182 state.last_area = area;
183 state.last_rendered_identifiers.clear();
184 if area.width < 1 || area.height < 1 {
185 return;
186 }
187
188 let visible = state.flatten(self.items);
189 state.last_biggest_index = visible.len().saturating_sub(1);
190 if visible.is_empty() {
191 return;
192 }
193 let available_height = area.height as usize;
194
195 let ensure_index_in_view =
196 if state.ensure_selected_in_view_on_next_render && !state.selected.is_empty() {
197 visible
198 .iter()
199 .position(|flattened| flattened.identifier == state.selected)
200 } else {
201 None
202 };
203
204 let mut start = state.offset.min(state.last_biggest_index);
206
207 if let Some(ensure_index_in_view) = ensure_index_in_view {
208 start = start.min(ensure_index_in_view);
209 }
210
211 let mut end = start;
212 let mut height = 0;
213 for item_height in visible
214 .iter()
215 .skip(start)
216 .map(|flattened| flattened.item.height())
217 {
218 if height + item_height > available_height {
219 break;
220 }
221 height += item_height;
222 end += 1;
223 }
224
225 if let Some(ensure_index_in_view) = ensure_index_in_view {
226 while ensure_index_in_view >= end {
227 height += visible[end].item.height();
228 end += 1;
229 while height > available_height {
230 height = height.saturating_sub(visible[start].item.height());
231 start += 1;
232 }
233 }
234 }
235
236 state.offset = start;
237 state.ensure_selected_in_view_on_next_render = false;
238
239 if let Some(scrollbar) = self.scrollbar {
240 let mut scrollbar_state = ScrollbarState::new(visible.len().saturating_sub(height))
241 .position(start)
242 .viewport_content_length(height);
243 let scrollbar_area = Rect {
244 y: area.y,
246 height: area.height,
247 x: full_area.x,
249 width: full_area.width,
250 };
251 scrollbar.render(scrollbar_area, buf, &mut scrollbar_state);
252 }
253
254 let blank_symbol = " ".repeat(self.highlight_symbol.width());
255
256 let mut current_height = 0;
257 let has_selection = !state.selected.is_empty();
258 #[expect(clippy::cast_possible_truncation)]
259 for flattened in visible.iter().skip(state.offset).take(end - start) {
260 let Flattened { identifier, item } = flattened;
261
262 let x = area.x;
263 let y = area.y + current_height;
264 let height = item.height() as u16;
265 current_height += height;
266
267 let area = Rect {
268 x,
269 y,
270 width: area.width,
271 height,
272 };
273
274 let text = &item.text;
275 let item_style = text.style;
276
277 let is_selected = state.selected == *identifier;
278 let after_highlight_symbol_x = if has_selection {
279 let symbol = if is_selected {
280 self.highlight_symbol
281 } else {
282 &blank_symbol
283 };
284 let (x, _) = buf.set_stringn(x, y, symbol, area.width as usize, item_style);
285 x
286 } else {
287 x
288 };
289
290 let after_depth_x = {
291 let indent_width = flattened.depth() * 2;
292 let (after_indent_x, _) = buf.set_stringn(
293 after_highlight_symbol_x,
294 y,
295 " ".repeat(indent_width),
296 indent_width,
297 item_style,
298 );
299 let symbol = if item.children.is_empty() {
300 self.node_no_children_symbol
301 } else if state.opened.contains(identifier) {
302 self.node_open_symbol
303 } else {
304 self.node_closed_symbol
305 };
306 let max_width = area.width.saturating_sub(after_indent_x - x);
307 let (x, _) =
308 buf.set_stringn(after_indent_x, y, symbol, max_width as usize, item_style);
309 x
310 };
311
312 let text_area = Rect {
313 x: after_depth_x,
314 width: area.width.saturating_sub(after_depth_x - x),
315 ..area
316 };
317 text.render(text_area, buf);
318
319 if is_selected {
320 buf.set_style(area, self.highlight_style);
321 }
322
323 state
324 .last_rendered_identifiers
325 .push((area.y, identifier.clone()));
326 }
327 state.last_identifiers = visible
328 .into_iter()
329 .map(|flattened| flattened.identifier)
330 .collect();
331 }
332}
333
334impl<Identifier> Widget for Tree<'_, Identifier>
335where
336 Identifier: Clone + Eq + core::hash::Hash,
337{
338 fn render(self, area: Rect, buf: &mut Buffer) {
339 let mut state = TreeState::default();
340 StatefulWidget::render(self, area, buf, &mut state);
341 }
342}
343
344#[cfg(test)]
345mod render_tests {
346 use super::*;
347
348 #[must_use]
349 #[track_caller]
350 fn render(width: u16, height: u16, state: &mut TreeState<&'static str>) -> Buffer {
351 let items = TreeItem::example();
352 let tree = Tree::new(&items).unwrap();
353 let area = Rect::new(0, 0, width, height);
354 let mut buffer = Buffer::empty(area);
355 StatefulWidget::render(tree, area, &mut buffer, state);
356 buffer
357 }
358
359 #[test]
360 fn does_not_panic() {
361 _ = render(0, 0, &mut TreeState::default());
362 _ = render(10, 0, &mut TreeState::default());
363 _ = render(0, 10, &mut TreeState::default());
364 _ = render(10, 10, &mut TreeState::default());
365 }
366
367 #[test]
368 fn nothing_open() {
369 let buffer = render(10, 4, &mut TreeState::default());
370 #[rustfmt::skip]
371 let expected = Buffer::with_lines([
372 " Alfa ",
373 "▶ Bravo ",
374 " Hotel ",
375 " ",
376 ]);
377 assert_eq!(buffer, expected);
378 }
379
380 #[test]
381 fn depth_one() {
382 let mut state = TreeState::default();
383 state.open(vec!["b"]);
384 let buffer = render(13, 7, &mut state);
385 let expected = Buffer::with_lines([
386 " Alfa ",
387 "▼ Bravo ",
388 " Charlie ",
389 " ▶ Delta ",
390 " Golf ",
391 " Hotel ",
392 " ",
393 ]);
394 assert_eq!(buffer, expected);
395 }
396
397 #[test]
398 fn depth_two() {
399 let mut state = TreeState::default();
400 state.open(vec!["b"]);
401 state.open(vec!["b", "d"]);
402 let buffer = render(15, 9, &mut state);
403 let expected = Buffer::with_lines([
404 " Alfa ",
405 "▼ Bravo ",
406 " Charlie ",
407 " ▼ Delta ",
408 " Echo ",
409 " Foxtrot ",
410 " Golf ",
411 " Hotel ",
412 " ",
413 ]);
414 assert_eq!(buffer, expected);
415 }
416}