warg_transparency/log/
node.rs

1use alloc::vec::Vec;
2
3/// Represents a node in a tree by its index.
4#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)]
5pub struct Node(pub usize);
6
7/// What side of its parent a given node is on.
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum Side {
10    /// A node on the left child of its parent
11    Left,
12    /// A node on the right side of its parent
13    Right,
14}
15
16impl Node {
17    /// The index in the tree space for this node
18    #[inline]
19    pub fn index(&self) -> usize {
20        self.0
21    }
22
23    /// The height of this node, where leaves have height 0
24    #[inline]
25    pub fn height(&self) -> u32 {
26        self.index().trailing_ones()
27    }
28
29    /// The distance from this node to its parent
30    #[inline]
31    fn delta(&self) -> usize {
32        2usize.pow(self.height())
33    }
34
35    /// The side this node is on of its parent
36    #[inline]
37    pub fn side(&self) -> Side {
38        let height = self.height();
39        let shift = height + 1;
40        let shifted = self.index() >> shift;
41        let masked = shifted & 1;
42        if masked == 0 {
43            Side::Left
44        } else {
45            Side::Right
46        }
47    }
48
49    /// The left sibling of this node.
50    /// Panics if node is not a right-side node.
51    #[inline]
52    pub fn left_sibling(&self) -> Node {
53        assert_eq!(self.side(), Side::Right);
54        let delta = self.delta();
55        Node(self.index() - delta - delta)
56    }
57
58    /// The right sibling of this node.
59    /// Panics if node is not a left-side node.
60    #[inline]
61    pub fn right_sibling(&self) -> Node {
62        assert_eq!(self.side(), Side::Left);
63        let delta = self.delta();
64        Node(self.index() + delta + delta)
65    }
66
67    /// The sibling of this node.
68    #[inline]
69    pub fn sibling(&self) -> Node {
70        match self.side() {
71            Side::Left => self.right_sibling(),
72            Side::Right => self.left_sibling(),
73        }
74    }
75
76    /// Finds the parent of a given node index.
77    #[inline]
78    pub fn parent(&self) -> Node {
79        let parent_index = match self.side() {
80            Side::Left => self.index() + self.delta(),
81            Side::Right => self.index() - self.delta(),
82        };
83        Node(parent_index)
84    }
85
86    /// Find the left and right child of this node.
87    #[inline]
88    pub fn children(&self) -> (Node, Node) {
89        assert_ne!(self.height(), 0);
90        let index = self.index();
91        let child_delta = self.delta() / 2;
92        (Node(index - child_delta), Node(index + child_delta))
93    }
94
95    /// Finds the right-most node which is a descendent of this one.
96    #[inline]
97    pub fn rightmost_descendent(&self) -> Node {
98        let offset = 2usize.pow(self.height()) - 1;
99        Node(self.index() + offset)
100    }
101
102    /// Finds the left-most node which is a descendent of this one.
103    #[inline]
104    pub fn leftmost_descendent(&self) -> Node {
105        let offset = 2usize.pow(self.height()) - 1;
106        Node(self.index() - offset)
107    }
108
109    /// Determines if a log with the given number of elements
110    /// would contain this node.
111    #[inline]
112    pub fn exists_at_length(&self, length: usize) -> bool {
113        let last_child = self.rightmost_descendent();
114        let required_entries = last_child.index() / 2;
115        required_entries < length
116    }
117
118    /// Finds the next node after this one with the specified height.
119    /// The specified height MUST be less than or equal the height of this node.
120    #[inline]
121    pub fn next_node_with_height(&self, height: u32) -> Node {
122        assert!(
123            self.height() >= height,
124            "This algorithm is designed to only work for smaller or equal successors"
125        );
126        let first_with_height = Self::first_node_with_height(height);
127        let next_leaf = self.rightmost_descendent().index() + 2;
128        Node(first_with_height.index() + next_leaf)
129    }
130
131    /// Compute the left-most node which has a given height.
132    #[inline]
133    pub fn first_node_with_height(height: u32) -> Node {
134        Node(2usize.pow(height) - 1)
135    }
136
137    /// Compute the balanced roots for a log with a given
138    /// log length in number of leaves.
139    #[inline]
140    pub fn broots_for_len(length: usize) -> Vec<Node> {
141        let mut value = length;
142        let mut broot_heights = Vec::new();
143        for i in 0..usize::BITS {
144            let present = (value & 1) == 1;
145            if present {
146                broot_heights.push(i);
147            }
148
149            value >>= 1;
150        }
151
152        let mut broots = Vec::new();
153        let mut current: Option<Node> = None;
154        for broot_height in broot_heights.into_iter().rev() {
155            let next = match current {
156                None => Self::first_node_with_height(broot_height),
157                Some(last) => last.next_node_with_height(broot_height),
158            };
159            broots.push(next);
160            current = Some(next);
161        }
162
163        broots
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    #[test]
172    fn test_node_height_delta() {
173        // Heights for nodes at indices
174        assert_eq!(Node(0).height(), 0);
175        assert_eq!(Node(1).height(), 1);
176        assert_eq!(Node(2).height(), 0);
177        assert_eq!(Node(3).height(), 2);
178        assert_eq!(Node(4).height(), 0);
179        assert_eq!(Node(5).height(), 1);
180        assert_eq!(Node(6).height(), 0);
181        assert_eq!(Node(7).height(), 3);
182        assert_eq!(Node(8).height(), 0);
183        assert_eq!(Node(9).height(), 1);
184        assert_eq!(Node(10).height(), 0);
185        assert_eq!(Node(11).height(), 2);
186        assert_eq!(Node(12).height(), 0);
187        assert_eq!(Node(13).height(), 1);
188        assert_eq!(Node(14).height(), 0);
189
190        // Deltas for nodes at indices
191        assert_eq!(Node(0).height(), 0);
192        assert_eq!(Node(1).height(), 1);
193        assert_eq!(Node(2).height(), 0);
194        assert_eq!(Node(3).height(), 2);
195        assert_eq!(Node(4).height(), 0);
196        assert_eq!(Node(5).height(), 1);
197        assert_eq!(Node(6).height(), 0);
198        assert_eq!(Node(7).height(), 3);
199        assert_eq!(Node(8).height(), 0);
200        assert_eq!(Node(9).height(), 1);
201        assert_eq!(Node(10).height(), 0);
202        assert_eq!(Node(11).height(), 2);
203        assert_eq!(Node(12).height(), 0);
204        assert_eq!(Node(13).height(), 1);
205        assert_eq!(Node(14).height(), 0);
206    }
207
208    #[test]
209    fn test_node_neighbors() {
210        // Whether each node is a left or right side child
211        assert_eq!(Node(0).side(), Side::Left);
212        assert_eq!(Node(1).side(), Side::Left);
213        assert_eq!(Node(2).side(), Side::Right);
214        assert_eq!(Node(3).side(), Side::Left);
215        assert_eq!(Node(4).side(), Side::Left);
216        assert_eq!(Node(5).side(), Side::Right);
217        assert_eq!(Node(6).side(), Side::Right);
218        assert_eq!(Node(7).side(), Side::Left);
219        assert_eq!(Node(8).side(), Side::Left);
220        assert_eq!(Node(9).side(), Side::Left);
221        assert_eq!(Node(10).side(), Side::Right);
222        assert_eq!(Node(11).side(), Side::Right);
223        assert_eq!(Node(12).side(), Side::Left);
224        assert_eq!(Node(13).side(), Side::Right);
225        assert_eq!(Node(14).side(), Side::Right);
226
227        // Sibling index for each node
228        assert_eq!(Node(0).right_sibling(), Node(2));
229        assert_eq!(Node(1).right_sibling(), Node(5));
230        assert_eq!(Node(2).left_sibling(), Node(0));
231        assert_eq!(Node(3).right_sibling(), Node(11));
232        assert_eq!(Node(4).right_sibling(), Node(6));
233        assert_eq!(Node(5).left_sibling(), Node(1));
234        assert_eq!(Node(6).left_sibling(), Node(4));
235        assert_eq!(Node(7).right_sibling(), Node(23));
236        assert_eq!(Node(8).right_sibling(), Node(10));
237        assert_eq!(Node(9).right_sibling(), Node(13));
238        assert_eq!(Node(10).left_sibling(), Node(8));
239        assert_eq!(Node(11).left_sibling(), Node(3));
240        assert_eq!(Node(12).right_sibling(), Node(14));
241        assert_eq!(Node(13).left_sibling(), Node(9));
242        assert_eq!(Node(14).left_sibling(), Node(12));
243
244        // Parent index for each node
245        assert_eq!(Node(0).parent(), Node(1));
246        assert_eq!(Node(1).parent(), Node(3));
247        assert_eq!(Node(2).parent(), Node(1));
248        assert_eq!(Node(3).parent(), Node(7));
249        assert_eq!(Node(4).parent(), Node(5));
250        assert_eq!(Node(5).parent(), Node(3));
251        assert_eq!(Node(6).parent(), Node(5));
252        assert_eq!(Node(7).parent(), Node(15));
253        assert_eq!(Node(8).parent(), Node(9));
254        assert_eq!(Node(9).parent(), Node(11));
255        assert_eq!(Node(10).parent(), Node(9));
256        assert_eq!(Node(11).parent(), Node(7));
257        assert_eq!(Node(12).parent(), Node(13));
258        assert_eq!(Node(13).parent(), Node(11));
259        assert_eq!(Node(14).parent(), Node(13));
260
261        // Children indices for each branch node
262        assert_eq!(Node(1).children(), (Node(0), Node(2)));
263        assert_eq!(Node(3).children(), (Node(1), Node(5)));
264        assert_eq!(Node(5).children(), (Node(4), Node(6)));
265        assert_eq!(Node(7).children(), (Node(3), Node(11)));
266        assert_eq!(Node(9).children(), (Node(8), Node(10)));
267        assert_eq!(Node(11).children(), (Node(9), Node(13)));
268        assert_eq!(Node(13).children(), (Node(12), Node(14)));
269    }
270
271    #[test]
272    fn test_node_existence() {
273        // The rightmost descendent of each branch node
274        assert_eq!(Node(1).rightmost_descendent(), Node(2));
275        assert_eq!(Node(3).rightmost_descendent(), Node(6));
276        assert_eq!(Node(5).rightmost_descendent(), Node(6));
277        assert_eq!(Node(7).rightmost_descendent(), Node(14));
278        assert_eq!(Node(9).rightmost_descendent(), Node(10));
279        assert_eq!(Node(11).rightmost_descendent(), Node(14));
280        assert_eq!(Node(13).rightmost_descendent(), Node(14));
281
282        // Whether each branch node exists at a given length
283        let cases = [(1, 2), (3, 4), (5, 4), (7, 8), (9, 6), (11, 8), (13, 8)];
284        for (index, min_len) in cases {
285            let node = Node(index);
286            for len in 0..=8 {
287                if len >= min_len {
288                    assert!(
289                        node.exists_at_length(len),
290                        "Node {} should exist when length is {}",
291                        index,
292                        len
293                    );
294                } else {
295                    assert!(
296                        !node.exists_at_length(len),
297                        "Node {} should not exist when length is {}",
298                        index,
299                        len
300                    );
301                }
302            }
303        }
304    }
305
306    #[test]
307    fn test_first_nodes() {
308        // First node with each height
309        let first_0 = Node::first_node_with_height(0);
310        assert_eq!(first_0, Node(0));
311        assert_eq!(first_0.next_node_with_height(0), Node(2));
312
313        let first_1 = Node::first_node_with_height(1);
314        assert_eq!(first_1, Node(1));
315        assert_eq!(first_1.next_node_with_height(0), Node(4));
316        assert_eq!(first_1.next_node_with_height(1), Node(5));
317
318        let first_2 = Node::first_node_with_height(2);
319        assert_eq!(first_2, Node(3));
320        assert_eq!(first_2.next_node_with_height(0), Node(8));
321        assert_eq!(first_2.next_node_with_height(1), Node(9));
322        assert_eq!(first_2.next_node_with_height(2), Node(11));
323
324        let first_3 = Node::first_node_with_height(3);
325        assert_eq!(first_3, Node(7));
326        assert_eq!(first_3.next_node_with_height(0), Node(16));
327        assert_eq!(first_3.next_node_with_height(1), Node(17));
328        assert_eq!(first_3.next_node_with_height(2), Node(19));
329        assert_eq!(first_3.next_node_with_height(3), Node(23));
330
331        assert_eq!(Node::first_node_with_height(4), Node(15));
332    }
333
334    #[test]
335    fn test_broots() {
336        use alloc::vec;
337
338        // This math is used when computing which roots are available
339        assert_eq!(Node::broots_for_len(0), vec![]);
340        assert_eq!(Node::broots_for_len(1), vec![Node(0)]);
341        assert_eq!(Node::broots_for_len(2), vec![Node(1)]);
342        assert_eq!(Node::broots_for_len(3), vec![Node(1), Node(4)]);
343        assert_eq!(Node::broots_for_len(4), vec![Node(3)]);
344        assert_eq!(Node::broots_for_len(5), vec![Node(3), Node(8)]);
345        assert_eq!(Node::broots_for_len(6), vec![Node(3), Node(9)]);
346        assert_eq!(Node::broots_for_len(7), vec![Node(3), Node(9), Node(12)]);
347        assert_eq!(Node::broots_for_len(8), vec![Node(7)]);
348    }
349}