vortex_sparse/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use itertools::Itertools as _;
7use num_traits::NumCast;
8use vortex_array::arrays::{BooleanBufferBuilder, ConstantArray};
9use vortex_array::compute::{Operator, compare, fill_null, filter, sub_scalar};
10use vortex_array::patches::Patches;
11use vortex_array::stats::{ArrayStats, StatsSetRef};
12use vortex_array::vtable::{ArrayVTable, NotSupported, VTable, ValidityVTable};
13use vortex_array::{Array, ArrayRef, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
14use vortex_buffer::Buffer;
15use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
16use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_ensure};
17use vortex_mask::{AllOr, Mask};
18use vortex_scalar::Scalar;
19
20mod canonical;
21mod compute;
22mod ops;
23mod serde;
24
25vtable!(Sparse);
26
27impl VTable for SparseVTable {
28    type Array = SparseArray;
29    type Encoding = SparseEncoding;
30
31    type ArrayVTable = Self;
32    type CanonicalVTable = Self;
33    type OperationsVTable = Self;
34    type ValidityVTable = Self;
35    type VisitorVTable = Self;
36    type ComputeVTable = NotSupported;
37    type EncodeVTable = Self;
38    type SerdeVTable = Self;
39
40    fn id(_encoding: &Self::Encoding) -> EncodingId {
41        EncodingId::new_ref("vortex.sparse")
42    }
43
44    fn encoding(_array: &Self::Array) -> EncodingRef {
45        EncodingRef::new_ref(SparseEncoding.as_ref())
46    }
47}
48
49#[derive(Clone, Debug)]
50pub struct SparseArray {
51    patches: Patches,
52    fill_value: Scalar,
53    stats_set: ArrayStats,
54}
55
56#[derive(Clone, Debug)]
57pub struct SparseEncoding;
58
59impl SparseArray {
60    pub fn try_new(
61        indices: ArrayRef,
62        values: ArrayRef,
63        len: usize,
64        fill_value: Scalar,
65    ) -> VortexResult<Self> {
66        vortex_ensure!(
67            indices.len() == values.len(),
68            "Mismatched indices {} and values {} length",
69            indices.len(),
70            values.len()
71        );
72
73        vortex_ensure!(
74            indices.statistics().compute_is_strict_sorted() == Some(true),
75            "SparseArray: indices must be strict-sorted"
76        );
77
78        // Verify the indices are all in the valid range
79        if !indices.is_empty() {
80            let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1))?;
81
82            vortex_ensure!(
83                last_index < len,
84                "Array length was {len} but the last index is {last_index}"
85            );
86        }
87
88        let patches = Patches::new(len, 0, indices, values);
89
90        Ok(Self {
91            patches,
92            fill_value,
93            stats_set: Default::default(),
94        })
95    }
96
97    /// Build a new SparseArray from an existing set of patches.
98    pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
99        vortex_ensure!(
100            fill_value.dtype() == patches.values().dtype(),
101            "fill value, {:?}, should be instance of values dtype, {} but was {}.",
102            fill_value,
103            patches.values().dtype(),
104            fill_value.dtype(),
105        );
106
107        Ok(Self {
108            patches,
109            fill_value,
110            stats_set: Default::default(),
111        })
112    }
113
114    pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
115        Self {
116            patches,
117            fill_value,
118            stats_set: Default::default(),
119        }
120    }
121
122    #[inline]
123    pub fn patches(&self) -> &Patches {
124        &self.patches
125    }
126
127    #[inline]
128    pub fn resolved_patches(&self) -> VortexResult<Patches> {
129        let patches = self.patches();
130        let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?;
131        let indices = sub_scalar(patches.indices(), indices_offset)?;
132        Ok(Patches::new(
133            patches.array_len(),
134            0,
135            indices,
136            patches.values().clone(),
137        ))
138    }
139
140    #[inline]
141    pub fn fill_scalar(&self) -> &Scalar {
142        &self.fill_value
143    }
144
145    /// Encode given array as a SparseArray.
146    ///
147    /// Optionally provided fill value will be respected if the array is less than 90% null.
148    pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
149        if let Some(fill_value) = fill_value.as_ref()
150            && array.dtype() != fill_value.dtype()
151        {
152            vortex_bail!(
153                "Array and fill value types must match. got {} and {}",
154                array.dtype(),
155                fill_value.dtype()
156            )
157        }
158        let mask = array.validity_mask()?;
159
160        if mask.all_false() {
161            // Array is constant NULL
162            return Ok(
163                ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
164            );
165        } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
166            // Array is dominated by NULL but has non-NULL values
167            let non_null_values = filter(array, &mask)?;
168            let non_null_indices = match mask.indices() {
169                AllOr::All => {
170                    // We already know that the mask is 90%+ false
171                    unreachable!("Mask is mostly null")
172                }
173                AllOr::None => {
174                    // we know there are some non-NULL values
175                    unreachable!("Mask is mostly null but not all null")
176                }
177                AllOr::Some(values) => {
178                    let buffer: Buffer<u32> = values
179                        .iter()
180                        .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
181                        .collect();
182
183                    buffer.into_array()
184                }
185            };
186
187            return Ok(SparseArray::try_new(
188                non_null_indices,
189                non_null_values,
190                array.len(),
191                Scalar::null(array.dtype().clone()),
192            )?
193            .into_array());
194        }
195
196        let fill = if let Some(fill) = fill_value {
197            fill
198        } else {
199            // TODO(robert): Support other dtypes, only thing missing is getting most common value out of the array
200            let (top_pvalue, _) = array
201                .to_primitive()?
202                .top_value()?
203                .vortex_expect("Non empty or all null array");
204
205            Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
206        };
207
208        let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
209        let non_top_mask = Mask::from_buffer(
210            fill_null(
211                &compare(array, &fill_array, Operator::NotEq)?,
212                &Scalar::bool(true, Nullability::NonNullable),
213            )?
214            .to_bool()?
215            .boolean_buffer()
216            .clone(),
217        );
218
219        let non_top_values = filter(array, &non_top_mask)?;
220
221        let indices: Buffer<u64> = match non_top_mask {
222            Mask::AllTrue(count) => {
223                // all true -> complete slice
224                (0u64..count as u64).collect()
225            }
226            Mask::AllFalse(_) => {
227                // All values are equal to the top value
228                return Ok(fill_array);
229            }
230            Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
231        };
232
233        SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
234            .map(|a| a.into_array())
235    }
236}
237
238impl ArrayVTable<SparseVTable> for SparseVTable {
239    fn len(array: &SparseArray) -> usize {
240        array.patches.array_len()
241    }
242
243    fn dtype(array: &SparseArray) -> &DType {
244        array.fill_scalar().dtype()
245    }
246
247    fn stats(array: &SparseArray) -> StatsSetRef<'_> {
248        array.stats_set.to_ref(array.as_ref())
249    }
250}
251
252impl ValidityVTable<SparseVTable> for SparseVTable {
253    fn is_valid(array: &SparseArray, index: usize) -> VortexResult<bool> {
254        Ok(match array.patches().get_patched(index) {
255            None => array.fill_scalar().is_valid(),
256            Some(patch_value) => patch_value.is_valid(),
257        })
258    }
259
260    fn all_valid(array: &SparseArray) -> VortexResult<bool> {
261        if array.fill_scalar().is_null() {
262            // We need _all_ values to be patched, and all patches to be valid
263            return Ok(array.patches().values().len() == array.len()
264                && array.patches().values().all_valid()?);
265        }
266
267        array.patches().values().all_valid()
268    }
269
270    fn all_invalid(array: &SparseArray) -> VortexResult<bool> {
271        if !array.fill_scalar().is_null() {
272            // We need _all_ values to be patched, and all patches to be invalid
273            return Ok(array.patches().values().len() == array.len()
274                && array.patches().values().all_invalid()?);
275        }
276
277        array.patches().values().all_invalid()
278    }
279
280    #[allow(clippy::unnecessary_fallible_conversions)]
281    fn validity_mask(array: &SparseArray) -> VortexResult<Mask> {
282        let fill_is_valid = array.fill_scalar().is_valid();
283        let values_validity = array.patches().values().validity_mask()?;
284        let len = array.len();
285
286        if matches!(values_validity, Mask::AllTrue(_)) && fill_is_valid {
287            return Ok(Mask::AllTrue(len));
288        }
289        if matches!(values_validity, Mask::AllFalse(_)) && !fill_is_valid {
290            return Ok(Mask::AllFalse(len));
291        }
292
293        // TODO(ngates): use vortex-buffer::BitBufferMut when it exists.
294        let mut is_valid_buffer = BooleanBufferBuilder::new(len);
295        is_valid_buffer.append_n(len, fill_is_valid);
296
297        let indices = array.patches().indices().to_primitive()?;
298        let index_offset = array.patches().offset();
299
300        match_each_integer_ptype!(indices.ptype(), |I| {
301            let indices = indices.as_slice::<I>();
302            patch_validity(&mut is_valid_buffer, indices, index_offset, values_validity);
303        });
304
305        Ok(Mask::from_buffer(is_valid_buffer.finish()))
306    }
307}
308
309fn patch_validity<I: NativePType>(
310    is_valid_buffer: &mut BooleanBufferBuilder,
311    indices: &[I],
312    index_offset: usize,
313    values_validity: Mask,
314) {
315    let indices = indices.iter().map(|index| {
316        let index = <usize as NumCast>::from(*index).vortex_expect("Failed to cast to usize");
317        index - index_offset
318    });
319    match values_validity {
320        Mask::AllTrue(_) => {
321            for index in indices {
322                is_valid_buffer.set_bit(index, true);
323            }
324        }
325        Mask::AllFalse(_) => {
326            for index in indices {
327                is_valid_buffer.set_bit(index, false);
328            }
329        }
330        Mask::Values(mask_values) => {
331            let is_valid = mask_values.boolean_buffer().iter();
332            for (index, is_valid) in indices.zip_eq(is_valid) {
333                is_valid_buffer.set_bit(index, is_valid);
334            }
335        }
336    }
337}
338
339#[cfg(test)]
340mod test {
341    use itertools::Itertools;
342    use vortex_array::IntoArray;
343    use vortex_array::arrays::{ConstantArray, PrimitiveArray};
344    use vortex_array::compute::cast;
345    use vortex_array::validity::Validity;
346    use vortex_buffer::buffer;
347    use vortex_dtype::{DType, Nullability, PType};
348    use vortex_error::VortexUnwrap;
349    use vortex_scalar::{PrimitiveScalar, Scalar};
350
351    use super::*;
352
353    fn nullable_fill() -> Scalar {
354        Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
355    }
356
357    fn non_nullable_fill() -> Scalar {
358        Scalar::from(42i32)
359    }
360
361    fn sparse_array(fill_value: Scalar) -> ArrayRef {
362        // merged array: [null, null, 100, null, null, 200, null, null, 300, null]
363        let mut values = buffer![100i32, 200, 300].into_array();
364        values = cast(&values, fill_value.dtype()).unwrap();
365
366        SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
367            .unwrap()
368            .into_array()
369    }
370
371    #[test]
372    pub fn test_scalar_at() {
373        let array = sparse_array(nullable_fill());
374
375        assert_eq!(array.scalar_at(0), nullable_fill());
376        assert_eq!(array.scalar_at(2), Scalar::from(Some(100_i32)));
377        assert_eq!(array.scalar_at(5), Scalar::from(Some(200_i32)));
378    }
379
380    #[test]
381    #[should_panic(expected = "out of bounds")]
382    fn test_scalar_at_oob() {
383        let array = sparse_array(nullable_fill());
384        let _ = array.scalar_at(10);
385    }
386
387    #[test]
388    pub fn test_scalar_at_again() {
389        let arr = SparseArray::try_new(
390            ConstantArray::new(10u32, 1).into_array(),
391            ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
392            100,
393            Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
394        )
395        .unwrap();
396
397        assert_eq!(
398            PrimitiveScalar::try_from(&arr.scalar_at(10))
399                .unwrap()
400                .typed_value::<u32>(),
401            Some(1234)
402        );
403        assert!(arr.scalar_at(0).is_null());
404        assert!(arr.scalar_at(99).is_null());
405    }
406
407    #[test]
408    pub fn scalar_at_sliced() {
409        let sliced = sparse_array(nullable_fill()).slice(2, 7);
410        assert_eq!(usize::try_from(&sliced.scalar_at(0)).unwrap(), 100);
411    }
412
413    #[test]
414    pub fn validity_mask_sliced_null_fill() {
415        let sliced = sparse_array(nullable_fill()).slice(2, 7);
416        assert_eq!(
417            sliced.validity_mask().unwrap(),
418            Mask::from_iter(vec![true, false, false, true, false])
419        );
420    }
421
422    #[test]
423    pub fn validity_mask_sliced_nonnull_fill() {
424        let sliced = SparseArray::try_new(
425            buffer![2u64, 5, 8].into_array(),
426            ConstantArray::new(
427                Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
428                3,
429            )
430            .into_array(),
431            10,
432            Scalar::primitive(1.0f32, Nullability::Nullable),
433        )
434        .unwrap()
435        .slice(2, 7);
436
437        assert_eq!(
438            sliced.validity_mask().unwrap(),
439            Mask::from_iter(vec![false, true, true, false, true])
440        );
441    }
442
443    #[test]
444    pub fn scalar_at_sliced_twice() {
445        let sliced_once = sparse_array(nullable_fill()).slice(1, 8);
446        assert_eq!(usize::try_from(&sliced_once.scalar_at(1)).unwrap(), 100);
447
448        let sliced_twice = sliced_once.slice(1, 6);
449        assert_eq!(usize::try_from(&sliced_twice.scalar_at(3)).unwrap(), 200);
450    }
451
452    #[test]
453    pub fn sparse_validity_mask() {
454        let array = sparse_array(nullable_fill());
455        assert_eq!(
456            array
457                .validity_mask()
458                .unwrap()
459                .to_boolean_buffer()
460                .iter()
461                .collect_vec(),
462            [
463                false, false, true, false, false, true, false, false, true, false
464            ]
465        );
466    }
467
468    #[test]
469    fn sparse_validity_mask_non_null_fill() {
470        let array = sparse_array(non_nullable_fill());
471        assert!(array.validity_mask().unwrap().all_true());
472    }
473
474    #[test]
475    #[should_panic]
476    fn test_invalid_length() {
477        let values = buffer![15_u32, 135, 13531, 42].into_array();
478        let indices = buffer![10_u64, 11, 50, 100].into_array();
479
480        SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
481    }
482
483    #[test]
484    fn test_valid_length() {
485        let values = buffer![15_u32, 135, 13531, 42].into_array();
486        let indices = buffer![10_u64, 11, 50, 100].into_array();
487
488        SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
489    }
490
491    #[test]
492    fn encode_with_nulls() {
493        let sparse = SparseArray::encode(
494            &PrimitiveArray::new(
495                buffer![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
496                Validity::from_iter(vec![
497                    true, true, false, true, false, true, false, true, true, false, true, false,
498                ]),
499            )
500            .into_array(),
501            None,
502        )
503        .vortex_unwrap();
504        let canonical = sparse.to_primitive().vortex_unwrap();
505        assert_eq!(
506            sparse.validity_mask().unwrap(),
507            Mask::from_iter(vec![
508                true, true, false, true, false, true, false, true, true, false, true, false,
509            ])
510        );
511        assert_eq!(
512            canonical.as_slice::<i32>(),
513            vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
514        );
515    }
516
517    #[test]
518    fn validity_mask_includes_null_values_when_fill_is_null() {
519        let indices = buffer![0u8, 2, 4, 6, 8].into_array();
520        let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
521            .into_array();
522        let array = SparseArray::try_new(indices, values, 10, Scalar::null_typed::<i16>()).unwrap();
523        let actual = array.validity_mask().unwrap();
524        let expected = Mask::from_iter([
525            true, false, true, false, false, false, false, false, true, false,
526        ]);
527
528        assert_eq!(actual, expected);
529    }
530}