1use std::fmt::Debug;
2
3use vortex_array::arrays::BooleanBufferBuilder;
4use vortex_array::compute::{scalar_at, sub_scalar};
5use vortex_array::patches::Patches;
6use vortex_array::stats::{ArrayStats, Stat, StatsSet, StatsSetRef};
7use vortex_array::variants::PrimitiveArrayTrait;
8use vortex_array::vtable::{EncodingVTable, StatisticsVTable, VTableRef};
9use vortex_array::{
10 Array, ArrayImpl, ArrayRef, ArrayStatisticsImpl, ArrayValidityImpl, Encoding, EncodingId,
11 RkyvMetadata, ToCanonical, try_from_array_ref,
12};
13use vortex_dtype::{DType, match_each_integer_ptype};
14use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
15use vortex_mask::Mask;
16use vortex_scalar::Scalar;
17
18use crate::serde::SparseMetadata;
19
20mod canonical;
21mod compute;
22mod serde;
23mod variants;
24
25#[derive(Clone, Debug)]
26pub struct SparseArray {
27 patches: Patches,
28 fill_value: Scalar,
29 stats_set: ArrayStats,
30}
31
32try_from_array_ref!(SparseArray);
33
34pub struct SparseEncoding;
35impl Encoding for SparseEncoding {
36 type Array = SparseArray;
37 type Metadata = RkyvMetadata<SparseMetadata>;
38}
39
40impl EncodingVTable for SparseEncoding {
41 fn id(&self) -> EncodingId {
42 EncodingId::new_ref("vortex.sparse")
43 }
44}
45
46impl SparseArray {
47 pub fn try_new(
48 indices: ArrayRef,
49 values: ArrayRef,
50 len: usize,
51 fill_value: Scalar,
52 ) -> VortexResult<Self> {
53 Self::try_new_with_offset(indices, values, len, 0, fill_value)
54 }
55
56 pub(crate) fn try_new_with_offset(
57 indices: ArrayRef,
58 values: ArrayRef,
59 len: usize,
60 indices_offset: usize,
61 fill_value: Scalar,
62 ) -> VortexResult<Self> {
63 if indices.len() != values.len() {
64 vortex_bail!(
65 "Mismatched indices {} and values {} length",
66 indices.len(),
67 values.len()
68 );
69 }
70
71 if !indices.is_empty() {
72 let last_index = usize::try_from(&scalar_at(&indices, indices.len() - 1)?)?;
73
74 if last_index - indices_offset >= len {
75 vortex_bail!("Array length was set to {len} but the last index is {last_index}");
76 }
77 }
78
79 let patches = Patches::new(len, indices_offset, indices, values);
80
81 Self::try_new_from_patches(patches, fill_value)
82 }
83
84 pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
85 if fill_value.dtype() != patches.values().dtype() {
86 vortex_bail!(
87 "fill value, {:?}, should be instance of values dtype, {}",
88 fill_value,
89 patches.values().dtype(),
90 );
91 }
92 Ok(Self {
93 patches,
94 fill_value,
95 stats_set: Default::default(),
96 })
97 }
98
99 #[inline]
100 pub fn patches(&self) -> &Patches {
101 &self.patches
102 }
103
104 #[inline]
105 pub fn resolved_patches(&self) -> VortexResult<Patches> {
106 let (len, offset, indices, values) = self.patches().clone().into_parts();
107 let indices_offset = Scalar::from(offset).cast(indices.dtype())?;
108 let indices = sub_scalar(&indices, indices_offset)?;
109 Ok(Patches::new(len, 0, indices, values))
110 }
111
112 #[inline]
113 pub fn fill_scalar(&self) -> &Scalar {
114 &self.fill_value
115 }
116}
117
118impl ArrayImpl for SparseArray {
119 type Encoding = SparseEncoding;
120
121 fn _len(&self) -> usize {
122 self.patches.array_len()
123 }
124
125 fn _dtype(&self) -> &DType {
126 self.fill_value.dtype()
127 }
128
129 fn _vtable(&self) -> VTableRef {
130 VTableRef::new_ref(&SparseEncoding)
131 }
132}
133
134impl ArrayStatisticsImpl for SparseArray {
135 fn _stats_ref(&self) -> StatsSetRef<'_> {
136 self.stats_set.to_ref(self)
137 }
138}
139
140impl ArrayValidityImpl for SparseArray {
141 fn _is_valid(&self, index: usize) -> VortexResult<bool> {
142 Ok(match self.patches().get_patched(index)? {
143 None => self.fill_scalar().is_valid(),
144 Some(patch_value) => patch_value.is_valid(),
145 })
146 }
147
148 fn _all_valid(&self) -> VortexResult<bool> {
149 if self.fill_scalar().is_null() {
150 return Ok(self.patches().values().len() == self.len()
152 && self.patches().values().all_valid()?);
153 }
154
155 self.patches().values().all_valid()
156 }
157
158 fn _all_invalid(&self) -> VortexResult<bool> {
159 if !self.fill_scalar().is_null() {
160 return Ok(self.patches().values().len() == self.len()
162 && self.patches().values().all_invalid()?);
163 }
164
165 self.patches().values().all_invalid()
166 }
167
168 fn _validity_mask(&self) -> VortexResult<Mask> {
169 let indices = self.patches().indices().to_primitive()?;
170
171 if self.fill_scalar().is_null() {
172 let mut buffer = BooleanBufferBuilder::new(self.len());
174 buffer.append_n(self.len(), false);
176
177 match_each_integer_ptype!(indices.ptype(), |$I| {
178 indices.as_slice::<$I>().into_iter().for_each(|&index| {
179 buffer.set_bit(index.try_into().vortex_expect("Failed to cast to usize"), true);
180 });
181 });
182
183 return Ok(Mask::from_buffer(buffer.finish()));
184 }
185
186 let mut buffer = BooleanBufferBuilder::new(self.len());
189 buffer.append_n(self.len(), true);
190
191 let values_validity = self.patches().values().validity_mask()?;
192 match_each_integer_ptype!(indices.ptype(), |$I| {
193 indices.as_slice::<$I>()
194 .into_iter()
195 .enumerate()
196 .for_each(|(patch_idx, &index)| {
197 buffer.set_bit(index.try_into().vortex_expect("failed to cast to usize"), values_validity.value(patch_idx));
198 })
199 });
200
201 Ok(Mask::from_buffer(buffer.finish()))
202 }
203}
204
205impl StatisticsVTable<&SparseArray> for SparseEncoding {
206 fn compute_statistics(&self, array: &SparseArray, stat: Stat) -> VortexResult<StatsSet> {
207 let values = array.patches().clone().into_values();
208 let stats = values.statistics().compute_all(&[stat])?;
209 if array.len() == values.len() {
210 return Ok(stats);
211 }
212
213 let fill_len = array.len() - values.len();
214 let fill_stats = if array.fill_scalar().is_null() {
215 StatsSet::nulls(fill_len)
216 } else {
217 StatsSet::constant(array.fill_scalar().clone(), fill_len)
218 };
219
220 if values.is_empty() {
221 return Ok(fill_stats);
222 }
223
224 Ok(stats.merge_unordered(&fill_stats, array.dtype()))
225 }
226}
227
228#[cfg(test)]
229mod test {
230 use itertools::Itertools;
231 use vortex_array::IntoArray;
232 use vortex_array::arrays::ConstantArray;
233 use vortex_array::compute::{slice, try_cast};
234 use vortex_buffer::buffer;
235 use vortex_dtype::Nullability::Nullable;
236 use vortex_dtype::{DType, PType};
237 use vortex_error::VortexError;
238 use vortex_scalar::{PrimitiveScalar, Scalar};
239
240 use super::*;
241
242 fn nullable_fill() -> Scalar {
243 Scalar::null(DType::Primitive(PType::I32, Nullable))
244 }
245
246 fn non_nullable_fill() -> Scalar {
247 Scalar::from(42i32)
248 }
249
250 fn sparse_array(fill_value: Scalar) -> ArrayRef {
251 let mut values = buffer![100i32, 200, 300].into_array();
253 values = try_cast(&values, fill_value.dtype()).unwrap();
254
255 SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
256 .unwrap()
257 .into_array()
258 }
259
260 #[test]
261 pub fn test_scalar_at() {
262 let array = sparse_array(nullable_fill());
263
264 assert_eq!(scalar_at(&array, 0).unwrap(), nullable_fill());
265 assert_eq!(scalar_at(&array, 2).unwrap(), Scalar::from(Some(100_i32)));
266 assert_eq!(scalar_at(&array, 5).unwrap(), Scalar::from(Some(200_i32)));
267
268 let error = scalar_at(&array, 10).err().unwrap();
269 let VortexError::OutOfBounds(i, start, stop, _) = error else {
270 unreachable!()
271 };
272 assert_eq!(i, 10);
273 assert_eq!(start, 0);
274 assert_eq!(stop, 10);
275 }
276
277 #[test]
278 pub fn test_scalar_at_again() {
279 let arr = SparseArray::try_new(
280 ConstantArray::new(10u32, 1).into_array(),
281 ConstantArray::new(Scalar::primitive(1234u32, Nullable), 1).into_array(),
282 100,
283 Scalar::null(DType::Primitive(PType::U32, Nullable)),
284 )
285 .unwrap();
286
287 assert_eq!(
288 PrimitiveScalar::try_from(&scalar_at(&arr, 10).unwrap())
289 .unwrap()
290 .typed_value::<u32>(),
291 Some(1234)
292 );
293 assert!(scalar_at(&arr, 0).unwrap().is_null());
294 assert!(scalar_at(&arr, 99).unwrap().is_null());
295 }
296
297 #[test]
298 pub fn scalar_at_sliced() {
299 let sliced = slice(&sparse_array(nullable_fill()), 2, 7).unwrap();
300 assert_eq!(
301 usize::try_from(&scalar_at(&sliced, 0).unwrap()).unwrap(),
302 100
303 );
304 let error = scalar_at(&sliced, 5).err().unwrap();
305 let VortexError::OutOfBounds(i, start, stop, _) = error else {
306 unreachable!()
307 };
308 assert_eq!(i, 5);
309 assert_eq!(start, 0);
310 assert_eq!(stop, 5);
311 }
312
313 #[test]
314 pub fn scalar_at_sliced_twice() {
315 let sliced_once = slice(&sparse_array(nullable_fill()), 1, 8).unwrap();
316 assert_eq!(
317 usize::try_from(&scalar_at(&sliced_once, 1).unwrap()).unwrap(),
318 100
319 );
320 let error = scalar_at(&sliced_once, 7).err().unwrap();
321 let VortexError::OutOfBounds(i, start, stop, _) = error else {
322 unreachable!()
323 };
324 assert_eq!(i, 7);
325 assert_eq!(start, 0);
326 assert_eq!(stop, 7);
327
328 let sliced_twice = slice(&sliced_once, 1, 6).unwrap();
329 assert_eq!(
330 usize::try_from(&scalar_at(&sliced_twice, 3).unwrap()).unwrap(),
331 200
332 );
333 let error2 = scalar_at(&sliced_twice, 5).err().unwrap();
334 let VortexError::OutOfBounds(i, start, stop, _) = error2 else {
335 unreachable!()
336 };
337 assert_eq!(i, 5);
338 assert_eq!(start, 0);
339 assert_eq!(stop, 5);
340 }
341
342 #[test]
343 pub fn sparse_validity_mask() {
344 let array = sparse_array(nullable_fill());
345 assert_eq!(
346 array
347 .validity_mask()
348 .unwrap()
349 .to_boolean_buffer()
350 .iter()
351 .collect_vec(),
352 [
353 false, false, true, false, false, true, false, false, true, false
354 ]
355 );
356 }
357
358 #[test]
359 fn sparse_validity_mask_non_null_fill() {
360 let array = sparse_array(non_nullable_fill());
361 assert!(array.validity_mask().unwrap().all_true());
362 }
363
364 #[test]
365 #[should_panic]
366 fn test_invalid_length() {
367 let values = buffer![15_u32, 135, 13531, 42].into_array();
368 let indices = buffer![10_u64, 11, 50, 100].into_array();
369
370 SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
371 }
372
373 #[test]
374 fn test_valid_length() {
375 let values = buffer![15_u32, 135, 13531, 42].into_array();
376 let indices = buffer![10_u64, 11, 50, 100].into_array();
377
378 SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
379 }
380}