Skip to main content

scry_index/
iter.rs

1//! Lock-free in-order iterator and range queries for the learned index tree.
2#![allow(unsafe_code)]
3
4use std::ops::{Bound, RangeBounds};
5
6use crossbeam_epoch::Guard;
7use crossbeam_utils::Backoff;
8
9use crate::key::Key;
10use crate::node::{is_child, Node, SLOT_DATA, SLOT_WRITING};
11
12/// An iterator over the key-value pairs in a learned index in sorted order.
13///
14/// Yields references tied to the lifetime of the epoch guard. The left-to-right
15/// DFS traversal produces keys in ascending order because the linear model is
16/// monotonic (non-negative slope fitted from sorted keys) and children at
17/// slot `s` contain only keys whose predicted position is `s`.
18///
19/// # Visibility under concurrency
20///
21/// The iterator provides a best-effort snapshot, not a linearizable one:
22///
23/// - Keys inserted into **not-yet-scanned** slots may be visible.
24/// - Keys inserted into **already-scanned** slots will not be visible.
25/// - Keys removed (tombstoned) after scanning will still be yielded if
26///   they were `DATA` when scanned.
27/// - A slot in the transient `WRITING` state is waited on (with backoff)
28///   until the insert completes, so in-flight writes are not silently missed.
29///
30/// For a fully consistent snapshot, use
31/// [`iter_sorted`](crate::LearnedMap::iter_sorted), which clones all entries
32/// under a single traversal.
33pub struct Iter<'g, K, V> {
34    /// Stack of (node, `next_slot_index`) for DFS traversal.
35    stack: Vec<(&'g Node<K, V>, usize)>,
36    /// The epoch guard that keeps referenced data alive.
37    guard: &'g Guard,
38    /// Approximate remaining entries, used for `size_hint`.
39    remaining: Option<usize>,
40}
41
42impl<'g, K: Key, V> Iter<'g, K, V> {
43    /// Create a new iterator starting from the root node.
44    pub fn new(root: &'g Node<K, V>, guard: &'g Guard) -> Self {
45        Self {
46            stack: vec![(root, 0)],
47            guard,
48            remaining: None,
49        }
50    }
51
52    /// Create a new iterator with an approximate entry count hint.
53    ///
54    /// The hint is used by [`size_hint`](Iterator::size_hint) to help
55    /// callers like `collect()` pre-allocate. It does not need to be exact.
56    pub fn with_hint(root: &'g Node<K, V>, guard: &'g Guard, count: usize) -> Self {
57        Self {
58            stack: vec![(root, 0)],
59            guard,
60            remaining: Some(count),
61        }
62    }
63}
64
65impl<'g, K: Key, V> Iterator for Iter<'g, K, V> {
66    type Item = (&'g K, &'g V);
67
68    fn next(&mut self) -> Option<Self::Item> {
69        loop {
70            let (node, slot_idx) = self.stack.last_mut()?;
71            if *slot_idx >= node.capacity() {
72                self.stack.pop();
73                continue;
74            }
75            let current_idx = *slot_idx;
76            *slot_idx += 1;
77
78            let state = node.slot_state(current_idx);
79            match state {
80                SLOT_DATA => {
81                    if let Some(r) = &mut self.remaining {
82                        *r = r.saturating_sub(1);
83                    }
84                    // SAFETY: state is DATA, inline data is valid.
85                    let key = unsafe { node.read_key(current_idx) };
86                    let value = unsafe { node.read_value(current_idx) };
87                    return Some((key, value));
88                }
89                s if is_child(s) => {
90                    let child_shared = node.load_child(current_idx, self.guard);
91                    if !child_shared.is_null() {
92                        let child = unsafe { child_shared.deref() };
93                        self.stack.push((child, 0));
94                    }
95                }
96                SLOT_WRITING => {
97                    // A concurrent insert is claiming this slot. Back off
98                    // so rebuild snapshots don't miss in-flight writes.
99                    let backoff = Backoff::new();
100                    while node.slot_state(current_idx) == SLOT_WRITING {
101                        backoff.snooze();
102                    }
103                    *slot_idx -= 1; // re-visit this slot with resolved state
104                }
105                _ => {} // EMPTY, TOMBSTONE
106            }
107        }
108    }
109
110    fn size_hint(&self) -> (usize, Option<usize>) {
111        self.remaining.map_or((0, None), |r| (r, Some(r)))
112    }
113}
114
115/// Collect all key-value pairs from a tree in sorted order.
116///
117/// This performs a full DFS traversal and clones all entries. The traversal
118/// naturally produces sorted output (see [`Iter`] docs).
119pub fn sorted_pairs<K: Key, V: Clone>(root: &Node<K, V>, guard: &Guard) -> Vec<(K, V)> {
120    let iter = Iter::new(root, guard);
121    iter.map(|(k, v)| (k.clone(), v.clone())).collect()
122}
123
124/// A range iterator over key-value pairs in a learned index.
125///
126/// Yields only entries whose keys fall within the specified bounds, in
127/// ascending key order. Uses model-guided seek for O(depth) initialization
128/// when the start bound is specified.
129///
130/// See [`Iter`] for visibility semantics under concurrency.
131pub struct Range<'g, K, V> {
132    /// Stack of (node, `next_slot_index`) for DFS traversal.
133    stack: Vec<(&'g Node<K, V>, usize)>,
134    /// The epoch guard that keeps referenced data alive.
135    guard: &'g Guard,
136    /// Lower bound of the range.
137    start: Bound<K>,
138    /// Upper bound of the range.
139    end: Bound<K>,
140    /// Whether we have yielded at least one entry (past the start bound).
141    started: bool,
142    /// Whether we have passed the end bound (iterator exhausted).
143    done: bool,
144}
145
146impl<'g, K: Key, V> Range<'g, K, V> {
147    /// Create a new range iterator over the given bounds.
148    pub fn new<R: RangeBounds<K>>(root: &'g Node<K, V>, range: R, guard: &'g Guard) -> Self {
149        let start = match range.start_bound() {
150            Bound::Included(k) => Bound::Included(k.clone()),
151            Bound::Excluded(k) => Bound::Excluded(k.clone()),
152            Bound::Unbounded => Bound::Unbounded,
153        };
154        let end = match range.end_bound() {
155            Bound::Included(k) => Bound::Included(k.clone()),
156            Bound::Excluded(k) => Bound::Excluded(k.clone()),
157            Bound::Unbounded => Bound::Unbounded,
158        };
159
160        let is_unbounded = matches!(&start, Bound::Unbounded);
161
162        let mut iter = Self {
163            stack: Vec::new(),
164            guard,
165            start,
166            end,
167            started: is_unbounded,
168            done: false,
169        };
170
171        let seek_key = match &iter.start {
172            Bound::Included(k) | Bound::Excluded(k) => Some(k.clone()),
173            Bound::Unbounded => None,
174        };
175        if let Some(ref k) = seek_key {
176            iter.seek_to(root, k);
177        } else {
178            iter.stack.push((root, 0));
179        }
180
181        iter
182    }
183
184    /// Seek the DFS stack to the predicted position of `key`.
185    ///
186    /// At each level, predicts the slot for `key` and pushes the node starting
187    /// at that slot. If the slot contains a child, pushes a continuation for the
188    /// parent at `slot + 1` and recurses into the child.
189    fn seek_to(&mut self, node: &'g Node<K, V>, key: &K) {
190        let p = node.predict_slot(key);
191        let state = node.slot_state(p);
192        if is_child(state) {
193            let child_shared = node.load_child(p, self.guard);
194            if !child_shared.is_null() {
195                let child = unsafe { child_shared.deref() };
196                self.stack.push((node, p + 1));
197                self.seek_to(child, key);
198                return;
199            }
200        }
201        self.stack.push((node, p));
202    }
203
204    /// Check if a key is past the end bound.
205    fn past_end(&self, key: &K) -> bool {
206        match &self.end {
207            Bound::Included(end) => key > end,
208            Bound::Excluded(end) => key >= end,
209            Bound::Unbounded => false,
210        }
211    }
212
213    /// Check if a key is before the start bound.
214    fn before_start(&self, key: &K) -> bool {
215        match &self.start {
216            Bound::Included(start) => key < start,
217            Bound::Excluded(start) => key <= start,
218            Bound::Unbounded => false,
219        }
220    }
221}
222
223impl<'g, K: Key, V> Iterator for Range<'g, K, V> {
224    type Item = (&'g K, &'g V);
225
226    fn next(&mut self) -> Option<Self::Item> {
227        if self.done {
228            return None;
229        }
230
231        loop {
232            let (node, slot_idx) = self.stack.last_mut()?;
233
234            if *slot_idx >= node.capacity() {
235                self.stack.pop();
236                continue;
237            }
238
239            let current_idx = *slot_idx;
240            *slot_idx += 1;
241
242            let state = node.slot_state(current_idx);
243            match state {
244                SLOT_DATA => {
245                    // SAFETY: state is DATA, inline data is valid.
246                    let key = unsafe { node.read_key(current_idx) };
247                    let value = unsafe { node.read_value(current_idx) };
248                    if self.past_end(key) {
249                        self.done = true;
250                        return None;
251                    }
252                    if !self.started {
253                        if self.before_start(key) {
254                            continue;
255                        }
256                        self.started = true;
257                    }
258                    return Some((key, value));
259                }
260                s if is_child(s) => {
261                    let child_shared = node.load_child(current_idx, self.guard);
262                    if !child_shared.is_null() {
263                        let child = unsafe { child_shared.deref() };
264                        self.stack.push((child, 0));
265                    }
266                }
267                SLOT_WRITING => {
268                    // A concurrent insert is claiming this slot. Back off
269                    // so rebuild snapshots don't miss in-flight writes.
270                    let backoff = Backoff::new();
271                    while node.slot_state(current_idx) == SLOT_WRITING {
272                        backoff.snooze();
273                    }
274                    *slot_idx -= 1; // re-visit this slot with resolved state
275                }
276                _ => {} // EMPTY, TOMBSTONE
277            }
278        }
279    }
280}
281
282/// Return the first (minimum) key-value pair in the tree.
283///
284/// Returns `None` if the tree is empty. O(depth) typical.
285pub fn first_entry<'g, K: Key, V>(
286    root: &'g Node<K, V>,
287    guard: &'g Guard,
288) -> Option<(&'g K, &'g V)> {
289    Iter::new(root, guard).next()
290}
291
292/// Return the last (maximum) key-value pair in the tree.
293///
294/// Uses a reverse DFS: scans slots right-to-left at each level, pushing
295/// children onto a stack. Correctly backtracks when a child subtree
296/// contains no data (e.g., all entries were removed).
297///
298/// Returns `None` if the tree is empty. O(depth) typical.
299pub fn last_entry<'g, K: Key, V>(root: &'g Node<K, V>, guard: &'g Guard) -> Option<(&'g K, &'g V)> {
300    // Stack of (node, next_slot_to_scan). Slots are scanned in reverse:
301    // slot_idx starts at capacity and decrements toward 0.
302    let mut stack: Vec<(&Node<K, V>, usize)> = vec![(root, root.capacity())];
303    loop {
304        let (node, slot_idx) = stack.last_mut()?;
305        if *slot_idx == 0 {
306            // Exhausted this node. Backtrack to the parent.
307            stack.pop();
308            continue;
309        }
310        *slot_idx -= 1;
311        let current_idx = *slot_idx;
312
313        let state = node.slot_state(current_idx);
314        match state {
315            SLOT_DATA => {
316                // SAFETY: state is DATA, inline data is valid.
317                let key = unsafe { node.read_key(current_idx) };
318                let value = unsafe { node.read_value(current_idx) };
319                return Some((key, value));
320            }
321            s if is_child(s) => {
322                let child_shared = node.load_child(current_idx, guard);
323                if !child_shared.is_null() {
324                    let child = unsafe { child_shared.deref() };
325                    stack.push((child, child.capacity()));
326                }
327            }
328            SLOT_WRITING => {
329                // Concurrent insert in progress. Wait for resolution.
330                let backoff = Backoff::new();
331                while node.slot_state(current_idx) == SLOT_WRITING {
332                    backoff.snooze();
333                }
334                // Re-visit this slot with the resolved state.
335                *slot_idx += 1;
336            }
337            _ => {} // EMPTY, TOMBSTONE
338        }
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use crate::config::Config;
346
347    use crossbeam_epoch as epoch;
348
349    fn guard() -> epoch::Guard {
350        epoch::pin()
351    }
352
353    #[test]
354    fn iter_empty_tree() {
355        let g = guard();
356        let node = Node::<u64, ()>::with_capacity(crate::model::LinearModel::constant(), 5);
357        assert!(Iter::new(&node, &g).next().is_none());
358    }
359
360    #[test]
361    fn iter_single_element() {
362        let g = guard();
363        let pairs = vec![(42u64, "answer")];
364        let node = crate::build::bulk_load(&pairs, &Config::default()).unwrap();
365        let items: Vec<_> = Iter::new(&node, &g).collect();
366        assert_eq!(items.len(), 1);
367        assert_eq!(items[0], (&42u64, &"answer"));
368    }
369
370    #[test]
371    fn iter_all_elements() {
372        let g = guard();
373        let pairs: Vec<(u64, u64)> = (0..100).map(|i| (i, i * 10)).collect();
374        let node = crate::build::bulk_load(&pairs, &Config::default()).unwrap();
375        assert_eq!(Iter::new(&node, &g).count(), 100);
376    }
377
378    #[test]
379    fn sorted_pairs_in_order() {
380        let g = guard();
381        let pairs: Vec<(u64, u64)> = (0..50).map(|i| (i * 3 + 1, i)).collect();
382        let node = crate::build::bulk_load(&pairs, &Config::default()).unwrap();
383        let sorted = sorted_pairs(&node, &g);
384        assert_eq!(sorted.len(), 50);
385        for window in sorted.windows(2) {
386            assert!(
387                window[0].0 < window[1].0,
388                "not sorted: {} >= {}",
389                window[0].0,
390                window[1].0
391            );
392        }
393    }
394
395    #[test]
396    fn sorted_pairs_match_input() {
397        let g = guard();
398        let pairs: Vec<(u64, &str)> = vec![(5, "e"), (10, "j"), (15, "o"), (20, "t")];
399        let node = crate::build::bulk_load(&pairs, &Config::default()).unwrap();
400        let sorted = sorted_pairs(&node, &g);
401        assert_eq!(sorted, pairs);
402    }
403
404    #[test]
405    fn iter_after_inserts() {
406        let g = guard();
407        let pairs: Vec<(u64, u64)> = vec![(10, 1), (30, 3), (50, 5)];
408        let node = crate::build::bulk_load(&pairs, &Config::default()).unwrap();
409
410        crate::insert::insert(&node, 20, &2, &Config::default(), &g);
411        crate::insert::insert(&node, 40, &4, &Config::default(), &g);
412
413        let sorted = sorted_pairs(&node, &g);
414        assert_eq!(sorted.len(), 5);
415        let keys: Vec<u64> = sorted.iter().map(|(k, _)| *k).collect();
416        assert_eq!(keys, vec![10, 20, 30, 40, 50]);
417    }
418
419    // -----------------------------------------------------------------------
420    // Iter sortedness (Part A validation)
421    // -----------------------------------------------------------------------
422
423    #[test]
424    fn iter_raw_is_sorted_bulk_load() {
425        let g = guard();
426        let pairs: Vec<(u64, u64)> = (0..500).map(|i| (i * 3 + 1, i)).collect();
427        let node = crate::build::bulk_load(&pairs, &Config::default()).unwrap();
428        let keys: Vec<u64> = Iter::new(&node, &g).map(|(k, _)| *k).collect();
429        assert_eq!(keys.len(), 500);
430        for w in keys.windows(2) {
431            assert!(w[0] < w[1], "not sorted: {} >= {}", w[0], w[1]);
432        }
433    }
434
435    #[test]
436    fn iter_raw_is_sorted_after_inserts() {
437        let g = guard();
438        let pairs: Vec<(u64, u64)> = (0..100).map(|i| (i * 4, i)).collect();
439        let node = crate::build::bulk_load(&pairs, &Config::default()).unwrap();
440        for i in 0..100u64 {
441            crate::insert::insert(&node, i * 4 + 2, &(i + 1000), &Config::default(), &g);
442        }
443        let keys: Vec<u64> = Iter::new(&node, &g).map(|(k, _)| *k).collect();
444        assert_eq!(keys.len(), 200);
445        for w in keys.windows(2) {
446            assert!(w[0] < w[1], "not sorted: {} >= {}", w[0], w[1]);
447        }
448    }
449
450    #[test]
451    fn iter_raw_is_sorted_reverse_inserts() {
452        let g = guard();
453        let node = Node::<u64, u64>::with_capacity(crate::model::LinearModel::new(0.01, 0.0), 16);
454        for i in (0..200u64).rev() {
455            crate::insert::insert(&node, i, &i, &Config::default(), &g);
456        }
457        let keys: Vec<u64> = Iter::new(&node, &g).map(|(k, _)| *k).collect();
458        assert_eq!(keys.len(), 200);
459        for w in keys.windows(2) {
460            assert!(w[0] < w[1], "not sorted: {} >= {}", w[0], w[1]);
461        }
462    }
463
464    // -----------------------------------------------------------------------
465    // Range iterator
466    // -----------------------------------------------------------------------
467
468    fn make_0_to_99() -> Node<u64, u64> {
469        let pairs: Vec<(u64, u64)> = (0..100).map(|i| (i, i * 10)).collect();
470        crate::build::bulk_load(&pairs, &Config::default()).unwrap()
471    }
472
473    #[test]
474    fn range_inclusive_both() {
475        let g = guard();
476        let node = make_0_to_99();
477        let items: Vec<u64> = Range::new(&node, 10..=20, &g).map(|(k, _)| *k).collect();
478        assert_eq!(items.len(), 11);
479        assert_eq!(items, (10..=20).collect::<Vec<_>>());
480    }
481
482    #[test]
483    fn range_exclusive_end() {
484        let g = guard();
485        let node = make_0_to_99();
486        let items: Vec<u64> = Range::new(&node, 10..20, &g).map(|(k, _)| *k).collect();
487        assert_eq!(items.len(), 10);
488        assert_eq!(items, (10..20).collect::<Vec<_>>());
489    }
490
491    #[test]
492    fn range_from() {
493        let g = guard();
494        let node = make_0_to_99();
495        let items: Vec<u64> = Range::new(&node, 90.., &g).map(|(k, _)| *k).collect();
496        assert_eq!(items.len(), 10);
497        assert_eq!(items, (90..100).collect::<Vec<_>>());
498    }
499
500    #[test]
501    fn range_to() {
502        let g = guard();
503        let node = make_0_to_99();
504        let items: Vec<u64> = Range::new(&node, ..5, &g).map(|(k, _)| *k).collect();
505        assert_eq!(items.len(), 5);
506        assert_eq!(items, (0..5).collect::<Vec<_>>());
507    }
508
509    #[test]
510    fn range_to_inclusive() {
511        let g = guard();
512        let node = make_0_to_99();
513        let items: Vec<u64> = Range::new(&node, ..=5, &g).map(|(k, _)| *k).collect();
514        assert_eq!(items.len(), 6);
515        assert_eq!(items, (0..=5).collect::<Vec<_>>());
516    }
517
518    #[test]
519    fn range_empty_result() {
520        let g = guard();
521        let node = make_0_to_99();
522        let items: Vec<u64> = Range::new(&node, 200..300, &g).map(|(k, _)| *k).collect();
523        assert!(items.is_empty());
524    }
525
526    #[test]
527    fn range_single_element() {
528        let g = guard();
529        let node = make_0_to_99();
530        let items: Vec<u64> = Range::new(&node, 50..=50, &g).map(|(k, _)| *k).collect();
531        assert_eq!(items, vec![50]);
532    }
533
534    // -----------------------------------------------------------------------
535    // first_entry / last_entry
536    // -----------------------------------------------------------------------
537
538    #[test]
539    fn first_entry_basic() {
540        let g = guard();
541        let node = make_0_to_99();
542        let (k, v) = first_entry(&node, &g).unwrap();
543        assert_eq!(*k, 0);
544        assert_eq!(*v, 0);
545    }
546
547    #[test]
548    fn last_entry_basic() {
549        let g = guard();
550        let node = make_0_to_99();
551        let (k, v) = last_entry(&node, &g).unwrap();
552        assert_eq!(*k, 99);
553        assert_eq!(*v, 990);
554    }
555
556    #[test]
557    fn first_entry_empty() {
558        let g = guard();
559        let node = Node::<u64, u64>::with_capacity(crate::model::LinearModel::constant(), 5);
560        assert!(first_entry(&node, &g).is_none());
561    }
562
563    #[test]
564    fn last_entry_empty() {
565        let g = guard();
566        let node = Node::<u64, u64>::with_capacity(crate::model::LinearModel::constant(), 5);
567        assert!(last_entry(&node, &g).is_none());
568    }
569
570    #[test]
571    fn size_hint_without_hint() {
572        let g = guard();
573        let pairs: Vec<(u64, u64)> = (0..10).map(|i| (i, i)).collect();
574        let node = crate::build::bulk_load(&pairs, &Config::default()).unwrap();
575        let iter = Iter::new(&node, &g);
576        assert_eq!(iter.size_hint(), (0, None));
577    }
578
579    #[test]
580    fn size_hint_with_hint() {
581        let g = guard();
582        let pairs: Vec<(u64, u64)> = (0..10).map(|i| (i, i)).collect();
583        let node = crate::build::bulk_load(&pairs, &Config::default()).unwrap();
584        let mut iter = Iter::with_hint(&node, &g, 10);
585        assert_eq!(iter.size_hint(), (10, Some(10)));
586        iter.next();
587        assert_eq!(iter.size_hint(), (9, Some(9)));
588    }
589}