1use std::fmt::Debug;
2
3use vortex_array::arrays::{BooleanBufferBuilder, ConstantArray};
4use vortex_array::compute::{Operator, compare, fill_null, filter, sub_scalar};
5use vortex_array::patches::Patches;
6use vortex_array::stats::{ArrayStats, StatsSetRef};
7use vortex_array::vtable::{ArrayVTable, NotSupported, VTable, ValidityVTable};
8use vortex_array::{Array, ArrayRef, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
9use vortex_buffer::Buffer;
10use vortex_dtype::{DType, Nullability, match_each_integer_ptype};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
12use vortex_mask::{AllOr, Mask};
13use vortex_scalar::Scalar;
14
15mod canonical;
16mod compute;
17mod ops;
18mod serde;
19
20vtable!(Sparse);
21
22impl VTable for SparseVTable {
23 type Array = SparseArray;
24 type Encoding = SparseEncoding;
25
26 type ArrayVTable = Self;
27 type CanonicalVTable = Self;
28 type OperationsVTable = Self;
29 type ValidityVTable = Self;
30 type VisitorVTable = Self;
31 type ComputeVTable = NotSupported;
32 type EncodeVTable = Self;
33 type SerdeVTable = Self;
34
35 fn id(_encoding: &Self::Encoding) -> EncodingId {
36 EncodingId::new_ref("vortex.sparse")
37 }
38
39 fn encoding(_array: &Self::Array) -> EncodingRef {
40 EncodingRef::new_ref(SparseEncoding.as_ref())
41 }
42}
43
44#[derive(Clone, Debug)]
45pub struct SparseArray {
46 patches: Patches,
47 fill_value: Scalar,
48 stats_set: ArrayStats,
49}
50
51#[derive(Clone, Debug)]
52pub struct SparseEncoding;
53
54impl SparseArray {
55 pub fn try_new(
56 indices: ArrayRef,
57 values: ArrayRef,
58 len: usize,
59 fill_value: Scalar,
60 ) -> VortexResult<Self> {
61 Self::try_new_with_offset(indices, values, len, 0, fill_value)
62 }
63
64 pub(crate) fn try_new_with_offset(
65 indices: ArrayRef,
66 values: ArrayRef,
67 len: usize,
68 indices_offset: usize,
69 fill_value: Scalar,
70 ) -> VortexResult<Self> {
71 if indices.len() != values.len() {
72 vortex_bail!(
73 "Mismatched indices {} and values {} length",
74 indices.len(),
75 values.len()
76 );
77 }
78
79 if !indices.is_empty() {
80 let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1)?)?;
81
82 if last_index - indices_offset >= len {
83 vortex_bail!("Array length was set to {len} but the last index is {last_index}");
84 }
85 }
86
87 let patches = Patches::new(len, indices_offset, indices, values);
88
89 Self::try_new_from_patches(patches, fill_value)
90 }
91
92 pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
93 if fill_value.dtype() != patches.values().dtype() {
94 vortex_bail!(
95 "fill value, {:?}, should be instance of values dtype, {}",
96 fill_value,
97 patches.values().dtype(),
98 );
99 }
100 Ok(Self {
101 patches,
102 fill_value,
103 stats_set: Default::default(),
104 })
105 }
106
107 #[inline]
108 pub fn patches(&self) -> &Patches {
109 &self.patches
110 }
111
112 #[inline]
113 pub fn resolved_patches(&self) -> VortexResult<Patches> {
114 let (len, offset, indices, values) = self.patches().clone().into_parts();
115 let indices_offset = Scalar::from(offset).cast(indices.dtype())?;
116 let indices = sub_scalar(&indices, indices_offset)?;
117 Ok(Patches::new(len, 0, indices, values))
118 }
119
120 #[inline]
121 pub fn fill_scalar(&self) -> &Scalar {
122 &self.fill_value
123 }
124
125 pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
129 if let Some(fill_value) = fill_value.as_ref() {
130 if array.dtype() != fill_value.dtype() {
131 vortex_bail!(
132 "Array and fill value types must match. got {} and {}",
133 array.dtype(),
134 fill_value.dtype()
135 )
136 }
137 }
138 let mask = array.validity_mask()?;
139
140 if mask.all_false() {
141 return Ok(
143 ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
144 );
145 } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
146 let non_null_values = filter(array, &mask)?;
148 let non_null_indices = match mask.indices() {
149 AllOr::All => {
150 unreachable!("Mask is mostly null")
152 }
153 AllOr::None => {
154 unreachable!("Mask is mostly null but not all null")
156 }
157 AllOr::Some(values) => {
158 let buffer: Buffer<u32> = values
159 .iter()
160 .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
161 .collect();
162
163 buffer.into_array()
164 }
165 };
166
167 return Ok(SparseArray::try_new(
168 non_null_indices,
169 non_null_values,
170 array.len(),
171 Scalar::null(array.dtype().clone()),
172 )?
173 .into_array());
174 }
175
176 let fill = if let Some(fill) = fill_value {
177 fill
178 } else {
179 let (top_pvalue, _) = array
181 .to_primitive()?
182 .top_value()?
183 .vortex_expect("Non empty or all null array");
184
185 Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
186 };
187
188 let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
189 let non_top_mask = Mask::from_buffer(
190 fill_null(
191 &compare(array, &fill_array, Operator::NotEq)?,
192 &Scalar::bool(true, Nullability::NonNullable),
193 )?
194 .to_bool()?
195 .boolean_buffer()
196 .clone(),
197 );
198
199 let non_top_values = filter(array, &non_top_mask)?;
200
201 let indices: Buffer<u64> = match non_top_mask {
202 Mask::AllTrue(count) => {
203 (0u64..count as u64).collect()
205 }
206 Mask::AllFalse(_) => {
207 return Ok(fill_array);
209 }
210 Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
211 };
212
213 SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
214 .map(|a| a.into_array())
215 }
216}
217
218impl ArrayVTable<SparseVTable> for SparseVTable {
219 fn len(array: &SparseArray) -> usize {
220 array.patches.array_len()
221 }
222
223 fn dtype(array: &SparseArray) -> &DType {
224 array.fill_scalar().dtype()
225 }
226
227 fn stats(array: &SparseArray) -> StatsSetRef<'_> {
228 array.stats_set.to_ref(array.as_ref())
229 }
230}
231
232impl ValidityVTable<SparseVTable> for SparseVTable {
233 fn is_valid(array: &SparseArray, index: usize) -> VortexResult<bool> {
234 Ok(match array.patches().get_patched(index)? {
235 None => array.fill_scalar().is_valid(),
236 Some(patch_value) => patch_value.is_valid(),
237 })
238 }
239
240 fn all_valid(array: &SparseArray) -> VortexResult<bool> {
241 if array.fill_scalar().is_null() {
242 return Ok(array.patches().values().len() == array.len()
244 && array.patches().values().all_valid()?);
245 }
246
247 array.patches().values().all_valid()
248 }
249
250 fn all_invalid(array: &SparseArray) -> VortexResult<bool> {
251 if !array.fill_scalar().is_null() {
252 return Ok(array.patches().values().len() == array.len()
254 && array.patches().values().all_invalid()?);
255 }
256
257 array.patches().values().all_invalid()
258 }
259
260 fn validity_mask(array: &SparseArray) -> VortexResult<Mask> {
261 let indices = array.patches().indices().to_primitive()?;
262
263 if array.fill_scalar().is_null() {
264 let mut buffer = BooleanBufferBuilder::new(array.len());
266 buffer.append_n(array.len(), false);
268
269 match_each_integer_ptype!(indices.ptype(), |$I| {
270 indices.as_slice::<$I>().into_iter().for_each(|&index| {
271 buffer.set_bit(usize::try_from(index).vortex_expect("Failed to cast to usize") - array.patches().offset(), true);
272 });
273 });
274
275 return Ok(Mask::from_buffer(buffer.finish()));
276 }
277
278 let mut buffer = BooleanBufferBuilder::new(array.len());
281 buffer.append_n(array.len(), true);
282
283 let values_validity = array.patches().values().validity_mask()?;
284 match_each_integer_ptype!(indices.ptype(), |$I| {
285 indices.as_slice::<$I>()
286 .into_iter()
287 .enumerate()
288 .for_each(|(patch_idx, &index)| {
289 buffer.set_bit(usize::try_from(index).vortex_expect("Failed to cast to usize") - array.patches().offset(), values_validity.value(patch_idx));
290 })
291 });
292
293 Ok(Mask::from_buffer(buffer.finish()))
294 }
295}
296
297#[cfg(test)]
298mod test {
299 use itertools::Itertools;
300 use vortex_array::IntoArray;
301 use vortex_array::arrays::{ConstantArray, PrimitiveArray};
302 use vortex_array::compute::cast;
303 use vortex_array::validity::Validity;
304 use vortex_buffer::buffer;
305 use vortex_dtype::{DType, Nullability, PType};
306 use vortex_error::{VortexError, VortexUnwrap};
307 use vortex_scalar::{PrimitiveScalar, Scalar};
308
309 use super::*;
310
311 fn nullable_fill() -> Scalar {
312 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
313 }
314
315 fn non_nullable_fill() -> Scalar {
316 Scalar::from(42i32)
317 }
318
319 fn sparse_array(fill_value: Scalar) -> ArrayRef {
320 let mut values = buffer![100i32, 200, 300].into_array();
322 values = cast(&values, fill_value.dtype()).unwrap();
323
324 SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
325 .unwrap()
326 .into_array()
327 }
328
329 #[test]
330 pub fn test_scalar_at() {
331 let array = sparse_array(nullable_fill());
332
333 assert_eq!(array.scalar_at(0).unwrap(), nullable_fill());
334 assert_eq!(array.scalar_at(2).unwrap(), Scalar::from(Some(100_i32)));
335 assert_eq!(array.scalar_at(5).unwrap(), Scalar::from(Some(200_i32)));
336
337 let error = array.scalar_at(10).err().unwrap();
338 let VortexError::OutOfBounds(i, start, stop, _) = error else {
339 unreachable!()
340 };
341 assert_eq!(i, 10);
342 assert_eq!(start, 0);
343 assert_eq!(stop, 10);
344 }
345
346 #[test]
347 pub fn test_scalar_at_again() {
348 let arr = SparseArray::try_new(
349 ConstantArray::new(10u32, 1).into_array(),
350 ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
351 100,
352 Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
353 )
354 .unwrap();
355
356 assert_eq!(
357 PrimitiveScalar::try_from(&arr.scalar_at(10).unwrap())
358 .unwrap()
359 .typed_value::<u32>(),
360 Some(1234)
361 );
362 assert!(arr.scalar_at(0).unwrap().is_null());
363 assert!(arr.scalar_at(99).unwrap().is_null());
364 }
365
366 #[test]
367 pub fn scalar_at_sliced() {
368 let sliced = sparse_array(nullable_fill()).slice(2, 7).unwrap();
369 assert_eq!(usize::try_from(&sliced.scalar_at(0).unwrap()).unwrap(), 100);
370 let error = sliced.scalar_at(5).err().unwrap();
371 let VortexError::OutOfBounds(i, start, stop, _) = error else {
372 unreachable!()
373 };
374 assert_eq!(i, 5);
375 assert_eq!(start, 0);
376 assert_eq!(stop, 5);
377 }
378
379 #[test]
380 pub fn validity_mask_sliced_null_fill() {
381 let sliced = sparse_array(nullable_fill()).slice(2, 7).unwrap();
382 assert_eq!(
383 sliced.validity_mask().unwrap(),
384 Mask::from_iter(vec![true, false, false, true, false])
385 );
386 }
387
388 #[test]
389 pub fn validity_mask_sliced_nonnull_fill() {
390 let sliced = SparseArray::try_new(
391 buffer![2u64, 5, 8].into_array(),
392 ConstantArray::new(
393 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
394 3,
395 )
396 .into_array(),
397 10,
398 Scalar::primitive(1.0f32, Nullability::Nullable),
399 )
400 .unwrap()
401 .slice(2, 7)
402 .unwrap();
403
404 assert_eq!(
405 sliced.validity_mask().unwrap(),
406 Mask::from_iter(vec![false, true, true, false, true])
407 );
408 }
409
410 #[test]
411 pub fn scalar_at_sliced_twice() {
412 let sliced_once = sparse_array(nullable_fill()).slice(1, 8).unwrap();
413 assert_eq!(
414 usize::try_from(&sliced_once.scalar_at(1).unwrap()).unwrap(),
415 100
416 );
417 let error = sliced_once.scalar_at(7).err().unwrap();
418 let VortexError::OutOfBounds(i, start, stop, _) = error else {
419 unreachable!()
420 };
421 assert_eq!(i, 7);
422 assert_eq!(start, 0);
423 assert_eq!(stop, 7);
424
425 let sliced_twice = sliced_once.slice(1, 6).unwrap();
426 assert_eq!(
427 usize::try_from(&sliced_twice.scalar_at(3).unwrap()).unwrap(),
428 200
429 );
430 let error2 = sliced_twice.scalar_at(5).err().unwrap();
431 let VortexError::OutOfBounds(i, start, stop, _) = error2 else {
432 unreachable!()
433 };
434 assert_eq!(i, 5);
435 assert_eq!(start, 0);
436 assert_eq!(stop, 5);
437 }
438
439 #[test]
440 pub fn sparse_validity_mask() {
441 let array = sparse_array(nullable_fill());
442 assert_eq!(
443 array
444 .validity_mask()
445 .unwrap()
446 .to_boolean_buffer()
447 .iter()
448 .collect_vec(),
449 [
450 false, false, true, false, false, true, false, false, true, false
451 ]
452 );
453 }
454
455 #[test]
456 fn sparse_validity_mask_non_null_fill() {
457 let array = sparse_array(non_nullable_fill());
458 assert!(array.validity_mask().unwrap().all_true());
459 }
460
461 #[test]
462 #[should_panic]
463 fn test_invalid_length() {
464 let values = buffer![15_u32, 135, 13531, 42].into_array();
465 let indices = buffer![10_u64, 11, 50, 100].into_array();
466
467 SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
468 }
469
470 #[test]
471 fn test_valid_length() {
472 let values = buffer![15_u32, 135, 13531, 42].into_array();
473 let indices = buffer![10_u64, 11, 50, 100].into_array();
474
475 SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
476 }
477
478 #[test]
479 fn encode_with_nulls() {
480 let sparse = SparseArray::encode(
481 &PrimitiveArray::new(
482 buffer![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
483 Validity::from_iter(vec![
484 true, true, false, true, false, true, false, true, true, false, true, false,
485 ]),
486 )
487 .into_array(),
488 None,
489 )
490 .vortex_unwrap();
491 let canonical = sparse.to_primitive().vortex_unwrap();
492 assert_eq!(
493 sparse.validity_mask().unwrap(),
494 Mask::from_iter(vec![
495 true, true, false, true, false, true, false, true, true, false, true, false,
496 ])
497 );
498 assert_eq!(
499 canonical.as_slice::<i32>(),
500 vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
501 );
502 }
503}