vortex_sparse/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use itertools::Itertools as _;
7use num_traits::NumCast;
8use vortex_array::arrays::{BooleanBufferBuilder, ConstantArray};
9use vortex_array::compute::{Operator, compare, fill_null, filter, sub_scalar};
10use vortex_array::patches::Patches;
11use vortex_array::stats::{ArrayStats, StatsSetRef};
12use vortex_array::vtable::{ArrayVTable, NotSupported, VTable, ValidityVTable};
13use vortex_array::{Array, ArrayRef, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
14use vortex_buffer::Buffer;
15use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
16use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
17use vortex_mask::{AllOr, Mask};
18use vortex_scalar::Scalar;
19
20mod canonical;
21mod compute;
22mod ops;
23mod serde;
24
25vtable!(Sparse);
26
27impl VTable for SparseVTable {
28    type Array = SparseArray;
29    type Encoding = SparseEncoding;
30
31    type ArrayVTable = Self;
32    type CanonicalVTable = Self;
33    type OperationsVTable = Self;
34    type ValidityVTable = Self;
35    type VisitorVTable = Self;
36    type ComputeVTable = NotSupported;
37    type EncodeVTable = Self;
38    type SerdeVTable = Self;
39
40    fn id(_encoding: &Self::Encoding) -> EncodingId {
41        EncodingId::new_ref("vortex.sparse")
42    }
43
44    fn encoding(_array: &Self::Array) -> EncodingRef {
45        EncodingRef::new_ref(SparseEncoding.as_ref())
46    }
47}
48
49#[derive(Clone, Debug)]
50pub struct SparseArray {
51    patches: Patches,
52    fill_value: Scalar,
53    stats_set: ArrayStats,
54}
55
56#[derive(Clone, Debug)]
57pub struct SparseEncoding;
58
59impl SparseArray {
60    pub fn try_new(
61        indices: ArrayRef,
62        values: ArrayRef,
63        len: usize,
64        fill_value: Scalar,
65    ) -> VortexResult<Self> {
66        Self::try_new_with_offset(indices, values, len, 0, fill_value)
67    }
68
69    pub(crate) fn try_new_with_offset(
70        indices: ArrayRef,
71        values: ArrayRef,
72        len: usize,
73        indices_offset: usize,
74        fill_value: Scalar,
75    ) -> VortexResult<Self> {
76        if indices.len() != values.len() {
77            vortex_bail!(
78                "Mismatched indices {} and values {} length",
79                indices.len(),
80                values.len()
81            );
82        }
83
84        if !indices.is_empty() {
85            let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1)?)?;
86
87            if last_index - indices_offset >= len {
88                vortex_bail!("Array length was set to {len} but the last index is {last_index}");
89            }
90        }
91
92        let patches = Patches::new(len, indices_offset, indices, values);
93
94        Self::try_new_from_patches(patches, fill_value)
95    }
96
97    pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
98        if fill_value.dtype() != patches.values().dtype() {
99            vortex_bail!(
100                "fill value, {:?}, should be instance of values dtype, {} but was {}.",
101                fill_value,
102                patches.values().dtype(),
103                fill_value.dtype(),
104            );
105        }
106        Ok(Self {
107            patches,
108            fill_value,
109            stats_set: Default::default(),
110        })
111    }
112
113    #[inline]
114    pub fn patches(&self) -> &Patches {
115        &self.patches
116    }
117
118    #[inline]
119    pub fn resolved_patches(&self) -> VortexResult<Patches> {
120        let patches = self.patches();
121        let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?;
122        let indices = sub_scalar(patches.indices(), indices_offset)?;
123        Ok(Patches::new(
124            patches.array_len(),
125            0,
126            indices,
127            patches.values().clone(),
128        ))
129    }
130
131    #[inline]
132    pub fn fill_scalar(&self) -> &Scalar {
133        &self.fill_value
134    }
135
136    /// Encode given array as a SparseArray.
137    ///
138    /// Optionally provided fill value will be respected if the array is less than 90% null.
139    pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
140        if let Some(fill_value) = fill_value.as_ref() {
141            if array.dtype() != fill_value.dtype() {
142                vortex_bail!(
143                    "Array and fill value types must match. got {} and {}",
144                    array.dtype(),
145                    fill_value.dtype()
146                )
147            }
148        }
149        let mask = array.validity_mask()?;
150
151        if mask.all_false() {
152            // Array is constant NULL
153            return Ok(
154                ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
155            );
156        } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
157            // Array is dominated by NULL but has non-NULL values
158            let non_null_values = filter(array, &mask)?;
159            let non_null_indices = match mask.indices() {
160                AllOr::All => {
161                    // We already know that the mask is 90%+ false
162                    unreachable!("Mask is mostly null")
163                }
164                AllOr::None => {
165                    // we know there are some non-NULL values
166                    unreachable!("Mask is mostly null but not all null")
167                }
168                AllOr::Some(values) => {
169                    let buffer: Buffer<u32> = values
170                        .iter()
171                        .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
172                        .collect();
173
174                    buffer.into_array()
175                }
176            };
177
178            return Ok(SparseArray::try_new(
179                non_null_indices,
180                non_null_values,
181                array.len(),
182                Scalar::null(array.dtype().clone()),
183            )?
184            .into_array());
185        }
186
187        let fill = if let Some(fill) = fill_value {
188            fill
189        } else {
190            // TODO(robert): Support other dtypes, only thing missing is getting most common value out of the array
191            let (top_pvalue, _) = array
192                .to_primitive()?
193                .top_value()?
194                .vortex_expect("Non empty or all null array");
195
196            Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
197        };
198
199        let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
200        let non_top_mask = Mask::from_buffer(
201            fill_null(
202                &compare(array, &fill_array, Operator::NotEq)?,
203                &Scalar::bool(true, Nullability::NonNullable),
204            )?
205            .to_bool()?
206            .boolean_buffer()
207            .clone(),
208        );
209
210        let non_top_values = filter(array, &non_top_mask)?;
211
212        let indices: Buffer<u64> = match non_top_mask {
213            Mask::AllTrue(count) => {
214                // all true -> complete slice
215                (0u64..count as u64).collect()
216            }
217            Mask::AllFalse(_) => {
218                // All values are equal to the top value
219                return Ok(fill_array);
220            }
221            Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
222        };
223
224        SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
225            .map(|a| a.into_array())
226    }
227}
228
229impl ArrayVTable<SparseVTable> for SparseVTable {
230    fn len(array: &SparseArray) -> usize {
231        array.patches.array_len()
232    }
233
234    fn dtype(array: &SparseArray) -> &DType {
235        array.fill_scalar().dtype()
236    }
237
238    fn stats(array: &SparseArray) -> StatsSetRef<'_> {
239        array.stats_set.to_ref(array.as_ref())
240    }
241}
242
243impl ValidityVTable<SparseVTable> for SparseVTable {
244    fn is_valid(array: &SparseArray, index: usize) -> VortexResult<bool> {
245        Ok(match array.patches().get_patched(index)? {
246            None => array.fill_scalar().is_valid(),
247            Some(patch_value) => patch_value.is_valid(),
248        })
249    }
250
251    fn all_valid(array: &SparseArray) -> VortexResult<bool> {
252        if array.fill_scalar().is_null() {
253            // We need _all_ values to be patched, and all patches to be valid
254            return Ok(array.patches().values().len() == array.len()
255                && array.patches().values().all_valid()?);
256        }
257
258        array.patches().values().all_valid()
259    }
260
261    fn all_invalid(array: &SparseArray) -> VortexResult<bool> {
262        if !array.fill_scalar().is_null() {
263            // We need _all_ values to be patched, and all patches to be invalid
264            return Ok(array.patches().values().len() == array.len()
265                && array.patches().values().all_invalid()?);
266        }
267
268        array.patches().values().all_invalid()
269    }
270
271    #[allow(clippy::unnecessary_fallible_conversions)]
272    fn validity_mask(array: &SparseArray) -> VortexResult<Mask> {
273        let fill_is_valid = array.fill_scalar().is_valid();
274        let values_validity = array.patches().values().validity_mask()?;
275        let len = array.len();
276
277        if matches!(values_validity, Mask::AllTrue(_)) && fill_is_valid {
278            return Ok(Mask::AllTrue(len));
279        }
280        if matches!(values_validity, Mask::AllFalse(_)) && !fill_is_valid {
281            return Ok(Mask::AllFalse(len));
282        }
283
284        // TODO(ngates): use vortex-buffer::BitBufferMut when it exists.
285        let mut is_valid_buffer = BooleanBufferBuilder::new(len);
286        is_valid_buffer.append_n(len, fill_is_valid);
287
288        let indices = array.patches().indices().to_primitive()?;
289        let index_offset = array.patches().offset();
290
291        match_each_integer_ptype!(indices.ptype(), |I| {
292            let indices = indices.as_slice::<I>();
293            patch_validity(&mut is_valid_buffer, indices, index_offset, values_validity);
294        });
295
296        Ok(Mask::from_buffer(is_valid_buffer.finish()))
297    }
298}
299
300fn patch_validity<I: NativePType>(
301    is_valid_buffer: &mut BooleanBufferBuilder,
302    indices: &[I],
303    index_offset: usize,
304    values_validity: Mask,
305) {
306    let indices = indices.iter().map(|index| {
307        let index = <usize as NumCast>::from(*index).vortex_expect("Failed to cast to usize");
308        index - index_offset
309    });
310    match values_validity {
311        Mask::AllTrue(_) => {
312            for index in indices {
313                is_valid_buffer.set_bit(index, true);
314            }
315        }
316        Mask::AllFalse(_) => {
317            for index in indices {
318                is_valid_buffer.set_bit(index, false);
319            }
320        }
321        Mask::Values(mask_values) => {
322            let is_valid = mask_values.boolean_buffer().iter();
323            for (index, is_valid) in indices.zip_eq(is_valid) {
324                is_valid_buffer.set_bit(index, is_valid);
325            }
326        }
327    }
328}
329
330#[cfg(test)]
331mod test {
332    use itertools::Itertools;
333    use vortex_array::IntoArray;
334    use vortex_array::arrays::{ConstantArray, PrimitiveArray};
335    use vortex_array::compute::cast;
336    use vortex_array::validity::Validity;
337    use vortex_buffer::buffer;
338    use vortex_dtype::{DType, Nullability, PType};
339    use vortex_error::{VortexError, VortexUnwrap};
340    use vortex_scalar::{PrimitiveScalar, Scalar};
341
342    use super::*;
343
344    fn nullable_fill() -> Scalar {
345        Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
346    }
347
348    fn non_nullable_fill() -> Scalar {
349        Scalar::from(42i32)
350    }
351
352    fn sparse_array(fill_value: Scalar) -> ArrayRef {
353        // merged array: [null, null, 100, null, null, 200, null, null, 300, null]
354        let mut values = buffer![100i32, 200, 300].into_array();
355        values = cast(&values, fill_value.dtype()).unwrap();
356
357        SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
358            .unwrap()
359            .into_array()
360    }
361
362    #[test]
363    pub fn test_scalar_at() {
364        let array = sparse_array(nullable_fill());
365
366        assert_eq!(array.scalar_at(0).unwrap(), nullable_fill());
367        assert_eq!(array.scalar_at(2).unwrap(), Scalar::from(Some(100_i32)));
368        assert_eq!(array.scalar_at(5).unwrap(), Scalar::from(Some(200_i32)));
369
370        let error = array.scalar_at(10).err().unwrap();
371        let VortexError::OutOfBounds(i, start, stop, _) = error else {
372            unreachable!()
373        };
374        assert_eq!(i, 10);
375        assert_eq!(start, 0);
376        assert_eq!(stop, 10);
377    }
378
379    #[test]
380    pub fn test_scalar_at_again() {
381        let arr = SparseArray::try_new(
382            ConstantArray::new(10u32, 1).into_array(),
383            ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
384            100,
385            Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
386        )
387        .unwrap();
388
389        assert_eq!(
390            PrimitiveScalar::try_from(&arr.scalar_at(10).unwrap())
391                .unwrap()
392                .typed_value::<u32>(),
393            Some(1234)
394        );
395        assert!(arr.scalar_at(0).unwrap().is_null());
396        assert!(arr.scalar_at(99).unwrap().is_null());
397    }
398
399    #[test]
400    pub fn scalar_at_sliced() {
401        let sliced = sparse_array(nullable_fill()).slice(2, 7).unwrap();
402        assert_eq!(usize::try_from(&sliced.scalar_at(0).unwrap()).unwrap(), 100);
403        let error = sliced.scalar_at(5).err().unwrap();
404        let VortexError::OutOfBounds(i, start, stop, _) = error else {
405            unreachable!()
406        };
407        assert_eq!(i, 5);
408        assert_eq!(start, 0);
409        assert_eq!(stop, 5);
410    }
411
412    #[test]
413    pub fn validity_mask_sliced_null_fill() {
414        let sliced = sparse_array(nullable_fill()).slice(2, 7).unwrap();
415        assert_eq!(
416            sliced.validity_mask().unwrap(),
417            Mask::from_iter(vec![true, false, false, true, false])
418        );
419    }
420
421    #[test]
422    pub fn validity_mask_sliced_nonnull_fill() {
423        let sliced = SparseArray::try_new(
424            buffer![2u64, 5, 8].into_array(),
425            ConstantArray::new(
426                Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
427                3,
428            )
429            .into_array(),
430            10,
431            Scalar::primitive(1.0f32, Nullability::Nullable),
432        )
433        .unwrap()
434        .slice(2, 7)
435        .unwrap();
436
437        assert_eq!(
438            sliced.validity_mask().unwrap(),
439            Mask::from_iter(vec![false, true, true, false, true])
440        );
441    }
442
443    #[test]
444    pub fn scalar_at_sliced_twice() {
445        let sliced_once = sparse_array(nullable_fill()).slice(1, 8).unwrap();
446        assert_eq!(
447            usize::try_from(&sliced_once.scalar_at(1).unwrap()).unwrap(),
448            100
449        );
450        let error = sliced_once.scalar_at(7).err().unwrap();
451        let VortexError::OutOfBounds(i, start, stop, _) = error else {
452            unreachable!()
453        };
454        assert_eq!(i, 7);
455        assert_eq!(start, 0);
456        assert_eq!(stop, 7);
457
458        let sliced_twice = sliced_once.slice(1, 6).unwrap();
459        assert_eq!(
460            usize::try_from(&sliced_twice.scalar_at(3).unwrap()).unwrap(),
461            200
462        );
463        let error2 = sliced_twice.scalar_at(5).err().unwrap();
464        let VortexError::OutOfBounds(i, start, stop, _) = error2 else {
465            unreachable!()
466        };
467        assert_eq!(i, 5);
468        assert_eq!(start, 0);
469        assert_eq!(stop, 5);
470    }
471
472    #[test]
473    pub fn sparse_validity_mask() {
474        let array = sparse_array(nullable_fill());
475        assert_eq!(
476            array
477                .validity_mask()
478                .unwrap()
479                .to_boolean_buffer()
480                .iter()
481                .collect_vec(),
482            [
483                false, false, true, false, false, true, false, false, true, false
484            ]
485        );
486    }
487
488    #[test]
489    fn sparse_validity_mask_non_null_fill() {
490        let array = sparse_array(non_nullable_fill());
491        assert!(array.validity_mask().unwrap().all_true());
492    }
493
494    #[test]
495    #[should_panic]
496    fn test_invalid_length() {
497        let values = buffer![15_u32, 135, 13531, 42].into_array();
498        let indices = buffer![10_u64, 11, 50, 100].into_array();
499
500        SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
501    }
502
503    #[test]
504    fn test_valid_length() {
505        let values = buffer![15_u32, 135, 13531, 42].into_array();
506        let indices = buffer![10_u64, 11, 50, 100].into_array();
507
508        SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
509    }
510
511    #[test]
512    fn encode_with_nulls() {
513        let sparse = SparseArray::encode(
514            &PrimitiveArray::new(
515                buffer![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
516                Validity::from_iter(vec![
517                    true, true, false, true, false, true, false, true, true, false, true, false,
518                ]),
519            )
520            .into_array(),
521            None,
522        )
523        .vortex_unwrap();
524        let canonical = sparse.to_primitive().vortex_unwrap();
525        assert_eq!(
526            sparse.validity_mask().unwrap(),
527            Mask::from_iter(vec![
528                true, true, false, true, false, true, false, true, true, false, true, false,
529            ])
530        );
531        assert_eq!(
532            canonical.as_slice::<i32>(),
533            vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
534        );
535    }
536
537    #[test]
538    fn validity_mask_includes_null_values_when_fill_is_null() {
539        let indices = buffer![0u8, 2, 4, 6, 8].into_array();
540        let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
541            .into_array();
542        let array = SparseArray::try_new(indices, values, 10, Scalar::null_typed::<i16>()).unwrap();
543        let actual = array.validity_mask().unwrap();
544        let expected = Mask::from_iter([
545            true, false, true, false, false, false, false, false, true, false,
546        ]);
547
548        assert_eq!(actual, expected);
549    }
550}