treemath/
lib.rs

1use crate::bounds::LEVEL_MAX;
2use std::collections::VecDeque;
3
4/// Returns the height of that node in the tree
5#[inline(always)]
6pub const fn level(node_index: usize) -> usize {
7    node_index.trailing_ones() as usize
8}
9
10/// Returns the root node
11#[inline(always)]
12pub const fn root(leaf_count: usize) -> usize {
13    // leaf_count.wrapping_sub(1); // works only in case leaf_count is a power of 2
14    if leaf_count == 0 {
15        return 0;
16    }
17    let shl = node_width(leaf_count).ilog2();
18    let pow2: usize = 1 << shl;
19    pow2.wrapping_sub(1)
20}
21
22/// Number of nodes needed to represent a tree with [leaf_count] leaves.
23#[inline(always)]
24pub const fn node_width(leaf_count: usize) -> usize {
25    if leaf_count == 0 {
26        return 0;
27    }
28    // 2*(n - 1) + 1
29    leaf_count
30        .wrapping_sub(1) // since leaf_count is >= 1
31        .saturating_mul(2)
32        .saturating_add(1)
33}
34
35/// Get the parent of a node, return [None] if the node is the root
36#[inline(always)]
37pub const fn parent(node_index: usize, leaf_count: usize) -> Option<usize> {
38    if node_index == root(leaf_count) {
39        return None;
40    }
41    Some((bits::last_zero_bit(node_index) | node_index) & !bits::last_zero_bit(node_index).wrapping_shl(1))
42}
43
44/// Given a node, return the left/right child of his parent, return [None] when root
45#[inline(always)]
46pub const fn sibling(node_index: usize, leaf_count: usize) -> Option<usize> {
47    let Some(parent) = parent(node_index, leaf_count) else {
48        return None;
49    };
50    let parent = parent as isize;
51    let d = parent.overflowing_sub(node_index as isize).0;
52    let sibling = parent.overflowing_add(d).0;
53    Some(sibling as usize)
54}
55
56#[inline(always)]
57pub fn left(node_index: usize) -> Option<usize> {
58    if node_index % 2 == 0 {
59        return None;
60    }
61    let lzb = bits::last_zero_bit(node_index);
62    let left = node_index & !lzb.wrapping_shr(1);
63    Some(left)
64}
65
66#[inline(always)]
67pub const fn right(node_index: usize) -> Option<usize> {
68    if node_index % 2 == 0 {
69        return None;
70    }
71    let lzb = bits::last_zero_bit(node_index);
72    let right = (node_index | lzb) & !lzb.wrapping_shr(1);
73    Some(right)
74}
75
76pub fn direct_path(node_index: usize, leaf_count: usize) -> Option<VecDeque<usize>> {
77    // see https://mmapped.blog/posts/22-flat-in-order-trees.html#sec-addressing
78    let mut root = root(leaf_count);
79    if node_index == root {
80        return None;
81    }
82
83    let mut root_level = level(root);
84
85    let floor = LEVEL_MAX.wrapping_sub(root_level);
86    let node_level = level(node_index);
87    let mut path_size = root_level.wrapping_sub(node_level);
88    let mut path = VecDeque::with_capacity(path_size);
89
90    path.push_back(root);
91
92    path_size = path_size.wrapping_sub(1);
93
94    let mask = usize::MAX >> floor;
95    let chunk = node_index & mask;
96
97    for _ in 0..path_size {
98        let d = ((chunk >> root_level) & 1) == 0;
99        root = child_with_direction(root, d, root_level);
100        root_level = root_level.wrapping_sub(1);
101        path.push_front(root);
102    }
103
104    Some(path)
105}
106
107#[inline(always)]
108pub const fn child_with_direction(node_index: usize, direction: bool, level: usize) -> usize {
109    let f = 2usize ^ (1usize.wrapping_shl(direction as u32) | 1);
110    let lvl = level.wrapping_sub(1);
111    let f = f.wrapping_shl(lvl as u32);
112    node_index ^ f
113}
114
115#[inline(always)]
116const fn nephew(node_index: usize, is_left: bool, leaf_count: usize, mut level: usize) -> Option<usize> {
117    if level < 1 {
118        return None;
119    }
120    let Some(parent) = parent(node_index, leaf_count) else {
121        return None;
122    };
123    level = level.wrapping_add(1);
124    let sibling = child_with_direction(parent, node_index > parent, level);
125    level = level.wrapping_sub(1);
126    let nephew = child_with_direction(sibling, is_left, level);
127    Some(nephew)
128}
129
130pub fn copath(node_index: usize, leaf_count: usize) -> Option<VecDeque<usize>> {
131    // see https://mmapped.blog/posts/22-flat-in-order-trees.html#sec-addressing
132    let mut root = root(leaf_count);
133    if node_index == root {
134        return None;
135    }
136
137    let mut root_level = level(root);
138
139    let floor = LEVEL_MAX.wrapping_sub(root_level);
140    let node_level = level(node_index);
141    let path_size = root_level.wrapping_sub(node_level).wrapping_sub(1);
142    let mut copath = VecDeque::with_capacity(path_size);
143
144    let mask = usize::MAX >> floor;
145    let chunk = node_index & mask;
146
147    let b = ((chunk >> root_level) & 1) == 0;
148    root = child_with_direction(root, !b, root_level);
149    root_level = root_level.wrapping_sub(1);
150    copath.push_front(root);
151
152    for _ in 0..path_size {
153        let b = ((chunk >> root_level) & 1) == 0;
154
155        let Some(nephew) = nephew(root, !b, leaf_count, root_level) else {
156            // should never happen because we should bail before the leaf if we compute the path_size right
157            return Some(copath);
158        };
159
160        root = nephew;
161        root_level = root_level.wrapping_sub(1);
162        copath.push_front(root);
163    }
164
165    Some(copath)
166}
167
168#[inline(always)]
169pub const fn common_ancestor(node_index: usize, other: usize) -> usize {
170    if node_index == other {
171        return node_index;
172    }
173    let d = bits::most_significant_bit(node_index ^ other);
174    (node_index & !d) | (d.wrapping_sub(1))
175}
176
177mod bits {
178    #[inline(always)]
179    pub const fn last_set_bit(n: usize) -> usize {
180        n.wrapping_sub(n.wrapping_sub(1) & n)
181    }
182
183    #[inline(always)]
184    pub const fn last_zero_bit(n: usize) -> usize {
185        last_set_bit(n + 1)
186    }
187
188    #[inline(always)]
189    pub const fn most_significant_bit(mut n: usize) -> usize {
190        n |= n.wrapping_shr(1);
191        n |= n.wrapping_shr(2);
192        n |= n.wrapping_shr(4);
193        n |= n.wrapping_shr(8);
194        n |= n.wrapping_shr(16);
195        n |= n.wrapping_shr(32);
196        n - n.wrapping_shr(1)
197    }
198
199    #[allow(dead_code)]
200    #[inline(always)]
201    pub const fn round_up_power_2(mut n: usize) -> usize {
202        n -= 1;
203        n |= n.wrapping_shr(1);
204        n |= n.wrapping_shr(2);
205        n |= n.wrapping_shr(4);
206        n |= n.wrapping_shr(8);
207        n |= n.wrapping_shr(16);
208        n |= n.wrapping_shr(32);
209        n += 1;
210        n
211    }
212
213    #[cfg(test)]
214    mod tests {
215        use super::*;
216
217        #[test]
218        fn msb_should_succeed() {
219            assert_eq!(1, most_significant_bit(1));
220            assert_eq!(2, most_significant_bit(2));
221            assert_eq!(2, most_significant_bit(3));
222            assert_eq!(4, most_significant_bit(4));
223            assert_eq!(4, most_significant_bit(5));
224            assert_eq!(4, most_significant_bit(6));
225            assert_eq!(usize::MAX, most_significant_bit(usize::MAX));
226        }
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::{bounds::*, naive::*, *};
233    use itertools::*;
234
235    mod level {
236        use super::*;
237
238        #[test]
239        fn should_succeed() {
240            for i in 0usize..100_000 {
241                assert_eq!(level(i), level_naive(i), "failed for node index {}", i);
242            }
243        }
244
245        #[test]
246        fn should_succeed_at_boundaries() {
247            assert_eq!(level(NODE_INDEX_MAX), level_naive(NODE_INDEX_MAX));
248            assert_eq!(level(0), level_naive(0));
249        }
250    }
251
252    mod root {
253        use super::*;
254
255        #[test]
256        fn should_succeed() {
257            for lc in leaf_count_range() {
258                assert_eq!(root(lc), root_naive(lc));
259                assert_eq!(root(lc), lc - 1);
260            }
261        }
262
263        #[test]
264        fn should_succeed_at_boundaries() {
265            assert_eq!(root(0), root_naive(0));
266            assert_eq!(root(0), 0);
267            assert_eq!(root(LEAF_COUNT_MAX), root_naive(LEAF_COUNT_MAX));
268            assert_eq!(root(LEAF_COUNT_MAX), ROOT_MAX);
269        }
270    }
271
272    mod node_width {
273        use super::*;
274
275        #[test]
276        fn should_succeed() {
277            assert_eq!(node_width(1), 1);
278            assert_eq!(node_width(2), 3);
279        }
280
281        #[test]
282        fn should_succeed_at_boundaries() {
283            assert_eq!(node_width(0), 0);
284            assert_eq!(node_width(LEAF_COUNT_MAX), NODE_WIDTH_MAX);
285            assert_eq!(node_width(usize::MAX), NODE_WIDTH_MAX);
286        }
287    }
288
289    mod parent {
290        use super::*;
291
292        #[test]
293        fn should_fail_when_root() {
294            for (r, lc) in root_range() {
295                assert!(parent(r, lc).is_none());
296                assert!(parent_naive(r, lc).is_none());
297            }
298        }
299
300        #[test]
301        fn should_succeed_for_leaves() {
302            let lc = LEAF_COUNT_MAX;
303            for left_leaf in (0..=u16::MAX).step_by(4) {
304                assert_eq!(parent(left_leaf as usize, lc), parent_naive(left_leaf as usize, lc));
305                assert_eq!(parent(left_leaf as usize, lc), Some(left_leaf as usize + 1));
306            }
307            for right_leaf in (2..=u16::MAX).step_by(4) {
308                assert_eq!(parent(right_leaf as usize, lc), parent_naive(right_leaf as usize, lc));
309                assert_eq!(parent(right_leaf as usize, lc), Some(right_leaf as usize - 1));
310            }
311        }
312
313        #[test]
314        fn should_succeed_at_boundaries() {
315            assert!(parent(NODE_INDEX_MAX, LEAF_COUNT_MAX).is_some());
316            assert_eq!(parent(NODE_INDEX_MAX, LEAF_COUNT_MAX), parent_naive(NODE_INDEX_MAX, LEAF_COUNT_MAX));
317        }
318    }
319
320    mod direct_path {
321        use super::*;
322
323        #[test]
324        fn should_succeed() {
325            for (lc, i) in leaf_count_range_with_node_index().take(100_000) {
326                assert_eq!(direct_path(i, lc), direct_path_naive(i, lc));
327            }
328        }
329
330        #[test]
331        fn should_succeed_for_remarkable_values() {
332            let lc = 8usize;
333            let values = [
334                (0usize, vec![1usize, 3, 7]),
335                (1, vec![3, 7]),
336                (2, vec![1, 3, 7]),
337                (3, vec![7]),
338                (4, vec![5, 3, 7]),
339                (5, vec![3, 7]),
340                (6, vec![5, 3, 7]),
341                (8, vec![9, 11, 7]),
342                (9, vec![11, 7]),
343                (10, vec![9, 11, 7]),
344                (11, vec![7]),
345                (12, vec![13, 11, 7]),
346                (13, vec![11, 7]),
347                (14, vec![13, 11, 7]),
348            ]
349            .map(|(i, e)| (i, VecDeque::from_iter(e)));
350            for (i, expected) in values {
351                assert_eq!(direct_path(i, lc), direct_path_naive(i, lc));
352                assert_eq!(direct_path(i, lc), Some(expected));
353            }
354        }
355
356        #[test]
357        fn should_succeed_at_boundaries() {
358            let values = [
359                (0, 1),
360                (0, 2),
361                (1, 2),
362                (NODE_INDEX_MAX - 2, LEAF_COUNT_MAX),
363                (NODE_INDEX_MAX - 1, LEAF_COUNT_MAX),
364                (NODE_INDEX_MAX, LEAF_COUNT_MAX),
365            ];
366            for (i, lc) in values {
367                assert_eq!(direct_path(i, lc), direct_path_naive(i, lc));
368            }
369        }
370
371        #[test]
372        fn should_fail_for_roots() {
373            for (r, lc) in root_range() {
374                assert!(direct_path(r, lc).is_none());
375                assert_eq!(direct_path(r, lc), direct_path_naive(r, lc));
376            }
377        }
378    }
379
380    mod copath {
381        use super::*;
382
383        #[test]
384        fn should_succeed() {
385            for (lc, i) in leaf_count_range_with_node_index().take(10) {
386                assert_eq!(copath(i, lc), copath_naive(i, lc));
387            }
388        }
389
390        #[test]
391        fn should_succeed_for_remarkable_values() {
392            let lc = 8usize;
393            let values = [
394                (0usize, vec![2usize, 5, 11]),
395                (1, vec![5, 11]),
396                (2, vec![0, 5, 11]),
397                (3, vec![11]),
398                (4, vec![6, 1, 11]),
399                (5, vec![1, 11]),
400                (6, vec![4, 1, 11]),
401                (8, vec![10, 13, 3]),
402                (9, vec![13, 3]),
403                (10, vec![8, 13, 3]),
404                (11, vec![3]),
405                (12, vec![14, 9, 3]),
406                (13, vec![9, 3]),
407                (14, vec![12, 9, 3]),
408            ]
409            .map(|(i, e)| (i, VecDeque::from_iter(e)));
410            for (i, expected) in values {
411                assert_eq!(copath(i, lc), copath_naive(i, lc));
412                assert_eq!(copath(i, lc), Some(expected));
413            }
414        }
415
416        #[test]
417        fn should_succeed_at_boundaries() {
418            let values = [
419                (0, 1),
420                (0, 2),
421                (1, 2),
422                (NODE_INDEX_MAX - 2, LEAF_COUNT_MAX),
423                (NODE_INDEX_MAX - 1, LEAF_COUNT_MAX),
424                (NODE_INDEX_MAX, LEAF_COUNT_MAX),
425            ];
426            for (i, lc) in values {
427                assert_eq!(copath(i, lc), copath_naive(i, lc));
428            }
429        }
430
431        #[test]
432        fn should_fail_for_roots() {
433            for (r, lc) in root_range() {
434                assert!(copath(r, lc).is_none());
435                assert_eq!(copath(r, lc), copath_naive(r, lc));
436            }
437        }
438    }
439
440    mod common_ancestor {
441        use super::*;
442
443        #[test]
444        fn should_succeed() {
445            let e = 10;
446            for a in level_range(0).take(1 << e) {
447                for b in level_range(0).take(1 << e) {
448                    assert_eq!(common_ancestor(a, b), common_ancestor_naive(a, b));
449                }
450            }
451        }
452
453        #[test]
454        fn should_succeed_at_boundaries() {
455            let values = [(0, 2), (0, NODE_INDEX_MAX)];
456            for (a, b) in values {
457                assert_eq!(common_ancestor(a, b), common_ancestor_naive(a, b));
458            }
459        }
460
461        fn level_range(level: usize) -> impl Iterator<Item = usize> {
462            let lower = (1 << level) - 1;
463            let step = 1 << (level + 1);
464            (lower..=NODE_INDEX_MAX).step_by(step).dedup()
465        }
466    }
467}
468
469pub mod bounds {
470    pub const NODE_INDEX_MAX: usize = usize::MAX - 1;
471    pub const LEAF_COUNT_MAX: usize = (NODE_INDEX_MAX / 2) + 1;
472    pub const NODE_WIDTH_MAX: usize = (LEAF_COUNT_MAX - 1) * 2 + 1;
473    pub const ROOT_MAX: usize = LEAF_COUNT_MAX - 1;
474    pub const LEVEL_MAX: usize = (usize::BITS - 1) as usize;
475
476    pub const LEAF_COUNT_BITS: usize = 31;
477    pub const ROOT_BITS: usize = 31;
478
479    pub fn leaf_count_range() -> impl Iterator<Item = usize> {
480        (0..=LEAF_COUNT_BITS).map(|sh| 1 << sh)
481    }
482
483    // returns an iterator of (root, leaf_count)
484    pub fn root_range() -> impl Iterator<Item = (usize, usize)> {
485        (0..=ROOT_BITS).map(|e| (1 << e) - 1).map(|root| (root, root + 1))
486    }
487
488    pub fn leaf_count_range_with_node_index() -> impl Iterator<Item = (usize, usize)> {
489        leaf_count_range().flat_map(|lc| (0..=lc.saturating_sub(1) * 2).map(move |i| (lc, i)))
490    }
491}
492
493#[cfg(any(test, feature = "bench"))]
494pub mod naive {
495    use super::*;
496    use std::collections::VecDeque;
497    use std::ops::Shl;
498
499    #[inline(always)]
500    pub fn level_naive(node_index: usize) -> usize {
501        if node_index & 0x01 == 0 {
502            return 0;
503        }
504
505        let mut k = 0;
506        while node_index.checked_shr(k).is_some() && (node_index >> k) & 0x01 == 1 {
507            k += 1;
508        }
509        k as usize
510    }
511
512    #[inline(always)]
513    pub fn root_naive(leaf_count: usize) -> usize {
514        if leaf_count == 0 {
515            return 0;
516        }
517        let width = node_width(leaf_count);
518        let pow2: usize = 1 << width.ilog2();
519        pow2.wrapping_sub(1)
520    }
521
522    #[inline(always)]
523    pub fn parent_naive(node_index: usize, leaf_count: usize) -> Option<usize> {
524        if node_index == root_naive(leaf_count) {
525            return None;
526        }
527
528        let k = level_naive(node_index);
529        let b = (node_index >> (k + 1)) & 0x01;
530        Some((node_index | (1 << k)) ^ (b << (k + 1)))
531    }
532
533    #[inline(always)]
534    pub fn sibling_naive(node_index: usize, leaf_count: usize) -> Option<usize> {
535        let parent = parent_naive(node_index, leaf_count)?;
536        if node_index < parent { right_naive(parent) } else { left_naive(parent) }
537    }
538
539    #[inline(always)]
540    pub fn left_naive(node_index: usize) -> Option<usize> {
541        let k = level_naive(node_index);
542        if k == 0 {
543            return None;
544        }
545        let node_index = node_index ^ (0x01 << k.wrapping_sub(1));
546        Some(node_index)
547    }
548
549    #[inline(always)]
550    pub fn right_naive(node_index: usize) -> Option<usize> {
551        let k = level_naive(node_index);
552        if k == 0 {
553            return None;
554        }
555        let node_index = node_index ^ (0x03 << k.wrapping_sub(1));
556        Some(node_index)
557    }
558
559    #[inline(always)]
560    pub fn direct_path_naive(mut node_index: usize, leaf_count: usize) -> Option<VecDeque<usize>> {
561        let root = root_naive(leaf_count);
562        if node_index == root {
563            return None;
564        }
565
566        let mut ret = VecDeque::new();
567        while node_index != root {
568            match parent_naive(node_index, leaf_count) {
569                Some(parent_idx) => node_index = parent_idx,
570                None => return None,
571            }
572            ret.push_back(node_index);
573        }
574
575        Some(ret)
576    }
577
578    pub fn copath_naive(node_index: usize, leaf_count: usize) -> Option<VecDeque<usize>> {
579        if node_index == root(leaf_count) {
580            return None;
581        }
582
583        let mut path = direct_path(node_index, leaf_count)?;
584        path.insert(0, node_index);
585        let _ = path.pop_back();
586
587        path.into_iter().map(|path_idx| sibling(path_idx, leaf_count)).collect()
588    }
589
590    pub fn common_ancestor_naive(mut node_index: usize, mut other: usize) -> usize {
591        let self_lvl = level_naive(node_index).saturating_add(1);
592        let other_lvl = level_naive(other).saturating_add(1);
593        if self_lvl <= other_lvl && (node_index >> other_lvl) == (other >> other_lvl) {
594            return other;
595        } else if other_lvl <= self_lvl && (node_index >> self_lvl) == (other >> self_lvl) {
596            return node_index;
597        }
598
599        let mut k = 0u32;
600        while node_index != other {
601            node_index >>= 1;
602            other >>= 1;
603            k = k.saturating_add(1);
604        }
605
606        let s = 1usize.shl(k.saturating_sub(1));
607        (node_index.overflowing_shl(k).0).saturating_add(s).saturating_sub(1)
608    }
609}