vortex_sparse/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6
7use itertools::Itertools as _;
8use num_traits::AsPrimitive;
9use prost::Message as _;
10use vortex_array::Array;
11use vortex_array::ArrayBufferVisitor;
12use vortex_array::ArrayChildVisitor;
13use vortex_array::ArrayEq;
14use vortex_array::ArrayHash;
15use vortex_array::ArrayRef;
16use vortex_array::Canonical;
17use vortex_array::IntoArray;
18use vortex_array::Precision;
19use vortex_array::ProstMetadata;
20use vortex_array::ToCanonical;
21use vortex_array::arrays::ConstantArray;
22use vortex_array::buffer::BufferHandle;
23use vortex_array::compute::Operator;
24use vortex_array::compute::compare;
25use vortex_array::compute::fill_null;
26use vortex_array::compute::filter;
27use vortex_array::compute::sub_scalar;
28use vortex_array::patches::Patches;
29use vortex_array::patches::PatchesMetadata;
30use vortex_array::serde::ArrayChildren;
31use vortex_array::stats::ArrayStats;
32use vortex_array::stats::StatsSetRef;
33use vortex_array::validity::Validity;
34use vortex_array::vtable;
35use vortex_array::vtable::ArrayId;
36use vortex_array::vtable::ArrayVTable;
37use vortex_array::vtable::ArrayVTableExt;
38use vortex_array::vtable::BaseArrayVTable;
39use vortex_array::vtable::EncodeVTable;
40use vortex_array::vtable::NotSupported;
41use vortex_array::vtable::VTable;
42use vortex_array::vtable::ValidityVTable;
43use vortex_array::vtable::VisitorVTable;
44use vortex_buffer::BitBufferMut;
45use vortex_buffer::Buffer;
46use vortex_buffer::ByteBufferMut;
47use vortex_dtype::DType;
48use vortex_dtype::NativePType;
49use vortex_dtype::Nullability;
50use vortex_dtype::match_each_integer_ptype;
51use vortex_error::VortexExpect as _;
52use vortex_error::VortexResult;
53use vortex_error::vortex_bail;
54use vortex_error::vortex_ensure;
55use vortex_mask::AllOr;
56use vortex_mask::Mask;
57use vortex_scalar::Scalar;
58use vortex_scalar::ScalarValue;
59
60mod canonical;
61mod compute;
62mod ops;
63
64vtable!(Sparse);
65
66#[derive(Clone, prost::Message)]
67#[repr(C)]
68pub struct SparseMetadata {
69    #[prost(message, required, tag = "1")]
70    patches: PatchesMetadata,
71}
72
73impl VTable for SparseVTable {
74    type Array = SparseArray;
75
76    type Metadata = ProstMetadata<SparseMetadata>;
77
78    type ArrayVTable = Self;
79    type CanonicalVTable = Self;
80    type OperationsVTable = Self;
81    type ValidityVTable = Self;
82    type VisitorVTable = Self;
83    type ComputeVTable = NotSupported;
84    type EncodeVTable = Self;
85
86    fn id(&self) -> ArrayId {
87        ArrayId::new_ref("vortex.sparse")
88    }
89
90    fn encoding(_array: &Self::Array) -> ArrayVTable {
91        SparseVTable.as_vtable()
92    }
93
94    fn metadata(array: &SparseArray) -> VortexResult<Self::Metadata> {
95        Ok(ProstMetadata(SparseMetadata {
96            patches: array.patches().to_metadata(array.len(), array.dtype())?,
97        }))
98    }
99
100    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
101        Ok(Some(metadata.0.encode_to_vec()))
102    }
103
104    fn deserialize(buffer: &[u8]) -> VortexResult<Self::Metadata> {
105        Ok(ProstMetadata(SparseMetadata::decode(buffer)?))
106    }
107
108    fn build(
109        &self,
110        dtype: &DType,
111        len: usize,
112        metadata: &Self::Metadata,
113        buffers: &[BufferHandle],
114        children: &dyn ArrayChildren,
115    ) -> VortexResult<SparseArray> {
116        if children.len() != 2 {
117            vortex_bail!(
118                "Expected 2 children for sparse encoding, found {}",
119                children.len()
120            )
121        }
122        assert_eq!(
123            metadata.0.patches.offset(),
124            0,
125            "Patches must start at offset 0"
126        );
127
128        let patch_indices = children.get(
129            0,
130            &metadata.0.patches.indices_dtype(),
131            metadata.0.patches.len(),
132        )?;
133        let patch_values = children.get(1, dtype, metadata.0.patches.len())?;
134
135        if buffers.len() != 1 {
136            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
137        }
138        let fill_value = Scalar::new(
139            dtype.clone(),
140            ScalarValue::from_protobytes(&buffers[0].clone().try_to_bytes()?)?,
141        );
142
143        SparseArray::try_new(patch_indices, patch_values, len, fill_value)
144    }
145
146    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
147        vortex_ensure!(
148            children.len() == 2,
149            "SparseArray expects 2 children, got {}",
150            children.len()
151        );
152
153        let mut children_iter = children.into_iter();
154        let patch_indices = children_iter.next().vortex_expect("patch_indices child");
155        let patch_values = children_iter.next().vortex_expect("patch_values child");
156
157        array.patches = Patches::new(
158            array.patches.array_len(),
159            array.patches.offset(),
160            patch_indices,
161            patch_values,
162            array.patches.chunk_offsets().clone(),
163        );
164
165        Ok(())
166    }
167}
168
169#[derive(Clone, Debug)]
170pub struct SparseArray {
171    patches: Patches,
172    fill_value: Scalar,
173    stats_set: ArrayStats,
174}
175
176#[derive(Debug)]
177pub struct SparseVTable;
178
179impl SparseArray {
180    pub fn try_new(
181        indices: ArrayRef,
182        values: ArrayRef,
183        len: usize,
184        fill_value: Scalar,
185    ) -> VortexResult<Self> {
186        vortex_ensure!(
187            indices.len() == values.len(),
188            "Mismatched indices {} and values {} length",
189            indices.len(),
190            values.len()
191        );
192
193        vortex_ensure!(
194            indices.statistics().compute_is_strict_sorted() == Some(true),
195            "SparseArray: indices must be strict-sorted"
196        );
197
198        // Verify the indices are all in the valid range
199        if !indices.is_empty() {
200            let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1))?;
201
202            vortex_ensure!(
203                last_index < len,
204                "Array length was {len} but the last index is {last_index}"
205            );
206        }
207
208        Ok(Self {
209            // TODO(0ax1): handle chunk offsets
210            patches: Patches::new(len, 0, indices, values, None),
211            fill_value,
212            stats_set: Default::default(),
213        })
214    }
215
216    /// Build a new SparseArray from an existing set of patches.
217    pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
218        vortex_ensure!(
219            fill_value.dtype() == patches.values().dtype(),
220            "fill value, {:?}, should be instance of values dtype, {} but was {}.",
221            fill_value,
222            patches.values().dtype(),
223            fill_value.dtype(),
224        );
225
226        Ok(Self {
227            patches,
228            fill_value,
229            stats_set: Default::default(),
230        })
231    }
232
233    pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
234        Self {
235            patches,
236            fill_value,
237            stats_set: Default::default(),
238        }
239    }
240
241    #[inline]
242    pub fn patches(&self) -> &Patches {
243        &self.patches
244    }
245
246    #[inline]
247    pub fn resolved_patches(&self) -> Patches {
248        let patches = self.patches();
249        let indices_offset = Scalar::from(patches.offset())
250            .cast(patches.indices().dtype())
251            .vortex_expect("Patches offset must cast to the indices dtype");
252        let indices = sub_scalar(patches.indices(), indices_offset)
253            .vortex_expect("must be able to subtract offset from indices");
254
255        Patches::new(
256            patches.array_len(),
257            0,
258            indices,
259            patches.values().clone(),
260            // TODO(0ax1): handle chunk offsets
261            None,
262        )
263    }
264
265    #[inline]
266    pub fn fill_scalar(&self) -> &Scalar {
267        &self.fill_value
268    }
269
270    /// Encode given array as a SparseArray.
271    ///
272    /// Optionally provided fill value will be respected if the array is less than 90% null.
273    pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
274        if let Some(fill_value) = fill_value.as_ref()
275            && array.dtype() != fill_value.dtype()
276        {
277            vortex_bail!(
278                "Array and fill value types must match. got {} and {}",
279                array.dtype(),
280                fill_value.dtype()
281            )
282        }
283        let mask = array.validity_mask();
284
285        if mask.all_false() {
286            // Array is constant NULL
287            return Ok(
288                ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
289            );
290        } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
291            // Array is dominated by NULL but has non-NULL values
292            let non_null_values = filter(array, &mask)?;
293            let non_null_indices = match mask.indices() {
294                AllOr::All => {
295                    // We already know that the mask is 90%+ false
296                    unreachable!("Mask is mostly null")
297                }
298                AllOr::None => {
299                    // we know there are some non-NULL values
300                    unreachable!("Mask is mostly null but not all null")
301                }
302                AllOr::Some(values) => {
303                    let buffer: Buffer<u32> = values
304                        .iter()
305                        .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
306                        .collect();
307
308                    buffer.into_array()
309                }
310            };
311
312            return Ok(SparseArray::try_new(
313                non_null_indices,
314                non_null_values,
315                array.len(),
316                Scalar::null(array.dtype().clone()),
317            )?
318            .into_array());
319        }
320
321        let fill = if let Some(fill) = fill_value {
322            fill
323        } else {
324            // TODO(robert): Support other dtypes, only thing missing is getting most common value out of the array
325            let (top_pvalue, _) = array
326                .to_primitive()
327                .top_value()?
328                .vortex_expect("Non empty or all null array");
329
330            Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
331        };
332
333        let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
334        let non_top_mask = Mask::from_buffer(
335            fill_null(
336                &compare(array, &fill_array, Operator::NotEq)?,
337                &Scalar::bool(true, Nullability::NonNullable),
338            )?
339            .to_bool()
340            .bit_buffer()
341            .clone(),
342        );
343
344        let non_top_values = filter(array, &non_top_mask)?;
345
346        let indices: Buffer<u64> = match non_top_mask {
347            Mask::AllTrue(count) => {
348                // all true -> complete slice
349                (0u64..count as u64).collect()
350            }
351            Mask::AllFalse(_) => {
352                // All values are equal to the top value
353                return Ok(fill_array);
354            }
355            Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
356        };
357
358        SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
359            .map(|a| a.into_array())
360    }
361}
362
363impl BaseArrayVTable<SparseVTable> for SparseVTable {
364    fn len(array: &SparseArray) -> usize {
365        array.patches.array_len()
366    }
367
368    fn dtype(array: &SparseArray) -> &DType {
369        array.fill_scalar().dtype()
370    }
371
372    fn stats(array: &SparseArray) -> StatsSetRef<'_> {
373        array.stats_set.to_ref(array.as_ref())
374    }
375
376    fn array_hash<H: std::hash::Hasher>(array: &SparseArray, state: &mut H, precision: Precision) {
377        array.patches.array_hash(state, precision);
378        array.fill_value.hash(state);
379    }
380
381    fn array_eq(array: &SparseArray, other: &SparseArray, precision: Precision) -> bool {
382        array.patches.array_eq(&other.patches, precision) && array.fill_value == other.fill_value
383    }
384}
385
386impl ValidityVTable<SparseVTable> for SparseVTable {
387    fn is_valid(array: &SparseArray, index: usize) -> bool {
388        match array.patches().get_patched(index) {
389            None => array.fill_scalar().is_valid(),
390            Some(patch_value) => patch_value.is_valid(),
391        }
392    }
393
394    fn all_valid(array: &SparseArray) -> bool {
395        if array.fill_scalar().is_null() {
396            // We need _all_ values to be patched, and all patches to be valid
397            return array.patches().values().len() == array.len()
398                && array.patches().values().all_valid();
399        }
400
401        array.patches().values().all_valid()
402    }
403
404    fn all_invalid(array: &SparseArray) -> bool {
405        if !array.fill_scalar().is_null() {
406            // We need _all_ values to be patched, and all patches to be invalid
407            return array.patches().values().len() == array.len()
408                && array.patches().values().all_invalid();
409        }
410
411        array.patches().values().all_invalid()
412    }
413
414    fn validity(array: &SparseArray) -> VortexResult<Validity> {
415        let patches = unsafe {
416            Patches::new_unchecked(
417                array.patches.array_len(),
418                array.patches.offset(),
419                array.patches.indices().clone(),
420                array
421                    .patches
422                    .values()
423                    .validity()?
424                    .to_array(array.patches.values().len()),
425                array.patches.chunk_offsets().clone(),
426                array.patches.offset_within_chunk(),
427            )
428        };
429
430        Ok(Validity::Array(
431            unsafe { SparseArray::new_unchecked(patches, array.fill_value.is_valid().into()) }
432                .into_array(),
433        ))
434    }
435
436    fn validity_mask(array: &SparseArray) -> Mask {
437        let fill_is_valid = array.fill_scalar().is_valid();
438        let values_validity = array.patches().values().validity_mask();
439        let len = array.len();
440
441        if matches!(values_validity, Mask::AllTrue(_)) && fill_is_valid {
442            return Mask::AllTrue(len);
443        }
444        if matches!(values_validity, Mask::AllFalse(_)) && !fill_is_valid {
445            return Mask::AllFalse(len);
446        }
447
448        let mut is_valid_buffer = if fill_is_valid {
449            BitBufferMut::new_set(len)
450        } else {
451            BitBufferMut::new_unset(len)
452        };
453
454        let indices = array.patches().indices().to_primitive();
455        let index_offset = array.patches().offset();
456
457        match_each_integer_ptype!(indices.ptype(), |I| {
458            let indices = indices.as_slice::<I>();
459            patch_validity(&mut is_valid_buffer, indices, index_offset, values_validity);
460        });
461
462        Mask::from_buffer(is_valid_buffer.freeze())
463    }
464}
465
466fn patch_validity<I: NativePType + AsPrimitive<usize>>(
467    is_valid_buffer: &mut BitBufferMut,
468    indices: &[I],
469    index_offset: usize,
470    values_validity: Mask,
471) {
472    let indices = indices.iter().map(|index| index.as_() - index_offset);
473    match values_validity {
474        Mask::AllTrue(_) => {
475            for index in indices {
476                is_valid_buffer.set(index);
477            }
478        }
479        Mask::AllFalse(_) => {
480            for index in indices {
481                is_valid_buffer.unset(index);
482            }
483        }
484        Mask::Values(mask_values) => {
485            let is_valid = mask_values.bit_buffer().iter();
486            for (index, is_valid) in indices.zip_eq(is_valid) {
487                is_valid_buffer.set_to(index, is_valid);
488            }
489        }
490    }
491}
492
493impl EncodeVTable<SparseVTable> for SparseVTable {
494    fn encode(
495        _vtable: &SparseVTable,
496        input: &Canonical,
497        like: Option<&SparseArray>,
498    ) -> VortexResult<Option<SparseArray>> {
499        // Try and cast the "like" fill value into the array's type. This is useful for cases where we narrow the arrays type.
500        let fill_value = like.and_then(|arr| arr.fill_scalar().cast(input.as_ref().dtype()).ok());
501
502        // TODO(ngates): encode should only handle arrays that _can_ be made sparse.
503        Ok(SparseArray::encode(input.as_ref(), fill_value)?
504            .as_opt::<SparseVTable>()
505            .cloned())
506    }
507}
508
509impl VisitorVTable<SparseVTable> for SparseVTable {
510    fn visit_buffers(array: &SparseArray, visitor: &mut dyn ArrayBufferVisitor) {
511        let fill_value_buffer = array
512            .fill_value
513            .value()
514            .to_protobytes::<ByteBufferMut>()
515            .freeze();
516        visitor.visit_buffer(&fill_value_buffer);
517    }
518
519    fn visit_children(array: &SparseArray, visitor: &mut dyn ArrayChildVisitor) {
520        visitor.visit_patches(array.patches())
521    }
522}
523
524#[cfg(test)]
525mod test {
526    use itertools::Itertools;
527    use vortex_array::IntoArray;
528    use vortex_array::arrays::ConstantArray;
529    use vortex_array::arrays::PrimitiveArray;
530    use vortex_array::compute::cast;
531    use vortex_array::validity::Validity;
532    use vortex_buffer::buffer;
533    use vortex_dtype::DType;
534    use vortex_dtype::Nullability;
535    use vortex_dtype::PType;
536    use vortex_error::VortexExpect;
537    use vortex_scalar::PrimitiveScalar;
538    use vortex_scalar::Scalar;
539
540    use super::*;
541
542    fn nullable_fill() -> Scalar {
543        Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
544    }
545
546    fn non_nullable_fill() -> Scalar {
547        Scalar::from(42i32)
548    }
549
550    fn sparse_array(fill_value: Scalar) -> ArrayRef {
551        // merged array: [null, null, 100, null, null, 200, null, null, 300, null]
552        let mut values = buffer![100i32, 200, 300].into_array();
553        values = cast(&values, fill_value.dtype()).unwrap();
554
555        SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
556            .unwrap()
557            .into_array()
558    }
559
560    #[test]
561    pub fn test_scalar_at() {
562        let array = sparse_array(nullable_fill());
563
564        assert_eq!(array.scalar_at(0), nullable_fill());
565        assert_eq!(array.scalar_at(2), Scalar::from(Some(100_i32)));
566        assert_eq!(array.scalar_at(5), Scalar::from(Some(200_i32)));
567    }
568
569    #[test]
570    #[should_panic(expected = "out of bounds")]
571    fn test_scalar_at_oob() {
572        let array = sparse_array(nullable_fill());
573        array.scalar_at(10);
574    }
575
576    #[test]
577    pub fn test_scalar_at_again() {
578        let arr = SparseArray::try_new(
579            ConstantArray::new(10u32, 1).into_array(),
580            ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
581            100,
582            Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
583        )
584        .unwrap();
585
586        assert_eq!(
587            PrimitiveScalar::try_from(&arr.scalar_at(10))
588                .unwrap()
589                .typed_value::<u32>(),
590            Some(1234)
591        );
592        assert!(arr.scalar_at(0).is_null());
593        assert!(arr.scalar_at(99).is_null());
594    }
595
596    #[test]
597    pub fn scalar_at_sliced() {
598        let sliced = sparse_array(nullable_fill()).slice(2..7);
599        assert_eq!(usize::try_from(&sliced.scalar_at(0)).unwrap(), 100);
600    }
601
602    #[test]
603    pub fn validity_mask_sliced_null_fill() {
604        let sliced = sparse_array(nullable_fill()).slice(2..7);
605        assert_eq!(
606            sliced.validity_mask(),
607            Mask::from_iter(vec![true, false, false, true, false])
608        );
609    }
610
611    #[test]
612    pub fn validity_mask_sliced_nonnull_fill() {
613        let sliced = SparseArray::try_new(
614            buffer![2u64, 5, 8].into_array(),
615            ConstantArray::new(
616                Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
617                3,
618            )
619            .into_array(),
620            10,
621            Scalar::primitive(1.0f32, Nullability::Nullable),
622        )
623        .unwrap()
624        .slice(2..7);
625
626        assert_eq!(
627            sliced.validity_mask(),
628            Mask::from_iter(vec![false, true, true, false, true])
629        );
630    }
631
632    #[test]
633    pub fn scalar_at_sliced_twice() {
634        let sliced_once = sparse_array(nullable_fill()).slice(1..8);
635        assert_eq!(usize::try_from(&sliced_once.scalar_at(1)).unwrap(), 100);
636
637        let sliced_twice = sliced_once.slice(1..6);
638        assert_eq!(usize::try_from(&sliced_twice.scalar_at(3)).unwrap(), 200);
639    }
640
641    #[test]
642    pub fn sparse_validity_mask() {
643        let array = sparse_array(nullable_fill());
644        assert_eq!(
645            array.validity_mask().to_bit_buffer().iter().collect_vec(),
646            [
647                false, false, true, false, false, true, false, false, true, false
648            ]
649        );
650    }
651
652    #[test]
653    fn sparse_validity_mask_non_null_fill() {
654        let array = sparse_array(non_nullable_fill());
655        assert!(array.validity_mask().all_true());
656    }
657
658    #[test]
659    #[should_panic]
660    fn test_invalid_length() {
661        let values = buffer![15_u32, 135, 13531, 42].into_array();
662        let indices = buffer![10_u64, 11, 50, 100].into_array();
663
664        SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
665    }
666
667    #[test]
668    fn test_valid_length() {
669        let values = buffer![15_u32, 135, 13531, 42].into_array();
670        let indices = buffer![10_u64, 11, 50, 100].into_array();
671
672        SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
673    }
674
675    #[test]
676    fn encode_with_nulls() {
677        let sparse = SparseArray::encode(
678            &PrimitiveArray::new(
679                buffer![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
680                Validity::from_iter(vec![
681                    true, true, false, true, false, true, false, true, true, false, true, false,
682                ]),
683            )
684            .into_array(),
685            None,
686        )
687        .vortex_expect("SparseArray::encode should succeed for test data");
688        let canonical = sparse.to_primitive();
689        assert_eq!(
690            sparse.validity_mask(),
691            Mask::from_iter(vec![
692                true, true, false, true, false, true, false, true, true, false, true, false,
693            ])
694        );
695        assert_eq!(
696            canonical.as_slice::<i32>(),
697            vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
698        );
699    }
700
701    #[test]
702    fn validity_mask_includes_null_values_when_fill_is_null() {
703        let indices = buffer![0u8, 2, 4, 6, 8].into_array();
704        let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
705            .into_array();
706        let array = SparseArray::try_new(indices, values, 10, Scalar::null_typed::<i16>()).unwrap();
707        let actual = array.validity_mask();
708        let expected = Mask::from_iter([
709            true, false, true, false, false, false, false, false, true, false,
710        ]);
711
712        assert_eq!(actual, expected);
713    }
714}