1use std::fmt::Debug;
2
3use vortex_array::arrays::{BooleanBufferBuilder, ConstantArray};
4use vortex_array::compute::{Operator, compare, fill_null, filter, sub_scalar};
5use vortex_array::patches::Patches;
6use vortex_array::stats::{ArrayStats, StatsSetRef};
7use vortex_array::vtable::{ArrayVTable, NotSupported, VTable, ValidityVTable};
8use vortex_array::{Array, ArrayRef, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
9use vortex_buffer::Buffer;
10use vortex_dtype::{DType, Nullability, match_each_integer_ptype};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
12use vortex_mask::{AllOr, Mask};
13use vortex_scalar::Scalar;
14
15mod canonical;
16mod compute;
17mod ops;
18mod serde;
19
20vtable!(Sparse);
21
22impl VTable for SparseVTable {
23 type Array = SparseArray;
24 type Encoding = SparseEncoding;
25
26 type ArrayVTable = Self;
27 type CanonicalVTable = Self;
28 type OperationsVTable = Self;
29 type ValidityVTable = Self;
30 type VisitorVTable = Self;
31 type ComputeVTable = NotSupported;
32 type EncodeVTable = Self;
33 type SerdeVTable = Self;
34
35 fn id(_encoding: &Self::Encoding) -> EncodingId {
36 EncodingId::new_ref("vortex.sparse")
37 }
38
39 fn encoding(_array: &Self::Array) -> EncodingRef {
40 EncodingRef::new_ref(SparseEncoding.as_ref())
41 }
42}
43
44#[derive(Clone, Debug)]
45pub struct SparseArray {
46 patches: Patches,
47 fill_value: Scalar,
48 stats_set: ArrayStats,
49}
50
51#[derive(Clone, Debug)]
52pub struct SparseEncoding;
53
54impl SparseArray {
55 pub fn try_new(
56 indices: ArrayRef,
57 values: ArrayRef,
58 len: usize,
59 fill_value: Scalar,
60 ) -> VortexResult<Self> {
61 Self::try_new_with_offset(indices, values, len, 0, fill_value)
62 }
63
64 pub(crate) fn try_new_with_offset(
65 indices: ArrayRef,
66 values: ArrayRef,
67 len: usize,
68 indices_offset: usize,
69 fill_value: Scalar,
70 ) -> VortexResult<Self> {
71 if indices.len() != values.len() {
72 vortex_bail!(
73 "Mismatched indices {} and values {} length",
74 indices.len(),
75 values.len()
76 );
77 }
78
79 if !indices.is_empty() {
80 let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1)?)?;
81
82 if last_index - indices_offset >= len {
83 vortex_bail!("Array length was set to {len} but the last index is {last_index}");
84 }
85 }
86
87 let patches = Patches::new(len, indices_offset, indices, values);
88
89 Self::try_new_from_patches(patches, fill_value)
90 }
91
92 pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
93 if fill_value.dtype() != patches.values().dtype() {
94 vortex_bail!(
95 "fill value, {:?}, should be instance of values dtype, {}",
96 fill_value,
97 patches.values().dtype(),
98 );
99 }
100 Ok(Self {
101 patches,
102 fill_value,
103 stats_set: Default::default(),
104 })
105 }
106
107 #[inline]
108 pub fn patches(&self) -> &Patches {
109 &self.patches
110 }
111
112 #[inline]
113 pub fn resolved_patches(&self) -> VortexResult<Patches> {
114 let (len, offset, indices, values) = self.patches().clone().into_parts();
115 let indices_offset = Scalar::from(offset).cast(indices.dtype())?;
116 let indices = sub_scalar(&indices, indices_offset)?;
117 Ok(Patches::new(len, 0, indices, values))
118 }
119
120 #[inline]
121 pub fn fill_scalar(&self) -> &Scalar {
122 &self.fill_value
123 }
124
125 pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
129 if let Some(fill_value) = fill_value.as_ref() {
130 if array.dtype() != fill_value.dtype() {
131 vortex_bail!(
132 "Array and fill value types must match. got {} and {}",
133 array.dtype(),
134 fill_value.dtype()
135 )
136 }
137 }
138 let mask = array.validity_mask()?;
139
140 if mask.all_false() {
141 return Ok(
143 ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
144 );
145 } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
146 let non_null_values = filter(array, &mask)?;
148 let non_null_indices = match mask.indices() {
149 AllOr::All => {
150 unreachable!("Mask is mostly null")
152 }
153 AllOr::None => {
154 unreachable!("Mask is mostly null but not all null")
156 }
157 AllOr::Some(values) => {
158 let buffer: Buffer<u32> = values
159 .iter()
160 .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
161 .collect();
162
163 buffer.into_array()
164 }
165 };
166
167 return Ok(SparseArray::try_new(
168 non_null_indices,
169 non_null_values,
170 array.len(),
171 Scalar::null(array.dtype().clone()),
172 )?
173 .into_array());
174 }
175
176 let fill = if let Some(fill) = fill_value {
177 fill
178 } else {
179 let (top_pvalue, _) = array
181 .to_primitive()?
182 .top_value()?
183 .vortex_expect("Non empty or all null array");
184
185 Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
186 };
187
188 let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
189 let non_top_mask = Mask::from_buffer(
190 fill_null(
191 &compare(array, &fill_array, Operator::NotEq)?,
192 &Scalar::bool(true, Nullability::NonNullable),
193 )?
194 .to_bool()?
195 .boolean_buffer()
196 .clone(),
197 );
198
199 let non_top_values = filter(array, &non_top_mask)?;
200
201 let indices: Buffer<u64> = match non_top_mask {
202 Mask::AllTrue(count) => {
203 (0u64..count as u64).collect()
205 }
206 Mask::AllFalse(_) => {
207 return Ok(fill_array);
209 }
210 Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
211 };
212
213 SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
214 .map(|a| a.into_array())
215 }
216}
217
218impl ArrayVTable<SparseVTable> for SparseVTable {
219 fn len(array: &SparseArray) -> usize {
220 array.patches.array_len()
221 }
222
223 fn dtype(array: &SparseArray) -> &DType {
224 array.fill_scalar().dtype()
225 }
226
227 fn stats(array: &SparseArray) -> StatsSetRef<'_> {
228 array.stats_set.to_ref(array.as_ref())
229 }
230}
231
232impl ValidityVTable<SparseVTable> for SparseVTable {
233 fn is_valid(array: &SparseArray, index: usize) -> VortexResult<bool> {
234 Ok(match array.patches().get_patched(index)? {
235 None => array.fill_scalar().is_valid(),
236 Some(patch_value) => patch_value.is_valid(),
237 })
238 }
239
240 fn all_valid(array: &SparseArray) -> VortexResult<bool> {
241 if array.fill_scalar().is_null() {
242 return Ok(array.patches().values().len() == array.len()
244 && array.patches().values().all_valid()?);
245 }
246
247 array.patches().values().all_valid()
248 }
249
250 fn all_invalid(array: &SparseArray) -> VortexResult<bool> {
251 if !array.fill_scalar().is_null() {
252 return Ok(array.patches().values().len() == array.len()
254 && array.patches().values().all_invalid()?);
255 }
256
257 array.patches().values().all_invalid()
258 }
259
260 #[allow(clippy::unnecessary_fallible_conversions)]
261 fn validity_mask(array: &SparseArray) -> VortexResult<Mask> {
262 let indices = array.patches().indices().to_primitive()?;
263
264 if array.fill_scalar().is_null() {
265 let mut buffer = BooleanBufferBuilder::new(array.len());
267 buffer.append_n(array.len(), false);
269
270 match_each_integer_ptype!(indices.ptype(), |I| {
271 indices.as_slice::<I>().iter().for_each(|&index| {
272 buffer.set_bit(
273 usize::try_from(index).vortex_expect("Failed to cast to usize")
274 - array.patches().offset(),
275 true,
276 );
277 });
278 });
279
280 return Ok(Mask::from_buffer(buffer.finish()));
281 }
282
283 let mut buffer = BooleanBufferBuilder::new(array.len());
286 buffer.append_n(array.len(), true);
287
288 let values_validity = array.patches().values().validity_mask()?;
289 match_each_integer_ptype!(indices.ptype(), |I| {
290 indices
291 .as_slice::<I>()
292 .iter()
293 .enumerate()
294 .for_each(|(patch_idx, &index)| {
295 buffer.set_bit(
296 usize::try_from(index).vortex_expect("Failed to cast to usize")
297 - array.patches().offset(),
298 values_validity.value(patch_idx),
299 );
300 })
301 });
302
303 Ok(Mask::from_buffer(buffer.finish()))
304 }
305}
306
307#[cfg(test)]
308mod test {
309 use itertools::Itertools;
310 use vortex_array::IntoArray;
311 use vortex_array::arrays::{ConstantArray, PrimitiveArray};
312 use vortex_array::compute::cast;
313 use vortex_array::validity::Validity;
314 use vortex_buffer::buffer;
315 use vortex_dtype::{DType, Nullability, PType};
316 use vortex_error::{VortexError, VortexUnwrap};
317 use vortex_scalar::{PrimitiveScalar, Scalar};
318
319 use super::*;
320
321 fn nullable_fill() -> Scalar {
322 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
323 }
324
325 fn non_nullable_fill() -> Scalar {
326 Scalar::from(42i32)
327 }
328
329 fn sparse_array(fill_value: Scalar) -> ArrayRef {
330 let mut values = buffer![100i32, 200, 300].into_array();
332 values = cast(&values, fill_value.dtype()).unwrap();
333
334 SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
335 .unwrap()
336 .into_array()
337 }
338
339 #[test]
340 pub fn test_scalar_at() {
341 let array = sparse_array(nullable_fill());
342
343 assert_eq!(array.scalar_at(0).unwrap(), nullable_fill());
344 assert_eq!(array.scalar_at(2).unwrap(), Scalar::from(Some(100_i32)));
345 assert_eq!(array.scalar_at(5).unwrap(), Scalar::from(Some(200_i32)));
346
347 let error = array.scalar_at(10).err().unwrap();
348 let VortexError::OutOfBounds(i, start, stop, _) = error else {
349 unreachable!()
350 };
351 assert_eq!(i, 10);
352 assert_eq!(start, 0);
353 assert_eq!(stop, 10);
354 }
355
356 #[test]
357 pub fn test_scalar_at_again() {
358 let arr = SparseArray::try_new(
359 ConstantArray::new(10u32, 1).into_array(),
360 ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
361 100,
362 Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
363 )
364 .unwrap();
365
366 assert_eq!(
367 PrimitiveScalar::try_from(&arr.scalar_at(10).unwrap())
368 .unwrap()
369 .typed_value::<u32>(),
370 Some(1234)
371 );
372 assert!(arr.scalar_at(0).unwrap().is_null());
373 assert!(arr.scalar_at(99).unwrap().is_null());
374 }
375
376 #[test]
377 pub fn scalar_at_sliced() {
378 let sliced = sparse_array(nullable_fill()).slice(2, 7).unwrap();
379 assert_eq!(usize::try_from(&sliced.scalar_at(0).unwrap()).unwrap(), 100);
380 let error = sliced.scalar_at(5).err().unwrap();
381 let VortexError::OutOfBounds(i, start, stop, _) = error else {
382 unreachable!()
383 };
384 assert_eq!(i, 5);
385 assert_eq!(start, 0);
386 assert_eq!(stop, 5);
387 }
388
389 #[test]
390 pub fn validity_mask_sliced_null_fill() {
391 let sliced = sparse_array(nullable_fill()).slice(2, 7).unwrap();
392 assert_eq!(
393 sliced.validity_mask().unwrap(),
394 Mask::from_iter(vec![true, false, false, true, false])
395 );
396 }
397
398 #[test]
399 pub fn validity_mask_sliced_nonnull_fill() {
400 let sliced = SparseArray::try_new(
401 buffer![2u64, 5, 8].into_array(),
402 ConstantArray::new(
403 Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
404 3,
405 )
406 .into_array(),
407 10,
408 Scalar::primitive(1.0f32, Nullability::Nullable),
409 )
410 .unwrap()
411 .slice(2, 7)
412 .unwrap();
413
414 assert_eq!(
415 sliced.validity_mask().unwrap(),
416 Mask::from_iter(vec![false, true, true, false, true])
417 );
418 }
419
420 #[test]
421 pub fn scalar_at_sliced_twice() {
422 let sliced_once = sparse_array(nullable_fill()).slice(1, 8).unwrap();
423 assert_eq!(
424 usize::try_from(&sliced_once.scalar_at(1).unwrap()).unwrap(),
425 100
426 );
427 let error = sliced_once.scalar_at(7).err().unwrap();
428 let VortexError::OutOfBounds(i, start, stop, _) = error else {
429 unreachable!()
430 };
431 assert_eq!(i, 7);
432 assert_eq!(start, 0);
433 assert_eq!(stop, 7);
434
435 let sliced_twice = sliced_once.slice(1, 6).unwrap();
436 assert_eq!(
437 usize::try_from(&sliced_twice.scalar_at(3).unwrap()).unwrap(),
438 200
439 );
440 let error2 = sliced_twice.scalar_at(5).err().unwrap();
441 let VortexError::OutOfBounds(i, start, stop, _) = error2 else {
442 unreachable!()
443 };
444 assert_eq!(i, 5);
445 assert_eq!(start, 0);
446 assert_eq!(stop, 5);
447 }
448
449 #[test]
450 pub fn sparse_validity_mask() {
451 let array = sparse_array(nullable_fill());
452 assert_eq!(
453 array
454 .validity_mask()
455 .unwrap()
456 .to_boolean_buffer()
457 .iter()
458 .collect_vec(),
459 [
460 false, false, true, false, false, true, false, false, true, false
461 ]
462 );
463 }
464
465 #[test]
466 fn sparse_validity_mask_non_null_fill() {
467 let array = sparse_array(non_nullable_fill());
468 assert!(array.validity_mask().unwrap().all_true());
469 }
470
471 #[test]
472 #[should_panic]
473 fn test_invalid_length() {
474 let values = buffer![15_u32, 135, 13531, 42].into_array();
475 let indices = buffer![10_u64, 11, 50, 100].into_array();
476
477 SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
478 }
479
480 #[test]
481 fn test_valid_length() {
482 let values = buffer![15_u32, 135, 13531, 42].into_array();
483 let indices = buffer![10_u64, 11, 50, 100].into_array();
484
485 SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
486 }
487
488 #[test]
489 fn encode_with_nulls() {
490 let sparse = SparseArray::encode(
491 &PrimitiveArray::new(
492 buffer![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
493 Validity::from_iter(vec![
494 true, true, false, true, false, true, false, true, true, false, true, false,
495 ]),
496 )
497 .into_array(),
498 None,
499 )
500 .vortex_unwrap();
501 let canonical = sparse.to_primitive().vortex_unwrap();
502 assert_eq!(
503 sparse.validity_mask().unwrap(),
504 Mask::from_iter(vec![
505 true, true, false, true, false, true, false, true, true, false, true, false,
506 ])
507 );
508 assert_eq!(
509 canonical.as_slice::<i32>(),
510 vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
511 );
512 }
513}