sorted_index_buffer/
lib.rs

1#[cfg(test)]
2mod tests;
3
4/// A buffer indexed by a u64 sequence number.
5///
6/// This behaves idential to a BTreeMap<u64, T>, but is optimized for the case where
7/// the indices are mostly contiguous, with occasional gaps.
8///
9/// It will have large memory overhead if the indices are very sparse, so it
10/// should not be used as a general-purpose sorted map.
11///
12/// # Internals
13///
14/// The underlying storage is a contiguous buffer of `Option<T>` of twice the size
15/// of the key range, rounded up to the next power of two. We track min and max.
16///
17/// The buffer is a `Vec<Option<T>>`, which can have some additional unused capacity.
18///
19/// |_x_xxxx_|________| max-min=6, page size 8, buffer fits in first page
20///   ^ min
21///         ^ max
22/// |______x_|xxxx____| max-min=6, page size 8, buffer spans two pages
23///        ^ min
24///               ^ max
25///
26/// Access is O(1). Insertion and removal is usually O(1), but will occasionally
27/// move contents around to resize the buffer, which is O(n). Moving will only
28/// happen every O(n.next_power_of_two()) operations, so amortized complexity is
29/// still O(1).
30#[derive(Debug, Clone)]
31pub struct SortedIndexBuffer<T> {
32    /// The underlying data buffer. Size is a power of two, 2 "pages".
33    data: Vec<Option<T>>,
34    /// The minimum valid index (inclusive).
35    min: u64,
36    /// The maximum valid index (exclusive, so we can model the empty buffer).
37    max: u64,
38}
39
40impl<T> Default for SortedIndexBuffer<T> {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl<T> SortedIndexBuffer<T> {
47    /// Create a new SortedIndexBuffer with the given initial capacity.
48    pub fn with_capacity(capacity: usize) -> Self {
49        let data = Vec::with_capacity(capacity);
50        Self {
51            data,
52            min: 0,
53            max: 0,
54        }
55    }
56
57    /// Create a new, empty SortedIndexBuffer.
58    pub fn new() -> Self {
59        Self::with_capacity(0)
60    }
61
62    /// Returns true if the buffer contains no elements.
63    pub fn is_empty(&self) -> bool {
64        self.data.is_empty()
65    }
66
67    /// Returns true if the buffer contains an element at the given index.
68    pub fn contains_key(&self, index: u64) -> bool {
69        self.get(index).is_some()
70    }
71
72    /// Iterate over all keys in the given index range in ascending order.
73    pub fn keys_range<R: std::ops::RangeBounds<u64>>(
74        &self,
75        range: R,
76    ) -> impl DoubleEndedIterator<Item = u64> + '_ {
77        let (buf_start, buf_end, start) = self.resolve_range(range);
78        self.data[buf_start..buf_end]
79            .iter()
80            .enumerate()
81            .filter_map(move |(i, slot)| slot.as_ref().map(|_| start + i as u64))
82    }
83
84    /// Iterate over all values in the given index range in ascending order of their keys.
85    pub fn values_range<R: std::ops::RangeBounds<u64>>(
86        &self,
87        range: R,
88    ) -> impl DoubleEndedIterator<Item = &T> + '_ {
89        let (buf_start, buf_end, _) = self.resolve_range(range);
90        self.data[buf_start..buf_end]
91            .iter()
92            .filter_map(|slot| slot.as_ref())
93    }
94
95    /// Iterate over all values in the given index range in ascending order of their keys.
96    pub fn values_range_mut<R: std::ops::RangeBounds<u64>>(
97        &mut self,
98        range: R,
99    ) -> impl DoubleEndedIterator<Item = &mut T> + '_ {
100        let (buf_start, buf_end, _) = self.resolve_range(range);
101        self.data[buf_start..buf_end]
102            .iter_mut()
103            .filter_map(|slot| slot.as_mut())
104    }
105
106    /// Iterate over all (index, value) pairs in the given index range in ascending order of their keys.
107    pub fn iter_range<R: std::ops::RangeBounds<u64>>(
108        &self,
109        range: R,
110    ) -> impl DoubleEndedIterator<Item = (u64, &T)> + '_ {
111        let (buf_start, buf_end, start) = self.resolve_range(range);
112        self.data[buf_start..buf_end]
113            .iter()
114            .enumerate()
115            .filter_map(move |(i, slot)| slot.as_ref().map(|v| (start + i as u64, v)))
116    }
117
118    pub fn retain(&mut self, f: impl FnMut(u64, &mut T) -> bool) {
119        if self.is_empty() {
120            return;
121        }
122        let mut f = f;
123        let base = base(self.min, self.max);
124        for i in self.min..self.max {
125            let offset = (i - base) as usize;
126            let Some(v) = &mut self.data[offset] else {
127                continue;
128            };
129            if !f(i, v) {
130                self.data[offset] = None;
131            }
132        }
133        // Now adjust min and max
134        let start = (self.min - base) as usize;
135        let end = (self.max - base) as usize;
136        let min1 = self.data[start..end]
137            .iter()
138            .position(|slot| slot.is_some())
139            .map(|p| p + start)
140            .unwrap_or(end) as u64
141            + base;
142        let max1 = self.data[start..end]
143            .iter()
144            .rev()
145            .position(|slot| slot.is_some())
146            .map(|p| end - p)
147            .unwrap_or(start + 1) as u64
148            + base;
149        self.resize(min1, max1);
150        self.check_invariants();
151    }
152
153    /// Retain only the elements in the given index range.
154    pub fn retain_range<R: std::ops::RangeBounds<u64>>(&mut self, range: R) {
155        let (min1, max1) = self.clip_bounds(range);
156        if min1 >= max1 {
157            self.data.clear();
158            self.min = 0;
159            self.max = 0;
160            self.check_invariants();
161            return;
162        }
163        let base = base(self.min, self.max);
164        for i in self.min..min1 {
165            self.data[(i - base) as usize] = None;
166        }
167        for i in max1..self.max {
168            self.data[(i - base) as usize] = None;
169        }
170        self.resize(min1, max1);
171        self.check_invariants();
172    }
173
174    /// Iterate over all keys in the buffer in ascending order.
175    pub fn keys(&self) -> impl DoubleEndedIterator<Item = u64> + '_ {
176        self.keys_range(..)
177    }
178
179    /// Iterate over all values in the buffer in ascending order of their keys.
180    pub fn values(&self) -> impl DoubleEndedIterator<Item = &T> + '_ {
181        self.values_range(..)
182    }
183
184    /// Iterate over all values in the buffer in ascending order of their keys.
185    pub fn values_mut(&mut self) -> impl DoubleEndedIterator<Item = &mut T> + '_ {
186        self.values_range_mut(..)
187    }
188
189    /// Iterate over all (index, value) pairs in the buffer in ascending order of their keys.
190    ///
191    /// Values are returned by reference.
192    pub fn iter(&self) -> impl DoubleEndedIterator<Item = (u64, &T)> + '_ {
193        self.iter_range(..)
194    }
195
196    /// Turn into an iterator over all (index, value) pairs in the buffer in ascending order of their keys.
197    ///
198    /// This is an explicit method instead of implementing IntoIterator, so we can return a
199    /// DoubleEndedIterator without having to name the iterator type.
200    #[allow(clippy::should_implement_trait)]
201    pub fn into_iter(self) -> impl DoubleEndedIterator<Item = (u64, T)> {
202        let base = base(self.min, self.max);
203        self.data
204            .into_iter()
205            .enumerate()
206            .filter_map(move |(i, slot)| slot.map(|v| (base + i as u64, v)))
207    }
208
209    /// Convert range bounds into an inclusive start and exclusive end, clipped to the current
210    /// bounds.
211    ///
212    /// The resulting range may be empty, which has to be handled by the caller.
213    #[inline]
214    fn clip_bounds<R: std::ops::RangeBounds<u64>>(&self, range: R) -> (u64, u64) {
215        use std::ops::Bound;
216
217        let start = match range.start_bound() {
218            Bound::Included(&n) => n.max(self.min),
219            Bound::Excluded(&n) => (n + 1).max(self.min),
220            Bound::Unbounded => self.min,
221        };
222        let end = match range.end_bound() {
223            Bound::Included(&n) => (n + 1).min(self.max),
224            Bound::Excluded(&n) => n.min(self.max),
225            Bound::Unbounded => self.max,
226        };
227        (start, end)
228    }
229
230    #[inline]
231    fn resolve_range<R: std::ops::RangeBounds<u64>>(&self, range: R) -> (usize, usize, u64) {
232        let (start, end) = self.clip_bounds(range);
233        if start >= end {
234            return (0, 0, start);
235        }
236        let base = base(self.min, self.max);
237        let buf_start = (start - base) as usize;
238        let buf_end = (end - base) as usize;
239        (buf_start, buf_end, start)
240    }
241
242    /// Get a reference to the value at the given index, if it exists.
243    pub fn get(&self, index: u64) -> Option<&T> {
244        if index < self.min || index >= self.max {
245            return None;
246        }
247        let base = base(self.min, self.max);
248        let offset = (index - base) as usize;
249        self.data[offset].as_ref()
250    }
251
252    /// Insert value at index.
253    pub fn insert(&mut self, index: u64, value: T) {
254        let (min1, max1) = if self.is_empty() {
255            (index, index + 1)
256        } else {
257            (self.min.min(index), self.max.max(index + 1))
258        };
259        self.resize(min1, max1);
260        self.insert0(index, value);
261        self.check_invariants();
262    }
263
264    /// Remove value at index.
265    pub fn remove(&mut self, index: u64) -> Option<T> {
266        if self.is_empty() {
267            return None;
268        }
269        let res = self.remove0(index);
270        if index == self.min {
271            let base = base(self.min, self.max);
272            let start = (self.min - base) as usize;
273            let end = (self.max - base) as usize;
274            // no need to check start, since we just removed that element
275            let skip = self.data[start + 1..end]
276                .iter()
277                .position(|slot| slot.is_some())
278                .map(|p| p + 1)
279                .unwrap_or(end - start);
280            self.resize(self.min + skip as u64, self.max);
281        } else if index + 1 == self.max {
282            let base = base(self.min, self.max);
283            let start = (self.min - base) as usize;
284            let end = (self.max - base) as usize;
285            // no need to check end-1, since we just removed that element
286            let skip = self.data[start..end - 1]
287                .iter()
288                .rev()
289                .position(|slot| slot.is_some())
290                .map(|p| p + 1)
291                .unwrap_or(end - start);
292            self.resize(self.min, self.max - skip as u64);
293        }
294        self.check_invariants();
295        res
296    }
297
298    /// Insert value at index, assuming the buffer already covers that index.
299    ///
300    /// The resulting buffer may violate the invariants.
301    fn insert0(&mut self, index: u64, value: T) {
302        let base = base(self.min, self.max);
303        let offset = (index - base) as usize;
304        self.buf_mut()[offset] = Some(value);
305    }
306
307    /// Remove value at index without resizing the buffer.
308    ///
309    /// The resulting buffer may violate the invariants.
310    fn remove0(&mut self, index: u64) -> Option<T> {
311        if index < self.min || index >= self.max {
312            return None;
313        }
314        let base = base(self.min, self.max);
315        let offset = (index - base) as usize;
316        self.buf_mut()[offset].take()
317    }
318
319    /// Resize the buffer to cover the range [`min1`, `max1`), while preserving existing
320    /// elements.
321    fn resize(&mut self, min1: u64, max1: u64) {
322        if min1 == self.min && max1 == self.max {
323            // nothing to do
324            return;
325        }
326        if min1 >= max1 {
327            // resizing to empty buffer
328            *self = Self::new();
329            return;
330        }
331        let len0 = self.buf().len();
332        let len1 = buf_len(min1, max1);
333        let base0 = base(self.min, self.max);
334        let base1 = base(min1, max1);
335        if len0 == len1 {
336            // just need to move data around within the existing buffer
337            //
338            // we use rotate even though half the buffer is empty, because
339            // otherwise we would require Copy on T.
340            if base1 < base0 {
341                let shift = (base0 - base1) as usize;
342                self.buf_mut().rotate_right(shift);
343            } else if base1 > base0 {
344                let shift = (base1 - base0) as usize;
345                self.buf_mut().rotate_left(shift);
346            }
347        } else if len0 < len1 {
348            // Grow
349            if len0 == 0 {
350                // buffer was empty before.
351                self.data = mk_empty(len1);
352            } else {
353                self.data
354                    .extend(std::iter::repeat_with(|| None).take(len1 - len0));
355
356                let start0 = (self.min - base0) as usize;
357                let start1 = (self.min - base1) as usize;
358                let count = (self.max - self.min) as usize;
359                for i in 0..count {
360                    self.data.swap(start0 + i, start1 + i);
361                }
362            }
363        } else {
364            // Shrink
365            let start0 = (min1 - base0) as usize;
366            let start1 = (min1 - base1) as usize;
367            let count = (max1 - min1) as usize;
368
369            for i in 0..count {
370                self.data.swap(start0 + i, start1 + i);
371            }
372            self.data.truncate(len1);
373        }
374        self.min = min1;
375        self.max = max1;
376    }
377
378    fn buf(&self) -> &[Option<T>] {
379        &self.data
380    }
381
382    fn buf_mut(&mut self) -> &mut [Option<T>] {
383        &mut self.data
384    }
385
386    /// Check that the invariants of the SortedIndexBuffer hold.
387    ///
388    /// This should be called after each public &mut method.
389    ///
390    /// It is a noop in release builds.
391    fn check_invariants(&self) {
392        if self.is_empty() {
393            // for the empty buffer, min and max must be zero
394            debug_assert_eq!(self.min, 0);
395            debug_assert_eq!(self.max, 0);
396        } else {
397            // for a non-empty buffer, elements min and max-1 must be valid
398            debug_assert!(self.min < self.max);
399            debug_assert!(self.get(self.min).is_some() && self.get(self.max - 1).is_some());
400        }
401    }
402
403    /// Same as `check_invariants`, but also checks that the unused parts of the buffer are empty.
404    ///
405    /// This is more expensive, so only used in tests.
406    #[cfg(test)]
407    fn check_invariants_expensive(&self) {
408        self.check_invariants();
409        let base = base(self.min, self.max);
410        let start = (self.min - base) as usize;
411        let end = (self.max - base) as usize;
412        for i in 0..start {
413            debug_assert!(self.data[i].is_none());
414        }
415        for i in end..self.data.len() {
416            debug_assert!(self.data[i].is_none());
417        }
418    }
419}
420
421fn mk_empty<T>(n: usize) -> Vec<Option<T>> {
422    let mut res = Vec::with_capacity(n);
423    for _ in 0..n {
424        res.push(None);
425    }
426    res
427}
428
429/// Compute the minimum buffer length needed to cover [min, max) even in the
430/// case where min..max go over a page boundary.
431fn buf_len(min: u64, max: u64) -> usize {
432    let page_size = (max - min).next_power_of_two() as usize;
433    page_size * 2
434}
435
436/// Compute the base index for the buffer covering [min, max).
437fn base(min: u64, max: u64) -> u64 {
438    let buf_len = buf_len(min, max);
439    let mask = (buf_len as u64) / 2 - 1;
440    min & !mask
441}