vortex_sparse/
lib.rs

1use std::fmt::Debug;
2
3use vortex_array::arrays::{BooleanBufferBuilder, ConstantArray};
4use vortex_array::compute::{Operator, compare, fill_null, filter, sub_scalar};
5use vortex_array::patches::Patches;
6use vortex_array::stats::{ArrayStats, StatsSetRef};
7use vortex_array::vtable::{ArrayVTable, NotSupported, VTable, ValidityVTable};
8use vortex_array::{Array, ArrayRef, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
9use vortex_buffer::Buffer;
10use vortex_dtype::{DType, Nullability, match_each_integer_ptype};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
12use vortex_mask::{AllOr, Mask};
13use vortex_scalar::Scalar;
14
15mod canonical;
16mod compute;
17mod ops;
18mod serde;
19
20vtable!(Sparse);
21
22impl VTable for SparseVTable {
23    type Array = SparseArray;
24    type Encoding = SparseEncoding;
25
26    type ArrayVTable = Self;
27    type CanonicalVTable = Self;
28    type OperationsVTable = Self;
29    type ValidityVTable = Self;
30    type VisitorVTable = Self;
31    type ComputeVTable = NotSupported;
32    type EncodeVTable = Self;
33    type SerdeVTable = Self;
34
35    fn id(_encoding: &Self::Encoding) -> EncodingId {
36        EncodingId::new_ref("vortex.sparse")
37    }
38
39    fn encoding(_array: &Self::Array) -> EncodingRef {
40        EncodingRef::new_ref(SparseEncoding.as_ref())
41    }
42}
43
44#[derive(Clone, Debug)]
45pub struct SparseArray {
46    patches: Patches,
47    fill_value: Scalar,
48    stats_set: ArrayStats,
49}
50
51#[derive(Clone, Debug)]
52pub struct SparseEncoding;
53
54impl SparseArray {
55    pub fn try_new(
56        indices: ArrayRef,
57        values: ArrayRef,
58        len: usize,
59        fill_value: Scalar,
60    ) -> VortexResult<Self> {
61        Self::try_new_with_offset(indices, values, len, 0, fill_value)
62    }
63
64    pub(crate) fn try_new_with_offset(
65        indices: ArrayRef,
66        values: ArrayRef,
67        len: usize,
68        indices_offset: usize,
69        fill_value: Scalar,
70    ) -> VortexResult<Self> {
71        if indices.len() != values.len() {
72            vortex_bail!(
73                "Mismatched indices {} and values {} length",
74                indices.len(),
75                values.len()
76            );
77        }
78
79        if !indices.is_empty() {
80            let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1)?)?;
81
82            if last_index - indices_offset >= len {
83                vortex_bail!("Array length was set to {len} but the last index is {last_index}");
84            }
85        }
86
87        let patches = Patches::new(len, indices_offset, indices, values);
88
89        Self::try_new_from_patches(patches, fill_value)
90    }
91
92    pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
93        if fill_value.dtype() != patches.values().dtype() {
94            vortex_bail!(
95                "fill value, {:?}, should be instance of values dtype, {}",
96                fill_value,
97                patches.values().dtype(),
98            );
99        }
100        Ok(Self {
101            patches,
102            fill_value,
103            stats_set: Default::default(),
104        })
105    }
106
107    #[inline]
108    pub fn patches(&self) -> &Patches {
109        &self.patches
110    }
111
112    #[inline]
113    pub fn resolved_patches(&self) -> VortexResult<Patches> {
114        let (len, offset, indices, values) = self.patches().clone().into_parts();
115        let indices_offset = Scalar::from(offset).cast(indices.dtype())?;
116        let indices = sub_scalar(&indices, indices_offset)?;
117        Ok(Patches::new(len, 0, indices, values))
118    }
119
120    #[inline]
121    pub fn fill_scalar(&self) -> &Scalar {
122        &self.fill_value
123    }
124
125    /// Encode given array as a SparseArray.
126    ///
127    /// Optionally provided fill value will be respected if the array is less than 90% null.
128    pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
129        if let Some(fill_value) = fill_value.as_ref() {
130            if array.dtype() != fill_value.dtype() {
131                vortex_bail!(
132                    "Array and fill value types must match. got {} and {}",
133                    array.dtype(),
134                    fill_value.dtype()
135                )
136            }
137        }
138        let mask = array.validity_mask()?;
139
140        if mask.all_false() {
141            // Array is constant NULL
142            return Ok(
143                ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
144            );
145        } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
146            // Array is dominated by NULL but has non-NULL values
147            let non_null_values = filter(array, &mask)?;
148            let non_null_indices = match mask.indices() {
149                AllOr::All => {
150                    // We already know that the mask is 90%+ false
151                    unreachable!("Mask is mostly null")
152                }
153                AllOr::None => {
154                    // we know there are some non-NULL values
155                    unreachable!("Mask is mostly null but not all null")
156                }
157                AllOr::Some(values) => {
158                    let buffer: Buffer<u32> = values
159                        .iter()
160                        .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
161                        .collect();
162
163                    buffer.into_array()
164                }
165            };
166
167            return Ok(SparseArray::try_new(
168                non_null_indices,
169                non_null_values,
170                array.len(),
171                Scalar::null(array.dtype().clone()),
172            )?
173            .into_array());
174        }
175
176        let fill = if let Some(fill) = fill_value {
177            fill
178        } else {
179            // TODO(robert): Support other dtypes, only thing missing is getting most common value out of the array
180            let (top_pvalue, _) = array
181                .to_primitive()?
182                .top_value()?
183                .vortex_expect("Non empty or all null array");
184
185            Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
186        };
187
188        let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
189        let non_top_mask = Mask::from_buffer(
190            fill_null(
191                &compare(array, &fill_array, Operator::NotEq)?,
192                &Scalar::bool(true, Nullability::NonNullable),
193            )?
194            .to_bool()?
195            .boolean_buffer()
196            .clone(),
197        );
198
199        let non_top_values = filter(array, &non_top_mask)?;
200
201        let indices: Buffer<u64> = match non_top_mask {
202            Mask::AllTrue(count) => {
203                // all true -> complete slice
204                (0u64..count as u64).collect()
205            }
206            Mask::AllFalse(_) => {
207                // All values are equal to the top value
208                return Ok(fill_array);
209            }
210            Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
211        };
212
213        SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
214            .map(|a| a.into_array())
215    }
216}
217
218impl ArrayVTable<SparseVTable> for SparseVTable {
219    fn len(array: &SparseArray) -> usize {
220        array.patches.array_len()
221    }
222
223    fn dtype(array: &SparseArray) -> &DType {
224        array.fill_scalar().dtype()
225    }
226
227    fn stats(array: &SparseArray) -> StatsSetRef<'_> {
228        array.stats_set.to_ref(array.as_ref())
229    }
230}
231
232impl ValidityVTable<SparseVTable> for SparseVTable {
233    fn is_valid(array: &SparseArray, index: usize) -> VortexResult<bool> {
234        Ok(match array.patches().get_patched(index)? {
235            None => array.fill_scalar().is_valid(),
236            Some(patch_value) => patch_value.is_valid(),
237        })
238    }
239
240    fn all_valid(array: &SparseArray) -> VortexResult<bool> {
241        if array.fill_scalar().is_null() {
242            // We need _all_ values to be patched, and all patches to be valid
243            return Ok(array.patches().values().len() == array.len()
244                && array.patches().values().all_valid()?);
245        }
246
247        array.patches().values().all_valid()
248    }
249
250    fn all_invalid(array: &SparseArray) -> VortexResult<bool> {
251        if !array.fill_scalar().is_null() {
252            // We need _all_ values to be patched, and all patches to be invalid
253            return Ok(array.patches().values().len() == array.len()
254                && array.patches().values().all_invalid()?);
255        }
256
257        array.patches().values().all_invalid()
258    }
259
260    fn validity_mask(array: &SparseArray) -> VortexResult<Mask> {
261        let indices = array.patches().indices().to_primitive()?;
262
263        if array.fill_scalar().is_null() {
264            // If we have a null fill value, then we set each patch value to true.
265            let mut buffer = BooleanBufferBuilder::new(array.len());
266            // TODO(ngates): use vortex-buffer::BitBufferMut when it exists.
267            buffer.append_n(array.len(), false);
268
269            match_each_integer_ptype!(indices.ptype(), |$I| {
270                indices.as_slice::<$I>().into_iter().for_each(|&index| {
271                    buffer.set_bit(usize::try_from(index).vortex_expect("Failed to cast to usize") - array.patches().offset(), true);
272                });
273            });
274
275            return Ok(Mask::from_buffer(buffer.finish()));
276        }
277
278        // If the fill_value is non-null, then the validity is based on the validity of the
279        // patch values.
280        let mut buffer = BooleanBufferBuilder::new(array.len());
281        buffer.append_n(array.len(), true);
282
283        let values_validity = array.patches().values().validity_mask()?;
284        match_each_integer_ptype!(indices.ptype(), |$I| {
285            indices.as_slice::<$I>()
286                .into_iter()
287                .enumerate()
288                .for_each(|(patch_idx, &index)| {
289                    buffer.set_bit(usize::try_from(index).vortex_expect("Failed to cast to usize") - array.patches().offset(), values_validity.value(patch_idx));
290                })
291        });
292
293        Ok(Mask::from_buffer(buffer.finish()))
294    }
295}
296
297#[cfg(test)]
298mod test {
299    use itertools::Itertools;
300    use vortex_array::IntoArray;
301    use vortex_array::arrays::{ConstantArray, PrimitiveArray};
302    use vortex_array::compute::cast;
303    use vortex_array::validity::Validity;
304    use vortex_buffer::buffer;
305    use vortex_dtype::{DType, Nullability, PType};
306    use vortex_error::{VortexError, VortexUnwrap};
307    use vortex_scalar::{PrimitiveScalar, Scalar};
308
309    use super::*;
310
311    fn nullable_fill() -> Scalar {
312        Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
313    }
314
315    fn non_nullable_fill() -> Scalar {
316        Scalar::from(42i32)
317    }
318
319    fn sparse_array(fill_value: Scalar) -> ArrayRef {
320        // merged array: [null, null, 100, null, null, 200, null, null, 300, null]
321        let mut values = buffer![100i32, 200, 300].into_array();
322        values = cast(&values, fill_value.dtype()).unwrap();
323
324        SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
325            .unwrap()
326            .into_array()
327    }
328
329    #[test]
330    pub fn test_scalar_at() {
331        let array = sparse_array(nullable_fill());
332
333        assert_eq!(array.scalar_at(0).unwrap(), nullable_fill());
334        assert_eq!(array.scalar_at(2).unwrap(), Scalar::from(Some(100_i32)));
335        assert_eq!(array.scalar_at(5).unwrap(), Scalar::from(Some(200_i32)));
336
337        let error = array.scalar_at(10).err().unwrap();
338        let VortexError::OutOfBounds(i, start, stop, _) = error else {
339            unreachable!()
340        };
341        assert_eq!(i, 10);
342        assert_eq!(start, 0);
343        assert_eq!(stop, 10);
344    }
345
346    #[test]
347    pub fn test_scalar_at_again() {
348        let arr = SparseArray::try_new(
349            ConstantArray::new(10u32, 1).into_array(),
350            ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
351            100,
352            Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
353        )
354        .unwrap();
355
356        assert_eq!(
357            PrimitiveScalar::try_from(&arr.scalar_at(10).unwrap())
358                .unwrap()
359                .typed_value::<u32>(),
360            Some(1234)
361        );
362        assert!(arr.scalar_at(0).unwrap().is_null());
363        assert!(arr.scalar_at(99).unwrap().is_null());
364    }
365
366    #[test]
367    pub fn scalar_at_sliced() {
368        let sliced = sparse_array(nullable_fill()).slice(2, 7).unwrap();
369        assert_eq!(usize::try_from(&sliced.scalar_at(0).unwrap()).unwrap(), 100);
370        let error = sliced.scalar_at(5).err().unwrap();
371        let VortexError::OutOfBounds(i, start, stop, _) = error else {
372            unreachable!()
373        };
374        assert_eq!(i, 5);
375        assert_eq!(start, 0);
376        assert_eq!(stop, 5);
377    }
378
379    #[test]
380    pub fn validity_mask_sliced_null_fill() {
381        let sliced = sparse_array(nullable_fill()).slice(2, 7).unwrap();
382        assert_eq!(
383            sliced.validity_mask().unwrap(),
384            Mask::from_iter(vec![true, false, false, true, false])
385        );
386    }
387
388    #[test]
389    pub fn validity_mask_sliced_nonnull_fill() {
390        let sliced = SparseArray::try_new(
391            buffer![2u64, 5, 8].into_array(),
392            ConstantArray::new(
393                Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
394                3,
395            )
396            .into_array(),
397            10,
398            Scalar::primitive(1.0f32, Nullability::Nullable),
399        )
400        .unwrap()
401        .slice(2, 7)
402        .unwrap();
403
404        assert_eq!(
405            sliced.validity_mask().unwrap(),
406            Mask::from_iter(vec![false, true, true, false, true])
407        );
408    }
409
410    #[test]
411    pub fn scalar_at_sliced_twice() {
412        let sliced_once = sparse_array(nullable_fill()).slice(1, 8).unwrap();
413        assert_eq!(
414            usize::try_from(&sliced_once.scalar_at(1).unwrap()).unwrap(),
415            100
416        );
417        let error = sliced_once.scalar_at(7).err().unwrap();
418        let VortexError::OutOfBounds(i, start, stop, _) = error else {
419            unreachable!()
420        };
421        assert_eq!(i, 7);
422        assert_eq!(start, 0);
423        assert_eq!(stop, 7);
424
425        let sliced_twice = sliced_once.slice(1, 6).unwrap();
426        assert_eq!(
427            usize::try_from(&sliced_twice.scalar_at(3).unwrap()).unwrap(),
428            200
429        );
430        let error2 = sliced_twice.scalar_at(5).err().unwrap();
431        let VortexError::OutOfBounds(i, start, stop, _) = error2 else {
432            unreachable!()
433        };
434        assert_eq!(i, 5);
435        assert_eq!(start, 0);
436        assert_eq!(stop, 5);
437    }
438
439    #[test]
440    pub fn sparse_validity_mask() {
441        let array = sparse_array(nullable_fill());
442        assert_eq!(
443            array
444                .validity_mask()
445                .unwrap()
446                .to_boolean_buffer()
447                .iter()
448                .collect_vec(),
449            [
450                false, false, true, false, false, true, false, false, true, false
451            ]
452        );
453    }
454
455    #[test]
456    fn sparse_validity_mask_non_null_fill() {
457        let array = sparse_array(non_nullable_fill());
458        assert!(array.validity_mask().unwrap().all_true());
459    }
460
461    #[test]
462    #[should_panic]
463    fn test_invalid_length() {
464        let values = buffer![15_u32, 135, 13531, 42].into_array();
465        let indices = buffer![10_u64, 11, 50, 100].into_array();
466
467        SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
468    }
469
470    #[test]
471    fn test_valid_length() {
472        let values = buffer![15_u32, 135, 13531, 42].into_array();
473        let indices = buffer![10_u64, 11, 50, 100].into_array();
474
475        SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
476    }
477
478    #[test]
479    fn encode_with_nulls() {
480        let sparse = SparseArray::encode(
481            &PrimitiveArray::new(
482                buffer![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
483                Validity::from_iter(vec![
484                    true, true, false, true, false, true, false, true, true, false, true, false,
485                ]),
486            )
487            .into_array(),
488            None,
489        )
490        .vortex_unwrap();
491        let canonical = sparse.to_primitive().vortex_unwrap();
492        assert_eq!(
493            sparse.validity_mask().unwrap(),
494            Mask::from_iter(vec![
495                true, true, false, true, false, true, false, true, true, false, true, false,
496            ])
497        );
498        assert_eq!(
499            canonical.as_slice::<i32>(),
500            vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
501        );
502    }
503}