1use super::*;
2use heapless::Vec as ArrayVec;
3use std::{cmp::Ordering, mem, sync::Arc};
4
5#[derive(Clone)]
6struct StackEntry<'a, T: Item, D> {
7 tree: &'a SumTree<T>,
8 index: u32,
9 position: D,
10}
11
12impl<'a, T: Item, D> StackEntry<'a, T, D> {
13 #[inline]
14 fn index(&self) -> usize {
15 self.index as usize
16 }
17}
18
19impl<T: Item + fmt::Debug, D: fmt::Debug> fmt::Debug for StackEntry<'_, T, D> {
20 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21 f.debug_struct("StackEntry")
22 .field("index", &self.index)
23 .field("position", &self.position)
24 .finish()
25 }
26}
27
28#[derive(Clone)]
29pub struct Cursor<'a, 'b, T: Item, D> {
30 tree: &'a SumTree<T>,
31 stack: ArrayVec<StackEntry<'a, T, D>, 16, u8>,
32 pub position: D,
33 did_seek: bool,
34 at_end: bool,
35 cx: <T::Summary as Summary>::Context<'b>,
36}
37
38impl<T: Item + fmt::Debug, D: fmt::Debug> fmt::Debug for Cursor<'_, '_, T, D>
39where
40 T::Summary: fmt::Debug,
41{
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43 f.debug_struct("Cursor")
44 .field("tree", &self.tree)
45 .field("stack", &self.stack)
46 .field("position", &self.position)
47 .field("did_seek", &self.did_seek)
48 .field("at_end", &self.at_end)
49 .finish()
50 }
51}
52
53pub struct Iter<'a, T: Item> {
54 tree: &'a SumTree<T>,
55 stack: ArrayVec<StackEntry<'a, T, ()>, 16, u8>,
56}
57
58impl<'a, 'b, T, D> Cursor<'a, 'b, T, D>
59where
60 T: Item,
61 D: Dimension<'a, T::Summary>,
62{
63 pub fn new(tree: &'a SumTree<T>, cx: <T::Summary as Summary>::Context<'b>) -> Self {
64 Self {
65 tree,
66 stack: ArrayVec::new(),
67 position: D::zero(cx),
68 did_seek: false,
69 at_end: tree.is_empty(),
70 cx,
71 }
72 }
73
74 pub fn reset(&mut self) {
75 self.did_seek = false;
76 self.at_end = self.tree.is_empty();
77 self.stack.truncate(0);
78 self.position = D::zero(self.cx);
79 }
80
81 pub fn start(&self) -> &D {
82 &self.position
83 }
84
85 #[track_caller]
86 pub fn end(&self) -> D {
87 if let Some(item_summary) = self.item_summary() {
88 let mut end = self.start().clone();
89 end.add_summary(item_summary, self.cx);
90 end
91 } else {
92 self.start().clone()
93 }
94 }
95
96 #[track_caller]
98 pub fn item(&self) -> Option<&'a T> {
99 self.assert_did_seek();
100 if let Some(entry) = self.stack.last() {
101 match *entry.tree.0 {
102 Node::Leaf { ref items, .. } => {
103 if entry.index() == items.len() {
104 None
105 } else {
106 Some(&items[entry.index()])
107 }
108 }
109 _ => unreachable!(),
110 }
111 } else {
112 None
113 }
114 }
115
116 #[track_caller]
117 pub fn item_summary(&self) -> Option<&'a T::Summary> {
118 self.assert_did_seek();
119 if let Some(entry) = self.stack.last() {
120 match *entry.tree.0 {
121 Node::Leaf {
122 ref item_summaries, ..
123 } => {
124 if entry.index() == item_summaries.len() {
125 None
126 } else {
127 Some(&item_summaries[entry.index()])
128 }
129 }
130 _ => unreachable!(),
131 }
132 } else {
133 None
134 }
135 }
136
137 #[track_caller]
138 pub fn next_item(&self) -> Option<&'a T> {
139 self.assert_did_seek();
140 if let Some(entry) = self.stack.last() {
141 if entry.index() == entry.tree.0.items().len() - 1 {
142 if let Some(next_leaf) = self.next_leaf() {
143 Some(next_leaf.0.items().first().unwrap())
144 } else {
145 None
146 }
147 } else {
148 match *entry.tree.0 {
149 Node::Leaf { ref items, .. } => Some(&items[entry.index() + 1]),
150 _ => unreachable!(),
151 }
152 }
153 } else if self.at_end {
154 None
155 } else {
156 self.tree.first()
157 }
158 }
159
160 #[track_caller]
161 fn next_leaf(&self) -> Option<&'a SumTree<T>> {
162 for entry in self.stack.iter().rev().skip(1) {
163 if entry.index() < entry.tree.0.child_trees().len() - 1 {
164 match *entry.tree.0 {
165 Node::Internal {
166 ref child_trees, ..
167 } => return Some(child_trees[entry.index() + 1].leftmost_leaf()),
168 Node::Leaf { .. } => unreachable!(),
169 };
170 }
171 }
172 None
173 }
174
175 #[track_caller]
176 pub fn prev_item(&self) -> Option<&'a T> {
177 self.assert_did_seek();
178 if let Some(entry) = self.stack.last() {
179 if entry.index() == 0 {
180 if let Some(prev_leaf) = self.prev_leaf() {
181 Some(prev_leaf.0.items().last().unwrap())
182 } else {
183 None
184 }
185 } else {
186 match *entry.tree.0 {
187 Node::Leaf { ref items, .. } => Some(&items[entry.index() - 1]),
188 _ => unreachable!(),
189 }
190 }
191 } else if self.at_end {
192 self.tree.last()
193 } else {
194 None
195 }
196 }
197
198 #[track_caller]
199 fn prev_leaf(&self) -> Option<&'a SumTree<T>> {
200 for entry in self.stack.iter().rev().skip(1) {
201 if entry.index() != 0 {
202 match *entry.tree.0 {
203 Node::Internal {
204 ref child_trees, ..
205 } => return Some(child_trees[entry.index() - 1].rightmost_leaf()),
206 Node::Leaf { .. } => unreachable!(),
207 };
208 }
209 }
210 None
211 }
212
213 #[track_caller]
214 pub fn prev(&mut self) {
215 self.search_backward(|_| true)
216 }
217
218 #[track_caller]
219 pub fn search_backward<F>(&mut self, mut filter_node: F)
220 where
221 F: FnMut(&T::Summary) -> bool,
222 {
223 if !self.did_seek {
224 self.did_seek = true;
225 self.at_end = true;
226 }
227
228 if self.at_end {
229 self.position = D::zero(self.cx);
230 self.at_end = self.tree.is_empty();
231 if !self.tree.is_empty() {
232 self.stack
233 .push(StackEntry {
234 tree: self.tree,
235 index: self.tree.0.child_summaries().len() as u32,
236 position: D::from_summary(self.tree.summary(), self.cx),
237 })
238 .unwrap_oob();
239 }
240 }
241
242 let mut descending = false;
243 while !self.stack.is_empty() {
244 if let Some(StackEntry { position, .. }) = self.stack.iter().rev().nth(1) {
245 self.position = position.clone();
246 } else {
247 self.position = D::zero(self.cx);
248 }
249
250 let entry = self.stack.last_mut().unwrap();
251 if !descending {
252 if entry.index() == 0 {
253 self.stack.pop();
254 continue;
255 } else {
256 entry.index -= 1;
257 }
258 }
259
260 for summary in &entry.tree.0.child_summaries()[..entry.index()] {
261 self.position.add_summary(summary, self.cx);
262 }
263 entry.position = self.position.clone();
264
265 descending = filter_node(&entry.tree.0.child_summaries()[entry.index()]);
266 match entry.tree.0.as_ref() {
267 Node::Internal { child_trees, .. } => {
268 if descending {
269 let tree = &child_trees[entry.index()];
270 self.stack
271 .push(StackEntry {
272 position: D::zero(self.cx),
273 tree,
274 index: tree.0.child_summaries().len() as u32 - 1,
275 })
276 .unwrap_oob();
277 }
278 }
279 Node::Leaf { .. } => {
280 if descending {
281 break;
282 }
283 }
284 }
285 }
286 }
287
288 #[track_caller]
289 pub fn next(&mut self) {
290 self.search_forward(|_| true)
291 }
292
293 #[track_caller]
294 pub fn search_forward<F>(&mut self, mut filter_node: F)
295 where
296 F: FnMut(&T::Summary) -> bool,
297 {
298 let mut descend = false;
299
300 if self.stack.is_empty() {
301 if !self.at_end {
302 self.stack
303 .push(StackEntry {
304 tree: self.tree,
305 index: 0,
306 position: D::zero(self.cx),
307 })
308 .unwrap_oob();
309 descend = true;
310 }
311 self.did_seek = true;
312 }
313
314 while !self.stack.is_empty() {
315 let new_subtree = {
316 let entry = self.stack.last_mut().unwrap();
317 match entry.tree.0.as_ref() {
318 Node::Internal {
319 child_trees,
320 child_summaries,
321 ..
322 } => {
323 if !descend {
324 entry.index += 1;
325 entry.position = self.position.clone();
326 }
327
328 while entry.index() < child_summaries.len() {
329 let next_summary = &child_summaries[entry.index()];
330 if filter_node(next_summary) {
331 break;
332 } else {
333 entry.index += 1;
334 entry.position.add_summary(next_summary, self.cx);
335 self.position.add_summary(next_summary, self.cx);
336 }
337 }
338
339 child_trees.get(entry.index())
340 }
341 Node::Leaf { item_summaries, .. } => {
342 if !descend {
343 let item_summary = &item_summaries[entry.index()];
344 entry.index += 1;
345 entry.position.add_summary(item_summary, self.cx);
346 self.position.add_summary(item_summary, self.cx);
347 }
348
349 loop {
350 if let Some(next_item_summary) = item_summaries.get(entry.index()) {
351 if filter_node(next_item_summary) {
352 return;
353 } else {
354 entry.index += 1;
355 entry.position.add_summary(next_item_summary, self.cx);
356 self.position.add_summary(next_item_summary, self.cx);
357 }
358 } else {
359 break None;
360 }
361 }
362 }
363 }
364 };
365
366 if let Some(subtree) = new_subtree {
367 descend = true;
368 self.stack
369 .push(StackEntry {
370 tree: subtree,
371 index: 0,
372 position: self.position.clone(),
373 })
374 .unwrap_oob();
375 } else {
376 descend = false;
377 self.stack.pop();
378 }
379 }
380
381 self.at_end = self.stack.is_empty();
382 debug_assert!(self.stack.is_empty() || self.stack.last().unwrap().tree.0.is_leaf());
383 }
384
385 #[track_caller]
386 fn assert_did_seek(&self) {
387 assert!(
388 self.did_seek,
389 "Must call `seek`, `next` or `prev` before calling this method"
390 );
391 }
392
393 pub fn did_seek(&self) -> bool {
394 self.did_seek
395 }
396}
397
398impl<'a, 'b, T, D> Cursor<'a, 'b, T, D>
399where
400 T: Item,
401 D: Dimension<'a, T::Summary>,
402{
403 #[track_caller]
405 pub fn seek<Target>(&mut self, pos: &Target, bias: Bias) -> bool
406 where
407 Target: SeekTarget<'a, T::Summary, D>,
408 {
409 self.reset();
410 self.seek_internal(pos, bias, &mut ())
411 }
412
413 #[track_caller]
419 pub fn seek_forward<Target>(&mut self, pos: &Target, bias: Bias) -> bool
420 where
421 Target: SeekTarget<'a, T::Summary, D>,
422 {
423 self.seek_internal(pos, bias, &mut ())
424 }
425
426 #[track_caller]
428 pub fn slice<Target>(&mut self, end: &Target, bias: Bias) -> SumTree<T>
429 where
430 Target: SeekTarget<'a, T::Summary, D>,
431 {
432 let mut slice = SliceSeekAggregate {
433 tree: SumTree::new(self.cx),
434 leaf_items: ArrayVec::new(),
435 leaf_item_summaries: ArrayVec::new(),
436 leaf_summary: <T::Summary as Summary>::zero(self.cx),
437 };
438 self.seek_internal(end, bias, &mut slice);
439 slice.tree
440 }
441
442 #[track_caller]
443 pub fn suffix(&mut self) -> SumTree<T> {
444 self.slice(&End::new(), Bias::Right)
445 }
446
447 #[track_caller]
448 pub fn summary<Target, Output>(&mut self, end: &Target, bias: Bias) -> Output
449 where
450 Target: SeekTarget<'a, T::Summary, D>,
451 Output: Dimension<'a, T::Summary>,
452 {
453 let mut summary = SummarySeekAggregate(Output::zero(self.cx));
454 self.seek_internal(end, bias, &mut summary);
455 summary.0
456 }
457
458 #[track_caller]
460 fn seek_internal(
461 &mut self,
462 target: &dyn SeekTarget<'a, T::Summary, D>,
463 bias: Bias,
464 aggregate: &mut dyn SeekAggregate<'a, T>,
465 ) -> bool {
466 assert!(
467 target.cmp(&self.position, self.cx).is_ge(),
468 "cannot seek backward",
469 );
470
471 if !self.did_seek {
472 self.did_seek = true;
473 self.stack
474 .push(StackEntry {
475 tree: self.tree,
476 index: 0,
477 position: D::zero(self.cx),
478 })
479 .unwrap_oob();
480 }
481
482 let mut ascending = false;
483 'outer: while let Some(entry) = self.stack.last_mut() {
484 match *entry.tree.0 {
485 Node::Internal {
486 ref child_summaries,
487 ref child_trees,
488 ..
489 } => {
490 if ascending {
491 entry.index += 1;
492 entry.position = self.position.clone();
493 }
494
495 for (child_tree, child_summary) in child_trees[entry.index()..]
496 .iter()
497 .zip(&child_summaries[entry.index()..])
498 {
499 let mut child_end = self.position.clone();
500 child_end.add_summary(child_summary, self.cx);
501
502 let comparison = target.cmp(&child_end, self.cx);
503 if comparison == Ordering::Greater
504 || (comparison == Ordering::Equal && bias == Bias::Right)
505 {
506 self.position = child_end;
507 aggregate.push_tree(child_tree, child_summary, self.cx);
508 entry.index += 1;
509 entry.position = self.position.clone();
510 } else {
511 self.stack
512 .push(StackEntry {
513 tree: child_tree,
514 index: 0,
515 position: self.position.clone(),
516 })
517 .unwrap_oob();
518 ascending = false;
519 continue 'outer;
520 }
521 }
522 }
523 Node::Leaf {
524 ref items,
525 ref item_summaries,
526 ..
527 } => {
528 aggregate.begin_leaf();
529
530 for (item, item_summary) in items[entry.index()..]
531 .iter()
532 .zip(&item_summaries[entry.index()..])
533 {
534 let mut child_end = self.position.clone();
535 child_end.add_summary(item_summary, self.cx);
536
537 let comparison = target.cmp(&child_end, self.cx);
538 if comparison == Ordering::Greater
539 || (comparison == Ordering::Equal && bias == Bias::Right)
540 {
541 self.position = child_end;
542 aggregate.push_item(item, item_summary, self.cx);
543 entry.index += 1;
544 } else {
545 aggregate.end_leaf(self.cx);
546 break 'outer;
547 }
548 }
549
550 aggregate.end_leaf(self.cx);
551 }
552 }
553
554 self.stack.pop();
555 ascending = true;
556 }
557
558 self.at_end = self.stack.is_empty();
559 debug_assert!(self.stack.is_empty() || self.stack.last().unwrap().tree.0.is_leaf());
560
561 let mut end = self.position.clone();
562 if bias == Bias::Left
563 && let Some(summary) = self.item_summary()
564 {
565 end.add_summary(summary, self.cx);
566 }
567
568 target.cmp(&end, self.cx) == Ordering::Equal
569 }
570}
571
572impl<'a, T: Item> Iter<'a, T> {
573 pub(crate) fn new(tree: &'a SumTree<T>) -> Self {
574 Self {
575 tree,
576 stack: Default::default(),
577 }
578 }
579}
580
581impl<'a, T: Item> Iterator for Iter<'a, T> {
582 type Item = &'a T;
583
584 fn next(&mut self) -> Option<Self::Item> {
585 let mut descend = false;
586
587 if self.stack.is_empty() {
588 self.stack
589 .push(StackEntry {
590 tree: self.tree,
591 index: 0,
592 position: (),
593 })
594 .unwrap_oob();
595 descend = true;
596 }
597
598 while let Some(entry) = self.stack.last_mut() {
599 let new_subtree = {
600 match entry.tree.0.as_ref() {
601 Node::Internal { child_trees, .. } => {
602 if !descend {
603 entry.index += 1;
604 }
605 child_trees.get(entry.index())
606 }
607 Node::Leaf { items, .. } => {
608 if !descend {
609 entry.index += 1;
610 }
611
612 if let Some(next_item) = items.get(entry.index()) {
613 return Some(next_item);
614 } else {
615 None
616 }
617 }
618 }
619 };
620
621 if let Some(subtree) = new_subtree {
622 descend = true;
623 self.stack
624 .push(StackEntry {
625 tree: subtree,
626 index: 0,
627 position: (),
628 })
629 .unwrap_oob();
630 } else {
631 descend = false;
632 self.stack.pop();
633 }
634 }
635
636 None
637 }
638
639 fn last(mut self) -> Option<Self::Item> {
640 self.stack.clear();
641 self.tree.rightmost_leaf().last()
642 }
643
644 fn size_hint(&self) -> (usize, Option<usize>) {
645 let lower_bound = match self.stack.last() {
646 Some(top) => top.tree.0.child_summaries().len() - top.index as usize,
647 None => self.tree.0.child_summaries().len(),
648 };
649
650 (lower_bound, None)
651 }
652}
653
654impl<'a, 'b, T: Item, D> Iterator for Cursor<'a, 'b, T, D>
655where
656 D: Dimension<'a, T::Summary>,
657{
658 type Item = &'a T;
659
660 fn next(&mut self) -> Option<Self::Item> {
661 if !self.did_seek {
662 self.next();
663 }
664
665 if let Some(item) = self.item() {
666 self.next();
667 Some(item)
668 } else {
669 None
670 }
671 }
672}
673
674pub struct FilterCursor<'a, 'b, F, T: Item, D> {
675 cursor: Cursor<'a, 'b, T, D>,
676 filter_node: F,
677}
678
679impl<'a, 'b, F, T: Item, D> FilterCursor<'a, 'b, F, T, D>
680where
681 F: FnMut(&T::Summary) -> bool,
682 T: Item,
683 D: Dimension<'a, T::Summary>,
684{
685 pub fn new(
686 tree: &'a SumTree<T>,
687 cx: <T::Summary as Summary>::Context<'b>,
688 filter_node: F,
689 ) -> Self {
690 let cursor = tree.cursor::<D>(cx);
691 Self {
692 cursor,
693 filter_node,
694 }
695 }
696
697 pub fn start(&self) -> &D {
698 self.cursor.start()
699 }
700
701 pub fn end(&self) -> D {
702 self.cursor.end()
703 }
704
705 pub fn item(&self) -> Option<&'a T> {
706 self.cursor.item()
707 }
708
709 pub fn item_summary(&self) -> Option<&'a T::Summary> {
710 self.cursor.item_summary()
711 }
712
713 pub fn next(&mut self) {
714 self.cursor.search_forward(&mut self.filter_node);
715 }
716
717 pub fn prev(&mut self) {
718 self.cursor.search_backward(&mut self.filter_node);
719 }
720}
721
722impl<'a, 'b, F, T: Item, U> Iterator for FilterCursor<'a, 'b, F, T, U>
723where
724 F: FnMut(&T::Summary) -> bool,
725 U: Dimension<'a, T::Summary>,
726{
727 type Item = &'a T;
728
729 fn next(&mut self) -> Option<Self::Item> {
730 if !self.cursor.did_seek {
731 self.next();
732 }
733
734 if let Some(item) = self.item() {
735 self.cursor.search_forward(&mut self.filter_node);
736 Some(item)
737 } else {
738 None
739 }
740 }
741}
742
743trait SeekAggregate<'a, T: Item> {
744 fn begin_leaf(&mut self);
745 fn end_leaf(&mut self, cx: <T::Summary as Summary>::Context<'_>);
746 fn push_item(
747 &mut self,
748 item: &'a T,
749 summary: &'a T::Summary,
750 cx: <T::Summary as Summary>::Context<'_>,
751 );
752 fn push_tree(
753 &mut self,
754 tree: &'a SumTree<T>,
755 summary: &'a T::Summary,
756 cx: <T::Summary as Summary>::Context<'_>,
757 );
758}
759
760struct SliceSeekAggregate<T: Item> {
761 tree: SumTree<T>,
762 leaf_items: ArrayVec<T, { 2 * TREE_BASE }, u8>,
763 leaf_item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }, u8>,
764 leaf_summary: T::Summary,
765}
766
767struct SummarySeekAggregate<D>(D);
768
769impl<T: Item> SeekAggregate<'_, T> for () {
770 fn begin_leaf(&mut self) {}
771 fn end_leaf(&mut self, _: <T::Summary as Summary>::Context<'_>) {}
772 fn push_item(&mut self, _: &T, _: &T::Summary, _: <T::Summary as Summary>::Context<'_>) {}
773 fn push_tree(
774 &mut self,
775 _: &SumTree<T>,
776 _: &T::Summary,
777 _: <T::Summary as Summary>::Context<'_>,
778 ) {
779 }
780}
781
782impl<T: Item> SeekAggregate<'_, T> for SliceSeekAggregate<T> {
783 fn begin_leaf(&mut self) {}
784 fn end_leaf(&mut self, cx: <T::Summary as Summary>::Context<'_>) {
785 self.tree.append(
786 SumTree(Arc::new(Node::Leaf {
787 summary: mem::replace(&mut self.leaf_summary, <T::Summary as Summary>::zero(cx)),
788 items: mem::take(&mut self.leaf_items),
789 item_summaries: mem::take(&mut self.leaf_item_summaries),
790 })),
791 cx,
792 );
793 }
794 fn push_item(
795 &mut self,
796 item: &T,
797 summary: &T::Summary,
798 cx: <T::Summary as Summary>::Context<'_>,
799 ) {
800 self.leaf_items.push(item.clone()).unwrap_oob();
801 self.leaf_item_summaries.push(summary.clone()).unwrap_oob();
802 Summary::add_summary(&mut self.leaf_summary, summary, cx);
803 }
804 fn push_tree(
805 &mut self,
806 tree: &SumTree<T>,
807 _: &T::Summary,
808 cx: <T::Summary as Summary>::Context<'_>,
809 ) {
810 self.tree.append(tree.clone(), cx);
811 }
812}
813
814impl<'a, T: Item, D> SeekAggregate<'a, T> for SummarySeekAggregate<D>
815where
816 D: Dimension<'a, T::Summary>,
817{
818 fn begin_leaf(&mut self) {}
819 fn end_leaf(&mut self, _: <T::Summary as Summary>::Context<'_>) {}
820 fn push_item(
821 &mut self,
822 _: &T,
823 summary: &'a T::Summary,
824 cx: <T::Summary as Summary>::Context<'_>,
825 ) {
826 self.0.add_summary(summary, cx);
827 }
828 fn push_tree(
829 &mut self,
830 _: &SumTree<T>,
831 summary: &'a T::Summary,
832 cx: <T::Summary as Summary>::Context<'_>,
833 ) {
834 self.0.add_summary(summary, cx);
835 }
836}
837
838struct End<D>(PhantomData<D>);
839
840impl<D> End<D> {
841 fn new() -> Self {
842 Self(PhantomData)
843 }
844}
845
846impl<'a, S: Summary, D: Dimension<'a, S>> SeekTarget<'a, S, D> for End<D> {
847 fn cmp(&self, _: &D, _: S::Context<'_>) -> Ordering {
848 Ordering::Greater
849 }
850}
851
852impl<D> fmt::Debug for End<D> {
853 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
854 f.debug_tuple("End").finish()
855 }
856}