1use std::collections::{HashMap, HashSet};
2use std::ops::Range;
3
4use itertools::Itertools;
5use rangemap::RangeMap;
6
7#[derive(Default, Debug)]
8pub struct SparseVec<T> {
9 map: RangeMap<u64, usize>,
10 data: HashMap<usize, (Range<u64>, Vec<T>)>,
11 key_counter: usize,
12}
13
14impl<T: Copy> SparseVec<T> {
15 fn assert_invariants(&self) {
16 for (range, key) in self.map.iter() {
17 let (old_range, vec) = &self.data[key];
18 assert_eq!(range, old_range);
19 assert_eq!(vec.len(), (range.end - range.start) as usize);
20 }
21 let mut duplicates = HashMap::new();
22 for (range, key) in self.map.iter() {
23 if let Some(other) = duplicates.get(key) {
24 panic!("{range:?} and {other:?} use key {key}");
25 } else {
26 duplicates.insert(*key, range.clone());
27 }
28 }
29 }
30
31 fn resize_block(
32 data: &mut HashMap<usize, (Range<u64>, Vec<T>)>,
33 key: &usize,
34 range: &Range<u64>,
35 ) {
36 let (old_range, vec) = data.get_mut(key).unwrap();
37 let new_vec_range: Range<usize> = cast_range(sub_range(range, old_range.start));
38 if new_vec_range.start != 0 {
39 vec.copy_within(new_vec_range.clone(), 0);
40 }
41 vec.truncate(new_vec_range.end - new_vec_range.start);
42 *old_range = range.clone();
43 }
44
45 fn collect_garbage(&mut self) {
46 let mut used = HashSet::with_capacity(self.map.len());
47 used.extend(self.map.iter().map(|(_, k)| *k));
48 self.data.retain(|k, _| used.contains(k));
49 }
50
51 pub fn get(&self, range: Range<u64>) -> Option<&[T]> {
52 let (found_range, key) = self.map.get_key_value(&range.start)?;
53 let slice_range = sub_range(&range, found_range.start);
54 self.data[key].1.get(cast_range(slice_range))
55 }
56
57 pub fn get_mut(&mut self, range: Range<u64>) -> Option<&mut [T]> {
58 let (found_range, key) = self.map.get_key_value(&range.start)?;
59 let slice_range = sub_range(&range, found_range.start);
60 self.data
61 .get_mut(key)
62 .unwrap()
63 .1
64 .get_mut(cast_range(slice_range))
65 }
66
67 pub fn overlaps(&self, range: &Range<u64>) -> bool {
68 self.map.overlaps(range)
69 }
70
71 pub fn insert(&mut self, data: Vec<T>, addr: u64) {
72 if data.is_empty() {
73 return;
74 }
75
76 let insert_range = addr..addr + data.len() as u64;
77
78 let start_key = self.map.get(&insert_range.start);
79 if let Some(&key) = start_key {
81 if start_key == self.map.get(&insert_range.end) {
82 let (range, vec) = self.data.get(&key).unwrap();
83 let range = range.clone();
84 let lower_range = range.start..insert_range.start;
85 let upper_range = insert_range.end..range.end;
86
87 if !upper_range.is_empty() {
88 self.map.insert(upper_range.clone(), self.key_counter);
89 let copy_range = sub_range(&upper_range, range.start);
90 self.data.insert(
91 self.key_counter,
92 (upper_range, vec[cast_range(copy_range)].to_vec()),
93 );
94 self.key_counter += 1;
95 }
96
97 if !lower_range.is_empty() {
98 Self::resize_block(&mut self.data, &key, &lower_range);
99 self.map.insert(lower_range, key);
100 }
101 }
102 }
103
104 self.map.insert(insert_range.clone(), self.key_counter);
106 self.data.insert(self.key_counter, (insert_range, data));
107 self.key_counter += 1;
108
109 for (range, key) in self.map.iter() {
111 Self::resize_block(&mut self.data, key, range);
112 }
113
114 loop {
116 let mut mergable = None;
117 for ((range, _), (range2, _)) in self.map.iter().tuple_windows() {
118 if range.end == range2.start {
119 mergable = Some((range.clone(), range2.clone()));
120 break;
121 }
122 }
123
124 if let Some((range, range2)) = mergable {
125 let key1 = *self.map.get(&range.start).unwrap();
126 let key2 = *self.map.get(&range2.start).unwrap();
127
128 let (_, vec2) = self.data.remove(&key2).unwrap();
129 self.map.remove(range2.clone());
130 let (data_range, vec1) = self.data.get_mut(&key1).unwrap();
131 *data_range = range.start..range2.end;
132 vec1.extend_from_slice(&vec2);
133 self.map.insert(data_range.clone(), key1);
134 } else {
135 break;
136 }
137 }
138
139 self.collect_garbage();
140
141 #[cfg(debug_assertions)]
142 self.assert_invariants();
143 }
144
145 pub fn ranges(&self) -> impl Iterator<Item = Range<u64>> + '_ {
146 self.map.iter().map(|(range, _)| range.clone())
147 }
148
149 pub fn stored_len(&self) -> usize {
150 self.map.iter().map(|(_, k)| self.data[k].1.len()).sum()
151 }
152}
153
154fn sub_range(range: &Range<u64>, offset: u64) -> Range<u64> {
155 range.start - offset..range.end - offset
156}
157
158fn cast_range<I, O: TryFrom<I>>(range: Range<I>) -> Range<O>
159where
160 O::Error: std::fmt::Debug,
161{
162 range.start.try_into().unwrap()..range.end.try_into().unwrap()
163}
164
165#[test]
166fn sparsevec() {
167 let mut map = SparseVec::default();
168 let mut insert_test = |n: u8, size: usize, addr: u64| {
169 map.insert(vec![n; size], addr);
170 map.assert_invariants();
171 assert_eq!(map.get(addr..addr + size as u64).unwrap(), &vec![n; size]);
172 };
173
174 insert_test(1, 20, 0);
175 insert_test(2, 10, 20);
176 insert_test(3, 10, 15);
177 insert_test(4, 5, 5);
178}
179
180#[test]
181fn sparsevec_fuzz() {
182 use rand::{Rng, SeedableRng};
183
184 let mut map = SparseVec::default();
185 let mut insert_test = |n: u8, size: usize, addr: u64| {
186 let vec = Vec::from_iter((0..size).map(|v| (v as u8).overflowing_mul(n).0));
187 map.insert(vec.clone(), addr);
188 map.assert_invariants();
189 assert_eq!(map.get(addr..addr + size as u64).unwrap(), &vec);
190 };
191
192 let mut rng = rand::rngs::StdRng::seed_from_u64(0);
193 for _ in 0..1_000_000 {
194 let n = rng.gen_range(0..255);
195 let size = rng.gen_range(0..1000);
196 let addr = rng.gen_range(0..1000);
197 insert_test(n, size, addr);
198 }
199}
200
201#[test]
202fn sparsevec_u64_fuzz() {
203 use rand::{Rng, SeedableRng};
204
205 let mut map = SparseVec::default();
206 let mut insert_test = |n: u64, size: usize, addr: u64| {
207 let vec = Vec::from_iter((0..size).map(|v| (v as u64).overflowing_mul(n).0));
208 map.insert(vec.clone(), addr);
209 map.assert_invariants();
210 assert_eq!(map.get(addr..addr + size as u64).unwrap(), &vec);
211 };
212
213 let mut rng = rand::rngs::StdRng::seed_from_u64(0);
214 for _ in 0..1_000_000 {
215 let n = rng.gen_range(0..255);
216 let size = rng.gen_range(0..1000);
217 let addr = rng.gen_range(0..1000);
218 insert_test(n, size, addr);
219 }
220}