1use std::rc::Rc;
2
3#[derive(Debug, PartialEq, Eq)]
4pub enum ListTree<T> {
5 Empty,
6 Leaf(T),
7 Node(Rc<ListNode<T>>),
8}
9
10impl<T: Clone> Clone for ListTree<T> {
11 fn clone(&self) -> Self {
12 match self {
13 Self::Empty => Self::Empty,
14 Self::Leaf(value) => Self::Leaf(value.clone()),
15 Self::Node(node) => Self::Node(node.clone()),
16 }
17 }
18}
19
20impl<T: Clone> ListTree<T> {
21 pub fn empty() -> Self {
22 Self::Empty
23 }
24
25 pub fn singleton(value: T) -> Self {
26 Self::Leaf(value)
27 }
28
29 pub fn len(&self) -> usize {
30 match self {
31 Self::Empty => 0,
32 Self::Leaf(_) => 1,
33 Self::Node(node) => node.len,
34 }
35 }
36
37 pub fn is_empty(&self) -> bool {
38 matches!(self, Self::Empty)
39 }
40
41 pub fn view(&self) -> ListView<T> {
42 match self {
43 Self::Empty => ListView::Empty,
44 Self::Leaf(value) => ListView::Leaf(value.clone()),
45 Self::Node(node) => ListView::Node {
46 color: node.color,
47 len: node.len,
48 left: node.left.clone(),
49 right: node.right.clone(),
50 },
51 }
52 }
53
54 pub fn index(&self, index: usize) -> Option<T> {
55 match self {
56 Self::Empty => None,
57 Self::Leaf(value) => (index == 0).then_some(value.clone()),
58 Self::Node(node) => {
59 let left_len = node.left.len();
60 if index < left_len {
61 node.left.index(index)
62 } else {
63 node.right.index(index - left_len)
64 }
65 }
66 }
67 }
68
69 pub fn index_range(&self, start: usize, end: usize) -> Option<Self> {
70 if start > end || end > self.len() {
71 return None;
72 }
73 let (_, suffix) = self.split_at(start)?;
74 let (range, _) = suffix.split_at(end - start)?;
75 Some(range)
76 }
77
78 pub fn splice(&self, start: usize, end: usize, insert: Self) -> Option<Self> {
79 if start > end || end > self.len() {
80 return None;
81 }
82 let (prefix, rest) = self.split_at(start)?;
83 let (_, suffix) = rest.split_at(end - start)?;
84 Some(Self::concat(prefix, Self::concat(insert, suffix)))
85 }
86
87 pub fn split_at(&self, index: usize) -> Option<(Self, Self)> {
88 if index > self.len() {
89 return None;
90 }
91 Some(self.split_at_unchecked(index))
92 }
93
94 pub fn concat(left: Self, right: Self) -> Self {
95 match (left, right) {
96 (Self::Empty, right) => right,
97 (left, Self::Empty) => left,
98 (left, right) => {
99 let left_height = left.black_height();
100 let right_height = right.black_height();
101 if left_height == right_height {
102 Self::black_node(left, right)
103 } else if left_height > right_height {
104 Self::blacken(join_right(left, right, right_height))
105 } else {
106 Self::blacken(join_left(left, right, left_height))
107 }
108 }
109 }
110 }
111
112 pub fn black_height(&self) -> usize {
113 match self {
114 Self::Empty | Self::Leaf(_) => 0,
115 Self::Node(node) => {
116 let child_height = node.left.black_height();
117 child_height + usize::from(node.color == Color::Black)
118 }
119 }
120 }
121
122 pub fn is_red_black_well_formed(&self) -> bool {
123 self.red_black_status().is_some()
124 }
125
126 fn black_node(left: Self, right: Self) -> Self {
127 Self::node(Color::Black, left, right)
128 }
129
130 fn red_node(left: Self, right: Self) -> Self {
131 Self::node(Color::Red, left, right)
132 }
133
134 fn blacken(tree: Self) -> Self {
135 match tree {
136 Self::Node(node) if node.color == Color::Red => {
137 Self::black_node(node.left.clone(), node.right.clone())
138 }
139 tree => tree,
140 }
141 }
142
143 fn node(color: Color, left: Self, right: Self) -> Self {
144 Self::Node(Rc::new(ListNode {
145 color,
146 len: left.len() + right.len(),
147 left,
148 right,
149 }))
150 }
151
152 fn red_black_status(&self) -> Option<usize> {
153 match self {
154 Self::Empty | Self::Leaf(_) => Some(0),
155 Self::Node(node) => {
156 let left = node.left.red_black_status()?;
157 let right = node.right.red_black_status()?;
158 if left != right {
159 return None;
160 }
161 if node.color == Color::Red
162 && (node.left.node_color() == Some(Color::Red)
163 || node.right.node_color() == Some(Color::Red))
164 {
165 return None;
166 }
167 Some(left + usize::from(node.color == Color::Black))
168 }
169 }
170 }
171
172 fn node_color(&self) -> Option<Color> {
173 match self {
174 Self::Node(node) => Some(node.color),
175 _ => None,
176 }
177 }
178
179 fn split_at_unchecked(&self, index: usize) -> (Self, Self) {
180 match self {
181 Self::Empty => (Self::Empty, Self::Empty),
182 Self::Leaf(_) if index == 0 => (Self::Empty, self.clone()),
183 Self::Leaf(_) => (self.clone(), Self::Empty),
184 Self::Node(node) => {
185 let left_len = node.left.len();
186 if index < left_len {
187 let (prefix, left_suffix) = node.left.split_at_unchecked(index);
188 (prefix, Self::concat(left_suffix, node.right.clone()))
189 } else if index > left_len {
190 let (right_prefix, suffix) = node.right.split_at_unchecked(index - left_len);
191 (Self::concat(node.left.clone(), right_prefix), suffix)
192 } else {
193 (node.left.clone(), node.right.clone())
194 }
195 }
196 }
197 }
198 pub fn from_items(items: impl IntoIterator<Item = T>) -> Self {
199 let leaves = items.into_iter().map(Self::singleton).collect::<Vec<_>>();
200 build_balanced(leaves)
201 }
202
203 pub fn to_vec(&self) -> Vec<T> {
204 let mut out = Vec::with_capacity(self.len());
205 self.push_items(&mut out);
206 out
207 }
208
209 fn push_items(&self, out: &mut Vec<T>) {
210 match self {
211 Self::Empty => {}
212 Self::Leaf(value) => out.push(value.clone()),
213 Self::Node(node) => {
214 node.left.push_items(out);
215 node.right.push_items(out);
216 }
217 }
218 }
219}
220
221#[derive(Debug, Clone, PartialEq, Eq)]
222pub enum ListView<T> {
223 Empty,
224 Leaf(T),
225 Node {
226 color: Color,
227 len: usize,
228 left: ListTree<T>,
229 right: ListTree<T>,
230 },
231}
232
233#[derive(Debug, Clone, Copy, PartialEq, Eq)]
234pub enum Color {
235 Red,
236 Black,
237}
238
239#[derive(Debug, Clone, PartialEq, Eq)]
240pub struct ListNode<T> {
241 pub color: Color,
242 pub len: usize,
243 pub left: ListTree<T>,
244 pub right: ListTree<T>,
245}
246
247fn build_balanced<T: Clone>(mut items: Vec<ListTree<T>>) -> ListTree<T> {
248 if items.is_empty() {
249 return ListTree::Empty;
250 }
251 while items.len() > 1 {
252 let count = items.len();
253 let triple_count = if count % 2 == 1 && count >= 3 { 1 } else { 0 };
254 let mut next = Vec::with_capacity(items.len().div_ceil(2));
255 let mut pairs = items.into_iter();
256 let mut remaining_triples = triple_count;
257 while let Some(first) = pairs.next() {
258 if remaining_triples > 0 {
259 let Some(second) = pairs.next() else {
260 next.push(first);
261 break;
262 };
263 let Some(third) = pairs.next() else {
264 next.push(ListTree::black_node(first, second));
265 break;
266 };
267 next.push(ListTree::black_node(
268 ListTree::red_node(first, second),
269 third,
270 ));
271 remaining_triples -= 1;
272 continue;
273 }
274 match pairs.next() {
275 Some(second) => next.push(ListTree::black_node(first, second)),
276 None => next.push(first),
277 }
278 }
279 items = next;
280 }
281 items.pop().unwrap_or(ListTree::Empty)
282}
283
284fn join_right<T: Clone>(left: ListTree<T>, right: ListTree<T>, right_height: usize) -> ListTree<T> {
285 match left {
286 ListTree::Node(node) if node.right.black_height() > right_height => {
287 let joined = join_right(node.right.clone(), right, right_height);
288 balance(node.color, node.left.clone(), joined)
289 }
290 ListTree::Node(node) => {
291 let joined = ListTree::red_node(node.right.clone(), right);
292 balance(node.color, node.left.clone(), joined)
293 }
294 left => ListTree::red_node(left, right),
295 }
296}
297
298fn join_left<T: Clone>(left: ListTree<T>, right: ListTree<T>, left_height: usize) -> ListTree<T> {
299 match right {
300 ListTree::Node(node) if node.left.black_height() > left_height => {
301 let joined = join_left(left, node.left.clone(), left_height);
302 balance(node.color, joined, node.right.clone())
303 }
304 ListTree::Node(node) => {
305 let joined = ListTree::red_node(left, node.left.clone());
306 balance(node.color, joined, node.right.clone())
307 }
308 right => ListTree::red_node(left, right),
309 }
310}
311
312fn balance<T: Clone>(color: Color, left: ListTree<T>, right: ListTree<T>) -> ListTree<T> {
313 if color != Color::Black {
314 return ListTree::node(color, left, right);
315 }
316
317 if let ListTree::Node(left_node) = &left
318 && left_node.color == Color::Red
319 {
320 if let ListTree::Node(left_left_node) = &left_node.left
321 && left_left_node.color == Color::Red
322 {
323 return ListTree::red_node(
324 ListTree::black_node(left_left_node.left.clone(), left_left_node.right.clone()),
325 ListTree::black_node(left_node.right.clone(), right),
326 );
327 }
328 if let ListTree::Node(left_right_node) = &left_node.right
329 && left_right_node.color == Color::Red
330 {
331 return ListTree::red_node(
332 ListTree::black_node(left_node.left.clone(), left_right_node.left.clone()),
333 ListTree::black_node(left_right_node.right.clone(), right),
334 );
335 }
336 }
337
338 if let ListTree::Node(right_node) = &right
339 && right_node.color == Color::Red
340 {
341 if let ListTree::Node(right_left_node) = &right_node.left
342 && right_left_node.color == Color::Red
343 {
344 return ListTree::red_node(
345 ListTree::black_node(left, right_left_node.left.clone()),
346 ListTree::black_node(right_left_node.right.clone(), right_node.right.clone()),
347 );
348 }
349 if let ListTree::Node(right_right_node) = &right_node.right
350 && right_right_node.color == Color::Red
351 {
352 return ListTree::red_node(
353 ListTree::black_node(left, right_node.left.clone()),
354 ListTree::black_node(
355 right_right_node.left.clone(),
356 right_right_node.right.clone(),
357 ),
358 );
359 }
360 }
361
362 ListTree::black_node(left, right)
363}
364
365#[cfg(test)]
366mod tests {
367 use super::{Color, ListTree, ListView};
368
369 #[test]
370 fn list_tree_from_items_forms_red_black_tree() {
371 for len in 0..16 {
372 let list = ListTree::from_items(0..len);
373 assert!(list.is_red_black_well_formed(), "len={len}");
374 }
375 }
376
377 #[test]
378 fn list_tree_concat_preserves_binary_view() {
379 let list = ListTree::concat(ListTree::from_items([1, 2]), ListTree::from_items([3, 4]));
380 let ListView::Node {
381 color,
382 len,
383 left,
384 right,
385 } = list.view()
386 else {
387 panic!("expected node view");
388 };
389
390 assert_eq!(color, Color::Black);
391 assert_eq!(len, 4);
392 assert_eq!(left.to_vec(), vec![1, 2]);
393 assert_eq!(right.to_vec(), vec![3, 4]);
394 }
395
396 #[test]
397 fn list_tree_range_and_splice_avoid_flat_runtime_shape() {
398 let list = ListTree::from_items([10, 20, 30, 40]);
399 assert_eq!(list.index_range(1, 3).unwrap().to_vec(), vec![20, 30]);
400 assert_eq!(
401 list.splice(1, 3, ListTree::from_items([99, 98]))
402 .unwrap()
403 .to_vec(),
404 vec![10, 99, 98, 40]
405 );
406 }
407
408 #[test]
409 fn list_tree_split_preserves_red_black_shape() {
410 let list = ListTree::from_items(0..4096);
411
412 for index in [0, 1, 17, 2048, 4095, 4096] {
413 let (prefix, suffix) = list.split_at(index).unwrap();
414 assert!(prefix.is_red_black_well_formed(), "prefix index={index}");
415 assert!(suffix.is_red_black_well_formed(), "suffix index={index}");
416 assert_eq!(prefix.len(), index);
417 assert_eq!(suffix.len(), 4096 - index);
418 assert_eq!(ListTree::concat(prefix, suffix).to_vec(), list.to_vec());
419 }
420 }
421
422 #[test]
423 fn list_tree_range_preserves_red_black_shape() {
424 let list = ListTree::from_items(0..4096);
425 let range = list.index_range(17, 4095).unwrap();
426
427 assert!(range.is_red_black_well_formed());
428 assert_eq!(range.len(), 4078);
429 assert_eq!(range.index(0), Some(17));
430 assert_eq!(range.index(4077), Some(4094));
431 }
432
433 #[test]
434 fn list_tree_repeated_singleton_concat_stays_balanced() {
435 let mut list = ListTree::empty();
436 for item in 0..4096 {
437 list = ListTree::concat(list, ListTree::singleton(item));
438 }
439
440 assert!(list.is_red_black_well_formed());
441 assert_eq!(
442 list.index_range(4090, 4096).unwrap().to_vec(),
443 vec![4090, 4091, 4092, 4093, 4094, 4095]
444 );
445 }
446
447 #[test]
448 fn list_tree_repeated_singleton_prepend_stays_balanced() {
449 let mut list = ListTree::empty();
450 for item in 0..4096 {
451 list = ListTree::concat(ListTree::singleton(item), list);
452 }
453
454 assert!(list.is_red_black_well_formed());
455 assert_eq!(
456 list.index_range(0, 6).unwrap().to_vec(),
457 vec![4095, 4094, 4093, 4092, 4091, 4090]
458 );
459 }
460}