vortex_sparse/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use itertools::Itertools as _;
7use num_traits::NumCast;
8use vortex_array::arrays::{BooleanBufferBuilder, ConstantArray};
9use vortex_array::compute::{Operator, compare, fill_null, filter, sub_scalar};
10use vortex_array::patches::Patches;
11use vortex_array::stats::{ArrayStats, StatsSetRef};
12use vortex_array::vtable::{ArrayVTable, NotSupported, VTable, ValidityVTable};
13use vortex_array::{Array, ArrayRef, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
14use vortex_buffer::Buffer;
15use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
16use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_ensure};
17use vortex_mask::{AllOr, Mask};
18use vortex_scalar::Scalar;
19
20mod canonical;
21mod compute;
22mod ops;
23mod serde;
24
25vtable!(Sparse);
26
27impl VTable for SparseVTable {
28    type Array = SparseArray;
29    type Encoding = SparseEncoding;
30
31    type ArrayVTable = Self;
32    type CanonicalVTable = Self;
33    type OperationsVTable = Self;
34    type ValidityVTable = Self;
35    type VisitorVTable = Self;
36    type ComputeVTable = NotSupported;
37    type EncodeVTable = Self;
38    type SerdeVTable = Self;
39    type PipelineVTable = NotSupported;
40
41    fn id(_encoding: &Self::Encoding) -> EncodingId {
42        EncodingId::new_ref("vortex.sparse")
43    }
44
45    fn encoding(_array: &Self::Array) -> EncodingRef {
46        EncodingRef::new_ref(SparseEncoding.as_ref())
47    }
48}
49
50#[derive(Clone, Debug)]
51pub struct SparseArray {
52    patches: Patches,
53    fill_value: Scalar,
54    stats_set: ArrayStats,
55}
56
57#[derive(Clone, Debug)]
58pub struct SparseEncoding;
59
60impl SparseArray {
61    pub fn try_new(
62        indices: ArrayRef,
63        values: ArrayRef,
64        len: usize,
65        fill_value: Scalar,
66    ) -> VortexResult<Self> {
67        vortex_ensure!(
68            indices.len() == values.len(),
69            "Mismatched indices {} and values {} length",
70            indices.len(),
71            values.len()
72        );
73
74        vortex_ensure!(
75            indices.statistics().compute_is_strict_sorted() == Some(true),
76            "SparseArray: indices must be strict-sorted"
77        );
78
79        // Verify the indices are all in the valid range
80        if !indices.is_empty() {
81            let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1))?;
82
83            vortex_ensure!(
84                last_index < len,
85                "Array length was {len} but the last index is {last_index}"
86            );
87        }
88
89        let patches = Patches::new(len, 0, indices, values);
90
91        Ok(Self {
92            patches,
93            fill_value,
94            stats_set: Default::default(),
95        })
96    }
97
98    /// Build a new SparseArray from an existing set of patches.
99    pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
100        vortex_ensure!(
101            fill_value.dtype() == patches.values().dtype(),
102            "fill value, {:?}, should be instance of values dtype, {} but was {}.",
103            fill_value,
104            patches.values().dtype(),
105            fill_value.dtype(),
106        );
107
108        Ok(Self {
109            patches,
110            fill_value,
111            stats_set: Default::default(),
112        })
113    }
114
115    pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
116        Self {
117            patches,
118            fill_value,
119            stats_set: Default::default(),
120        }
121    }
122
123    #[inline]
124    pub fn patches(&self) -> &Patches {
125        &self.patches
126    }
127
128    #[inline]
129    pub fn resolved_patches(&self) -> VortexResult<Patches> {
130        let patches = self.patches();
131        let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?;
132        let indices = sub_scalar(patches.indices(), indices_offset)?;
133        Ok(Patches::new(
134            patches.array_len(),
135            0,
136            indices,
137            patches.values().clone(),
138        ))
139    }
140
141    #[inline]
142    pub fn fill_scalar(&self) -> &Scalar {
143        &self.fill_value
144    }
145
146    /// Encode given array as a SparseArray.
147    ///
148    /// Optionally provided fill value will be respected if the array is less than 90% null.
149    pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
150        if let Some(fill_value) = fill_value.as_ref()
151            && array.dtype() != fill_value.dtype()
152        {
153            vortex_bail!(
154                "Array and fill value types must match. got {} and {}",
155                array.dtype(),
156                fill_value.dtype()
157            )
158        }
159        let mask = array.validity_mask()?;
160
161        if mask.all_false() {
162            // Array is constant NULL
163            return Ok(
164                ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
165            );
166        } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
167            // Array is dominated by NULL but has non-NULL values
168            let non_null_values = filter(array, &mask)?;
169            let non_null_indices = match mask.indices() {
170                AllOr::All => {
171                    // We already know that the mask is 90%+ false
172                    unreachable!("Mask is mostly null")
173                }
174                AllOr::None => {
175                    // we know there are some non-NULL values
176                    unreachable!("Mask is mostly null but not all null")
177                }
178                AllOr::Some(values) => {
179                    let buffer: Buffer<u32> = values
180                        .iter()
181                        .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
182                        .collect();
183
184                    buffer.into_array()
185                }
186            };
187
188            return Ok(SparseArray::try_new(
189                non_null_indices,
190                non_null_values,
191                array.len(),
192                Scalar::null(array.dtype().clone()),
193            )?
194            .into_array());
195        }
196
197        let fill = if let Some(fill) = fill_value {
198            fill
199        } else {
200            // TODO(robert): Support other dtypes, only thing missing is getting most common value out of the array
201            let (top_pvalue, _) = array
202                .to_primitive()?
203                .top_value()?
204                .vortex_expect("Non empty or all null array");
205
206            Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
207        };
208
209        let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
210        let non_top_mask = Mask::from_buffer(
211            fill_null(
212                &compare(array, &fill_array, Operator::NotEq)?,
213                &Scalar::bool(true, Nullability::NonNullable),
214            )?
215            .to_bool()?
216            .boolean_buffer()
217            .clone(),
218        );
219
220        let non_top_values = filter(array, &non_top_mask)?;
221
222        let indices: Buffer<u64> = match non_top_mask {
223            Mask::AllTrue(count) => {
224                // all true -> complete slice
225                (0u64..count as u64).collect()
226            }
227            Mask::AllFalse(_) => {
228                // All values are equal to the top value
229                return Ok(fill_array);
230            }
231            Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
232        };
233
234        SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
235            .map(|a| a.into_array())
236    }
237}
238
239impl ArrayVTable<SparseVTable> for SparseVTable {
240    fn len(array: &SparseArray) -> usize {
241        array.patches.array_len()
242    }
243
244    fn dtype(array: &SparseArray) -> &DType {
245        array.fill_scalar().dtype()
246    }
247
248    fn stats(array: &SparseArray) -> StatsSetRef<'_> {
249        array.stats_set.to_ref(array.as_ref())
250    }
251}
252
253impl ValidityVTable<SparseVTable> for SparseVTable {
254    fn is_valid(array: &SparseArray, index: usize) -> VortexResult<bool> {
255        Ok(match array.patches().get_patched(index) {
256            None => array.fill_scalar().is_valid(),
257            Some(patch_value) => patch_value.is_valid(),
258        })
259    }
260
261    fn all_valid(array: &SparseArray) -> VortexResult<bool> {
262        if array.fill_scalar().is_null() {
263            // We need _all_ values to be patched, and all patches to be valid
264            return Ok(array.patches().values().len() == array.len()
265                && array.patches().values().all_valid()?);
266        }
267
268        array.patches().values().all_valid()
269    }
270
271    fn all_invalid(array: &SparseArray) -> VortexResult<bool> {
272        if !array.fill_scalar().is_null() {
273            // We need _all_ values to be patched, and all patches to be invalid
274            return Ok(array.patches().values().len() == array.len()
275                && array.patches().values().all_invalid()?);
276        }
277
278        array.patches().values().all_invalid()
279    }
280
281    #[allow(clippy::unnecessary_fallible_conversions)]
282    fn validity_mask(array: &SparseArray) -> VortexResult<Mask> {
283        let fill_is_valid = array.fill_scalar().is_valid();
284        let values_validity = array.patches().values().validity_mask()?;
285        let len = array.len();
286
287        if matches!(values_validity, Mask::AllTrue(_)) && fill_is_valid {
288            return Ok(Mask::AllTrue(len));
289        }
290        if matches!(values_validity, Mask::AllFalse(_)) && !fill_is_valid {
291            return Ok(Mask::AllFalse(len));
292        }
293
294        // TODO(ngates): use vortex-buffer::BitBufferMut when it exists.
295        let mut is_valid_buffer = BooleanBufferBuilder::new(len);
296        is_valid_buffer.append_n(len, fill_is_valid);
297
298        let indices = array.patches().indices().to_primitive()?;
299        let index_offset = array.patches().offset();
300
301        match_each_integer_ptype!(indices.ptype(), |I| {
302            let indices = indices.as_slice::<I>();
303            patch_validity(&mut is_valid_buffer, indices, index_offset, values_validity);
304        });
305
306        Ok(Mask::from_buffer(is_valid_buffer.finish()))
307    }
308}
309
310fn patch_validity<I: NativePType>(
311    is_valid_buffer: &mut BooleanBufferBuilder,
312    indices: &[I],
313    index_offset: usize,
314    values_validity: Mask,
315) {
316    let indices = indices.iter().map(|index| {
317        let index = <usize as NumCast>::from(*index).vortex_expect("Failed to cast to usize");
318        index - index_offset
319    });
320    match values_validity {
321        Mask::AllTrue(_) => {
322            for index in indices {
323                is_valid_buffer.set_bit(index, true);
324            }
325        }
326        Mask::AllFalse(_) => {
327            for index in indices {
328                is_valid_buffer.set_bit(index, false);
329            }
330        }
331        Mask::Values(mask_values) => {
332            let is_valid = mask_values.boolean_buffer().iter();
333            for (index, is_valid) in indices.zip_eq(is_valid) {
334                is_valid_buffer.set_bit(index, is_valid);
335            }
336        }
337    }
338}
339
340#[cfg(test)]
341mod test {
342    use itertools::Itertools;
343    use vortex_array::IntoArray;
344    use vortex_array::arrays::{ConstantArray, PrimitiveArray};
345    use vortex_array::compute::cast;
346    use vortex_array::validity::Validity;
347    use vortex_buffer::buffer;
348    use vortex_dtype::{DType, Nullability, PType};
349    use vortex_error::VortexUnwrap;
350    use vortex_scalar::{PrimitiveScalar, Scalar};
351
352    use super::*;
353
354    fn nullable_fill() -> Scalar {
355        Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
356    }
357
358    fn non_nullable_fill() -> Scalar {
359        Scalar::from(42i32)
360    }
361
362    fn sparse_array(fill_value: Scalar) -> ArrayRef {
363        // merged array: [null, null, 100, null, null, 200, null, null, 300, null]
364        let mut values = buffer![100i32, 200, 300].into_array();
365        values = cast(&values, fill_value.dtype()).unwrap();
366
367        SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
368            .unwrap()
369            .into_array()
370    }
371
372    #[test]
373    pub fn test_scalar_at() {
374        let array = sparse_array(nullable_fill());
375
376        assert_eq!(array.scalar_at(0), nullable_fill());
377        assert_eq!(array.scalar_at(2), Scalar::from(Some(100_i32)));
378        assert_eq!(array.scalar_at(5), Scalar::from(Some(200_i32)));
379    }
380
381    #[test]
382    #[should_panic(expected = "out of bounds")]
383    fn test_scalar_at_oob() {
384        let array = sparse_array(nullable_fill());
385        let _ = array.scalar_at(10);
386    }
387
388    #[test]
389    pub fn test_scalar_at_again() {
390        let arr = SparseArray::try_new(
391            ConstantArray::new(10u32, 1).into_array(),
392            ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
393            100,
394            Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
395        )
396        .unwrap();
397
398        assert_eq!(
399            PrimitiveScalar::try_from(&arr.scalar_at(10))
400                .unwrap()
401                .typed_value::<u32>(),
402            Some(1234)
403        );
404        assert!(arr.scalar_at(0).is_null());
405        assert!(arr.scalar_at(99).is_null());
406    }
407
408    #[test]
409    pub fn scalar_at_sliced() {
410        let sliced = sparse_array(nullable_fill()).slice(2, 7);
411        assert_eq!(usize::try_from(&sliced.scalar_at(0)).unwrap(), 100);
412    }
413
414    #[test]
415    pub fn validity_mask_sliced_null_fill() {
416        let sliced = sparse_array(nullable_fill()).slice(2, 7);
417        assert_eq!(
418            sliced.validity_mask().unwrap(),
419            Mask::from_iter(vec![true, false, false, true, false])
420        );
421    }
422
423    #[test]
424    pub fn validity_mask_sliced_nonnull_fill() {
425        let sliced = SparseArray::try_new(
426            buffer![2u64, 5, 8].into_array(),
427            ConstantArray::new(
428                Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
429                3,
430            )
431            .into_array(),
432            10,
433            Scalar::primitive(1.0f32, Nullability::Nullable),
434        )
435        .unwrap()
436        .slice(2, 7);
437
438        assert_eq!(
439            sliced.validity_mask().unwrap(),
440            Mask::from_iter(vec![false, true, true, false, true])
441        );
442    }
443
444    #[test]
445    pub fn scalar_at_sliced_twice() {
446        let sliced_once = sparse_array(nullable_fill()).slice(1, 8);
447        assert_eq!(usize::try_from(&sliced_once.scalar_at(1)).unwrap(), 100);
448
449        let sliced_twice = sliced_once.slice(1, 6);
450        assert_eq!(usize::try_from(&sliced_twice.scalar_at(3)).unwrap(), 200);
451    }
452
453    #[test]
454    pub fn sparse_validity_mask() {
455        let array = sparse_array(nullable_fill());
456        assert_eq!(
457            array
458                .validity_mask()
459                .unwrap()
460                .to_boolean_buffer()
461                .iter()
462                .collect_vec(),
463            [
464                false, false, true, false, false, true, false, false, true, false
465            ]
466        );
467    }
468
469    #[test]
470    fn sparse_validity_mask_non_null_fill() {
471        let array = sparse_array(non_nullable_fill());
472        assert!(array.validity_mask().unwrap().all_true());
473    }
474
475    #[test]
476    #[should_panic]
477    fn test_invalid_length() {
478        let values = buffer![15_u32, 135, 13531, 42].into_array();
479        let indices = buffer![10_u64, 11, 50, 100].into_array();
480
481        SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
482    }
483
484    #[test]
485    fn test_valid_length() {
486        let values = buffer![15_u32, 135, 13531, 42].into_array();
487        let indices = buffer![10_u64, 11, 50, 100].into_array();
488
489        SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
490    }
491
492    #[test]
493    fn encode_with_nulls() {
494        let sparse = SparseArray::encode(
495            &PrimitiveArray::new(
496                buffer![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
497                Validity::from_iter(vec![
498                    true, true, false, true, false, true, false, true, true, false, true, false,
499                ]),
500            )
501            .into_array(),
502            None,
503        )
504        .vortex_unwrap();
505        let canonical = sparse.to_primitive().vortex_unwrap();
506        assert_eq!(
507            sparse.validity_mask().unwrap(),
508            Mask::from_iter(vec![
509                true, true, false, true, false, true, false, true, true, false, true, false,
510            ])
511        );
512        assert_eq!(
513            canonical.as_slice::<i32>(),
514            vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
515        );
516    }
517
518    #[test]
519    fn validity_mask_includes_null_values_when_fill_is_null() {
520        let indices = buffer![0u8, 2, 4, 6, 8].into_array();
521        let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
522            .into_array();
523        let array = SparseArray::try_new(indices, values, 10, Scalar::null_typed::<i16>()).unwrap();
524        let actual = array.validity_mask().unwrap();
525        let expected = Mask::from_iter([
526            true, false, true, false, false, false, false, false, true, false,
527        ]);
528
529        assert_eq!(actual, expected);
530    }
531}