vortex_sparse/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6
7use itertools::Itertools as _;
8use num_traits::AsPrimitive;
9use prost::Message as _;
10use vortex_array::Array;
11use vortex_array::ArrayBufferVisitor;
12use vortex_array::ArrayChildVisitor;
13use vortex_array::ArrayEq;
14use vortex_array::ArrayHash;
15use vortex_array::ArrayRef;
16use vortex_array::Canonical;
17use vortex_array::IntoArray;
18use vortex_array::Precision;
19use vortex_array::ProstMetadata;
20use vortex_array::ToCanonical;
21use vortex_array::arrays::ConstantArray;
22use vortex_array::compute::Operator;
23use vortex_array::compute::compare;
24use vortex_array::compute::fill_null;
25use vortex_array::compute::filter;
26use vortex_array::compute::sub_scalar;
27use vortex_array::patches::Patches;
28use vortex_array::patches::PatchesMetadata;
29use vortex_array::serde::ArrayChildren;
30use vortex_array::stats::ArrayStats;
31use vortex_array::stats::StatsSetRef;
32use vortex_array::vtable;
33use vortex_array::vtable::ArrayId;
34use vortex_array::vtable::ArrayVTable;
35use vortex_array::vtable::ArrayVTableExt;
36use vortex_array::vtable::BaseArrayVTable;
37use vortex_array::vtable::EncodeVTable;
38use vortex_array::vtable::NotSupported;
39use vortex_array::vtable::VTable;
40use vortex_array::vtable::ValidityVTable;
41use vortex_array::vtable::VisitorVTable;
42use vortex_buffer::BitBufferMut;
43use vortex_buffer::Buffer;
44use vortex_buffer::BufferHandle;
45use vortex_buffer::ByteBufferMut;
46use vortex_dtype::DType;
47use vortex_dtype::NativePType;
48use vortex_dtype::Nullability;
49use vortex_dtype::match_each_integer_ptype;
50use vortex_error::VortexExpect as _;
51use vortex_error::VortexResult;
52use vortex_error::vortex_bail;
53use vortex_error::vortex_ensure;
54use vortex_mask::AllOr;
55use vortex_mask::Mask;
56use vortex_scalar::Scalar;
57use vortex_scalar::ScalarValue;
58
59mod canonical;
60mod compute;
61mod ops;
62
63vtable!(Sparse);
64
65#[derive(Clone, prost::Message)]
66#[repr(C)]
67pub struct SparseMetadata {
68    #[prost(message, required, tag = "1")]
69    patches: PatchesMetadata,
70}
71
72impl VTable for SparseVTable {
73    type Array = SparseArray;
74
75    type Metadata = ProstMetadata<SparseMetadata>;
76
77    type ArrayVTable = Self;
78    type CanonicalVTable = Self;
79    type OperationsVTable = Self;
80    type ValidityVTable = Self;
81    type VisitorVTable = Self;
82    type ComputeVTable = NotSupported;
83    type EncodeVTable = Self;
84
85    fn id(&self) -> ArrayId {
86        ArrayId::new_ref("vortex.sparse")
87    }
88
89    fn encoding(_array: &Self::Array) -> ArrayVTable {
90        SparseVTable.as_vtable()
91    }
92
93    fn metadata(array: &SparseArray) -> VortexResult<Self::Metadata> {
94        Ok(ProstMetadata(SparseMetadata {
95            patches: array.patches().to_metadata(array.len(), array.dtype())?,
96        }))
97    }
98
99    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
100        Ok(Some(metadata.0.encode_to_vec()))
101    }
102
103    fn deserialize(buffer: &[u8]) -> VortexResult<Self::Metadata> {
104        Ok(ProstMetadata(SparseMetadata::decode(buffer)?))
105    }
106
107    fn build(
108        &self,
109        dtype: &DType,
110        len: usize,
111        metadata: &Self::Metadata,
112        buffers: &[BufferHandle],
113        children: &dyn ArrayChildren,
114    ) -> VortexResult<SparseArray> {
115        if children.len() != 2 {
116            vortex_bail!(
117                "Expected 2 children for sparse encoding, found {}",
118                children.len()
119            )
120        }
121        assert_eq!(
122            metadata.0.patches.offset(),
123            0,
124            "Patches must start at offset 0"
125        );
126
127        let patch_indices = children.get(
128            0,
129            &metadata.0.patches.indices_dtype(),
130            metadata.0.patches.len(),
131        )?;
132        let patch_values = children.get(1, dtype, metadata.0.patches.len())?;
133
134        if buffers.len() != 1 {
135            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
136        }
137        let fill_value = Scalar::new(
138            dtype.clone(),
139            ScalarValue::from_protobytes(&buffers[0].clone().try_to_bytes()?)?,
140        );
141
142        SparseArray::try_new(patch_indices, patch_values, len, fill_value)
143    }
144}
145
146#[derive(Clone, Debug)]
147pub struct SparseArray {
148    patches: Patches,
149    fill_value: Scalar,
150    stats_set: ArrayStats,
151}
152
153#[derive(Debug)]
154pub struct SparseVTable;
155
156impl SparseArray {
157    pub fn try_new(
158        indices: ArrayRef,
159        values: ArrayRef,
160        len: usize,
161        fill_value: Scalar,
162    ) -> VortexResult<Self> {
163        vortex_ensure!(
164            indices.len() == values.len(),
165            "Mismatched indices {} and values {} length",
166            indices.len(),
167            values.len()
168        );
169
170        vortex_ensure!(
171            indices.statistics().compute_is_strict_sorted() == Some(true),
172            "SparseArray: indices must be strict-sorted"
173        );
174
175        // Verify the indices are all in the valid range
176        if !indices.is_empty() {
177            let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1))?;
178
179            vortex_ensure!(
180                last_index < len,
181                "Array length was {len} but the last index is {last_index}"
182            );
183        }
184
185        Ok(Self {
186            // TODO(0ax1): handle chunk offsets
187            patches: Patches::new(len, 0, indices, values, None),
188            fill_value,
189            stats_set: Default::default(),
190        })
191    }
192
193    /// Build a new SparseArray from an existing set of patches.
194    pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
195        vortex_ensure!(
196            fill_value.dtype() == patches.values().dtype(),
197            "fill value, {:?}, should be instance of values dtype, {} but was {}.",
198            fill_value,
199            patches.values().dtype(),
200            fill_value.dtype(),
201        );
202
203        Ok(Self {
204            patches,
205            fill_value,
206            stats_set: Default::default(),
207        })
208    }
209
210    pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
211        Self {
212            patches,
213            fill_value,
214            stats_set: Default::default(),
215        }
216    }
217
218    #[inline]
219    pub fn patches(&self) -> &Patches {
220        &self.patches
221    }
222
223    #[inline]
224    pub fn resolved_patches(&self) -> Patches {
225        let patches = self.patches();
226        let indices_offset = Scalar::from(patches.offset())
227            .cast(patches.indices().dtype())
228            .vortex_expect("Patches offset must cast to the indices dtype");
229        let indices = sub_scalar(patches.indices(), indices_offset)
230            .vortex_expect("must be able to subtract offset from indices");
231
232        Patches::new(
233            patches.array_len(),
234            0,
235            indices,
236            patches.values().clone(),
237            // TODO(0ax1): handle chunk offsets
238            None,
239        )
240    }
241
242    #[inline]
243    pub fn fill_scalar(&self) -> &Scalar {
244        &self.fill_value
245    }
246
247    /// Encode given array as a SparseArray.
248    ///
249    /// Optionally provided fill value will be respected if the array is less than 90% null.
250    pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
251        if let Some(fill_value) = fill_value.as_ref()
252            && array.dtype() != fill_value.dtype()
253        {
254            vortex_bail!(
255                "Array and fill value types must match. got {} and {}",
256                array.dtype(),
257                fill_value.dtype()
258            )
259        }
260        let mask = array.validity_mask();
261
262        if mask.all_false() {
263            // Array is constant NULL
264            return Ok(
265                ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
266            );
267        } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
268            // Array is dominated by NULL but has non-NULL values
269            let non_null_values = filter(array, &mask)?;
270            let non_null_indices = match mask.indices() {
271                AllOr::All => {
272                    // We already know that the mask is 90%+ false
273                    unreachable!("Mask is mostly null")
274                }
275                AllOr::None => {
276                    // we know there are some non-NULL values
277                    unreachable!("Mask is mostly null but not all null")
278                }
279                AllOr::Some(values) => {
280                    let buffer: Buffer<u32> = values
281                        .iter()
282                        .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
283                        .collect();
284
285                    buffer.into_array()
286                }
287            };
288
289            return Ok(SparseArray::try_new(
290                non_null_indices,
291                non_null_values,
292                array.len(),
293                Scalar::null(array.dtype().clone()),
294            )?
295            .into_array());
296        }
297
298        let fill = if let Some(fill) = fill_value {
299            fill
300        } else {
301            // TODO(robert): Support other dtypes, only thing missing is getting most common value out of the array
302            let (top_pvalue, _) = array
303                .to_primitive()
304                .top_value()?
305                .vortex_expect("Non empty or all null array");
306
307            Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
308        };
309
310        let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
311        let non_top_mask = Mask::from_buffer(
312            fill_null(
313                &compare(array, &fill_array, Operator::NotEq)?,
314                &Scalar::bool(true, Nullability::NonNullable),
315            )?
316            .to_bool()
317            .bit_buffer()
318            .clone(),
319        );
320
321        let non_top_values = filter(array, &non_top_mask)?;
322
323        let indices: Buffer<u64> = match non_top_mask {
324            Mask::AllTrue(count) => {
325                // all true -> complete slice
326                (0u64..count as u64).collect()
327            }
328            Mask::AllFalse(_) => {
329                // All values are equal to the top value
330                return Ok(fill_array);
331            }
332            Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
333        };
334
335        SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
336            .map(|a| a.into_array())
337    }
338}
339
340impl BaseArrayVTable<SparseVTable> for SparseVTable {
341    fn len(array: &SparseArray) -> usize {
342        array.patches.array_len()
343    }
344
345    fn dtype(array: &SparseArray) -> &DType {
346        array.fill_scalar().dtype()
347    }
348
349    fn stats(array: &SparseArray) -> StatsSetRef<'_> {
350        array.stats_set.to_ref(array.as_ref())
351    }
352
353    fn array_hash<H: std::hash::Hasher>(array: &SparseArray, state: &mut H, precision: Precision) {
354        array.patches.array_hash(state, precision);
355        array.fill_value.hash(state);
356    }
357
358    fn array_eq(array: &SparseArray, other: &SparseArray, precision: Precision) -> bool {
359        array.patches.array_eq(&other.patches, precision) && array.fill_value == other.fill_value
360    }
361}
362
363impl ValidityVTable<SparseVTable> for SparseVTable {
364    fn is_valid(array: &SparseArray, index: usize) -> bool {
365        match array.patches().get_patched(index) {
366            None => array.fill_scalar().is_valid(),
367            Some(patch_value) => patch_value.is_valid(),
368        }
369    }
370
371    fn all_valid(array: &SparseArray) -> bool {
372        if array.fill_scalar().is_null() {
373            // We need _all_ values to be patched, and all patches to be valid
374            return array.patches().values().len() == array.len()
375                && array.patches().values().all_valid();
376        }
377
378        array.patches().values().all_valid()
379    }
380
381    fn all_invalid(array: &SparseArray) -> bool {
382        if !array.fill_scalar().is_null() {
383            // We need _all_ values to be patched, and all patches to be invalid
384            return array.patches().values().len() == array.len()
385                && array.patches().values().all_invalid();
386        }
387
388        array.patches().values().all_invalid()
389    }
390
391    fn validity_mask(array: &SparseArray) -> Mask {
392        let fill_is_valid = array.fill_scalar().is_valid();
393        let values_validity = array.patches().values().validity_mask();
394        let len = array.len();
395
396        if matches!(values_validity, Mask::AllTrue(_)) && fill_is_valid {
397            return Mask::AllTrue(len);
398        }
399        if matches!(values_validity, Mask::AllFalse(_)) && !fill_is_valid {
400            return Mask::AllFalse(len);
401        }
402
403        let mut is_valid_buffer = if fill_is_valid {
404            BitBufferMut::new_set(len)
405        } else {
406            BitBufferMut::new_unset(len)
407        };
408
409        let indices = array.patches().indices().to_primitive();
410        let index_offset = array.patches().offset();
411
412        match_each_integer_ptype!(indices.ptype(), |I| {
413            let indices = indices.as_slice::<I>();
414            patch_validity(&mut is_valid_buffer, indices, index_offset, values_validity);
415        });
416
417        Mask::from_buffer(is_valid_buffer.freeze())
418    }
419}
420
421fn patch_validity<I: NativePType + AsPrimitive<usize>>(
422    is_valid_buffer: &mut BitBufferMut,
423    indices: &[I],
424    index_offset: usize,
425    values_validity: Mask,
426) {
427    let indices = indices.iter().map(|index| index.as_() - index_offset);
428    match values_validity {
429        Mask::AllTrue(_) => {
430            for index in indices {
431                is_valid_buffer.set(index);
432            }
433        }
434        Mask::AllFalse(_) => {
435            for index in indices {
436                is_valid_buffer.unset(index);
437            }
438        }
439        Mask::Values(mask_values) => {
440            let is_valid = mask_values.bit_buffer().iter();
441            for (index, is_valid) in indices.zip_eq(is_valid) {
442                is_valid_buffer.set_to(index, is_valid);
443            }
444        }
445    }
446}
447
448impl EncodeVTable<SparseVTable> for SparseVTable {
449    fn encode(
450        _vtable: &SparseVTable,
451        input: &Canonical,
452        like: Option<&SparseArray>,
453    ) -> VortexResult<Option<SparseArray>> {
454        // Try and cast the "like" fill value into the array's type. This is useful for cases where we narrow the arrays type.
455        let fill_value = like.and_then(|arr| arr.fill_scalar().cast(input.as_ref().dtype()).ok());
456
457        // TODO(ngates): encode should only handle arrays that _can_ be made sparse.
458        Ok(SparseArray::encode(input.as_ref(), fill_value)?
459            .as_opt::<SparseVTable>()
460            .cloned())
461    }
462}
463
464impl VisitorVTable<SparseVTable> for SparseVTable {
465    fn visit_buffers(array: &SparseArray, visitor: &mut dyn ArrayBufferVisitor) {
466        let fill_value_buffer = array
467            .fill_value
468            .value()
469            .to_protobytes::<ByteBufferMut>()
470            .freeze();
471        visitor.visit_buffer(&fill_value_buffer);
472    }
473
474    fn visit_children(array: &SparseArray, visitor: &mut dyn ArrayChildVisitor) {
475        visitor.visit_patches(array.patches())
476    }
477}
478
479#[cfg(test)]
480mod test {
481    use itertools::Itertools;
482    use vortex_array::IntoArray;
483    use vortex_array::arrays::ConstantArray;
484    use vortex_array::arrays::PrimitiveArray;
485    use vortex_array::compute::cast;
486    use vortex_array::validity::Validity;
487    use vortex_buffer::buffer;
488    use vortex_dtype::DType;
489    use vortex_dtype::Nullability;
490    use vortex_dtype::PType;
491    use vortex_error::VortexUnwrap;
492    use vortex_scalar::PrimitiveScalar;
493    use vortex_scalar::Scalar;
494
495    use super::*;
496
497    fn nullable_fill() -> Scalar {
498        Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
499    }
500
501    fn non_nullable_fill() -> Scalar {
502        Scalar::from(42i32)
503    }
504
505    fn sparse_array(fill_value: Scalar) -> ArrayRef {
506        // merged array: [null, null, 100, null, null, 200, null, null, 300, null]
507        let mut values = buffer![100i32, 200, 300].into_array();
508        values = cast(&values, fill_value.dtype()).unwrap();
509
510        SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
511            .unwrap()
512            .into_array()
513    }
514
515    #[test]
516    pub fn test_scalar_at() {
517        let array = sparse_array(nullable_fill());
518
519        assert_eq!(array.scalar_at(0), nullable_fill());
520        assert_eq!(array.scalar_at(2), Scalar::from(Some(100_i32)));
521        assert_eq!(array.scalar_at(5), Scalar::from(Some(200_i32)));
522    }
523
524    #[test]
525    #[should_panic(expected = "out of bounds")]
526    fn test_scalar_at_oob() {
527        let array = sparse_array(nullable_fill());
528        array.scalar_at(10);
529    }
530
531    #[test]
532    pub fn test_scalar_at_again() {
533        let arr = SparseArray::try_new(
534            ConstantArray::new(10u32, 1).into_array(),
535            ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
536            100,
537            Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
538        )
539        .unwrap();
540
541        assert_eq!(
542            PrimitiveScalar::try_from(&arr.scalar_at(10))
543                .unwrap()
544                .typed_value::<u32>(),
545            Some(1234)
546        );
547        assert!(arr.scalar_at(0).is_null());
548        assert!(arr.scalar_at(99).is_null());
549    }
550
551    #[test]
552    pub fn scalar_at_sliced() {
553        let sliced = sparse_array(nullable_fill()).slice(2..7);
554        assert_eq!(usize::try_from(&sliced.scalar_at(0)).unwrap(), 100);
555    }
556
557    #[test]
558    pub fn validity_mask_sliced_null_fill() {
559        let sliced = sparse_array(nullable_fill()).slice(2..7);
560        assert_eq!(
561            sliced.validity_mask(),
562            Mask::from_iter(vec![true, false, false, true, false])
563        );
564    }
565
566    #[test]
567    pub fn validity_mask_sliced_nonnull_fill() {
568        let sliced = SparseArray::try_new(
569            buffer![2u64, 5, 8].into_array(),
570            ConstantArray::new(
571                Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
572                3,
573            )
574            .into_array(),
575            10,
576            Scalar::primitive(1.0f32, Nullability::Nullable),
577        )
578        .unwrap()
579        .slice(2..7);
580
581        assert_eq!(
582            sliced.validity_mask(),
583            Mask::from_iter(vec![false, true, true, false, true])
584        );
585    }
586
587    #[test]
588    pub fn scalar_at_sliced_twice() {
589        let sliced_once = sparse_array(nullable_fill()).slice(1..8);
590        assert_eq!(usize::try_from(&sliced_once.scalar_at(1)).unwrap(), 100);
591
592        let sliced_twice = sliced_once.slice(1..6);
593        assert_eq!(usize::try_from(&sliced_twice.scalar_at(3)).unwrap(), 200);
594    }
595
596    #[test]
597    pub fn sparse_validity_mask() {
598        let array = sparse_array(nullable_fill());
599        assert_eq!(
600            array.validity_mask().to_bit_buffer().iter().collect_vec(),
601            [
602                false, false, true, false, false, true, false, false, true, false
603            ]
604        );
605    }
606
607    #[test]
608    fn sparse_validity_mask_non_null_fill() {
609        let array = sparse_array(non_nullable_fill());
610        assert!(array.validity_mask().all_true());
611    }
612
613    #[test]
614    #[should_panic]
615    fn test_invalid_length() {
616        let values = buffer![15_u32, 135, 13531, 42].into_array();
617        let indices = buffer![10_u64, 11, 50, 100].into_array();
618
619        SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
620    }
621
622    #[test]
623    fn test_valid_length() {
624        let values = buffer![15_u32, 135, 13531, 42].into_array();
625        let indices = buffer![10_u64, 11, 50, 100].into_array();
626
627        SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
628    }
629
630    #[test]
631    fn encode_with_nulls() {
632        let sparse = SparseArray::encode(
633            &PrimitiveArray::new(
634                buffer![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
635                Validity::from_iter(vec![
636                    true, true, false, true, false, true, false, true, true, false, true, false,
637                ]),
638            )
639            .into_array(),
640            None,
641        )
642        .vortex_unwrap();
643        let canonical = sparse.to_primitive();
644        assert_eq!(
645            sparse.validity_mask(),
646            Mask::from_iter(vec![
647                true, true, false, true, false, true, false, true, true, false, true, false,
648            ])
649        );
650        assert_eq!(
651            canonical.as_slice::<i32>(),
652            vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
653        );
654    }
655
656    #[test]
657    fn validity_mask_includes_null_values_when_fill_is_null() {
658        let indices = buffer![0u8, 2, 4, 6, 8].into_array();
659        let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
660            .into_array();
661        let array = SparseArray::try_new(indices, values, 10, Scalar::null_typed::<i16>()).unwrap();
662        let actual = array.validity_mask();
663        let expected = Mask::from_iter([
664            true, false, true, false, false, false, false, false, true, false,
665        ]);
666
667        assert_eq!(actual, expected);
668    }
669}