sorted_index_buffer/
lib.rs

1#[cfg(test)]
2mod tests;
3
4/// A buffer indexed by a u64 sequence number.
5///
6/// This behaves idential to a BTreeMap<u64, T>, but is optimized for the case where
7/// the indices are mostly contiguous, with occasional gaps.
8///
9/// It will have large memory overhead if the indices are very sparse, so it
10/// should not be used as a general-purpose sorted map.
11///
12/// # Internals
13///
14/// The underlying storage is a contiguous buffer of `Option<T>` of twice the size
15/// of the key range, rounded up to the next power of two. We track min and max.
16///
17/// The buffer is a `Vec<Option<T>>`, which can have some additional unused capacity.
18///
19/// |_x_xxxx_|________| max-min=6, page size 8, buffer fits in first page
20///   ^ min
21///         ^ max
22/// |______x_|xxxx____| max-min=6, page size 8, buffer spans two pages
23///        ^ min
24///               ^ max
25///
26/// Access is O(1). Insertion and removal is usually O(1), but will occasionally
27/// move contents around to resize the buffer, which is O(n). Moving will only
28/// happen every O(n.next_power_of_two()) operations, so amortized complexity is
29/// still O(1).
30#[derive(Debug, Clone)]
31pub struct SortedIndexBuffer<T> {
32    /// The underlying data buffer. Size is a power of two, 2 "pages".
33    data: Vec<Option<T>>,
34    /// The minimum valid index (inclusive).
35    min: u64,
36    /// The maximum valid index (exclusive, so we can model the empty buffer).
37    max: u64,
38}
39
40impl<T> Default for SortedIndexBuffer<T> {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl<T> SortedIndexBuffer<T> {
47    /// Create a new SortedIndexBuffer with the given initial capacity.
48    pub fn with_capacity(capacity: usize) -> Self {
49        let data = Vec::with_capacity(capacity);
50        Self {
51            data,
52            min: 0,
53            max: 0,
54        }
55    }
56
57    /// Create a new, empty SortedIndexBuffer.
58    pub fn new() -> Self {
59        Self::with_capacity(0)
60    }
61
62    /// Returns true if the buffer contains no elements.
63    pub fn is_empty(&self) -> bool {
64        self.data.is_empty()
65    }
66
67    /// Returns true if the buffer contains an element at the given index.
68    pub fn contains_key(&self, index: u64) -> bool {
69        self.get(index).is_some()
70    }
71
72    /// Iterate over all keys in the given index range in ascending order.
73    pub fn keys_range<R: std::ops::RangeBounds<u64>>(
74        &self,
75        range: R,
76    ) -> impl DoubleEndedIterator<Item = u64> + '_ {
77        let (buf_start, buf_end, start) = self.resolve_range(range);
78        self.data[buf_start..buf_end]
79            .iter()
80            .enumerate()
81            .filter_map(move |(i, slot)| slot.as_ref().map(|_| start + i as u64))
82    }
83
84    /// Iterate over all values in the given index range in ascending order of their keys.
85    pub fn values_range<R: std::ops::RangeBounds<u64>>(
86        &self,
87        range: R,
88    ) -> impl DoubleEndedIterator<Item = &T> + '_ {
89        let (buf_start, buf_end, _) = self.resolve_range(range);
90        self.data[buf_start..buf_end]
91            .iter()
92            .filter_map(|slot| slot.as_ref())
93    }
94
95    /// Iterate over all values in the given index range in ascending order of their keys.
96    pub fn values_range_mut<R: std::ops::RangeBounds<u64>>(
97        &mut self,
98        range: R,
99    ) -> impl DoubleEndedIterator<Item = &mut T> + '_ {
100        let (buf_start, buf_end, _) = self.resolve_range(range);
101        self.data[buf_start..buf_end]
102            .iter_mut()
103            .filter_map(|slot| slot.as_mut())
104    }
105
106    /// Iterate over all (index, value) pairs in the given index range in ascending order of their keys.
107    pub fn iter_range<R: std::ops::RangeBounds<u64>>(
108        &self,
109        range: R,
110    ) -> impl DoubleEndedIterator<Item = (u64, &T)> + '_ {
111        let (buf_start, buf_end, start) = self.resolve_range(range);
112        self.data[buf_start..buf_end]
113            .iter()
114            .enumerate()
115            .filter_map(move |(i, slot)| slot.as_ref().map(|v| (start + i as u64, v)))
116    }
117
118    pub fn retain(&mut self, f: impl FnMut(u64, &mut T) -> bool) {
119        if self.is_empty() {
120            return;
121        }
122        let mut f = f;
123        let base = base(self.min, self.max);
124        for i in self.min..self.max {
125            let offset = (i - base) as usize;
126            let Some(v) = &mut self.data[offset] else {
127                continue;
128            };
129            if !f(i, v) {
130                self.data[offset] = None;
131            }
132        }
133        // Now adjust min and max
134        let start = (self.min - base) as usize;
135        let end = (self.max - base) as usize;
136        let min1 = self.data[start..end]
137            .iter()
138            .position(|slot| slot.is_some())
139            .map(|p| p + start)
140            .unwrap_or(end) as u64
141            + base;
142        let max1 = self.data[start..end]
143            .iter()
144            .rev()
145            .position(|slot| slot.is_some())
146            .map(|p| end - p)
147            .unwrap_or(start + 1) as u64
148            + base;
149        self.resize(min1, max1);
150        self.check_invariants();
151    }
152
153    /// Retain only the elements in the given index range.
154    pub fn retain_range<R: std::ops::RangeBounds<u64>>(&mut self, range: R) {
155        let (min1, max1) = self.clip_bounds(range);
156        if min1 >= max1 {
157            self.data.clear();
158            self.min = 0;
159            self.max = 0;
160            self.check_invariants();
161            return;
162        }
163        let base = base(self.min, self.max);
164        for i in self.min..min1 {
165            self.data[(i - base) as usize] = None;
166        }
167        for i in max1..self.max {
168            self.data[(i - base) as usize] = None;
169        }
170        self.resize(min1, max1);
171        self.check_invariants();
172    }
173
174    /// Iterate over all keys in the buffer in ascending order.
175    pub fn keys(&self) -> impl DoubleEndedIterator<Item = u64> + '_ {
176        self.keys_range(..)
177    }
178
179    /// Iterate over all values in the buffer in ascending order of their keys.
180    pub fn values(&self) -> impl DoubleEndedIterator<Item = &T> + '_ {
181        self.values_range(..)
182    }
183
184    /// Iterate over all values in the buffer in ascending order of their keys.
185    pub fn values_mut(&mut self) -> impl DoubleEndedIterator<Item = &mut T> + '_ {
186        self.values_range_mut(..)
187    }
188
189    /// Iterate over all (index, value) pairs in the buffer in ascending order of their keys.
190    ///
191    /// Values are returned by reference.
192    pub fn iter(&self) -> impl DoubleEndedIterator<Item = (u64, &T)> + '_ {
193        self.iter_range(..)
194    }
195
196    /// Turn into an iterator over all (index, value) pairs in the buffer in ascending order of their keys.
197    ///
198    /// This is an explicit method instead of implementing IntoIterator, so we can return a
199    /// DoubleEndedIterator without having to name the iterator type.
200    #[allow(clippy::should_implement_trait)]
201    pub fn into_iter(self) -> impl DoubleEndedIterator<Item = (u64, T)> {
202        let base = base(self.min, self.max);
203        self.data
204            .into_iter()
205            .enumerate()
206            .filter_map(move |(i, slot)| slot.map(|v| (base + i as u64, v)))
207    }
208
209    /// Convert range bounds into an inclusive start and exclusive end, clipped to the current
210    /// bounds.
211    ///
212    /// The resulting range may be empty, which has to be handled by the caller.
213    #[inline]
214    fn clip_bounds<R: std::ops::RangeBounds<u64>>(&self, range: R) -> (u64, u64) {
215        use std::ops::Bound;
216
217        let start = match range.start_bound() {
218            Bound::Included(&n) => n.max(self.min),
219            Bound::Excluded(&n) => (n + 1).max(self.min),
220            Bound::Unbounded => self.min,
221        };
222        let end = match range.end_bound() {
223            Bound::Included(&n) => (n + 1).min(self.max),
224            Bound::Excluded(&n) => n.min(self.max),
225            Bound::Unbounded => self.max,
226        };
227        (start, end)
228    }
229
230    #[inline]
231    fn resolve_range<R: std::ops::RangeBounds<u64>>(&self, range: R) -> (usize, usize, u64) {
232        let (start, end) = self.clip_bounds(range);
233        if start >= end {
234            return (0, 0, start);
235        }
236        let base = base(self.min, self.max);
237        let buf_start = (start - base) as usize;
238        let buf_end = (end - base) as usize;
239        (buf_start, buf_end, start)
240    }
241
242    /// Get a reference to the value at the given index, if it exists.
243    pub fn get(&self, index: u64) -> Option<&T> {
244        if index < self.min || index >= self.max {
245            return None;
246        }
247        let base = base(self.min, self.max);
248        let offset = (index - base) as usize;
249        self.data[offset].as_ref()
250    }
251
252    /// Insert value at index.
253    pub fn insert(&mut self, index: u64, value: T) {
254        if self.is_empty() {
255            self.min = index;
256            self.max = index + 1;
257            self.data.push(Some(value));
258        } else {
259            let (min1, max1) = (self.min.min(index), self.max.max(index + 1));
260            self.resize(min1, max1);
261            self.insert0(index, value);
262        }
263        self.check_invariants();
264    }
265
266    /// Remove value at index.
267    pub fn remove(&mut self, index: u64) -> Option<T> {
268        if self.is_empty() {
269            return None;
270        }
271        let res = self.remove0(index);
272        if index == self.min {
273            let base = base(self.min, self.max);
274            let start = (self.min - base) as usize;
275            let end = (self.max - base) as usize;
276            // no need to check start, since we just removed that element
277            let skip = self.data[start + 1..end]
278                .iter()
279                .position(|slot| slot.is_some())
280                .map(|p| p + 1)
281                .unwrap_or(end - start);
282            self.resize(self.min + skip as u64, self.max);
283        } else if index + 1 == self.max {
284            let base = base(self.min, self.max);
285            let start = (self.min - base) as usize;
286            let end = (self.max - base) as usize;
287            // no need to check end-1, since we just removed that element
288            let skip = self.data[start..end - 1]
289                .iter()
290                .rev()
291                .position(|slot| slot.is_some())
292                .map(|p| p + 1)
293                .unwrap_or(end - start);
294            self.resize(self.min, self.max - skip as u64);
295        }
296        self.check_invariants();
297        res
298    }
299
300    /// Insert value at index, assuming the buffer already covers that index.
301    ///
302    /// The resulting buffer may violate the invariants.
303    #[inline(always)]
304    fn insert0(&mut self, index: u64, value: T) {
305        let base = base(self.min, self.max);
306        let offset = (index - base) as usize;
307        self.buf_mut()[offset] = Some(value);
308    }
309
310    /// Remove value at index without resizing the buffer.
311    ///
312    /// The resulting buffer may violate the invariants.
313    #[inline(always)]
314    fn remove0(&mut self, index: u64) -> Option<T> {
315        if index < self.min || index >= self.max {
316            return None;
317        }
318        let base = base(self.min, self.max);
319        let offset = (index - base) as usize;
320        self.buf_mut()[offset].take()
321    }
322
323    /// Resize the buffer to cover the range [`min1`, `max1`), while preserving existing
324    /// elements.
325    #[inline(always)]
326    fn resize(&mut self, min1: u64, max1: u64) {
327        if min1 == self.min && max1 == self.max {
328            // nothing to do
329            return;
330        }
331        if min1 >= max1 {
332            // resizing to empty buffer
333            self.data.clear();
334            self.min = 0;
335            self.max = 0;
336            return;
337        }
338        let len0 = self.buf().len();
339        let len1 = buf_len(min1, max1);
340        let base0 = base(self.min, self.max);
341        let base1 = base(min1, max1);
342        if len0 == len1 {
343            // just need to move data around within the existing buffer
344            //
345            // we use rotate even though half the buffer is empty, because
346            // otherwise we would require Copy on T.
347            if base1 < base0 {
348                let shift = (base0 - base1) as usize;
349                self.buf_mut().rotate_right(shift);
350            } else if base1 > base0 {
351                let shift = (base1 - base0) as usize;
352                self.buf_mut().rotate_left(shift);
353            }
354        } else if len0 < len1 {
355            for _ in len0..len1 {
356                self.data.push(None);
357            }
358            // Grow
359            if len0 != 0 {
360                let start0 = (self.min - base0) as usize;
361                let start1 = (self.min - base1) as usize;
362                let count = (self.max - self.min) as usize;
363                for i in 0..count {
364                    self.data.swap(start0 + i, start1 + i);
365                }
366            }
367        } else {
368            // Shrink
369            let start0 = (min1 - base0) as usize;
370            let start1 = (min1 - base1) as usize;
371            let count = (max1 - min1) as usize;
372
373            for i in 0..count {
374                self.data.swap(start0 + i, start1 + i);
375            }
376            self.data.truncate(len1);
377        }
378        self.min = min1;
379        self.max = max1;
380    }
381
382    #[inline(always)]
383    fn buf(&self) -> &[Option<T>] {
384        &self.data
385    }
386
387    #[inline(always)]
388    fn buf_mut(&mut self) -> &mut [Option<T>] {
389        &mut self.data
390    }
391
392    /// Check that the invariants of the SortedIndexBuffer hold.
393    ///
394    /// This should be called after each public &mut method.
395    ///
396    /// It is a noop in release builds.
397    fn check_invariants(&self) {
398        if self.is_empty() {
399            // for the empty buffer, min and max must be zero
400            debug_assert_eq!(self.min, 0);
401            debug_assert_eq!(self.max, 0);
402        } else {
403            // for a non-empty buffer, elements min and max-1 must be valid
404            debug_assert!(self.min < self.max);
405            debug_assert!(self.get(self.min).is_some() && self.get(self.max - 1).is_some());
406        }
407    }
408
409    /// Same as `check_invariants`, but also checks that the unused parts of the buffer are empty.
410    ///
411    /// This is more expensive, so only used in tests.
412    #[cfg(test)]
413    fn check_invariants_expensive(&self) {
414        self.check_invariants();
415        let base = base(self.min, self.max);
416        let start = (self.min - base) as usize;
417        let end = (self.max - base) as usize;
418        for i in 0..start {
419            debug_assert!(self.data[i].is_none());
420        }
421        for i in end..self.data.len() {
422            debug_assert!(self.data[i].is_none());
423        }
424    }
425}
426
427/// Compute the minimum buffer length needed to cover [min, max) even in the
428/// case where min..max go over a page boundary.
429#[inline(always)]
430fn buf_len(min: u64, max: u64) -> usize {
431    let page_size = (max - min).next_power_of_two() as usize;
432    page_size * 2
433}
434
435/// Compute the base index for the buffer covering [min, max).
436#[inline(always)]
437fn base(min: u64, max: u64) -> u64 {
438    let buf_len = buf_len(min, max);
439    let mask = (buf_len as u64) / 2 - 1;
440    min & !mask
441}