tree_layout/
lib.rs

1use std::{collections::HashMap, hash::Hash};
2
3pub use shape_core::{Line, Point, Rectangle};
4
5pub use crate::tree::{TreeLayout, TreeNode};
6
7mod tree;
8
9#[derive(Clone, Copy, Debug)]
10pub struct TreeBox {
11    pub top: f64,
12    pub right: f64,
13    pub bottom: f64,
14    pub left: f64,
15}
16
17impl TreeBox {
18    pub fn square(size: f64) -> TreeBox {
19        TreeBox { top: size, right: size, bottom: size, left: size }
20    }
21    pub fn rectangle(width: f64, height: f64) -> TreeBox {
22        TreeBox { top: height, right: width, bottom: height, left: width }
23    }
24}
25
26#[allow(unused_variables)]
27pub trait NodeInfo<N>
28where
29    Self::Key: Eq + Hash,
30    N: Clone,
31{
32    type Key;
33
34    /// Returns a key that will be used to uniquely identify a given node.
35    fn key(&self, node: N) -> Self::Key;
36
37    /// Returns the children that a given node has.
38    fn children(&self, node: N) -> impl Iterator<Item = N>;
39
40    /// Returns the dimensions of a given node.
41    ///
42    /// This is the padding that you want around the centre point of the node so that you can line
43    /// things up as you want to (e.g. nodes aligned by their top border vs being aligned by their
44    /// centres).
45    ///
46    /// This value is generic over units (but all nodes must use the same unit) and the layout that
47    /// this crate calculates will be given in terms of this unit. For example if you give this
48    /// value in pixels then the layout will be given in terms of number of pixels from the left of
49    /// the tree. Alternatively you might want to give this value in terms of the proportion of the
50    /// width of your window (though note that this does not guarantee that the tree will fit in
51    /// your window).
52    ///
53    /// # Default
54    ///
55    /// By default the algorithm assumes that each node is point-like (i.e. has no width or height).
56    fn dimensions(&self, node: N) -> TreeBox {
57        TreeBox::square(0.0)
58    }
59
60    /// Returns the desired border around a given node.
61    ///
62    /// See the `dimensions` method for a description of what units this has.
63    ///
64    /// # Default
65    ///
66    /// By default the algorithm assumes that each node has a border of `0.5` on every side.
67    fn border(&self, node: N) -> TreeBox {
68        TreeBox::square(0.5)
69    }
70}
71
72#[derive(Clone, Debug)]
73pub struct TreeData<K> {
74    pub key: K,
75    x: f64,
76    y: f64,
77    modifier: f64,
78    dimensions: TreeBox,
79    border: TreeBox,
80}
81
82#[allow(dead_code)]
83impl<K> TreeData<K> {
84    fn top_space(&self) -> f64 {
85        self.dimensions.top + self.border.top
86    }
87    fn top(&self) -> f64 {
88        self.y - self.top_space()
89    }
90    fn bottom_space(&self) -> f64 {
91        self.dimensions.bottom + self.border.bottom
92    }
93
94    fn bottom(&self) -> f64 {
95        self.y + self.bottom_space()
96    }
97
98    fn left_space(&self) -> f64 {
99        self.dimensions.left + self.border.left
100    }
101
102    fn left(&self) -> f64 {
103        self.x - self.left_space()
104    }
105
106    fn right_space(&self) -> f64 {
107        self.dimensions.right + self.border.right
108    }
109
110    fn right(&self) -> f64 {
111        self.x + self.right_space()
112    }
113
114    pub fn center(&self) -> Point<f64> {
115        Point { x: self.x, y: self.y }
116    }
117
118    pub fn boundary(&self) -> Rectangle<f64> {
119        Rectangle::from_center(
120            Point::new(self.x, self.y),
121            self.dimensions.left + self.dimensions.right,
122            self.dimensions.top + self.dimensions.bottom,
123        )
124    }
125}
126
127/// Returns the coordinates for the _centre_ of each node.
128///
129/// The origin of the coordinate system will be at the top left of the tree. The coordinates take
130/// into account the width of the left-most node and shift everything so that the left-most border
131/// of the left-most node is at 0 on the x-axis.
132///
133/// # Important
134///
135/// This algorithm _does_ account for the height of nodes but this is only to allow each row of
136/// nodes to be aligned by their centre. If your tree has some nodes at a given depth which are
137/// significantly larger than others and you want to avoid large gaps between rows then a more
138/// general graph layout algorithm is required.
139pub fn layout<N, T>(tree: &T, root: N) -> TreeLayout<TreeData<<T as NodeInfo<N>>::Key>>
140where
141    N: Clone,
142    T: NodeInfo<N>,
143{
144    let mut tree = TreeLayout::new(tree, root, |t, n| TreeData {
145        key: t.key(n.clone()),
146        x: 0.0,
147        y: 0.0,
148        modifier: 0.0,
149        dimensions: t.dimensions(n.clone()),
150        border: t.border(n.clone()),
151    });
152    if let Some(root) = tree.root() {
153        initialise_y(&mut tree, root);
154        initialise_x(&mut tree, root);
155        ensure_positive_x(&mut tree, root);
156        finalise_x(&mut tree, root);
157        tree
158    }
159    else {
160        Default::default()
161    }
162}
163
164fn initialise_y<K>(tree: &mut TreeLayout<TreeData<K>>, root: usize) {
165    let mut next_row = vec![root];
166    while !next_row.is_empty() {
167        let row = next_row;
168        next_row = Vec::new();
169        let mut max = f64::NEG_INFINITY;
170        for node in &row {
171            let node = *node;
172            tree[node].data.y = if let Some(parent) = tree[node].parent { tree[parent].data.bottom() } else { 0.0 }
173                + tree[node].data.top_space();
174            if tree[node].data.y > max {
175                max = tree[node].data.y;
176            }
177            next_row.extend_from_slice(&tree[node].children);
178        }
179
180        for node in &row {
181            tree[*node].data.y = max;
182        }
183    }
184}
185
186fn initialise_x<K>(tree: &mut TreeLayout<TreeData<K>>, root: usize) {
187    for node in tree.post_order(root) {
188        if tree[node].is_leaf() {
189            tree[node].data.x = if let Some(sibling) = tree.previous_sibling(node) { tree[sibling].data.right() } else { 0.0 }
190                + tree[node].data.left_space();
191        }
192        else {
193            let mid = {
194                let first = tree[*tree[node].children.first().expect("Only leaf nodes have no children.")].data.x;
195                let last = tree[*tree[node].children.last().expect("Only leaf nodes have no children.")].data.x;
196                (first + last) / 2.0
197            };
198            if let Some(sibling) = tree.previous_sibling(node) {
199                tree[node].data.x = tree[sibling].data.right() + tree[node].data.left_space();
200                tree[node].data.modifier = tree[node].data.x - mid;
201            }
202            else {
203                tree[node].data.x = mid;
204            }
205            fix_overlaps(tree, node);
206        }
207    }
208}
209
210fn fix_overlaps<K>(tree: &mut TreeLayout<TreeData<K>>, right: usize) {
211    fn max_depth(l: &HashMap<usize, f64>, r: &HashMap<usize, f64>) -> usize {
212        if let Some(l) = l.keys().max() {
213            if let Some(r) = r.keys().max() {
214                return std::cmp::min(*l, *r);
215            }
216        }
217        0
218    }
219    let right_node_contour = left_contour(tree, right);
220    for left in tree.left_siblings(right) {
221        let left_node_contour = right_contour(tree, left);
222        let mut shift = 0.0;
223        for depth in tree[right].depth..=max_depth(&right_node_contour, &left_node_contour) {
224            let gap = right_node_contour[&depth] - left_node_contour[&depth];
225            if gap + shift < 0.0 {
226                shift = -gap;
227            }
228        }
229        tree[right].data.x += shift;
230        tree[right].data.modifier += shift;
231        centre_nodes_between(tree, left, right);
232    }
233}
234
235fn left_contour<K>(tree: &TreeLayout<TreeData<K>>, node: usize) -> HashMap<usize, f64> {
236    contour(tree, node, min, |n| n.data.left())
237}
238
239fn right_contour<K>(tree: &TreeLayout<TreeData<K>>, node: usize) -> HashMap<usize, f64> {
240    contour(tree, node, max, |n| n.data.right())
241}
242
243fn min<T: PartialOrd>(l: T, r: T) -> T {
244    if l < r { l } else { r }
245}
246
247fn max<T: PartialOrd>(l: T, r: T) -> T {
248    if l > r { l } else { r }
249}
250
251fn contour<C, E, K>(tree: &TreeLayout<TreeData<K>>, node: usize, cmp: C, edge: E) -> HashMap<usize, f64>
252where
253    C: Fn(f64, f64) -> f64,
254    E: Fn(&TreeNode<TreeData<K>>) -> f64,
255{
256    let mut stack = vec![(0.0, node)];
257    let mut contour = HashMap::new();
258    while let Some((mod_, node)) = stack.pop() {
259        let depth = tree[node].depth;
260        let shifted = edge(&tree[node]) + mod_;
261        let new = if let Some(current) = contour.get(&depth) { cmp(*current, shifted) } else { shifted };
262        let mod_ = mod_ + tree[node].data.modifier;
263        contour.insert(depth, new);
264        stack.extend(tree[node].children.iter().map(|c| (mod_, *c)));
265    }
266    contour
267}
268
269fn centre_nodes_between<K>(tree: &mut TreeLayout<TreeData<K>>, left: usize, right: usize) {
270    let num_gaps = tree[right].order - tree[left].order;
271
272    let space_per_gap = (tree[right].data.left() - tree[left].data.right()) / (num_gaps as f64);
273
274    for (i, sibling) in tree.siblings_between(left, right).into_iter().enumerate() {
275        let i = i + 1;
276
277        let old_x = tree[sibling].data.x;
278        // HINT: We traverse the tree in post-order so we should never be moving anything to the
279        //       left.
280        // TODO: Have some kind of `move_node` method that checks things like this?
281        let new_x = max(old_x, tree[left].data.right() + space_per_gap * (i as f64));
282        let diff = new_x - old_x;
283
284        tree[sibling].data.x = new_x;
285        tree[sibling].data.modifier += diff;
286    }
287}
288
289fn ensure_positive_x<K>(tree: &mut TreeLayout<TreeData<K>>, root: usize) {
290    let contour = left_contour(tree, root);
291    let shift = -contour
292        .values()
293        .fold(None, |acc, curr| {
294            let acc = acc.unwrap_or(f64::INFINITY);
295            let curr = *curr;
296            Some(if curr < acc { curr } else { acc })
297        })
298        .unwrap_or(0.0);
299
300    tree[root].data.x += shift;
301    tree[root].data.modifier += shift;
302}
303
304fn finalise_x<K>(tree: &mut TreeLayout<TreeData<K>>, root: usize) {
305    for node in tree.breadth_first(root) {
306        let shift = if let Some(parent) = tree[node].parent { tree[parent].data.modifier } else { 0.0 };
307        tree[node].data.x += shift;
308        tree[node].data.modifier += shift;
309    }
310}