vortex_sparse/
lib.rs

1use std::fmt::Debug;
2
3use vortex_array::arrays::{BooleanBufferBuilder, ConstantArray};
4use vortex_array::compute::{Operator, compare, fill_null, filter, scalar_at, sub_scalar};
5use vortex_array::patches::Patches;
6use vortex_array::stats::{ArrayStats, StatsSetRef};
7use vortex_array::variants::PrimitiveArrayTrait;
8use vortex_array::vtable::VTableRef;
9use vortex_array::{
10    Array, ArrayImpl, ArrayRef, ArrayStatisticsImpl, ArrayValidityImpl, Encoding, IntoArray,
11    ProstMetadata, ToCanonical, try_from_array_ref,
12};
13use vortex_buffer::Buffer;
14use vortex_dtype::{DType, Nullability, match_each_integer_ptype};
15use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
16use vortex_mask::{AllOr, Mask};
17use vortex_scalar::Scalar;
18
19use crate::serde::SparseMetadata;
20
21mod canonical;
22mod compute;
23mod serde;
24mod variants;
25
26#[derive(Clone, Debug)]
27pub struct SparseArray {
28    patches: Patches,
29    fill_value: Scalar,
30    stats_set: ArrayStats,
31}
32
33try_from_array_ref!(SparseArray);
34
35#[derive(Debug)]
36pub struct SparseEncoding;
37impl Encoding for SparseEncoding {
38    type Array = SparseArray;
39    type Metadata = ProstMetadata<SparseMetadata>;
40}
41
42impl SparseArray {
43    pub fn try_new(
44        indices: ArrayRef,
45        values: ArrayRef,
46        len: usize,
47        fill_value: Scalar,
48    ) -> VortexResult<Self> {
49        Self::try_new_with_offset(indices, values, len, 0, fill_value)
50    }
51
52    pub(crate) fn try_new_with_offset(
53        indices: ArrayRef,
54        values: ArrayRef,
55        len: usize,
56        indices_offset: usize,
57        fill_value: Scalar,
58    ) -> VortexResult<Self> {
59        if indices.len() != values.len() {
60            vortex_bail!(
61                "Mismatched indices {} and values {} length",
62                indices.len(),
63                values.len()
64            );
65        }
66
67        if !indices.is_empty() {
68            let last_index = usize::try_from(&scalar_at(&indices, indices.len() - 1)?)?;
69
70            if last_index - indices_offset >= len {
71                vortex_bail!("Array length was set to {len} but the last index is {last_index}");
72            }
73        }
74
75        let patches = Patches::new(len, indices_offset, indices, values);
76
77        Self::try_new_from_patches(patches, fill_value)
78    }
79
80    pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
81        if fill_value.dtype() != patches.values().dtype() {
82            vortex_bail!(
83                "fill value, {:?}, should be instance of values dtype, {}",
84                fill_value,
85                patches.values().dtype(),
86            );
87        }
88        Ok(Self {
89            patches,
90            fill_value,
91            stats_set: Default::default(),
92        })
93    }
94
95    #[inline]
96    pub fn patches(&self) -> &Patches {
97        &self.patches
98    }
99
100    #[inline]
101    pub fn resolved_patches(&self) -> VortexResult<Patches> {
102        let (len, offset, indices, values) = self.patches().clone().into_parts();
103        let indices_offset = Scalar::from(offset).cast(indices.dtype())?;
104        let indices = sub_scalar(&indices, indices_offset)?;
105        Ok(Patches::new(len, 0, indices, values))
106    }
107
108    #[inline]
109    pub fn fill_scalar(&self) -> &Scalar {
110        &self.fill_value
111    }
112
113    /// Encode given array as a SparseArray.
114    ///
115    /// Optionally provided fill value will be respected if the array is less than 90% null.
116    pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
117        if let Some(fill_value) = fill_value.as_ref() {
118            if array.dtype() != fill_value.dtype() {
119                vortex_bail!(
120                    "Array and fill value types must match. got {} and {}",
121                    array.dtype(),
122                    fill_value.dtype()
123                )
124            }
125        }
126        let mask = array.validity_mask()?;
127
128        if mask.all_false() {
129            // Array is constant NULL
130            return Ok(
131                ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
132            );
133        } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
134            // Array is dominated by NULL but has non-NULL values
135            let non_null_values = filter(array, &mask)?;
136            let non_null_indices = match mask.indices() {
137                AllOr::All => {
138                    // We already know that the mask is 90%+ false
139                    unreachable!("Mask is mostly null")
140                }
141                AllOr::None => {
142                    // we know there are some non-NULL values
143                    unreachable!("Mask is mostly null but not all null")
144                }
145                AllOr::Some(values) => {
146                    let buffer: Buffer<u32> = values
147                        .iter()
148                        .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
149                        .collect();
150
151                    buffer.into_array()
152                }
153            };
154
155            return Ok(SparseArray::try_new(
156                non_null_indices,
157                non_null_values,
158                array.len(),
159                Scalar::null(array.dtype().clone()),
160            )?
161            .into_array());
162        }
163
164        let fill = if let Some(fill) = fill_value {
165            fill
166        } else {
167            // TODO(robert): Support other dtypes, only thing missing is getting most common value out of the array
168            let (top_pvalue, _) = array
169                .to_primitive()?
170                .top_value()?
171                .vortex_expect("Non empty or all null array");
172
173            Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
174        };
175
176        let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
177        let non_top_mask = Mask::from_buffer(
178            fill_null(
179                &compare(array, &fill_array, Operator::NotEq)?,
180                Scalar::bool(true, Nullability::NonNullable),
181            )?
182            .to_bool()?
183            .boolean_buffer()
184            .clone(),
185        );
186
187        let non_top_values = filter(array, &non_top_mask)?;
188
189        let indices: Buffer<u64> = match non_top_mask {
190            Mask::AllTrue(count) => {
191                // all true -> complete slice
192                (0u64..count as u64).collect()
193            }
194            Mask::AllFalse(_) => {
195                // All values are equal to the top value
196                return Ok(fill_array);
197            }
198            Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
199        };
200
201        SparseArray::try_new(
202            indices.into_array(),
203            non_top_values.into_array(),
204            array.len(),
205            fill,
206        )
207        .map(|a| a.into_array())
208    }
209}
210
211impl ArrayImpl for SparseArray {
212    type Encoding = SparseEncoding;
213
214    fn _len(&self) -> usize {
215        self.patches.array_len()
216    }
217
218    fn _dtype(&self) -> &DType {
219        self.fill_value.dtype()
220    }
221
222    fn _vtable(&self) -> VTableRef {
223        VTableRef::new_ref(&SparseEncoding)
224    }
225
226    fn _with_children(&self, children: &[ArrayRef]) -> VortexResult<Self> {
227        let patch_indices = children[0].clone();
228        let patch_values = children[1].clone();
229
230        let patches = Patches::new(
231            self.patches().array_len(),
232            self.patches().offset(),
233            patch_indices,
234            patch_values,
235        );
236
237        Self::try_new_from_patches(patches, self.fill_value.clone())
238    }
239}
240
241impl ArrayStatisticsImpl for SparseArray {
242    fn _stats_ref(&self) -> StatsSetRef<'_> {
243        self.stats_set.to_ref(self)
244    }
245}
246
247impl ArrayValidityImpl for SparseArray {
248    fn _is_valid(&self, index: usize) -> VortexResult<bool> {
249        Ok(match self.patches().get_patched(index)? {
250            None => self.fill_scalar().is_valid(),
251            Some(patch_value) => patch_value.is_valid(),
252        })
253    }
254
255    fn _all_valid(&self) -> VortexResult<bool> {
256        if self.fill_scalar().is_null() {
257            // We need _all_ values to be patched, and all patches to be valid
258            return Ok(self.patches().values().len() == self.len()
259                && self.patches().values().all_valid()?);
260        }
261
262        self.patches().values().all_valid()
263    }
264
265    fn _all_invalid(&self) -> VortexResult<bool> {
266        if !self.fill_scalar().is_null() {
267            // We need _all_ values to be patched, and all patches to be invalid
268            return Ok(self.patches().values().len() == self.len()
269                && self.patches().values().all_invalid()?);
270        }
271
272        self.patches().values().all_invalid()
273    }
274
275    fn _validity_mask(&self) -> VortexResult<Mask> {
276        let indices = self.patches().indices().to_primitive()?;
277
278        if self.fill_scalar().is_null() {
279            // If we have a null fill value, then we set each patch value to true.
280            let mut buffer = BooleanBufferBuilder::new(self.len());
281            // TODO(ngates): use vortex-buffer::BitBufferMut when it exists.
282            buffer.append_n(self.len(), false);
283
284            match_each_integer_ptype!(indices.ptype(), |$I| {
285                indices.as_slice::<$I>().into_iter().for_each(|&index| {
286                    buffer.set_bit(usize::try_from(index).vortex_expect("Failed to cast to usize") - self.patches().offset(), true);
287                });
288            });
289
290            return Ok(Mask::from_buffer(buffer.finish()));
291        }
292
293        // If the fill_value is non-null, then the validity is based on the validity of the
294        // patch values.
295        let mut buffer = BooleanBufferBuilder::new(self.len());
296        buffer.append_n(self.len(), true);
297
298        let values_validity = self.patches().values().validity_mask()?;
299        match_each_integer_ptype!(indices.ptype(), |$I| {
300            indices.as_slice::<$I>()
301                .into_iter()
302                .enumerate()
303                .for_each(|(patch_idx, &index)| {
304                    buffer.set_bit(usize::try_from(index).vortex_expect("Failed to cast to usize") - self.patches().offset(), values_validity.value(patch_idx));
305                })
306        });
307
308        Ok(Mask::from_buffer(buffer.finish()))
309    }
310}
311
312#[cfg(test)]
313mod test {
314    use itertools::Itertools;
315    use vortex_array::IntoArray;
316    use vortex_array::arrays::{ConstantArray, PrimitiveArray};
317    use vortex_array::compute::{cast, slice};
318    use vortex_array::validity::Validity;
319    use vortex_buffer::buffer;
320    use vortex_dtype::{DType, Nullability, PType};
321    use vortex_error::{VortexError, VortexUnwrap};
322    use vortex_scalar::{PrimitiveScalar, Scalar};
323
324    use super::*;
325
326    fn nullable_fill() -> Scalar {
327        Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
328    }
329
330    fn non_nullable_fill() -> Scalar {
331        Scalar::from(42i32)
332    }
333
334    fn sparse_array(fill_value: Scalar) -> ArrayRef {
335        // merged array: [null, null, 100, null, null, 200, null, null, 300, null]
336        let mut values = buffer![100i32, 200, 300].into_array();
337        values = cast(&values, fill_value.dtype()).unwrap();
338
339        SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
340            .unwrap()
341            .into_array()
342    }
343
344    #[test]
345    pub fn test_scalar_at() {
346        let array = sparse_array(nullable_fill());
347
348        assert_eq!(scalar_at(&array, 0).unwrap(), nullable_fill());
349        assert_eq!(scalar_at(&array, 2).unwrap(), Scalar::from(Some(100_i32)));
350        assert_eq!(scalar_at(&array, 5).unwrap(), Scalar::from(Some(200_i32)));
351
352        let error = scalar_at(&array, 10).err().unwrap();
353        let VortexError::OutOfBounds(i, start, stop, _) = error else {
354            unreachable!()
355        };
356        assert_eq!(i, 10);
357        assert_eq!(start, 0);
358        assert_eq!(stop, 10);
359    }
360
361    #[test]
362    pub fn test_scalar_at_again() {
363        let arr = SparseArray::try_new(
364            ConstantArray::new(10u32, 1).into_array(),
365            ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
366            100,
367            Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
368        )
369        .unwrap();
370
371        assert_eq!(
372            PrimitiveScalar::try_from(&scalar_at(&arr, 10).unwrap())
373                .unwrap()
374                .typed_value::<u32>(),
375            Some(1234)
376        );
377        assert!(scalar_at(&arr, 0).unwrap().is_null());
378        assert!(scalar_at(&arr, 99).unwrap().is_null());
379    }
380
381    #[test]
382    pub fn scalar_at_sliced() {
383        let sliced = slice(&sparse_array(nullable_fill()), 2, 7).unwrap();
384        assert_eq!(
385            usize::try_from(&scalar_at(&sliced, 0).unwrap()).unwrap(),
386            100
387        );
388        let error = scalar_at(&sliced, 5).err().unwrap();
389        let VortexError::OutOfBounds(i, start, stop, _) = error else {
390            unreachable!()
391        };
392        assert_eq!(i, 5);
393        assert_eq!(start, 0);
394        assert_eq!(stop, 5);
395    }
396
397    #[test]
398    pub fn validity_mask_sliced_null_fill() {
399        let sliced = slice(&sparse_array(nullable_fill()), 2, 7).unwrap();
400        assert_eq!(
401            sliced.validity_mask().unwrap(),
402            Mask::from_iter(vec![true, false, false, true, false])
403        );
404    }
405
406    #[test]
407    pub fn validity_mask_sliced_nonnull_fill() {
408        let sliced = slice(
409            &SparseArray::try_new(
410                buffer![2u64, 5, 8].into_array(),
411                ConstantArray::new(
412                    Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
413                    3,
414                )
415                .into_array(),
416                10,
417                Scalar::primitive(1.0f32, Nullability::Nullable),
418            )
419            .unwrap()
420            .into_array(),
421            2,
422            7,
423        )
424        .unwrap();
425
426        assert_eq!(
427            sliced.validity_mask().unwrap(),
428            Mask::from_iter(vec![false, true, true, false, true])
429        );
430    }
431
432    #[test]
433    pub fn scalar_at_sliced_twice() {
434        let sliced_once = slice(&sparse_array(nullable_fill()), 1, 8).unwrap();
435        assert_eq!(
436            usize::try_from(&scalar_at(&sliced_once, 1).unwrap()).unwrap(),
437            100
438        );
439        let error = scalar_at(&sliced_once, 7).err().unwrap();
440        let VortexError::OutOfBounds(i, start, stop, _) = error else {
441            unreachable!()
442        };
443        assert_eq!(i, 7);
444        assert_eq!(start, 0);
445        assert_eq!(stop, 7);
446
447        let sliced_twice = slice(&sliced_once, 1, 6).unwrap();
448        assert_eq!(
449            usize::try_from(&scalar_at(&sliced_twice, 3).unwrap()).unwrap(),
450            200
451        );
452        let error2 = scalar_at(&sliced_twice, 5).err().unwrap();
453        let VortexError::OutOfBounds(i, start, stop, _) = error2 else {
454            unreachable!()
455        };
456        assert_eq!(i, 5);
457        assert_eq!(start, 0);
458        assert_eq!(stop, 5);
459    }
460
461    #[test]
462    pub fn sparse_validity_mask() {
463        let array = sparse_array(nullable_fill());
464        assert_eq!(
465            array
466                .validity_mask()
467                .unwrap()
468                .to_boolean_buffer()
469                .iter()
470                .collect_vec(),
471            [
472                false, false, true, false, false, true, false, false, true, false
473            ]
474        );
475    }
476
477    #[test]
478    fn sparse_validity_mask_non_null_fill() {
479        let array = sparse_array(non_nullable_fill());
480        assert!(array.validity_mask().unwrap().all_true());
481    }
482
483    #[test]
484    #[should_panic]
485    fn test_invalid_length() {
486        let values = buffer![15_u32, 135, 13531, 42].into_array();
487        let indices = buffer![10_u64, 11, 50, 100].into_array();
488
489        SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
490    }
491
492    #[test]
493    fn test_valid_length() {
494        let values = buffer![15_u32, 135, 13531, 42].into_array();
495        let indices = buffer![10_u64, 11, 50, 100].into_array();
496
497        SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
498    }
499
500    #[test]
501    fn encode_with_nulls() {
502        let sparse = SparseArray::encode(
503            &PrimitiveArray::new(
504                buffer![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
505                Validity::from_iter(vec![
506                    true, true, false, true, false, true, false, true, true, false, true, false,
507                ]),
508            )
509            .into_array(),
510            None,
511        )
512        .vortex_unwrap();
513        let canonical = sparse.to_primitive().vortex_unwrap();
514        assert_eq!(
515            sparse.validity_mask().unwrap(),
516            Mask::from_iter(vec![
517                true, true, false, true, false, true, false, true, true, false, true, false,
518            ])
519        );
520        assert_eq!(
521            canonical.as_slice::<i32>(),
522            vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
523        );
524    }
525}