vortex_sparse/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6
7use itertools::Itertools as _;
8use num_traits::AsPrimitive;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::compute::{Operator, compare, fill_null, filter, sub_scalar};
11use vortex_array::patches::Patches;
12use vortex_array::stats::{ArrayStats, StatsSetRef};
13use vortex_array::vtable::{ArrayVTable, NotSupported, VTable, ValidityVTable};
14use vortex_array::{
15    Array, ArrayEq, ArrayHash, ArrayRef, EncodingId, EncodingRef, IntoArray, Precision,
16    ToCanonical, vtable,
17};
18use vortex_buffer::{BitBufferMut, Buffer};
19use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
20use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_ensure};
21use vortex_mask::{AllOr, Mask};
22use vortex_scalar::Scalar;
23
24mod canonical;
25mod compute;
26mod ops;
27mod serde;
28
29vtable!(Sparse);
30
31impl VTable for SparseVTable {
32    type Array = SparseArray;
33    type Encoding = SparseEncoding;
34
35    type ArrayVTable = Self;
36    type CanonicalVTable = Self;
37    type OperationsVTable = Self;
38    type ValidityVTable = Self;
39    type VisitorVTable = Self;
40    type ComputeVTable = NotSupported;
41    type EncodeVTable = Self;
42    type SerdeVTable = Self;
43    type OperatorVTable = NotSupported;
44
45    fn id(_encoding: &Self::Encoding) -> EncodingId {
46        EncodingId::new_ref("vortex.sparse")
47    }
48
49    fn encoding(_array: &Self::Array) -> EncodingRef {
50        EncodingRef::new_ref(SparseEncoding.as_ref())
51    }
52}
53
54#[derive(Clone, Debug)]
55pub struct SparseArray {
56    patches: Patches,
57    fill_value: Scalar,
58    stats_set: ArrayStats,
59}
60
61#[derive(Clone, Debug)]
62pub struct SparseEncoding;
63
64impl SparseArray {
65    pub fn try_new(
66        indices: ArrayRef,
67        values: ArrayRef,
68        len: usize,
69        fill_value: Scalar,
70    ) -> VortexResult<Self> {
71        vortex_ensure!(
72            indices.len() == values.len(),
73            "Mismatched indices {} and values {} length",
74            indices.len(),
75            values.len()
76        );
77
78        vortex_ensure!(
79            indices.statistics().compute_is_strict_sorted() == Some(true),
80            "SparseArray: indices must be strict-sorted"
81        );
82
83        // Verify the indices are all in the valid range
84        if !indices.is_empty() {
85            let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1))?;
86
87            vortex_ensure!(
88                last_index < len,
89                "Array length was {len} but the last index is {last_index}"
90            );
91        }
92
93        Ok(Self {
94            // TODO(0ax1): handle chunk offsets
95            patches: Patches::new(len, 0, indices, values, None),
96            fill_value,
97            stats_set: Default::default(),
98        })
99    }
100
101    /// Build a new SparseArray from an existing set of patches.
102    pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
103        vortex_ensure!(
104            fill_value.dtype() == patches.values().dtype(),
105            "fill value, {:?}, should be instance of values dtype, {} but was {}.",
106            fill_value,
107            patches.values().dtype(),
108            fill_value.dtype(),
109        );
110
111        Ok(Self {
112            patches,
113            fill_value,
114            stats_set: Default::default(),
115        })
116    }
117
118    pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
119        Self {
120            patches,
121            fill_value,
122            stats_set: Default::default(),
123        }
124    }
125
126    #[inline]
127    pub fn patches(&self) -> &Patches {
128        &self.patches
129    }
130
131    #[inline]
132    pub fn resolved_patches(&self) -> Patches {
133        let patches = self.patches();
134        let indices_offset = Scalar::from(patches.offset())
135            .cast(patches.indices().dtype())
136            .vortex_expect("Patches offset must cast to the indices dtype");
137        let indices = sub_scalar(patches.indices(), indices_offset)
138            .vortex_expect("must be able to subtract offset from indices");
139
140        Patches::new(
141            patches.array_len(),
142            0,
143            indices,
144            patches.values().clone(),
145            // TODO(0ax1): handle chunk offsets
146            None,
147        )
148    }
149
150    #[inline]
151    pub fn fill_scalar(&self) -> &Scalar {
152        &self.fill_value
153    }
154
155    /// Encode given array as a SparseArray.
156    ///
157    /// Optionally provided fill value will be respected if the array is less than 90% null.
158    pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
159        if let Some(fill_value) = fill_value.as_ref()
160            && array.dtype() != fill_value.dtype()
161        {
162            vortex_bail!(
163                "Array and fill value types must match. got {} and {}",
164                array.dtype(),
165                fill_value.dtype()
166            )
167        }
168        let mask = array.validity_mask();
169
170        if mask.all_false() {
171            // Array is constant NULL
172            return Ok(
173                ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
174            );
175        } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
176            // Array is dominated by NULL but has non-NULL values
177            let non_null_values = filter(array, &mask)?;
178            let non_null_indices = match mask.indices() {
179                AllOr::All => {
180                    // We already know that the mask is 90%+ false
181                    unreachable!("Mask is mostly null")
182                }
183                AllOr::None => {
184                    // we know there are some non-NULL values
185                    unreachable!("Mask is mostly null but not all null")
186                }
187                AllOr::Some(values) => {
188                    let buffer: Buffer<u32> = values
189                        .iter()
190                        .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
191                        .collect();
192
193                    buffer.into_array()
194                }
195            };
196
197            return Ok(SparseArray::try_new(
198                non_null_indices,
199                non_null_values,
200                array.len(),
201                Scalar::null(array.dtype().clone()),
202            )?
203            .into_array());
204        }
205
206        let fill = if let Some(fill) = fill_value {
207            fill
208        } else {
209            // TODO(robert): Support other dtypes, only thing missing is getting most common value out of the array
210            let (top_pvalue, _) = array
211                .to_primitive()
212                .top_value()?
213                .vortex_expect("Non empty or all null array");
214
215            Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
216        };
217
218        let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
219        let non_top_mask = Mask::from_buffer(
220            fill_null(
221                &compare(array, &fill_array, Operator::NotEq)?,
222                &Scalar::bool(true, Nullability::NonNullable),
223            )?
224            .to_bool()
225            .bit_buffer()
226            .clone(),
227        );
228
229        let non_top_values = filter(array, &non_top_mask)?;
230
231        let indices: Buffer<u64> = match non_top_mask {
232            Mask::AllTrue(count) => {
233                // all true -> complete slice
234                (0u64..count as u64).collect()
235            }
236            Mask::AllFalse(_) => {
237                // All values are equal to the top value
238                return Ok(fill_array);
239            }
240            Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
241        };
242
243        SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
244            .map(|a| a.into_array())
245    }
246}
247
248impl ArrayVTable<SparseVTable> for SparseVTable {
249    fn len(array: &SparseArray) -> usize {
250        array.patches.array_len()
251    }
252
253    fn dtype(array: &SparseArray) -> &DType {
254        array.fill_scalar().dtype()
255    }
256
257    fn stats(array: &SparseArray) -> StatsSetRef<'_> {
258        array.stats_set.to_ref(array.as_ref())
259    }
260
261    fn array_hash<H: std::hash::Hasher>(array: &SparseArray, state: &mut H, precision: Precision) {
262        array.patches.array_hash(state, precision);
263        array.fill_value.hash(state);
264    }
265
266    fn array_eq(array: &SparseArray, other: &SparseArray, precision: Precision) -> bool {
267        array.patches.array_eq(&other.patches, precision) && array.fill_value == other.fill_value
268    }
269}
270
271impl ValidityVTable<SparseVTable> for SparseVTable {
272    fn is_valid(array: &SparseArray, index: usize) -> bool {
273        match array.patches().get_patched(index) {
274            None => array.fill_scalar().is_valid(),
275            Some(patch_value) => patch_value.is_valid(),
276        }
277    }
278
279    fn all_valid(array: &SparseArray) -> bool {
280        if array.fill_scalar().is_null() {
281            // We need _all_ values to be patched, and all patches to be valid
282            return array.patches().values().len() == array.len()
283                && array.patches().values().all_valid();
284        }
285
286        array.patches().values().all_valid()
287    }
288
289    fn all_invalid(array: &SparseArray) -> bool {
290        if !array.fill_scalar().is_null() {
291            // We need _all_ values to be patched, and all patches to be invalid
292            return array.patches().values().len() == array.len()
293                && array.patches().values().all_invalid();
294        }
295
296        array.patches().values().all_invalid()
297    }
298
299    #[allow(clippy::unnecessary_fallible_conversions)]
300    fn validity_mask(array: &SparseArray) -> Mask {
301        let fill_is_valid = array.fill_scalar().is_valid();
302        let values_validity = array.patches().values().validity_mask();
303        let len = array.len();
304
305        if matches!(values_validity, Mask::AllTrue(_)) && fill_is_valid {
306            return Mask::AllTrue(len);
307        }
308        if matches!(values_validity, Mask::AllFalse(_)) && !fill_is_valid {
309            return Mask::AllFalse(len);
310        }
311
312        let mut is_valid_buffer = if fill_is_valid {
313            BitBufferMut::new_set(len)
314        } else {
315            BitBufferMut::new_unset(len)
316        };
317
318        let indices = array.patches().indices().to_primitive();
319        let index_offset = array.patches().offset();
320
321        match_each_integer_ptype!(indices.ptype(), |I| {
322            let indices = indices.as_slice::<I>();
323            patch_validity(&mut is_valid_buffer, indices, index_offset, values_validity);
324        });
325
326        Mask::from_buffer(is_valid_buffer.freeze())
327    }
328}
329
330fn patch_validity<I: NativePType + AsPrimitive<usize>>(
331    is_valid_buffer: &mut BitBufferMut,
332    indices: &[I],
333    index_offset: usize,
334    values_validity: Mask,
335) {
336    let indices = indices.iter().map(|index| index.as_() - index_offset);
337    match values_validity {
338        Mask::AllTrue(_) => {
339            for index in indices {
340                is_valid_buffer.set(index);
341            }
342        }
343        Mask::AllFalse(_) => {
344            for index in indices {
345                is_valid_buffer.unset(index);
346            }
347        }
348        Mask::Values(mask_values) => {
349            let is_valid = mask_values.bit_buffer().iter();
350            for (index, is_valid) in indices.zip_eq(is_valid) {
351                is_valid_buffer.set_to(index, is_valid);
352            }
353        }
354    }
355}
356
357#[cfg(test)]
358mod test {
359    use itertools::Itertools;
360    use vortex_array::IntoArray;
361    use vortex_array::arrays::{ConstantArray, PrimitiveArray};
362    use vortex_array::compute::cast;
363    use vortex_array::validity::Validity;
364    use vortex_buffer::buffer;
365    use vortex_dtype::{DType, Nullability, PType};
366    use vortex_error::VortexUnwrap;
367    use vortex_scalar::{PrimitiveScalar, Scalar};
368
369    use super::*;
370
371    fn nullable_fill() -> Scalar {
372        Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
373    }
374
375    fn non_nullable_fill() -> Scalar {
376        Scalar::from(42i32)
377    }
378
379    fn sparse_array(fill_value: Scalar) -> ArrayRef {
380        // merged array: [null, null, 100, null, null, 200, null, null, 300, null]
381        let mut values = buffer![100i32, 200, 300].into_array();
382        values = cast(&values, fill_value.dtype()).unwrap();
383
384        SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
385            .unwrap()
386            .into_array()
387    }
388
389    #[test]
390    pub fn test_scalar_at() {
391        let array = sparse_array(nullable_fill());
392
393        assert_eq!(array.scalar_at(0), nullable_fill());
394        assert_eq!(array.scalar_at(2), Scalar::from(Some(100_i32)));
395        assert_eq!(array.scalar_at(5), Scalar::from(Some(200_i32)));
396    }
397
398    #[test]
399    #[should_panic(expected = "out of bounds")]
400    fn test_scalar_at_oob() {
401        let array = sparse_array(nullable_fill());
402        array.scalar_at(10);
403    }
404
405    #[test]
406    pub fn test_scalar_at_again() {
407        let arr = SparseArray::try_new(
408            ConstantArray::new(10u32, 1).into_array(),
409            ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
410            100,
411            Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
412        )
413        .unwrap();
414
415        assert_eq!(
416            PrimitiveScalar::try_from(&arr.scalar_at(10))
417                .unwrap()
418                .typed_value::<u32>(),
419            Some(1234)
420        );
421        assert!(arr.scalar_at(0).is_null());
422        assert!(arr.scalar_at(99).is_null());
423    }
424
425    #[test]
426    pub fn scalar_at_sliced() {
427        let sliced = sparse_array(nullable_fill()).slice(2..7);
428        assert_eq!(usize::try_from(&sliced.scalar_at(0)).unwrap(), 100);
429    }
430
431    #[test]
432    pub fn validity_mask_sliced_null_fill() {
433        let sliced = sparse_array(nullable_fill()).slice(2..7);
434        assert_eq!(
435            sliced.validity_mask(),
436            Mask::from_iter(vec![true, false, false, true, false])
437        );
438    }
439
440    #[test]
441    pub fn validity_mask_sliced_nonnull_fill() {
442        let sliced = SparseArray::try_new(
443            buffer![2u64, 5, 8].into_array(),
444            ConstantArray::new(
445                Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
446                3,
447            )
448            .into_array(),
449            10,
450            Scalar::primitive(1.0f32, Nullability::Nullable),
451        )
452        .unwrap()
453        .slice(2..7);
454
455        assert_eq!(
456            sliced.validity_mask(),
457            Mask::from_iter(vec![false, true, true, false, true])
458        );
459    }
460
461    #[test]
462    pub fn scalar_at_sliced_twice() {
463        let sliced_once = sparse_array(nullable_fill()).slice(1..8);
464        assert_eq!(usize::try_from(&sliced_once.scalar_at(1)).unwrap(), 100);
465
466        let sliced_twice = sliced_once.slice(1..6);
467        assert_eq!(usize::try_from(&sliced_twice.scalar_at(3)).unwrap(), 200);
468    }
469
470    #[test]
471    pub fn sparse_validity_mask() {
472        let array = sparse_array(nullable_fill());
473        assert_eq!(
474            array.validity_mask().to_bit_buffer().iter().collect_vec(),
475            [
476                false, false, true, false, false, true, false, false, true, false
477            ]
478        );
479    }
480
481    #[test]
482    fn sparse_validity_mask_non_null_fill() {
483        let array = sparse_array(non_nullable_fill());
484        assert!(array.validity_mask().all_true());
485    }
486
487    #[test]
488    #[should_panic]
489    fn test_invalid_length() {
490        let values = buffer![15_u32, 135, 13531, 42].into_array();
491        let indices = buffer![10_u64, 11, 50, 100].into_array();
492
493        SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
494    }
495
496    #[test]
497    fn test_valid_length() {
498        let values = buffer![15_u32, 135, 13531, 42].into_array();
499        let indices = buffer![10_u64, 11, 50, 100].into_array();
500
501        SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
502    }
503
504    #[test]
505    fn encode_with_nulls() {
506        let sparse = SparseArray::encode(
507            &PrimitiveArray::new(
508                buffer![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
509                Validity::from_iter(vec![
510                    true, true, false, true, false, true, false, true, true, false, true, false,
511                ]),
512            )
513            .into_array(),
514            None,
515        )
516        .vortex_unwrap();
517        let canonical = sparse.to_primitive();
518        assert_eq!(
519            sparse.validity_mask(),
520            Mask::from_iter(vec![
521                true, true, false, true, false, true, false, true, true, false, true, false,
522            ])
523        );
524        assert_eq!(
525            canonical.as_slice::<i32>(),
526            vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
527        );
528    }
529
530    #[test]
531    fn validity_mask_includes_null_values_when_fill_is_null() {
532        let indices = buffer![0u8, 2, 4, 6, 8].into_array();
533        let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
534            .into_array();
535        let array = SparseArray::try_new(indices, values, 10, Scalar::null_typed::<i16>()).unwrap();
536        let actual = array.validity_mask();
537        let expected = Mask::from_iter([
538            true, false, true, false, false, false, false, false, true, false,
539        ]);
540
541        assert_eq!(actual, expected);
542    }
543}