Skip to main content

yulang_runtime/runtime/
list_tree.rs

1use std::rc::Rc;
2
3#[derive(Debug, PartialEq, Eq)]
4pub enum ListTree<T> {
5    Empty,
6    Leaf(T),
7    Node(Rc<ListNode<T>>),
8}
9
10impl<T: Clone> Clone for ListTree<T> {
11    fn clone(&self) -> Self {
12        match self {
13            Self::Empty => Self::Empty,
14            Self::Leaf(value) => Self::Leaf(value.clone()),
15            Self::Node(node) => Self::Node(node.clone()),
16        }
17    }
18}
19
20impl<T: Clone> ListTree<T> {
21    pub fn empty() -> Self {
22        Self::Empty
23    }
24
25    pub fn singleton(value: T) -> Self {
26        Self::Leaf(value)
27    }
28
29    pub fn len(&self) -> usize {
30        match self {
31            Self::Empty => 0,
32            Self::Leaf(_) => 1,
33            Self::Node(node) => node.len,
34        }
35    }
36
37    pub fn is_empty(&self) -> bool {
38        matches!(self, Self::Empty)
39    }
40
41    pub fn view(&self) -> ListView<T> {
42        match self {
43            Self::Empty => ListView::Empty,
44            Self::Leaf(value) => ListView::Leaf(value.clone()),
45            Self::Node(node) => ListView::Node {
46                color: node.color,
47                len: node.len,
48                left: node.left.clone(),
49                right: node.right.clone(),
50            },
51        }
52    }
53
54    pub fn index(&self, index: usize) -> Option<T> {
55        match self {
56            Self::Empty => None,
57            Self::Leaf(value) => (index == 0).then_some(value.clone()),
58            Self::Node(node) => {
59                let left_len = node.left.len();
60                if index < left_len {
61                    node.left.index(index)
62                } else {
63                    node.right.index(index - left_len)
64                }
65            }
66        }
67    }
68
69    pub fn index_range(&self, start: usize, end: usize) -> Option<Self> {
70        if start > end || end > self.len() {
71            return None;
72        }
73        let (_, suffix) = self.split_at(start)?;
74        let (range, _) = suffix.split_at(end - start)?;
75        Some(range)
76    }
77
78    pub fn splice(&self, start: usize, end: usize, insert: Self) -> Option<Self> {
79        if start > end || end > self.len() {
80            return None;
81        }
82        let (prefix, rest) = self.split_at(start)?;
83        let (_, suffix) = rest.split_at(end - start)?;
84        Some(Self::concat(prefix, Self::concat(insert, suffix)))
85    }
86
87    pub fn split_at(&self, index: usize) -> Option<(Self, Self)> {
88        if index > self.len() {
89            return None;
90        }
91        Some(self.split_at_unchecked(index))
92    }
93
94    pub fn concat(left: Self, right: Self) -> Self {
95        match (left, right) {
96            (Self::Empty, right) => right,
97            (left, Self::Empty) => left,
98            (left, right) => {
99                let left_height = left.black_height();
100                let right_height = right.black_height();
101                if left_height == right_height {
102                    Self::black_node(left, right)
103                } else if left_height > right_height {
104                    Self::blacken(join_right(left, right, right_height))
105                } else {
106                    Self::blacken(join_left(left, right, left_height))
107                }
108            }
109        }
110    }
111
112    pub fn black_height(&self) -> usize {
113        match self {
114            Self::Empty | Self::Leaf(_) => 0,
115            Self::Node(node) => {
116                let child_height = node.left.black_height();
117                child_height + usize::from(node.color == Color::Black)
118            }
119        }
120    }
121
122    pub fn is_red_black_well_formed(&self) -> bool {
123        self.red_black_status().is_some()
124    }
125
126    fn black_node(left: Self, right: Self) -> Self {
127        Self::node(Color::Black, left, right)
128    }
129
130    fn red_node(left: Self, right: Self) -> Self {
131        Self::node(Color::Red, left, right)
132    }
133
134    fn blacken(tree: Self) -> Self {
135        match tree {
136            Self::Node(node) if node.color == Color::Red => {
137                Self::black_node(node.left.clone(), node.right.clone())
138            }
139            tree => tree,
140        }
141    }
142
143    fn node(color: Color, left: Self, right: Self) -> Self {
144        Self::Node(Rc::new(ListNode {
145            color,
146            len: left.len() + right.len(),
147            left,
148            right,
149        }))
150    }
151
152    fn red_black_status(&self) -> Option<usize> {
153        match self {
154            Self::Empty | Self::Leaf(_) => Some(0),
155            Self::Node(node) => {
156                let left = node.left.red_black_status()?;
157                let right = node.right.red_black_status()?;
158                if left != right {
159                    return None;
160                }
161                if node.color == Color::Red
162                    && (node.left.node_color() == Some(Color::Red)
163                        || node.right.node_color() == Some(Color::Red))
164                {
165                    return None;
166                }
167                Some(left + usize::from(node.color == Color::Black))
168            }
169        }
170    }
171
172    fn node_color(&self) -> Option<Color> {
173        match self {
174            Self::Node(node) => Some(node.color),
175            _ => None,
176        }
177    }
178
179    fn split_at_unchecked(&self, index: usize) -> (Self, Self) {
180        match self {
181            Self::Empty => (Self::Empty, Self::Empty),
182            Self::Leaf(_) if index == 0 => (Self::Empty, self.clone()),
183            Self::Leaf(_) => (self.clone(), Self::Empty),
184            Self::Node(node) => {
185                let left_len = node.left.len();
186                if index < left_len {
187                    let (prefix, left_suffix) = node.left.split_at_unchecked(index);
188                    (prefix, Self::concat(left_suffix, node.right.clone()))
189                } else if index > left_len {
190                    let (right_prefix, suffix) = node.right.split_at_unchecked(index - left_len);
191                    (Self::concat(node.left.clone(), right_prefix), suffix)
192                } else {
193                    (node.left.clone(), node.right.clone())
194                }
195            }
196        }
197    }
198    pub fn from_items(items: impl IntoIterator<Item = T>) -> Self {
199        let leaves = items.into_iter().map(Self::singleton).collect::<Vec<_>>();
200        build_balanced(leaves)
201    }
202
203    pub fn to_vec(&self) -> Vec<T> {
204        let mut out = Vec::with_capacity(self.len());
205        self.push_items(&mut out);
206        out
207    }
208
209    fn push_items(&self, out: &mut Vec<T>) {
210        match self {
211            Self::Empty => {}
212            Self::Leaf(value) => out.push(value.clone()),
213            Self::Node(node) => {
214                node.left.push_items(out);
215                node.right.push_items(out);
216            }
217        }
218    }
219}
220
221#[derive(Debug, Clone, PartialEq, Eq)]
222pub enum ListView<T> {
223    Empty,
224    Leaf(T),
225    Node {
226        color: Color,
227        len: usize,
228        left: ListTree<T>,
229        right: ListTree<T>,
230    },
231}
232
233#[derive(Debug, Clone, Copy, PartialEq, Eq)]
234pub enum Color {
235    Red,
236    Black,
237}
238
239#[derive(Debug, Clone, PartialEq, Eq)]
240pub struct ListNode<T> {
241    pub color: Color,
242    pub len: usize,
243    pub left: ListTree<T>,
244    pub right: ListTree<T>,
245}
246
247fn build_balanced<T: Clone>(mut items: Vec<ListTree<T>>) -> ListTree<T> {
248    if items.is_empty() {
249        return ListTree::Empty;
250    }
251    while items.len() > 1 {
252        let count = items.len();
253        let triple_count = if count % 2 == 1 && count >= 3 { 1 } else { 0 };
254        let mut next = Vec::with_capacity(items.len().div_ceil(2));
255        let mut pairs = items.into_iter();
256        let mut remaining_triples = triple_count;
257        while let Some(first) = pairs.next() {
258            if remaining_triples > 0 {
259                let Some(second) = pairs.next() else {
260                    next.push(first);
261                    break;
262                };
263                let Some(third) = pairs.next() else {
264                    next.push(ListTree::black_node(first, second));
265                    break;
266                };
267                next.push(ListTree::black_node(
268                    ListTree::red_node(first, second),
269                    third,
270                ));
271                remaining_triples -= 1;
272                continue;
273            }
274            match pairs.next() {
275                Some(second) => next.push(ListTree::black_node(first, second)),
276                None => next.push(first),
277            }
278        }
279        items = next;
280    }
281    items.pop().unwrap_or(ListTree::Empty)
282}
283
284fn join_right<T: Clone>(left: ListTree<T>, right: ListTree<T>, right_height: usize) -> ListTree<T> {
285    match left {
286        ListTree::Node(node) if node.right.black_height() > right_height => {
287            let joined = join_right(node.right.clone(), right, right_height);
288            balance(node.color, node.left.clone(), joined)
289        }
290        ListTree::Node(node) => {
291            let joined = ListTree::red_node(node.right.clone(), right);
292            balance(node.color, node.left.clone(), joined)
293        }
294        left => ListTree::red_node(left, right),
295    }
296}
297
298fn join_left<T: Clone>(left: ListTree<T>, right: ListTree<T>, left_height: usize) -> ListTree<T> {
299    match right {
300        ListTree::Node(node) if node.left.black_height() > left_height => {
301            let joined = join_left(left, node.left.clone(), left_height);
302            balance(node.color, joined, node.right.clone())
303        }
304        ListTree::Node(node) => {
305            let joined = ListTree::red_node(left, node.left.clone());
306            balance(node.color, joined, node.right.clone())
307        }
308        right => ListTree::red_node(left, right),
309    }
310}
311
312fn balance<T: Clone>(color: Color, left: ListTree<T>, right: ListTree<T>) -> ListTree<T> {
313    if color != Color::Black {
314        return ListTree::node(color, left, right);
315    }
316
317    if let ListTree::Node(left_node) = &left
318        && left_node.color == Color::Red
319    {
320        if let ListTree::Node(left_left_node) = &left_node.left
321            && left_left_node.color == Color::Red
322        {
323            return ListTree::red_node(
324                ListTree::black_node(left_left_node.left.clone(), left_left_node.right.clone()),
325                ListTree::black_node(left_node.right.clone(), right),
326            );
327        }
328        if let ListTree::Node(left_right_node) = &left_node.right
329            && left_right_node.color == Color::Red
330        {
331            return ListTree::red_node(
332                ListTree::black_node(left_node.left.clone(), left_right_node.left.clone()),
333                ListTree::black_node(left_right_node.right.clone(), right),
334            );
335        }
336    }
337
338    if let ListTree::Node(right_node) = &right
339        && right_node.color == Color::Red
340    {
341        if let ListTree::Node(right_left_node) = &right_node.left
342            && right_left_node.color == Color::Red
343        {
344            return ListTree::red_node(
345                ListTree::black_node(left, right_left_node.left.clone()),
346                ListTree::black_node(right_left_node.right.clone(), right_node.right.clone()),
347            );
348        }
349        if let ListTree::Node(right_right_node) = &right_node.right
350            && right_right_node.color == Color::Red
351        {
352            return ListTree::red_node(
353                ListTree::black_node(left, right_node.left.clone()),
354                ListTree::black_node(
355                    right_right_node.left.clone(),
356                    right_right_node.right.clone(),
357                ),
358            );
359        }
360    }
361
362    ListTree::black_node(left, right)
363}
364
365#[cfg(test)]
366mod tests {
367    use super::{Color, ListTree, ListView};
368
369    #[test]
370    fn list_tree_from_items_forms_red_black_tree() {
371        for len in 0..16 {
372            let list = ListTree::from_items(0..len);
373            assert!(list.is_red_black_well_formed(), "len={len}");
374        }
375    }
376
377    #[test]
378    fn list_tree_concat_preserves_binary_view() {
379        let list = ListTree::concat(ListTree::from_items([1, 2]), ListTree::from_items([3, 4]));
380        let ListView::Node {
381            color,
382            len,
383            left,
384            right,
385        } = list.view()
386        else {
387            panic!("expected node view");
388        };
389
390        assert_eq!(color, Color::Black);
391        assert_eq!(len, 4);
392        assert_eq!(left.to_vec(), vec![1, 2]);
393        assert_eq!(right.to_vec(), vec![3, 4]);
394    }
395
396    #[test]
397    fn list_tree_range_and_splice_avoid_flat_runtime_shape() {
398        let list = ListTree::from_items([10, 20, 30, 40]);
399        assert_eq!(list.index_range(1, 3).unwrap().to_vec(), vec![20, 30]);
400        assert_eq!(
401            list.splice(1, 3, ListTree::from_items([99, 98]))
402                .unwrap()
403                .to_vec(),
404            vec![10, 99, 98, 40]
405        );
406    }
407
408    #[test]
409    fn list_tree_split_preserves_red_black_shape() {
410        let list = ListTree::from_items(0..4096);
411
412        for index in [0, 1, 17, 2048, 4095, 4096] {
413            let (prefix, suffix) = list.split_at(index).unwrap();
414            assert!(prefix.is_red_black_well_formed(), "prefix index={index}");
415            assert!(suffix.is_red_black_well_formed(), "suffix index={index}");
416            assert_eq!(prefix.len(), index);
417            assert_eq!(suffix.len(), 4096 - index);
418            assert_eq!(ListTree::concat(prefix, suffix).to_vec(), list.to_vec());
419        }
420    }
421
422    #[test]
423    fn list_tree_range_preserves_red_black_shape() {
424        let list = ListTree::from_items(0..4096);
425        let range = list.index_range(17, 4095).unwrap();
426
427        assert!(range.is_red_black_well_formed());
428        assert_eq!(range.len(), 4078);
429        assert_eq!(range.index(0), Some(17));
430        assert_eq!(range.index(4077), Some(4094));
431    }
432
433    #[test]
434    fn list_tree_repeated_singleton_concat_stays_balanced() {
435        let mut list = ListTree::empty();
436        for item in 0..4096 {
437            list = ListTree::concat(list, ListTree::singleton(item));
438        }
439
440        assert!(list.is_red_black_well_formed());
441        assert_eq!(
442            list.index_range(4090, 4096).unwrap().to_vec(),
443            vec![4090, 4091, 4092, 4093, 4094, 4095]
444        );
445    }
446
447    #[test]
448    fn list_tree_repeated_singleton_prepend_stays_balanced() {
449        let mut list = ListTree::empty();
450        for item in 0..4096 {
451            list = ListTree::concat(ListTree::singleton(item), list);
452        }
453
454        assert!(list.is_red_black_well_formed());
455        assert_eq!(
456            list.index_range(0, 6).unwrap().to_vec(),
457            vec![4095, 4094, 4093, 4092, 4091, 4090]
458        );
459    }
460}