Skip to main content

vortex_sparse/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6
7use kernel::PARENT_KERNELS;
8use prost::Message as _;
9use vortex_array::ArrayEq;
10use vortex_array::ArrayHash;
11use vortex_array::ArrayRef;
12use vortex_array::DynArray;
13use vortex_array::ExecutionCtx;
14use vortex_array::ExecutionStep;
15use vortex_array::IntoArray;
16use vortex_array::Precision;
17use vortex_array::ToCanonical;
18use vortex_array::arrays::ConstantArray;
19use vortex_array::buffer::BufferHandle;
20use vortex_array::builtins::ArrayBuiltins;
21use vortex_array::dtype::DType;
22use vortex_array::dtype::Nullability;
23use vortex_array::patches::Patches;
24use vortex_array::patches::PatchesMetadata;
25use vortex_array::scalar::Scalar;
26use vortex_array::scalar::ScalarValue;
27use vortex_array::scalar_fn::fns::operators::Operator;
28use vortex_array::serde::ArrayChildren;
29use vortex_array::stats::ArrayStats;
30use vortex_array::stats::StatsSetRef;
31use vortex_array::validity::Validity;
32use vortex_array::vtable;
33use vortex_array::vtable::ArrayId;
34use vortex_array::vtable::VTable;
35use vortex_array::vtable::ValidityVTable;
36use vortex_array::vtable::patches_child;
37use vortex_array::vtable::patches_child_name;
38use vortex_array::vtable::patches_nchildren;
39use vortex_buffer::Buffer;
40use vortex_buffer::ByteBufferMut;
41use vortex_error::VortexExpect as _;
42use vortex_error::VortexResult;
43use vortex_error::vortex_bail;
44use vortex_error::vortex_ensure;
45use vortex_error::vortex_ensure_eq;
46use vortex_error::vortex_panic;
47use vortex_mask::AllOr;
48use vortex_mask::Mask;
49use vortex_session::VortexSession;
50
51use crate::canonical::execute_sparse;
52use crate::rules::RULES;
53
54mod canonical;
55mod compute;
56mod kernel;
57mod ops;
58mod rules;
59mod slice;
60
61vtable!(Sparse);
62
63#[derive(Debug)]
64pub struct SparseMetadata {
65    patches: PatchesMetadata,
66    fill_value: Scalar,
67}
68
69#[derive(Clone, prost::Message)]
70#[repr(C)]
71pub struct ProstPatchesMetadata {
72    #[prost(message, required, tag = "1")]
73    patches: PatchesMetadata,
74}
75
76impl VTable for SparseVTable {
77    type Array = SparseArray;
78
79    type Metadata = SparseMetadata;
80    type OperationsVTable = Self;
81    type ValidityVTable = Self;
82
83    fn id(_array: &Self::Array) -> ArrayId {
84        Self::ID
85    }
86
87    fn len(array: &SparseArray) -> usize {
88        array.patches.array_len()
89    }
90
91    fn dtype(array: &SparseArray) -> &DType {
92        array.fill_scalar().dtype()
93    }
94
95    fn stats(array: &SparseArray) -> StatsSetRef<'_> {
96        array.stats_set.to_ref(array.as_ref())
97    }
98
99    fn array_hash<H: std::hash::Hasher>(array: &SparseArray, state: &mut H, precision: Precision) {
100        array.patches.array_hash(state, precision);
101        array.fill_value.hash(state);
102    }
103
104    fn array_eq(array: &SparseArray, other: &SparseArray, precision: Precision) -> bool {
105        array.patches.array_eq(&other.patches, precision) && array.fill_value == other.fill_value
106    }
107
108    fn nbuffers(_array: &SparseArray) -> usize {
109        1
110    }
111
112    fn buffer(array: &SparseArray, idx: usize) -> BufferHandle {
113        match idx {
114            0 => {
115                let fill_value_buffer =
116                    ScalarValue::to_proto_bytes::<ByteBufferMut>(array.fill_value.value()).freeze();
117                BufferHandle::new_host(fill_value_buffer)
118            }
119            _ => vortex_panic!("SparseArray buffer index {idx} out of bounds"),
120        }
121    }
122
123    fn buffer_name(_array: &SparseArray, idx: usize) -> Option<String> {
124        match idx {
125            0 => Some("fill_value".to_string()),
126            _ => vortex_panic!("SparseArray buffer_name index {idx} out of bounds"),
127        }
128    }
129
130    fn nchildren(array: &SparseArray) -> usize {
131        patches_nchildren(array.patches())
132    }
133
134    fn child(array: &SparseArray, idx: usize) -> ArrayRef {
135        patches_child(array.patches(), idx)
136    }
137
138    fn child_name(_array: &SparseArray, idx: usize) -> String {
139        patches_child_name(idx).to_string()
140    }
141
142    fn metadata(array: &SparseArray) -> VortexResult<Self::Metadata> {
143        let patches = array.patches().to_metadata(array.len(), array.dtype())?;
144
145        Ok(SparseMetadata {
146            patches,
147            fill_value: array.fill_value.clone(),
148        })
149    }
150
151    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
152        let prost_patches = ProstPatchesMetadata {
153            patches: metadata.patches,
154        };
155
156        // Note that we DO NOT serialize the fill value since that is stored in the buffers.
157        Ok(Some(prost_patches.encode_to_vec()))
158    }
159
160    fn deserialize(
161        bytes: &[u8],
162        dtype: &DType,
163        _len: usize,
164        buffers: &[BufferHandle],
165        session: &VortexSession,
166    ) -> VortexResult<Self::Metadata> {
167        let prost_patches = ProstPatchesMetadata::decode(bytes)?;
168
169        // Once we have the patches metadata, we need to get the fill value from the buffers.
170
171        if buffers.len() != 1 {
172            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
173        }
174        let scalar_bytes: &[u8] = &buffers[0].clone().try_to_host_sync()?;
175
176        let scalar_value = ScalarValue::from_proto_bytes(scalar_bytes, dtype, session)?;
177        let fill_value = Scalar::try_new(dtype.clone(), scalar_value)?;
178
179        Ok(SparseMetadata {
180            patches: prost_patches.patches,
181            fill_value,
182        })
183    }
184
185    fn build(
186        dtype: &DType,
187        len: usize,
188        metadata: &Self::Metadata,
189        _buffers: &[BufferHandle],
190        children: &dyn ArrayChildren,
191    ) -> VortexResult<SparseArray> {
192        vortex_ensure_eq!(
193            children.len(),
194            2,
195            "SparseArray expects 2 children for sparse encoding, found {}",
196            children.len()
197        );
198        vortex_ensure_eq!(
199            metadata.patches.offset()?,
200            0,
201            "Patches must start at offset 0"
202        );
203
204        let patch_indices = children.get(
205            0,
206            &metadata.patches.indices_dtype()?,
207            metadata.patches.len()?,
208        )?;
209        let patch_values = children.get(1, dtype, metadata.patches.len()?)?;
210
211        SparseArray::try_new(
212            patch_indices,
213            patch_values,
214            len,
215            metadata.fill_value.clone(),
216        )
217    }
218
219    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
220        vortex_ensure_eq!(
221            children.len(),
222            2,
223            "SparseArray expects 2 children, got {}",
224            children.len()
225        );
226
227        let mut children_iter = children.into_iter();
228        let patch_indices = children_iter.next().vortex_expect("patch_indices child");
229        let patch_values = children_iter.next().vortex_expect("patch_values child");
230
231        array.patches = Patches::new(
232            array.patches.array_len(),
233            array.patches.offset(),
234            patch_indices,
235            patch_values,
236            array.patches.chunk_offsets().clone(),
237        )?;
238
239        Ok(())
240    }
241
242    fn reduce_parent(
243        array: &Self::Array,
244        parent: &ArrayRef,
245        child_idx: usize,
246    ) -> VortexResult<Option<ArrayRef>> {
247        RULES.evaluate(array, parent, child_idx)
248    }
249
250    fn execute_parent(
251        array: &Self::Array,
252        parent: &ArrayRef,
253        child_idx: usize,
254        ctx: &mut ExecutionCtx,
255    ) -> VortexResult<Option<ArrayRef>> {
256        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
257    }
258
259    fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionStep> {
260        execute_sparse(array, ctx).map(ExecutionStep::Done)
261    }
262}
263
264#[derive(Clone, Debug)]
265pub struct SparseArray {
266    patches: Patches,
267    fill_value: Scalar,
268    stats_set: ArrayStats,
269}
270
271#[derive(Debug)]
272pub struct SparseVTable;
273
274impl SparseVTable {
275    pub const ID: ArrayId = ArrayId::new_ref("vortex.sparse");
276}
277
278impl SparseArray {
279    pub fn try_new(
280        indices: ArrayRef,
281        values: ArrayRef,
282        len: usize,
283        fill_value: Scalar,
284    ) -> VortexResult<Self> {
285        vortex_ensure!(
286            indices.len() == values.len(),
287            "Mismatched indices {} and values {} length",
288            indices.len(),
289            values.len()
290        );
291
292        if indices.is_host() {
293            debug_assert_eq!(
294                indices.statistics().compute_is_strict_sorted(),
295                Some(true),
296                "SparseArray: indices must be strict-sorted"
297            );
298
299            // Verify the indices are all in the valid range
300            if !indices.is_empty() {
301                let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1)?)?;
302
303                vortex_ensure!(
304                    last_index < len,
305                    "Array length was {len} but the last index is {last_index}"
306                );
307            }
308        }
309
310        Ok(Self {
311            // TODO(0ax1): handle chunk offsets
312            patches: Patches::new(len, 0, indices, values, None)?,
313            fill_value,
314            stats_set: Default::default(),
315        })
316    }
317
318    /// Build a new SparseArray from an existing set of patches.
319    pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
320        vortex_ensure!(
321            fill_value.dtype() == patches.values().dtype(),
322            "fill value, {:?}, should be instance of values dtype, {} but was {}.",
323            fill_value,
324            patches.values().dtype(),
325            fill_value.dtype(),
326        );
327
328        Ok(Self {
329            patches,
330            fill_value,
331            stats_set: Default::default(),
332        })
333    }
334
335    pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
336        Self {
337            patches,
338            fill_value,
339            stats_set: Default::default(),
340        }
341    }
342
343    #[inline]
344    pub fn patches(&self) -> &Patches {
345        &self.patches
346    }
347
348    #[inline]
349    pub fn resolved_patches(&self) -> VortexResult<Patches> {
350        let patches = self.patches();
351        let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?;
352        let indices = patches.indices().to_array().binary(
353            ConstantArray::new(indices_offset, patches.indices().len()).into_array(),
354            Operator::Sub,
355        )?;
356
357        Patches::new(
358            patches.array_len(),
359            0,
360            indices,
361            patches.values().clone(),
362            // TODO(0ax1): handle chunk offsets
363            None,
364        )
365    }
366
367    #[inline]
368    pub fn fill_scalar(&self) -> &Scalar {
369        &self.fill_value
370    }
371
372    /// Encode given array as a SparseArray.
373    ///
374    /// Optionally provided fill value will be respected if the array is less than 90% null.
375    pub fn encode(array: &ArrayRef, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
376        if let Some(fill_value) = fill_value.as_ref()
377            && array.dtype() != fill_value.dtype()
378        {
379            vortex_bail!(
380                "Array and fill value types must match. got {} and {}",
381                array.dtype(),
382                fill_value.dtype()
383            )
384        }
385        let mask = array.validity_mask()?;
386
387        if mask.all_false() {
388            // Array is constant NULL
389            return Ok(
390                ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
391            );
392        } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
393            // Array is dominated by NULL but has non-NULL values
394            // TODO(joe): use exe ctx?
395            let non_null_values = array.filter(mask.clone())?.to_canonical()?.into_array();
396            let non_null_indices = match mask.indices() {
397                AllOr::All => {
398                    // We already know that the mask is 90%+ false
399                    unreachable!("Mask is mostly null")
400                }
401                AllOr::None => {
402                    // we know there are some non-NULL values
403                    unreachable!("Mask is mostly null but not all null")
404                }
405                AllOr::Some(values) => {
406                    let buffer: Buffer<u32> = values
407                        .iter()
408                        .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
409                        .collect();
410
411                    buffer.into_array()
412                }
413            };
414
415            return Ok(SparseArray::try_new(
416                non_null_indices,
417                non_null_values,
418                array.len(),
419                Scalar::null(array.dtype().clone()),
420            )?
421            .into_array());
422        }
423
424        let fill = if let Some(fill) = fill_value {
425            fill
426        } else {
427            // TODO(robert): Support other dtypes, only thing missing is getting most common value out of the array
428            let (top_pvalue, _) = array
429                .to_primitive()
430                .top_value()?
431                .vortex_expect("Non empty or all null array");
432
433            Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
434        };
435
436        let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
437        let non_top_mask = Mask::from_buffer(
438            array
439                .to_array()
440                .binary(fill_array.clone(), Operator::NotEq)?
441                .fill_null(Scalar::bool(true, Nullability::NonNullable))?
442                .to_bool()
443                .to_bit_buffer(),
444        );
445
446        let non_top_values = array
447            .filter(non_top_mask.clone())?
448            .to_canonical()?
449            .into_array();
450
451        let indices: Buffer<u64> = match non_top_mask {
452            Mask::AllTrue(count) => {
453                // all true -> complete slice
454                (0u64..count as u64).collect()
455            }
456            Mask::AllFalse(_) => {
457                // All values are equal to the top value
458                return Ok(fill_array);
459            }
460            Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
461        };
462
463        SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
464            .map(|a| a.into_array())
465    }
466}
467
468impl ValidityVTable<SparseVTable> for SparseVTable {
469    fn validity(array: &SparseArray) -> VortexResult<Validity> {
470        let patches = unsafe {
471            Patches::new_unchecked(
472                array.patches.array_len(),
473                array.patches.offset(),
474                array.patches.indices().clone(),
475                array
476                    .patches
477                    .values()
478                    .validity()?
479                    .to_array(array.patches.values().len()),
480                array.patches.chunk_offsets().clone(),
481                array.patches.offset_within_chunk(),
482            )
483        };
484
485        Ok(Validity::Array(
486            unsafe { SparseArray::new_unchecked(patches, array.fill_value.is_valid().into()) }
487                .into_array(),
488        ))
489    }
490}
491
492#[cfg(test)]
493mod test {
494    use itertools::Itertools;
495    use vortex_array::DynArray;
496    use vortex_array::IntoArray;
497    use vortex_array::arrays::ConstantArray;
498    use vortex_array::arrays::PrimitiveArray;
499    use vortex_array::assert_arrays_eq;
500    use vortex_array::builtins::ArrayBuiltins;
501    use vortex_array::dtype::DType;
502    use vortex_array::dtype::Nullability;
503    use vortex_array::dtype::PType;
504    use vortex_array::scalar::Scalar;
505    use vortex_array::validity::Validity;
506    use vortex_buffer::buffer;
507    use vortex_error::VortexExpect;
508
509    use super::*;
510
511    fn nullable_fill() -> Scalar {
512        Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
513    }
514
515    fn non_nullable_fill() -> Scalar {
516        Scalar::from(42i32)
517    }
518
519    fn sparse_array(fill_value: Scalar) -> ArrayRef {
520        // merged array: [null, null, 100, null, null, 200, null, null, 300, null]
521        let mut values = buffer![100i32, 200, 300].into_array();
522        values = values.cast(fill_value.dtype().clone()).unwrap();
523
524        SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
525            .unwrap()
526            .into_array()
527    }
528
529    #[test]
530    pub fn test_scalar_at() {
531        let array = sparse_array(nullable_fill());
532
533        assert_eq!(array.scalar_at(0).unwrap(), nullable_fill());
534        assert_eq!(array.scalar_at(2).unwrap(), Scalar::from(Some(100_i32)));
535        assert_eq!(array.scalar_at(5).unwrap(), Scalar::from(Some(200_i32)));
536    }
537
538    #[test]
539    #[should_panic(expected = "out of bounds")]
540    fn test_scalar_at_oob() {
541        let array = sparse_array(nullable_fill());
542        array.scalar_at(10).unwrap();
543    }
544
545    #[test]
546    pub fn test_scalar_at_again() {
547        let arr = SparseArray::try_new(
548            ConstantArray::new(10u32, 1).into_array(),
549            ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
550            100,
551            Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
552        )
553        .unwrap();
554
555        assert_eq!(
556            arr.scalar_at(10)
557                .unwrap()
558                .as_primitive()
559                .typed_value::<u32>(),
560            Some(1234)
561        );
562        assert!(arr.scalar_at(0).unwrap().is_null());
563        assert!(arr.scalar_at(99).unwrap().is_null());
564    }
565
566    #[test]
567    pub fn scalar_at_sliced() {
568        let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
569        assert_eq!(usize::try_from(&sliced.scalar_at(0).unwrap()).unwrap(), 100);
570    }
571
572    #[test]
573    pub fn validity_mask_sliced_null_fill() {
574        let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
575        assert_eq!(
576            sliced.validity_mask().unwrap(),
577            Mask::from_iter(vec![true, false, false, true, false])
578        );
579    }
580
581    #[test]
582    pub fn validity_mask_sliced_nonnull_fill() {
583        let sliced = SparseArray::try_new(
584            buffer![2u64, 5, 8].into_array(),
585            ConstantArray::new(
586                Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
587                3,
588            )
589            .into_array(),
590            10,
591            Scalar::primitive(1.0f32, Nullability::Nullable),
592        )
593        .unwrap()
594        .slice(2..7)
595        .unwrap();
596
597        assert_eq!(
598            sliced.validity_mask().unwrap(),
599            Mask::from_iter(vec![false, true, true, false, true])
600        );
601    }
602
603    #[test]
604    pub fn scalar_at_sliced_twice() {
605        let sliced_once = sparse_array(nullable_fill()).slice(1..8).unwrap();
606        assert_eq!(
607            usize::try_from(&sliced_once.scalar_at(1).unwrap()).unwrap(),
608            100
609        );
610
611        let sliced_twice = sliced_once.slice(1..6).unwrap();
612        assert_eq!(
613            usize::try_from(&sliced_twice.scalar_at(3).unwrap()).unwrap(),
614            200
615        );
616    }
617
618    #[test]
619    pub fn sparse_validity_mask() {
620        let array = sparse_array(nullable_fill());
621        assert_eq!(
622            array
623                .validity_mask()
624                .unwrap()
625                .to_bit_buffer()
626                .iter()
627                .collect_vec(),
628            [
629                false, false, true, false, false, true, false, false, true, false
630            ]
631        );
632    }
633
634    #[test]
635    fn sparse_validity_mask_non_null_fill() {
636        let array = sparse_array(non_nullable_fill());
637        assert!(array.validity_mask().unwrap().all_true());
638    }
639
640    #[test]
641    #[should_panic]
642    fn test_invalid_length() {
643        let values = buffer![15_u32, 135, 13531, 42].into_array();
644        let indices = buffer![10_u64, 11, 50, 100].into_array();
645
646        SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
647    }
648
649    #[test]
650    fn test_valid_length() {
651        let values = buffer![15_u32, 135, 13531, 42].into_array();
652        let indices = buffer![10_u64, 11, 50, 100].into_array();
653
654        SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
655    }
656
657    #[test]
658    fn encode_with_nulls() {
659        let original = PrimitiveArray::new(
660            buffer![0i32, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
661            Validity::from_iter(vec![
662                true, true, false, true, false, true, false, true, true, false, true, false,
663            ]),
664        );
665        let sparse = SparseArray::encode(&original.clone().into_array(), None)
666            .vortex_expect("SparseArray::encode should succeed for test data");
667        assert_eq!(
668            sparse.validity_mask().unwrap(),
669            Mask::from_iter(vec![
670                true, true, false, true, false, true, false, true, true, false, true, false,
671            ])
672        );
673        assert_arrays_eq!(sparse.to_primitive(), original);
674    }
675
676    #[test]
677    fn validity_mask_includes_null_values_when_fill_is_null() {
678        let indices = buffer![0u8, 2, 4, 6, 8].into_array();
679        let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
680            .into_array();
681        let array =
682            SparseArray::try_new(indices, values, 10, Scalar::null_native::<i16>()).unwrap();
683        let actual = array.validity_mask().unwrap();
684        let expected = Mask::from_iter([
685            true, false, true, false, false, false, false, false, true, false,
686        ]);
687
688        assert_eq!(actual, expected);
689    }
690}