vortex_sparse/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use itertools::Itertools as _;
7use num_traits::NumCast;
8use vortex_array::arrays::{BooleanBufferBuilder, ConstantArray};
9use vortex_array::compute::{Operator, compare, fill_null, filter, sub_scalar};
10use vortex_array::patches::Patches;
11use vortex_array::stats::{ArrayStats, StatsSetRef};
12use vortex_array::vtable::{ArrayVTable, NotSupported, VTable, ValidityVTable};
13use vortex_array::{Array, ArrayRef, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
14use vortex_buffer::Buffer;
15use vortex_dtype::{DType, IntegerPType, Nullability, match_each_integer_ptype};
16use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_ensure};
17use vortex_mask::{AllOr, Mask};
18use vortex_scalar::Scalar;
19
20mod canonical;
21mod compute;
22mod ops;
23mod serde;
24
25vtable!(Sparse);
26
27impl VTable for SparseVTable {
28    type Array = SparseArray;
29    type Encoding = SparseEncoding;
30
31    type ArrayVTable = Self;
32    type CanonicalVTable = Self;
33    type OperationsVTable = Self;
34    type ValidityVTable = Self;
35    type VisitorVTable = Self;
36    type ComputeVTable = NotSupported;
37    type EncodeVTable = Self;
38    type SerdeVTable = Self;
39    type PipelineVTable = NotSupported;
40
41    fn id(_encoding: &Self::Encoding) -> EncodingId {
42        EncodingId::new_ref("vortex.sparse")
43    }
44
45    fn encoding(_array: &Self::Array) -> EncodingRef {
46        EncodingRef::new_ref(SparseEncoding.as_ref())
47    }
48}
49
50#[derive(Clone, Debug)]
51pub struct SparseArray {
52    patches: Patches,
53    fill_value: Scalar,
54    stats_set: ArrayStats,
55}
56
57#[derive(Clone, Debug)]
58pub struct SparseEncoding;
59
60impl SparseArray {
61    pub fn try_new(
62        indices: ArrayRef,
63        values: ArrayRef,
64        len: usize,
65        fill_value: Scalar,
66    ) -> VortexResult<Self> {
67        vortex_ensure!(
68            indices.len() == values.len(),
69            "Mismatched indices {} and values {} length",
70            indices.len(),
71            values.len()
72        );
73
74        vortex_ensure!(
75            indices.statistics().compute_is_strict_sorted() == Some(true),
76            "SparseArray: indices must be strict-sorted"
77        );
78
79        // Verify the indices are all in the valid range
80        if !indices.is_empty() {
81            let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1))?;
82
83            vortex_ensure!(
84                last_index < len,
85                "Array length was {len} but the last index is {last_index}"
86            );
87        }
88
89        Ok(Self {
90            // TODO(0ax1): handle chunk offsets
91            patches: Patches::new(len, 0, indices, values, None),
92            fill_value,
93            stats_set: Default::default(),
94        })
95    }
96
97    /// Build a new SparseArray from an existing set of patches.
98    pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
99        vortex_ensure!(
100            fill_value.dtype() == patches.values().dtype(),
101            "fill value, {:?}, should be instance of values dtype, {} but was {}.",
102            fill_value,
103            patches.values().dtype(),
104            fill_value.dtype(),
105        );
106
107        Ok(Self {
108            patches,
109            fill_value,
110            stats_set: Default::default(),
111        })
112    }
113
114    pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
115        Self {
116            patches,
117            fill_value,
118            stats_set: Default::default(),
119        }
120    }
121
122    #[inline]
123    pub fn patches(&self) -> &Patches {
124        &self.patches
125    }
126
127    #[inline]
128    pub fn resolved_patches(&self) -> Patches {
129        let patches = self.patches();
130        let indices_offset = Scalar::from(patches.offset())
131            .cast(patches.indices().dtype())
132            .vortex_expect("Patches offset must cast to the indices dtype");
133        let indices = sub_scalar(patches.indices(), indices_offset)
134            .vortex_expect("must be able to subtract offset from indices");
135
136        Patches::new(
137            patches.array_len(),
138            0,
139            indices,
140            patches.values().clone(),
141            // TODO(0ax1): handle chunk offsets
142            None,
143        )
144    }
145
146    #[inline]
147    pub fn fill_scalar(&self) -> &Scalar {
148        &self.fill_value
149    }
150
151    /// Encode given array as a SparseArray.
152    ///
153    /// Optionally provided fill value will be respected if the array is less than 90% null.
154    pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
155        if let Some(fill_value) = fill_value.as_ref()
156            && array.dtype() != fill_value.dtype()
157        {
158            vortex_bail!(
159                "Array and fill value types must match. got {} and {}",
160                array.dtype(),
161                fill_value.dtype()
162            )
163        }
164        let mask = array.validity_mask();
165
166        if mask.all_false() {
167            // Array is constant NULL
168            return Ok(
169                ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
170            );
171        } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
172            // Array is dominated by NULL but has non-NULL values
173            let non_null_values = filter(array, &mask)?;
174            let non_null_indices = match mask.indices() {
175                AllOr::All => {
176                    // We already know that the mask is 90%+ false
177                    unreachable!("Mask is mostly null")
178                }
179                AllOr::None => {
180                    // we know there are some non-NULL values
181                    unreachable!("Mask is mostly null but not all null")
182                }
183                AllOr::Some(values) => {
184                    let buffer: Buffer<u32> = values
185                        .iter()
186                        .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
187                        .collect();
188
189                    buffer.into_array()
190                }
191            };
192
193            return Ok(SparseArray::try_new(
194                non_null_indices,
195                non_null_values,
196                array.len(),
197                Scalar::null(array.dtype().clone()),
198            )?
199            .into_array());
200        }
201
202        let fill = if let Some(fill) = fill_value {
203            fill
204        } else {
205            // TODO(robert): Support other dtypes, only thing missing is getting most common value out of the array
206            let (top_pvalue, _) = array
207                .to_primitive()
208                .top_value()?
209                .vortex_expect("Non empty or all null array");
210
211            Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
212        };
213
214        let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
215        let non_top_mask = Mask::from_buffer(
216            fill_null(
217                &compare(array, &fill_array, Operator::NotEq)?,
218                &Scalar::bool(true, Nullability::NonNullable),
219            )?
220            .to_bool()
221            .boolean_buffer()
222            .clone(),
223        );
224
225        let non_top_values = filter(array, &non_top_mask)?;
226
227        let indices: Buffer<u64> = match non_top_mask {
228            Mask::AllTrue(count) => {
229                // all true -> complete slice
230                (0u64..count as u64).collect()
231            }
232            Mask::AllFalse(_) => {
233                // All values are equal to the top value
234                return Ok(fill_array);
235            }
236            Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
237        };
238
239        SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
240            .map(|a| a.into_array())
241    }
242}
243
244impl ArrayVTable<SparseVTable> for SparseVTable {
245    fn len(array: &SparseArray) -> usize {
246        array.patches.array_len()
247    }
248
249    fn dtype(array: &SparseArray) -> &DType {
250        array.fill_scalar().dtype()
251    }
252
253    fn stats(array: &SparseArray) -> StatsSetRef<'_> {
254        array.stats_set.to_ref(array.as_ref())
255    }
256}
257
258impl ValidityVTable<SparseVTable> for SparseVTable {
259    fn is_valid(array: &SparseArray, index: usize) -> bool {
260        match array.patches().get_patched(index) {
261            None => array.fill_scalar().is_valid(),
262            Some(patch_value) => patch_value.is_valid(),
263        }
264    }
265
266    fn all_valid(array: &SparseArray) -> bool {
267        if array.fill_scalar().is_null() {
268            // We need _all_ values to be patched, and all patches to be valid
269            return array.patches().values().len() == array.len()
270                && array.patches().values().all_valid();
271        }
272
273        array.patches().values().all_valid()
274    }
275
276    fn all_invalid(array: &SparseArray) -> bool {
277        if !array.fill_scalar().is_null() {
278            // We need _all_ values to be patched, and all patches to be invalid
279            return array.patches().values().len() == array.len()
280                && array.patches().values().all_invalid();
281        }
282
283        array.patches().values().all_invalid()
284    }
285
286    #[allow(clippy::unnecessary_fallible_conversions)]
287    fn validity_mask(array: &SparseArray) -> Mask {
288        let fill_is_valid = array.fill_scalar().is_valid();
289        let values_validity = array.patches().values().validity_mask();
290        let len = array.len();
291
292        if matches!(values_validity, Mask::AllTrue(_)) && fill_is_valid {
293            return Mask::AllTrue(len);
294        }
295        if matches!(values_validity, Mask::AllFalse(_)) && !fill_is_valid {
296            return Mask::AllFalse(len);
297        }
298
299        // TODO(ngates): use vortex-buffer::BitBufferMut when it exists.
300        let mut is_valid_buffer = BooleanBufferBuilder::new(len);
301        is_valid_buffer.append_n(len, fill_is_valid);
302
303        let indices = array.patches().indices().to_primitive();
304        let index_offset = array.patches().offset();
305
306        match_each_integer_ptype!(indices.ptype(), |I| {
307            let indices = indices.as_slice::<I>();
308            patch_validity(&mut is_valid_buffer, indices, index_offset, values_validity);
309        });
310
311        Mask::from_buffer(is_valid_buffer.finish())
312    }
313}
314
315fn patch_validity<I: IntegerPType>(
316    is_valid_buffer: &mut BooleanBufferBuilder,
317    indices: &[I],
318    index_offset: usize,
319    values_validity: Mask,
320) {
321    let indices = indices.iter().map(|index| {
322        let index = <usize as NumCast>::from(*index).vortex_expect("Failed to cast to usize");
323        index - index_offset
324    });
325    match values_validity {
326        Mask::AllTrue(_) => {
327            for index in indices {
328                is_valid_buffer.set_bit(index, true);
329            }
330        }
331        Mask::AllFalse(_) => {
332            for index in indices {
333                is_valid_buffer.set_bit(index, false);
334            }
335        }
336        Mask::Values(mask_values) => {
337            let is_valid = mask_values.boolean_buffer().iter();
338            for (index, is_valid) in indices.zip_eq(is_valid) {
339                is_valid_buffer.set_bit(index, is_valid);
340            }
341        }
342    }
343}
344
345#[cfg(test)]
346mod test {
347    use itertools::Itertools;
348    use vortex_array::IntoArray;
349    use vortex_array::arrays::{ConstantArray, PrimitiveArray};
350    use vortex_array::compute::cast;
351    use vortex_array::validity::Validity;
352    use vortex_buffer::buffer;
353    use vortex_dtype::{DType, Nullability, PType};
354    use vortex_error::VortexUnwrap;
355    use vortex_scalar::{PrimitiveScalar, Scalar};
356
357    use super::*;
358
359    fn nullable_fill() -> Scalar {
360        Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
361    }
362
363    fn non_nullable_fill() -> Scalar {
364        Scalar::from(42i32)
365    }
366
367    fn sparse_array(fill_value: Scalar) -> ArrayRef {
368        // merged array: [null, null, 100, null, null, 200, null, null, 300, null]
369        let mut values = buffer![100i32, 200, 300].into_array();
370        values = cast(&values, fill_value.dtype()).unwrap();
371
372        SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
373            .unwrap()
374            .into_array()
375    }
376
377    #[test]
378    pub fn test_scalar_at() {
379        let array = sparse_array(nullable_fill());
380
381        assert_eq!(array.scalar_at(0), nullable_fill());
382        assert_eq!(array.scalar_at(2), Scalar::from(Some(100_i32)));
383        assert_eq!(array.scalar_at(5), Scalar::from(Some(200_i32)));
384    }
385
386    #[test]
387    #[should_panic(expected = "out of bounds")]
388    fn test_scalar_at_oob() {
389        let array = sparse_array(nullable_fill());
390        let _ = array.scalar_at(10);
391    }
392
393    #[test]
394    pub fn test_scalar_at_again() {
395        let arr = SparseArray::try_new(
396            ConstantArray::new(10u32, 1).into_array(),
397            ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
398            100,
399            Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
400        )
401        .unwrap();
402
403        assert_eq!(
404            PrimitiveScalar::try_from(&arr.scalar_at(10))
405                .unwrap()
406                .typed_value::<u32>(),
407            Some(1234)
408        );
409        assert!(arr.scalar_at(0).is_null());
410        assert!(arr.scalar_at(99).is_null());
411    }
412
413    #[test]
414    pub fn scalar_at_sliced() {
415        let sliced = sparse_array(nullable_fill()).slice(2..7);
416        assert_eq!(usize::try_from(&sliced.scalar_at(0)).unwrap(), 100);
417    }
418
419    #[test]
420    pub fn validity_mask_sliced_null_fill() {
421        let sliced = sparse_array(nullable_fill()).slice(2..7);
422        assert_eq!(
423            sliced.validity_mask(),
424            Mask::from_iter(vec![true, false, false, true, false])
425        );
426    }
427
428    #[test]
429    pub fn validity_mask_sliced_nonnull_fill() {
430        let sliced = SparseArray::try_new(
431            buffer![2u64, 5, 8].into_array(),
432            ConstantArray::new(
433                Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
434                3,
435            )
436            .into_array(),
437            10,
438            Scalar::primitive(1.0f32, Nullability::Nullable),
439        )
440        .unwrap()
441        .slice(2..7);
442
443        assert_eq!(
444            sliced.validity_mask(),
445            Mask::from_iter(vec![false, true, true, false, true])
446        );
447    }
448
449    #[test]
450    pub fn scalar_at_sliced_twice() {
451        let sliced_once = sparse_array(nullable_fill()).slice(1..8);
452        assert_eq!(usize::try_from(&sliced_once.scalar_at(1)).unwrap(), 100);
453
454        let sliced_twice = sliced_once.slice(1..6);
455        assert_eq!(usize::try_from(&sliced_twice.scalar_at(3)).unwrap(), 200);
456    }
457
458    #[test]
459    pub fn sparse_validity_mask() {
460        let array = sparse_array(nullable_fill());
461        assert_eq!(
462            array
463                .validity_mask()
464                .to_boolean_buffer()
465                .iter()
466                .collect_vec(),
467            [
468                false, false, true, false, false, true, false, false, true, false
469            ]
470        );
471    }
472
473    #[test]
474    fn sparse_validity_mask_non_null_fill() {
475        let array = sparse_array(non_nullable_fill());
476        assert!(array.validity_mask().all_true());
477    }
478
479    #[test]
480    #[should_panic]
481    fn test_invalid_length() {
482        let values = buffer![15_u32, 135, 13531, 42].into_array();
483        let indices = buffer![10_u64, 11, 50, 100].into_array();
484
485        SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
486    }
487
488    #[test]
489    fn test_valid_length() {
490        let values = buffer![15_u32, 135, 13531, 42].into_array();
491        let indices = buffer![10_u64, 11, 50, 100].into_array();
492
493        SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
494    }
495
496    #[test]
497    fn encode_with_nulls() {
498        let sparse = SparseArray::encode(
499            &PrimitiveArray::new(
500                buffer![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
501                Validity::from_iter(vec![
502                    true, true, false, true, false, true, false, true, true, false, true, false,
503                ]),
504            )
505            .into_array(),
506            None,
507        )
508        .vortex_unwrap();
509        let canonical = sparse.to_primitive();
510        assert_eq!(
511            sparse.validity_mask(),
512            Mask::from_iter(vec![
513                true, true, false, true, false, true, false, true, true, false, true, false,
514            ])
515        );
516        assert_eq!(
517            canonical.as_slice::<i32>(),
518            vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
519        );
520    }
521
522    #[test]
523    fn validity_mask_includes_null_values_when_fill_is_null() {
524        let indices = buffer![0u8, 2, 4, 6, 8].into_array();
525        let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
526            .into_array();
527        let array = SparseArray::try_new(indices, values, 10, Scalar::null_typed::<i16>()).unwrap();
528        let actual = array.validity_mask();
529        let expected = Mask::from_iter([
530            true, false, true, false, false, false, false, false, true, false,
531        ]);
532
533        assert_eq!(actual, expected);
534    }
535}