1use std::rc::Rc;
2
3pub const MAX_LEAF_CHARS: usize = 64;
4
5#[derive(Debug, PartialEq, Eq)]
6pub enum StringTree<S = Rc<str>>
7where
8 S: StringLeaf,
9{
10 Empty,
11 Leaf(S),
12 Node(Rc<StringNode<S>>),
13}
14
15impl<S> Clone for StringTree<S>
16where
17 S: StringLeaf,
18{
19 fn clone(&self) -> Self {
20 match self {
21 Self::Empty => Self::Empty,
22 Self::Leaf(value) => Self::Leaf(value.clone()),
23 Self::Node(node) => Self::Node(node.clone()),
24 }
25 }
26}
27
28pub trait StringLeaf: Clone + std::fmt::Debug + PartialEq + Eq {
29 fn from_string(value: String) -> Self;
30 fn as_str(&self) -> &str;
31}
32
33impl StringLeaf for Rc<str> {
34 fn from_string(value: String) -> Self {
35 Rc::from(value.into_boxed_str())
36 }
37
38 fn as_str(&self) -> &str {
39 self
40 }
41}
42
43impl StringLeaf for String {
44 fn from_string(value: String) -> Self {
45 value
46 }
47
48 fn as_str(&self) -> &str {
49 self
50 }
51}
52
53impl StringLeaf for Box<str> {
54 fn from_string(value: String) -> Self {
55 value.into_boxed_str()
56 }
57
58 fn as_str(&self) -> &str {
59 self
60 }
61}
62
63impl<S> StringTree<S>
64where
65 S: StringLeaf,
66{
67 pub fn empty() -> Self {
68 Self::Empty
69 }
70
71 pub fn from_str(value: &str) -> Self {
72 let leaves = chunk_str(value)
73 .into_iter()
74 .map(Self::leaf)
75 .collect::<Vec<_>>();
76 build_balanced(leaves)
77 }
78
79 pub fn len(&self) -> usize {
80 match self {
81 Self::Empty => 0,
82 Self::Leaf(value) => value.as_str().chars().count(),
83 Self::Node(node) => node.len_chars,
84 }
85 }
86
87 pub fn len_bytes(&self) -> usize {
88 match self {
89 Self::Empty => 0,
90 Self::Leaf(value) => value.as_str().len(),
91 Self::Node(node) => node.len_bytes,
92 }
93 }
94
95 pub fn is_empty(&self) -> bool {
96 matches!(self, Self::Empty)
97 }
98
99 pub fn concat(left: Self, right: Self) -> Self {
100 match (left, right) {
101 (Self::Empty, right) => right,
102 (left, Self::Empty) => left,
103 (Self::Leaf(left), Self::Leaf(right))
104 if left.as_str().chars().count() + right.as_str().chars().count()
105 <= MAX_LEAF_CHARS =>
106 {
107 let mut merged = String::with_capacity(left.as_str().len() + right.as_str().len());
108 merged.push_str(left.as_str());
109 merged.push_str(right.as_str());
110 Self::leaf(merged)
111 }
112 (left, right) if left.black_height() == right.black_height() => {
113 Self::black_node(left, right)
114 }
115 (left, right) => {
116 let left_height = left.black_height();
117 let right_height = right.black_height();
118 if left_height > right_height {
119 Self::blacken(join_right(left, right, right_height))
120 } else {
121 Self::blacken(join_left(left, right, left_height))
122 }
123 }
124 }
125 }
126
127 pub fn index(&self, index: usize) -> Option<char> {
128 match self {
129 Self::Empty => None,
130 Self::Leaf(value) => value.as_str().chars().nth(index),
131 Self::Node(node) => {
132 let left_len = node.left.len();
133 if index < left_len {
134 node.left.index(index)
135 } else {
136 node.right.index(index - left_len)
137 }
138 }
139 }
140 }
141
142 pub fn index_range(&self, start: usize, end: usize) -> Option<Self> {
143 if start > end || end > self.len() {
144 return None;
145 }
146 if start == 0 && end == self.len() {
147 return Some(self.clone());
148 }
149 match self {
150 Self::Empty => Some(Self::Empty),
151 Self::Leaf(value) => Some(Self::from_str(slice_str_by_chars(
152 value.as_str(),
153 start,
154 end,
155 )?)),
156 Self::Node(node) => {
157 let left_len = node.left.len();
158 if end <= left_len {
159 node.left.index_range(start, end)
160 } else if start >= left_len {
161 node.right.index_range(start - left_len, end - left_len)
162 } else {
163 let left = node.left.index_range(start, left_len)?;
164 let right = node.right.index_range(0, end - left_len)?;
165 Some(Self::concat(left, right))
166 }
167 }
168 }
169 }
170
171 pub fn splice(&self, start: usize, end: usize, insert: Self) -> Option<Self> {
172 if start > end || end > self.len() {
173 return None;
174 }
175 let prefix = self.index_range(0, start)?;
176 let suffix = self.index_range(end, self.len())?;
177 Some(Self::concat(prefix, Self::concat(insert, suffix)))
178 }
179
180 pub fn to_flat_string(&self) -> String {
181 let mut out = String::with_capacity(self.len_bytes());
182 self.push_str(&mut out);
183 out
184 }
185
186 pub fn view(&self) -> StringView<S> {
187 match self {
188 Self::Empty => StringView::Empty,
189 Self::Leaf(value) => StringView::Leaf(value.clone()),
190 Self::Node(node) => StringView::Node {
191 color: node.color,
192 len_chars: node.len_chars,
193 len_bytes: node.len_bytes,
194 left: node.left.clone(),
195 right: node.right.clone(),
196 },
197 }
198 }
199
200 pub fn black_height(&self) -> usize {
201 match self {
202 Self::Empty | Self::Leaf(_) => 0,
203 Self::Node(node) => {
204 let child_height = node.left.black_height();
205 child_height + usize::from(node.color == Color::Black)
206 }
207 }
208 }
209
210 pub fn is_red_black_well_formed(&self) -> bool {
211 self.red_black_status().is_some()
212 }
213
214 fn leaf(value: impl Into<String>) -> Self {
215 let value = value.into();
216 if value.is_empty() {
217 Self::Empty
218 } else {
219 Self::Leaf(S::from_string(value))
220 }
221 }
222
223 fn black_node(left: Self, right: Self) -> Self {
224 Self::node(Color::Black, left, right)
225 }
226
227 fn red_node(left: Self, right: Self) -> Self {
228 Self::node(Color::Red, left, right)
229 }
230
231 fn blacken(tree: Self) -> Self {
232 match tree {
233 Self::Node(node) if node.color == Color::Red => {
234 Self::black_node(node.left.clone(), node.right.clone())
235 }
236 tree => tree,
237 }
238 }
239
240 fn node(color: Color, left: Self, right: Self) -> Self {
241 Self::Node(Rc::new(StringNode {
242 color,
243 len_chars: left.len() + right.len(),
244 len_bytes: left.len_bytes() + right.len_bytes(),
245 left,
246 right,
247 }))
248 }
249
250 fn push_str(&self, out: &mut String) {
251 match self {
252 Self::Empty => {}
253 Self::Leaf(value) => out.push_str(value.as_str()),
254 Self::Node(node) => {
255 node.left.push_str(out);
256 node.right.push_str(out);
257 }
258 }
259 }
260
261 fn red_black_status(&self) -> Option<usize> {
262 match self {
263 Self::Empty | Self::Leaf(_) => Some(0),
264 Self::Node(node) => {
265 let left = node.left.red_black_status()?;
266 let right = node.right.red_black_status()?;
267 if left != right {
268 return None;
269 }
270 if node.color == Color::Red
271 && (node.left.node_color() == Some(Color::Red)
272 || node.right.node_color() == Some(Color::Red))
273 {
274 return None;
275 }
276 Some(left + usize::from(node.color == Color::Black))
277 }
278 }
279 }
280
281 fn node_color(&self) -> Option<Color> {
282 match self {
283 Self::Node(node) => Some(node.color),
284 _ => None,
285 }
286 }
287}
288
289fn join_right<S>(left: StringTree<S>, right: StringTree<S>, right_height: usize) -> StringTree<S>
290where
291 S: StringLeaf,
292{
293 match left {
294 StringTree::Node(node) if node.right.black_height() > right_height => {
295 let joined = join_right(node.right.clone(), right, right_height);
296 balance(node.color, node.left.clone(), joined)
297 }
298 StringTree::Node(node) => {
299 let joined = StringTree::red_node(node.right.clone(), right);
300 balance(node.color, node.left.clone(), joined)
301 }
302 left => StringTree::red_node(left, right),
303 }
304}
305
306fn join_left<S>(left: StringTree<S>, right: StringTree<S>, left_height: usize) -> StringTree<S>
307where
308 S: StringLeaf,
309{
310 match right {
311 StringTree::Node(node) if node.left.black_height() > left_height => {
312 let joined = join_left(left, node.left.clone(), left_height);
313 balance(node.color, joined, node.right.clone())
314 }
315 StringTree::Node(node) => {
316 let joined = StringTree::red_node(left, node.left.clone());
317 balance(node.color, joined, node.right.clone())
318 }
319 right => StringTree::red_node(left, right),
320 }
321}
322
323fn balance<S>(color: Color, left: StringTree<S>, right: StringTree<S>) -> StringTree<S>
324where
325 S: StringLeaf,
326{
327 if color != Color::Black {
328 return StringTree::node(color, left, right);
329 }
330
331 if let StringTree::Node(left_node) = &left
332 && left_node.color == Color::Red
333 {
334 if let StringTree::Node(left_left_node) = &left_node.left
335 && left_left_node.color == Color::Red
336 {
337 return StringTree::red_node(
338 StringTree::black_node(left_left_node.left.clone(), left_left_node.right.clone()),
339 StringTree::black_node(left_node.right.clone(), right),
340 );
341 }
342 if let StringTree::Node(left_right_node) = &left_node.right
343 && left_right_node.color == Color::Red
344 {
345 return StringTree::red_node(
346 StringTree::black_node(left_node.left.clone(), left_right_node.left.clone()),
347 StringTree::black_node(left_right_node.right.clone(), right),
348 );
349 }
350 }
351
352 if let StringTree::Node(right_node) = &right
353 && right_node.color == Color::Red
354 {
355 if let StringTree::Node(right_left_node) = &right_node.left
356 && right_left_node.color == Color::Red
357 {
358 return StringTree::red_node(
359 StringTree::black_node(left, right_left_node.left.clone()),
360 StringTree::black_node(right_left_node.right.clone(), right_node.right.clone()),
361 );
362 }
363 if let StringTree::Node(right_right_node) = &right_node.right
364 && right_right_node.color == Color::Red
365 {
366 return StringTree::red_node(
367 StringTree::black_node(left, right_node.left.clone()),
368 StringTree::black_node(
369 right_right_node.left.clone(),
370 right_right_node.right.clone(),
371 ),
372 );
373 }
374 }
375
376 StringTree::black_node(left, right)
377}
378
379impl<S> From<&str> for StringTree<S>
380where
381 S: StringLeaf,
382{
383 fn from(value: &str) -> Self {
384 Self::from_str(value)
385 }
386}
387
388impl<S> From<String> for StringTree<S>
389where
390 S: StringLeaf,
391{
392 fn from(value: String) -> Self {
393 Self::from_str(&value)
394 }
395}
396
397#[derive(Debug, Clone, PartialEq, Eq)]
398pub enum StringView<S = Rc<str>>
399where
400 S: StringLeaf,
401{
402 Empty,
403 Leaf(S),
404 Node {
405 color: Color,
406 len_chars: usize,
407 len_bytes: usize,
408 left: StringTree<S>,
409 right: StringTree<S>,
410 },
411}
412
413#[derive(Debug, Clone, Copy, PartialEq, Eq)]
414pub enum Color {
415 Red,
416 Black,
417}
418
419#[derive(Debug, Clone, PartialEq, Eq)]
420pub struct StringNode<S = Rc<str>>
421where
422 S: StringLeaf,
423{
424 pub color: Color,
425 pub len_chars: usize,
426 pub len_bytes: usize,
427 pub left: StringTree<S>,
428 pub right: StringTree<S>,
429}
430
431fn chunk_str(value: &str) -> Vec<String> {
432 let mut chunks = Vec::new();
433 let mut current = String::new();
434 let mut current_chars = 0usize;
435 for ch in value.chars() {
436 if current_chars >= MAX_LEAF_CHARS {
437 chunks.push(std::mem::take(&mut current));
438 current_chars = 0;
439 }
440 current.push(ch);
441 current_chars += 1;
442 }
443 if !current.is_empty() {
444 chunks.push(current);
445 }
446 chunks
447}
448
449fn build_balanced<S>(mut items: Vec<StringTree<S>>) -> StringTree<S>
450where
451 S: StringLeaf,
452{
453 items.retain(|item| !item.is_empty());
454 if items.is_empty() {
455 return StringTree::Empty;
456 }
457 while items.len() > 1 {
458 let count = items.len();
459 let triple_count = if count % 2 == 1 && count >= 3 { 1 } else { 0 };
460 let mut next = Vec::with_capacity(items.len().div_ceil(2));
461 let mut pairs = items.into_iter();
462 let mut remaining_triples = triple_count;
463 while let Some(first) = pairs.next() {
464 if remaining_triples > 0 {
465 let Some(second) = pairs.next() else {
466 next.push(first);
467 break;
468 };
469 let Some(third) = pairs.next() else {
470 next.push(StringTree::black_node(first, second));
471 break;
472 };
473 next.push(StringTree::black_node(
474 StringTree::red_node(first, second),
475 third,
476 ));
477 remaining_triples -= 1;
478 continue;
479 }
480 match pairs.next() {
481 Some(second) => next.push(StringTree::black_node(first, second)),
482 None => next.push(first),
483 }
484 }
485 items = next;
486 }
487 items.pop().unwrap_or(StringTree::Empty)
488}
489
490fn slice_str_by_chars(value: &str, start: usize, end: usize) -> Option<&str> {
491 if start > end {
492 return None;
493 }
494 let start_byte = byte_index_for_char(value, start)?;
495 let end_byte = byte_index_for_char(value, end)?;
496 value.get(start_byte..end_byte)
497}
498
499fn byte_index_for_char(value: &str, index: usize) -> Option<usize> {
500 if index == value.chars().count() {
501 return Some(value.len());
502 }
503 value.char_indices().nth(index).map(|(offset, _)| offset)
504}
505
506#[cfg(test)]
507mod tests {
508 use super::{Color, MAX_LEAF_CHARS, StringTree, StringView};
509
510 type RuntimeStringTree = StringTree<std::rc::Rc<str>>;
511
512 #[test]
513 fn string_tree_tracks_char_and_byte_len() {
514 let text = RuntimeStringTree::from_str("aćš");
515
516 assert_eq!(text.len(), 3);
517 assert_eq!(text.len_bytes(), "aćš".len());
518 assert_eq!(text.to_flat_string(), "aćš");
519 }
520
521 #[test]
522 fn string_tree_chunks_large_leaves() {
523 let source = "x".repeat(MAX_LEAF_CHARS + 1);
524 let text = RuntimeStringTree::from_str(&source);
525
526 assert!(matches!(text.view(), StringView::Node { .. }));
527 assert_eq!(text.to_flat_string(), source);
528 assert!(text.is_red_black_well_formed());
529 }
530
531 #[test]
532 fn string_tree_concat_range_and_splice_use_tree_operations() {
533 let text = RuntimeStringTree::concat(
534 RuntimeStringTree::from_str("ać"),
535 RuntimeStringTree::from_str("šz"),
536 );
537 let (StringView::Leaf(_) | StringView::Node { .. }) = text.view() else {
538 panic!("expected non-empty text");
539 };
540
541 assert_eq!(text.index(1), Some('ć'));
542 assert_eq!(text.index_range(1, 3).unwrap().to_flat_string(), "ćš");
543 assert_eq!(
544 text.splice(1, 3, RuntimeStringTree::from_str("bc"))
545 .unwrap()
546 .to_flat_string(),
547 "abcz"
548 );
549 }
550
551 #[test]
552 fn string_tree_view_reports_node_metadata() {
553 let text = RuntimeStringTree::concat(
554 RuntimeStringTree::from_str(&"a".repeat(MAX_LEAF_CHARS)),
555 RuntimeStringTree::from_str(&"b".repeat(MAX_LEAF_CHARS)),
556 );
557 let StringView::Node {
558 color,
559 len_chars,
560 len_bytes,
561 ..
562 } = text.view()
563 else {
564 panic!("expected node view");
565 };
566
567 assert_eq!(color, Color::Black);
568 assert_eq!(len_chars, MAX_LEAF_CHARS * 2);
569 assert_eq!(len_bytes, MAX_LEAF_CHARS * 2);
570 }
571
572 #[test]
573 fn string_tree_repeated_singleton_concat_stays_balanced() {
574 let mut text = RuntimeStringTree::empty();
575 let mut expected = String::new();
576 for index in 0..4096 {
577 let ch = char::from(b'a' + (index % 26) as u8);
578 expected.push(ch);
579 text = RuntimeStringTree::concat(text, RuntimeStringTree::from_str(&ch.to_string()));
580 }
581
582 assert!(text.is_red_black_well_formed());
583 assert_eq!(text.len(), 4096);
584 assert_eq!(text.to_flat_string(), expected);
585 }
586
587 #[test]
588 fn string_tree_repeated_singleton_prepend_stays_balanced() {
589 let mut text = RuntimeStringTree::empty();
590 let mut expected = String::new();
591 for index in 0..4096 {
592 let ch = char::from(b'a' + (index % 26) as u8);
593 expected.insert(0, ch);
594 text = RuntimeStringTree::concat(RuntimeStringTree::from_str(&ch.to_string()), text);
595 }
596
597 assert!(text.is_red_black_well_formed());
598 assert_eq!(text.len(), 4096);
599 assert_eq!(text.to_flat_string(), expected);
600 }
601}