vortex_sparse/
lib.rs

1use std::fmt::Debug;
2
3use vortex_array::arrays::{BooleanBufferBuilder, ConstantArray};
4use vortex_array::compute::{Operator, compare, fill_null, filter, sub_scalar};
5use vortex_array::patches::Patches;
6use vortex_array::stats::{ArrayStats, StatsSetRef};
7use vortex_array::vtable::{ArrayVTable, NotSupported, VTable, ValidityVTable};
8use vortex_array::{Array, ArrayRef, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
9use vortex_buffer::Buffer;
10use vortex_dtype::{DType, Nullability, match_each_integer_ptype};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
12use vortex_mask::{AllOr, Mask};
13use vortex_scalar::Scalar;
14
15mod canonical;
16mod compute;
17mod ops;
18mod serde;
19
20vtable!(Sparse);
21
22impl VTable for SparseVTable {
23    type Array = SparseArray;
24    type Encoding = SparseEncoding;
25
26    type ArrayVTable = Self;
27    type CanonicalVTable = Self;
28    type OperationsVTable = Self;
29    type ValidityVTable = Self;
30    type VisitorVTable = Self;
31    type ComputeVTable = NotSupported;
32    type EncodeVTable = Self;
33    type SerdeVTable = Self;
34
35    fn id(_encoding: &Self::Encoding) -> EncodingId {
36        EncodingId::new_ref("vortex.sparse")
37    }
38
39    fn encoding(_array: &Self::Array) -> EncodingRef {
40        EncodingRef::new_ref(SparseEncoding.as_ref())
41    }
42}
43
44#[derive(Clone, Debug)]
45pub struct SparseArray {
46    patches: Patches,
47    fill_value: Scalar,
48    stats_set: ArrayStats,
49}
50
51#[derive(Clone, Debug)]
52pub struct SparseEncoding;
53
54impl SparseArray {
55    pub fn try_new(
56        indices: ArrayRef,
57        values: ArrayRef,
58        len: usize,
59        fill_value: Scalar,
60    ) -> VortexResult<Self> {
61        Self::try_new_with_offset(indices, values, len, 0, fill_value)
62    }
63
64    pub(crate) fn try_new_with_offset(
65        indices: ArrayRef,
66        values: ArrayRef,
67        len: usize,
68        indices_offset: usize,
69        fill_value: Scalar,
70    ) -> VortexResult<Self> {
71        if indices.len() != values.len() {
72            vortex_bail!(
73                "Mismatched indices {} and values {} length",
74                indices.len(),
75                values.len()
76            );
77        }
78
79        if !indices.is_empty() {
80            let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1)?)?;
81
82            if last_index - indices_offset >= len {
83                vortex_bail!("Array length was set to {len} but the last index is {last_index}");
84            }
85        }
86
87        let patches = Patches::new(len, indices_offset, indices, values);
88
89        Self::try_new_from_patches(patches, fill_value)
90    }
91
92    pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
93        if fill_value.dtype() != patches.values().dtype() {
94            vortex_bail!(
95                "fill value, {:?}, should be instance of values dtype, {} but was {}.",
96                fill_value,
97                patches.values().dtype(),
98                fill_value.dtype(),
99            );
100        }
101        Ok(Self {
102            patches,
103            fill_value,
104            stats_set: Default::default(),
105        })
106    }
107
108    #[inline]
109    pub fn patches(&self) -> &Patches {
110        &self.patches
111    }
112
113    #[inline]
114    pub fn resolved_patches(&self) -> VortexResult<Patches> {
115        let patches = self.patches();
116        let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?;
117        let indices = sub_scalar(patches.indices(), indices_offset)?;
118        Ok(Patches::new(
119            patches.array_len(),
120            0,
121            indices,
122            patches.values().clone(),
123        ))
124    }
125
126    #[inline]
127    pub fn fill_scalar(&self) -> &Scalar {
128        &self.fill_value
129    }
130
131    /// Encode given array as a SparseArray.
132    ///
133    /// Optionally provided fill value will be respected if the array is less than 90% null.
134    pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
135        if let Some(fill_value) = fill_value.as_ref() {
136            if array.dtype() != fill_value.dtype() {
137                vortex_bail!(
138                    "Array and fill value types must match. got {} and {}",
139                    array.dtype(),
140                    fill_value.dtype()
141                )
142            }
143        }
144        let mask = array.validity_mask()?;
145
146        if mask.all_false() {
147            // Array is constant NULL
148            return Ok(
149                ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
150            );
151        } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
152            // Array is dominated by NULL but has non-NULL values
153            let non_null_values = filter(array, &mask)?;
154            let non_null_indices = match mask.indices() {
155                AllOr::All => {
156                    // We already know that the mask is 90%+ false
157                    unreachable!("Mask is mostly null")
158                }
159                AllOr::None => {
160                    // we know there are some non-NULL values
161                    unreachable!("Mask is mostly null but not all null")
162                }
163                AllOr::Some(values) => {
164                    let buffer: Buffer<u32> = values
165                        .iter()
166                        .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
167                        .collect();
168
169                    buffer.into_array()
170                }
171            };
172
173            return Ok(SparseArray::try_new(
174                non_null_indices,
175                non_null_values,
176                array.len(),
177                Scalar::null(array.dtype().clone()),
178            )?
179            .into_array());
180        }
181
182        let fill = if let Some(fill) = fill_value {
183            fill
184        } else {
185            // TODO(robert): Support other dtypes, only thing missing is getting most common value out of the array
186            let (top_pvalue, _) = array
187                .to_primitive()?
188                .top_value()?
189                .vortex_expect("Non empty or all null array");
190
191            Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
192        };
193
194        let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
195        let non_top_mask = Mask::from_buffer(
196            fill_null(
197                &compare(array, &fill_array, Operator::NotEq)?,
198                &Scalar::bool(true, Nullability::NonNullable),
199            )?
200            .to_bool()?
201            .boolean_buffer()
202            .clone(),
203        );
204
205        let non_top_values = filter(array, &non_top_mask)?;
206
207        let indices: Buffer<u64> = match non_top_mask {
208            Mask::AllTrue(count) => {
209                // all true -> complete slice
210                (0u64..count as u64).collect()
211            }
212            Mask::AllFalse(_) => {
213                // All values are equal to the top value
214                return Ok(fill_array);
215            }
216            Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
217        };
218
219        SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
220            .map(|a| a.into_array())
221    }
222}
223
224impl ArrayVTable<SparseVTable> for SparseVTable {
225    fn len(array: &SparseArray) -> usize {
226        array.patches.array_len()
227    }
228
229    fn dtype(array: &SparseArray) -> &DType {
230        array.fill_scalar().dtype()
231    }
232
233    fn stats(array: &SparseArray) -> StatsSetRef<'_> {
234        array.stats_set.to_ref(array.as_ref())
235    }
236}
237
238impl ValidityVTable<SparseVTable> for SparseVTable {
239    fn is_valid(array: &SparseArray, index: usize) -> VortexResult<bool> {
240        Ok(match array.patches().get_patched(index)? {
241            None => array.fill_scalar().is_valid(),
242            Some(patch_value) => patch_value.is_valid(),
243        })
244    }
245
246    fn all_valid(array: &SparseArray) -> VortexResult<bool> {
247        if array.fill_scalar().is_null() {
248            // We need _all_ values to be patched, and all patches to be valid
249            return Ok(array.patches().values().len() == array.len()
250                && array.patches().values().all_valid()?);
251        }
252
253        array.patches().values().all_valid()
254    }
255
256    fn all_invalid(array: &SparseArray) -> VortexResult<bool> {
257        if !array.fill_scalar().is_null() {
258            // We need _all_ values to be patched, and all patches to be invalid
259            return Ok(array.patches().values().len() == array.len()
260                && array.patches().values().all_invalid()?);
261        }
262
263        array.patches().values().all_invalid()
264    }
265
266    #[allow(clippy::unnecessary_fallible_conversions)]
267    fn validity_mask(array: &SparseArray) -> VortexResult<Mask> {
268        let indices = array.patches().indices().to_primitive()?;
269
270        if array.fill_scalar().is_null() {
271            // If we have a null fill value, then we set each patch value to true.
272            let mut buffer = BooleanBufferBuilder::new(array.len());
273            // TODO(ngates): use vortex-buffer::BitBufferMut when it exists.
274            buffer.append_n(array.len(), false);
275
276            match_each_integer_ptype!(indices.ptype(), |I| {
277                indices.as_slice::<I>().iter().for_each(|&index| {
278                    buffer.set_bit(
279                        usize::try_from(index).vortex_expect("Failed to cast to usize")
280                            - array.patches().offset(),
281                        true,
282                    );
283                });
284            });
285
286            return Ok(Mask::from_buffer(buffer.finish()));
287        }
288
289        // If the fill_value is non-null, then the validity is based on the validity of the
290        // patch values.
291        let mut buffer = BooleanBufferBuilder::new(array.len());
292        buffer.append_n(array.len(), true);
293
294        let values_validity = array.patches().values().validity_mask()?;
295        match_each_integer_ptype!(indices.ptype(), |I| {
296            indices
297                .as_slice::<I>()
298                .iter()
299                .enumerate()
300                .for_each(|(patch_idx, &index)| {
301                    buffer.set_bit(
302                        usize::try_from(index).vortex_expect("Failed to cast to usize")
303                            - array.patches().offset(),
304                        values_validity.value(patch_idx),
305                    );
306                })
307        });
308
309        Ok(Mask::from_buffer(buffer.finish()))
310    }
311}
312
313#[cfg(test)]
314mod test {
315    use itertools::Itertools;
316    use vortex_array::IntoArray;
317    use vortex_array::arrays::{ConstantArray, PrimitiveArray};
318    use vortex_array::compute::cast;
319    use vortex_array::validity::Validity;
320    use vortex_buffer::buffer;
321    use vortex_dtype::{DType, Nullability, PType};
322    use vortex_error::{VortexError, VortexUnwrap};
323    use vortex_scalar::{PrimitiveScalar, Scalar};
324
325    use super::*;
326
327    fn nullable_fill() -> Scalar {
328        Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
329    }
330
331    fn non_nullable_fill() -> Scalar {
332        Scalar::from(42i32)
333    }
334
335    fn sparse_array(fill_value: Scalar) -> ArrayRef {
336        // merged array: [null, null, 100, null, null, 200, null, null, 300, null]
337        let mut values = buffer![100i32, 200, 300].into_array();
338        values = cast(&values, fill_value.dtype()).unwrap();
339
340        SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
341            .unwrap()
342            .into_array()
343    }
344
345    #[test]
346    pub fn test_scalar_at() {
347        let array = sparse_array(nullable_fill());
348
349        assert_eq!(array.scalar_at(0).unwrap(), nullable_fill());
350        assert_eq!(array.scalar_at(2).unwrap(), Scalar::from(Some(100_i32)));
351        assert_eq!(array.scalar_at(5).unwrap(), Scalar::from(Some(200_i32)));
352
353        let error = array.scalar_at(10).err().unwrap();
354        let VortexError::OutOfBounds(i, start, stop, _) = error else {
355            unreachable!()
356        };
357        assert_eq!(i, 10);
358        assert_eq!(start, 0);
359        assert_eq!(stop, 10);
360    }
361
362    #[test]
363    pub fn test_scalar_at_again() {
364        let arr = SparseArray::try_new(
365            ConstantArray::new(10u32, 1).into_array(),
366            ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
367            100,
368            Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
369        )
370        .unwrap();
371
372        assert_eq!(
373            PrimitiveScalar::try_from(&arr.scalar_at(10).unwrap())
374                .unwrap()
375                .typed_value::<u32>(),
376            Some(1234)
377        );
378        assert!(arr.scalar_at(0).unwrap().is_null());
379        assert!(arr.scalar_at(99).unwrap().is_null());
380    }
381
382    #[test]
383    pub fn scalar_at_sliced() {
384        let sliced = sparse_array(nullable_fill()).slice(2, 7).unwrap();
385        assert_eq!(usize::try_from(&sliced.scalar_at(0).unwrap()).unwrap(), 100);
386        let error = sliced.scalar_at(5).err().unwrap();
387        let VortexError::OutOfBounds(i, start, stop, _) = error else {
388            unreachable!()
389        };
390        assert_eq!(i, 5);
391        assert_eq!(start, 0);
392        assert_eq!(stop, 5);
393    }
394
395    #[test]
396    pub fn validity_mask_sliced_null_fill() {
397        let sliced = sparse_array(nullable_fill()).slice(2, 7).unwrap();
398        assert_eq!(
399            sliced.validity_mask().unwrap(),
400            Mask::from_iter(vec![true, false, false, true, false])
401        );
402    }
403
404    #[test]
405    pub fn validity_mask_sliced_nonnull_fill() {
406        let sliced = SparseArray::try_new(
407            buffer![2u64, 5, 8].into_array(),
408            ConstantArray::new(
409                Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
410                3,
411            )
412            .into_array(),
413            10,
414            Scalar::primitive(1.0f32, Nullability::Nullable),
415        )
416        .unwrap()
417        .slice(2, 7)
418        .unwrap();
419
420        assert_eq!(
421            sliced.validity_mask().unwrap(),
422            Mask::from_iter(vec![false, true, true, false, true])
423        );
424    }
425
426    #[test]
427    pub fn scalar_at_sliced_twice() {
428        let sliced_once = sparse_array(nullable_fill()).slice(1, 8).unwrap();
429        assert_eq!(
430            usize::try_from(&sliced_once.scalar_at(1).unwrap()).unwrap(),
431            100
432        );
433        let error = sliced_once.scalar_at(7).err().unwrap();
434        let VortexError::OutOfBounds(i, start, stop, _) = error else {
435            unreachable!()
436        };
437        assert_eq!(i, 7);
438        assert_eq!(start, 0);
439        assert_eq!(stop, 7);
440
441        let sliced_twice = sliced_once.slice(1, 6).unwrap();
442        assert_eq!(
443            usize::try_from(&sliced_twice.scalar_at(3).unwrap()).unwrap(),
444            200
445        );
446        let error2 = sliced_twice.scalar_at(5).err().unwrap();
447        let VortexError::OutOfBounds(i, start, stop, _) = error2 else {
448            unreachable!()
449        };
450        assert_eq!(i, 5);
451        assert_eq!(start, 0);
452        assert_eq!(stop, 5);
453    }
454
455    #[test]
456    pub fn sparse_validity_mask() {
457        let array = sparse_array(nullable_fill());
458        assert_eq!(
459            array
460                .validity_mask()
461                .unwrap()
462                .to_boolean_buffer()
463                .iter()
464                .collect_vec(),
465            [
466                false, false, true, false, false, true, false, false, true, false
467            ]
468        );
469    }
470
471    #[test]
472    fn sparse_validity_mask_non_null_fill() {
473        let array = sparse_array(non_nullable_fill());
474        assert!(array.validity_mask().unwrap().all_true());
475    }
476
477    #[test]
478    #[should_panic]
479    fn test_invalid_length() {
480        let values = buffer![15_u32, 135, 13531, 42].into_array();
481        let indices = buffer![10_u64, 11, 50, 100].into_array();
482
483        SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
484    }
485
486    #[test]
487    fn test_valid_length() {
488        let values = buffer![15_u32, 135, 13531, 42].into_array();
489        let indices = buffer![10_u64, 11, 50, 100].into_array();
490
491        SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
492    }
493
494    #[test]
495    fn encode_with_nulls() {
496        let sparse = SparseArray::encode(
497            &PrimitiveArray::new(
498                buffer![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
499                Validity::from_iter(vec![
500                    true, true, false, true, false, true, false, true, true, false, true, false,
501                ]),
502            )
503            .into_array(),
504            None,
505        )
506        .vortex_unwrap();
507        let canonical = sparse.to_primitive().vortex_unwrap();
508        assert_eq!(
509            sparse.validity_mask().unwrap(),
510            Mask::from_iter(vec![
511                true, true, false, true, false, true, false, true, true, false, true, false,
512            ])
513        );
514        assert_eq!(
515            canonical.as_slice::<i32>(),
516            vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
517        );
518    }
519}