sparse_vec/
lib.rs

1use std::collections::{HashMap, HashSet};
2use std::ops::Range;
3
4use itertools::Itertools;
5use rangemap::RangeMap;
6
7#[derive(Default, Debug)]
8pub struct SparseVec<T> {
9    map: RangeMap<u64, usize>,
10    data: HashMap<usize, (Range<u64>, Vec<T>)>,
11    key_counter: usize,
12}
13
14impl<T: Copy> SparseVec<T> {
15    fn assert_invariants(&self) {
16        for (range, key) in self.map.iter() {
17            let (old_range, vec) = &self.data[key];
18            assert_eq!(range, old_range);
19            assert_eq!(vec.len(), (range.end - range.start) as usize);
20        }
21        let mut duplicates = HashMap::new();
22        for (range, key) in self.map.iter() {
23            if let Some(other) = duplicates.get(key) {
24                panic!("{range:?} and {other:?} use key {key}");
25            } else {
26                duplicates.insert(*key, range.clone());
27            }
28        }
29    }
30
31    fn resize_block(
32        data: &mut HashMap<usize, (Range<u64>, Vec<T>)>,
33        key: &usize,
34        range: &Range<u64>,
35    ) {
36        let (old_range, vec) = data.get_mut(key).unwrap();
37        let new_vec_range: Range<usize> = cast_range(sub_range(range, old_range.start));
38        if new_vec_range.start != 0 {
39            vec.copy_within(new_vec_range.clone(), 0);
40        }
41        vec.truncate(new_vec_range.end - new_vec_range.start);
42        *old_range = range.clone();
43    }
44
45    fn collect_garbage(&mut self) {
46        let mut used = HashSet::with_capacity(self.map.len());
47        used.extend(self.map.iter().map(|(_, k)| *k));
48        self.data.retain(|k, _| used.contains(k));
49    }
50
51    pub fn get(&self, range: Range<u64>) -> Option<&[T]> {
52        let (found_range, key) = self.map.get_key_value(&range.start)?;
53        let slice_range = sub_range(&range, found_range.start);
54        self.data[key].1.get(cast_range(slice_range))
55    }
56
57    pub fn get_mut(&mut self, range: Range<u64>) -> Option<&mut [T]> {
58        let (found_range, key) = self.map.get_key_value(&range.start)?;
59        let slice_range = sub_range(&range, found_range.start);
60        self.data
61            .get_mut(key)
62            .unwrap()
63            .1
64            .get_mut(cast_range(slice_range))
65    }
66
67    pub fn overlaps(&self, range: &Range<u64>) -> bool {
68        self.map.overlaps(range)
69    }
70
71    pub fn insert(&mut self, data: Vec<T>, addr: u64) {
72        if data.is_empty() {
73            return;
74        }
75
76        let insert_range = addr..addr + data.len() as u64;
77
78        let start_key = self.map.get(&insert_range.start);
79        // Will create duplicate key
80        if let Some(&key) = start_key {
81            if start_key == self.map.get(&insert_range.end) {
82                let (range, vec) = self.data.get(&key).unwrap();
83                let range = range.clone();
84                let lower_range = range.start..insert_range.start;
85                let upper_range = insert_range.end..range.end;
86
87                if !upper_range.is_empty() {
88                    self.map.insert(upper_range.clone(), self.key_counter);
89                    let copy_range = sub_range(&upper_range, range.start);
90                    self.data.insert(
91                        self.key_counter,
92                        (upper_range, vec[cast_range(copy_range)].to_vec()),
93                    );
94                    self.key_counter += 1;
95                }
96
97                if !lower_range.is_empty() {
98                    Self::resize_block(&mut self.data, &key, &lower_range);
99                    self.map.insert(lower_range, key);
100                }
101            }
102        }
103
104        // Insert
105        self.map.insert(insert_range.clone(), self.key_counter);
106        self.data.insert(self.key_counter, (insert_range, data));
107        self.key_counter += 1;
108
109        // Resize
110        for (range, key) in self.map.iter() {
111            Self::resize_block(&mut self.data, key, range);
112        }
113
114        // Merge
115        loop {
116            let mut mergable = None;
117            for ((range, _), (range2, _)) in self.map.iter().tuple_windows() {
118                if range.end == range2.start {
119                    mergable = Some((range.clone(), range2.clone()));
120                    break;
121                }
122            }
123
124            if let Some((range, range2)) = mergable {
125                let key1 = *self.map.get(&range.start).unwrap();
126                let key2 = *self.map.get(&range2.start).unwrap();
127
128                let (_, vec2) = self.data.remove(&key2).unwrap();
129                self.map.remove(range2.clone());
130                let (data_range, vec1) = self.data.get_mut(&key1).unwrap();
131                *data_range = range.start..range2.end;
132                vec1.extend_from_slice(&vec2);
133                self.map.insert(data_range.clone(), key1);
134            } else {
135                break;
136            }
137        }
138
139        self.collect_garbage();
140
141        #[cfg(debug_assertions)]
142        self.assert_invariants();
143    }
144
145    pub fn ranges(&self) -> impl Iterator<Item = Range<u64>> + '_ {
146        self.map.iter().map(|(range, _)| range.clone())
147    }
148
149    pub fn stored_len(&self) -> usize {
150        self.map.iter().map(|(_, k)| self.data[k].1.len()).sum()
151    }
152}
153
154fn sub_range(range: &Range<u64>, offset: u64) -> Range<u64> {
155    range.start - offset..range.end - offset
156}
157
158fn cast_range<I, O: TryFrom<I>>(range: Range<I>) -> Range<O>
159where
160    O::Error: std::fmt::Debug,
161{
162    range.start.try_into().unwrap()..range.end.try_into().unwrap()
163}
164
165#[test]
166fn sparsevec() {
167    let mut map = SparseVec::default();
168    let mut insert_test = |n: u8, size: usize, addr: u64| {
169        map.insert(vec![n; size], addr);
170        map.assert_invariants();
171        assert_eq!(map.get(addr..addr + size as u64).unwrap(), &vec![n; size]);
172    };
173
174    insert_test(1, 20, 0);
175    insert_test(2, 10, 20);
176    insert_test(3, 10, 15);
177    insert_test(4, 5, 5);
178}
179
180#[test]
181fn sparsevec_fuzz() {
182    use rand::{Rng, SeedableRng};
183
184    let mut map = SparseVec::default();
185    let mut insert_test = |n: u8, size: usize, addr: u64| {
186        let vec = Vec::from_iter((0..size).map(|v| (v as u8).overflowing_mul(n).0));
187        map.insert(vec.clone(), addr);
188        map.assert_invariants();
189        assert_eq!(map.get(addr..addr + size as u64).unwrap(), &vec);
190    };
191
192    let mut rng = rand::rngs::StdRng::seed_from_u64(0);
193    for _ in 0..1_000_000 {
194        let n = rng.gen_range(0..255);
195        let size = rng.gen_range(0..1000);
196        let addr = rng.gen_range(0..1000);
197        insert_test(n, size, addr);
198    }
199}
200
201#[test]
202fn sparsevec_u64_fuzz() {
203    use rand::{Rng, SeedableRng};
204
205    let mut map = SparseVec::default();
206    let mut insert_test = |n: u64, size: usize, addr: u64| {
207        let vec = Vec::from_iter((0..size).map(|v| (v as u64).overflowing_mul(n).0));
208        map.insert(vec.clone(), addr);
209        map.assert_invariants();
210        assert_eq!(map.get(addr..addr + size as u64).unwrap(), &vec);
211    };
212
213    let mut rng = rand::rngs::StdRng::seed_from_u64(0);
214    for _ in 0..1_000_000 {
215        let n = rng.gen_range(0..255);
216        let size = rng.gen_range(0..1000);
217        let addr = rng.gen_range(0..1000);
218        insert_test(n, size, addr);
219    }
220}