Skip to main content

vortex_sparse/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6
7use kernel::PARENT_KERNELS;
8use prost::Message as _;
9use vortex_array::Array;
10use vortex_array::ArrayEq;
11use vortex_array::ArrayHash;
12use vortex_array::ArrayRef;
13use vortex_array::DeserializeMetadata;
14use vortex_array::ExecutionCtx;
15use vortex_array::IntoArray;
16use vortex_array::Precision;
17use vortex_array::ProstMetadata;
18use vortex_array::ToCanonical;
19use vortex_array::arrays::ConstantArray;
20use vortex_array::buffer::BufferHandle;
21use vortex_array::builtins::ArrayBuiltins;
22use vortex_array::dtype::DType;
23use vortex_array::dtype::Nullability;
24use vortex_array::patches::Patches;
25use vortex_array::patches::PatchesMetadata;
26use vortex_array::scalar::Scalar;
27use vortex_array::scalar::ScalarValue;
28use vortex_array::scalar_fn::fns::operators::Operator;
29use vortex_array::serde::ArrayChildren;
30use vortex_array::stats::ArrayStats;
31use vortex_array::stats::StatsSetRef;
32use vortex_array::validity::Validity;
33use vortex_array::vtable;
34use vortex_array::vtable::ArrayId;
35use vortex_array::vtable::VTable;
36use vortex_array::vtable::ValidityVTable;
37use vortex_array::vtable::patches_child;
38use vortex_array::vtable::patches_child_name;
39use vortex_array::vtable::patches_nchildren;
40use vortex_buffer::Buffer;
41use vortex_buffer::ByteBufferMut;
42use vortex_error::VortexExpect as _;
43use vortex_error::VortexResult;
44use vortex_error::vortex_bail;
45use vortex_error::vortex_ensure;
46use vortex_error::vortex_ensure_eq;
47use vortex_error::vortex_panic;
48use vortex_mask::AllOr;
49use vortex_mask::Mask;
50use vortex_session::VortexSession;
51
52use crate::canonical::execute_sparse;
53use crate::rules::RULES;
54
55mod canonical;
56mod compute;
57mod kernel;
58mod ops;
59mod rules;
60mod slice;
61
62vtable!(Sparse);
63
64#[derive(Debug)]
65pub struct SparseMetadata {
66    patches: PatchesMetadata,
67    fill_value: Scalar,
68}
69
70#[derive(Clone, prost::Message)]
71#[repr(C)]
72pub struct ProstPatchesMetadata {
73    #[prost(message, required, tag = "1")]
74    patches: PatchesMetadata,
75}
76
77impl VTable for SparseVTable {
78    type Array = SparseArray;
79
80    type Metadata = SparseMetadata;
81    type OperationsVTable = Self;
82    type ValidityVTable = Self;
83
84    fn id(_array: &Self::Array) -> ArrayId {
85        Self::ID
86    }
87
88    fn len(array: &SparseArray) -> usize {
89        array.patches.array_len()
90    }
91
92    fn dtype(array: &SparseArray) -> &DType {
93        array.fill_scalar().dtype()
94    }
95
96    fn stats(array: &SparseArray) -> StatsSetRef<'_> {
97        array.stats_set.to_ref(array.as_ref())
98    }
99
100    fn array_hash<H: std::hash::Hasher>(array: &SparseArray, state: &mut H, precision: Precision) {
101        array.patches.array_hash(state, precision);
102        array.fill_value.hash(state);
103    }
104
105    fn array_eq(array: &SparseArray, other: &SparseArray, precision: Precision) -> bool {
106        array.patches.array_eq(&other.patches, precision) && array.fill_value == other.fill_value
107    }
108
109    fn nbuffers(_array: &SparseArray) -> usize {
110        1
111    }
112
113    fn buffer(array: &SparseArray, idx: usize) -> BufferHandle {
114        match idx {
115            0 => {
116                let fill_value_buffer =
117                    ScalarValue::to_proto_bytes::<ByteBufferMut>(array.fill_value.value()).freeze();
118                BufferHandle::new_host(fill_value_buffer)
119            }
120            _ => vortex_panic!("SparseArray buffer index {idx} out of bounds"),
121        }
122    }
123
124    fn buffer_name(_array: &SparseArray, idx: usize) -> Option<String> {
125        match idx {
126            0 => Some("fill_value".to_string()),
127            _ => vortex_panic!("SparseArray buffer_name index {idx} out of bounds"),
128        }
129    }
130
131    fn nchildren(array: &SparseArray) -> usize {
132        patches_nchildren(array.patches())
133    }
134
135    fn child(array: &SparseArray, idx: usize) -> ArrayRef {
136        patches_child(array.patches(), idx)
137    }
138
139    fn child_name(_array: &SparseArray, idx: usize) -> String {
140        patches_child_name(idx).to_string()
141    }
142
143    fn metadata(array: &SparseArray) -> VortexResult<Self::Metadata> {
144        let patches = array.patches().to_metadata(array.len(), array.dtype())?;
145
146        Ok(SparseMetadata {
147            patches,
148            fill_value: array.fill_value.clone(),
149        })
150    }
151
152    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
153        let prost_patches = ProstPatchesMetadata {
154            patches: metadata.patches,
155        };
156
157        // Note that we DO NOT serialize the fill value since that is stored in the buffers.
158        Ok(Some(prost_patches.encode_to_vec()))
159    }
160
161    fn deserialize(
162        bytes: &[u8],
163        dtype: &DType,
164        _len: usize,
165        buffers: &[BufferHandle],
166        _session: &VortexSession,
167    ) -> VortexResult<Self::Metadata> {
168        let prost_patches =
169            <ProstMetadata<ProstPatchesMetadata> as DeserializeMetadata>::deserialize(bytes)?;
170
171        // Once we have the patches metadata, we need to get the fill value from the buffers.
172
173        if buffers.len() != 1 {
174            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
175        }
176        let scalar_bytes: &[u8] = &buffers[0].clone().try_to_host_sync()?;
177
178        let scalar_value = ScalarValue::from_proto_bytes(scalar_bytes, dtype)?;
179        let fill_value = Scalar::try_new(dtype.clone(), scalar_value)?;
180
181        Ok(SparseMetadata {
182            patches: prost_patches.patches,
183            fill_value,
184        })
185    }
186
187    fn build(
188        dtype: &DType,
189        len: usize,
190        metadata: &Self::Metadata,
191        _buffers: &[BufferHandle],
192        children: &dyn ArrayChildren,
193    ) -> VortexResult<SparseArray> {
194        vortex_ensure_eq!(
195            children.len(),
196            2,
197            "SparseArray expects 2 children for sparse encoding, found {}",
198            children.len()
199        );
200        vortex_ensure_eq!(
201            metadata.patches.offset()?,
202            0,
203            "Patches must start at offset 0"
204        );
205
206        let patch_indices = children.get(
207            0,
208            &metadata.patches.indices_dtype()?,
209            metadata.patches.len()?,
210        )?;
211        let patch_values = children.get(1, dtype, metadata.patches.len()?)?;
212
213        SparseArray::try_new(
214            patch_indices,
215            patch_values,
216            len,
217            metadata.fill_value.clone(),
218        )
219    }
220
221    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
222        vortex_ensure_eq!(
223            children.len(),
224            2,
225            "SparseArray expects 2 children, got {}",
226            children.len()
227        );
228
229        let mut children_iter = children.into_iter();
230        let patch_indices = children_iter.next().vortex_expect("patch_indices child");
231        let patch_values = children_iter.next().vortex_expect("patch_values child");
232
233        array.patches = Patches::new(
234            array.patches.array_len(),
235            array.patches.offset(),
236            patch_indices,
237            patch_values,
238            array.patches.chunk_offsets().clone(),
239        )?;
240
241        Ok(())
242    }
243
244    fn reduce_parent(
245        array: &Self::Array,
246        parent: &ArrayRef,
247        child_idx: usize,
248    ) -> VortexResult<Option<ArrayRef>> {
249        RULES.evaluate(array, parent, child_idx)
250    }
251
252    fn execute_parent(
253        array: &Self::Array,
254        parent: &ArrayRef,
255        child_idx: usize,
256        ctx: &mut ExecutionCtx,
257    ) -> VortexResult<Option<ArrayRef>> {
258        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
259    }
260
261    fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
262        execute_sparse(array, ctx)
263    }
264}
265
266#[derive(Clone, Debug)]
267pub struct SparseArray {
268    patches: Patches,
269    fill_value: Scalar,
270    stats_set: ArrayStats,
271}
272
273#[derive(Debug)]
274pub struct SparseVTable;
275
276impl SparseVTable {
277    pub const ID: ArrayId = ArrayId::new_ref("vortex.sparse");
278}
279
280impl SparseArray {
281    pub fn try_new(
282        indices: ArrayRef,
283        values: ArrayRef,
284        len: usize,
285        fill_value: Scalar,
286    ) -> VortexResult<Self> {
287        vortex_ensure!(
288            indices.len() == values.len(),
289            "Mismatched indices {} and values {} length",
290            indices.len(),
291            values.len()
292        );
293
294        if indices.is_host() {
295            debug_assert_eq!(
296                indices.statistics().compute_is_strict_sorted(),
297                Some(true),
298                "SparseArray: indices must be strict-sorted"
299            );
300
301            // Verify the indices are all in the valid range
302            if !indices.is_empty() {
303                let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1)?)?;
304
305                vortex_ensure!(
306                    last_index < len,
307                    "Array length was {len} but the last index is {last_index}"
308                );
309            }
310        }
311
312        Ok(Self {
313            // TODO(0ax1): handle chunk offsets
314            patches: Patches::new(len, 0, indices, values, None)?,
315            fill_value,
316            stats_set: Default::default(),
317        })
318    }
319
320    /// Build a new SparseArray from an existing set of patches.
321    pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
322        vortex_ensure!(
323            fill_value.dtype() == patches.values().dtype(),
324            "fill value, {:?}, should be instance of values dtype, {} but was {}.",
325            fill_value,
326            patches.values().dtype(),
327            fill_value.dtype(),
328        );
329
330        Ok(Self {
331            patches,
332            fill_value,
333            stats_set: Default::default(),
334        })
335    }
336
337    pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
338        Self {
339            patches,
340            fill_value,
341            stats_set: Default::default(),
342        }
343    }
344
345    #[inline]
346    pub fn patches(&self) -> &Patches {
347        &self.patches
348    }
349
350    #[inline]
351    pub fn resolved_patches(&self) -> VortexResult<Patches> {
352        let patches = self.patches();
353        let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?;
354        let indices = patches.indices().to_array().binary(
355            ConstantArray::new(indices_offset, patches.indices().len()).into_array(),
356            Operator::Sub,
357        )?;
358
359        Patches::new(
360            patches.array_len(),
361            0,
362            indices,
363            patches.values().clone(),
364            // TODO(0ax1): handle chunk offsets
365            None,
366        )
367    }
368
369    #[inline]
370    pub fn fill_scalar(&self) -> &Scalar {
371        &self.fill_value
372    }
373
374    /// Encode given array as a SparseArray.
375    ///
376    /// Optionally provided fill value will be respected if the array is less than 90% null.
377    pub fn encode(array: &ArrayRef, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
378        if let Some(fill_value) = fill_value.as_ref()
379            && array.dtype() != fill_value.dtype()
380        {
381            vortex_bail!(
382                "Array and fill value types must match. got {} and {}",
383                array.dtype(),
384                fill_value.dtype()
385            )
386        }
387        let mask = array.validity_mask()?;
388
389        if mask.all_false() {
390            // Array is constant NULL
391            return Ok(
392                ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
393            );
394        } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
395            // Array is dominated by NULL but has non-NULL values
396            // TODO(joe): use exe ctx?
397            let non_null_values = array.filter(mask.clone())?.to_canonical()?.into_array();
398            let non_null_indices = match mask.indices() {
399                AllOr::All => {
400                    // We already know that the mask is 90%+ false
401                    unreachable!("Mask is mostly null")
402                }
403                AllOr::None => {
404                    // we know there are some non-NULL values
405                    unreachable!("Mask is mostly null but not all null")
406                }
407                AllOr::Some(values) => {
408                    let buffer: Buffer<u32> = values
409                        .iter()
410                        .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
411                        .collect();
412
413                    buffer.into_array()
414                }
415            };
416
417            return Ok(SparseArray::try_new(
418                non_null_indices,
419                non_null_values,
420                array.len(),
421                Scalar::null(array.dtype().clone()),
422            )?
423            .into_array());
424        }
425
426        let fill = if let Some(fill) = fill_value {
427            fill
428        } else {
429            // TODO(robert): Support other dtypes, only thing missing is getting most common value out of the array
430            let (top_pvalue, _) = array
431                .to_primitive()
432                .top_value()?
433                .vortex_expect("Non empty or all null array");
434
435            Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
436        };
437
438        let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
439        let non_top_mask = Mask::from_buffer(
440            array
441                .to_array()
442                .binary(fill_array.clone(), Operator::NotEq)?
443                .fill_null(Scalar::bool(true, Nullability::NonNullable))?
444                .to_bool()
445                .to_bit_buffer(),
446        );
447
448        let non_top_values = array
449            .filter(non_top_mask.clone())?
450            .to_canonical()?
451            .into_array();
452
453        let indices: Buffer<u64> = match non_top_mask {
454            Mask::AllTrue(count) => {
455                // all true -> complete slice
456                (0u64..count as u64).collect()
457            }
458            Mask::AllFalse(_) => {
459                // All values are equal to the top value
460                return Ok(fill_array);
461            }
462            Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
463        };
464
465        SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
466            .map(|a| a.into_array())
467    }
468}
469
470impl ValidityVTable<SparseVTable> for SparseVTable {
471    fn validity(array: &SparseArray) -> VortexResult<Validity> {
472        let patches = unsafe {
473            Patches::new_unchecked(
474                array.patches.array_len(),
475                array.patches.offset(),
476                array.patches.indices().clone(),
477                array
478                    .patches
479                    .values()
480                    .validity()?
481                    .to_array(array.patches.values().len()),
482                array.patches.chunk_offsets().clone(),
483                array.patches.offset_within_chunk(),
484            )
485        };
486
487        Ok(Validity::Array(
488            unsafe { SparseArray::new_unchecked(patches, array.fill_value.is_valid().into()) }
489                .into_array(),
490        ))
491    }
492}
493
494#[cfg(test)]
495mod test {
496    use itertools::Itertools;
497    use vortex_array::IntoArray;
498    use vortex_array::arrays::ConstantArray;
499    use vortex_array::arrays::PrimitiveArray;
500    use vortex_array::assert_arrays_eq;
501    use vortex_array::builtins::ArrayBuiltins;
502    use vortex_array::dtype::DType;
503    use vortex_array::dtype::Nullability;
504    use vortex_array::dtype::PType;
505    use vortex_array::scalar::Scalar;
506    use vortex_array::validity::Validity;
507    use vortex_buffer::buffer;
508    use vortex_error::VortexExpect;
509
510    use super::*;
511
512    fn nullable_fill() -> Scalar {
513        Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
514    }
515
516    fn non_nullable_fill() -> Scalar {
517        Scalar::from(42i32)
518    }
519
520    fn sparse_array(fill_value: Scalar) -> ArrayRef {
521        // merged array: [null, null, 100, null, null, 200, null, null, 300, null]
522        let mut values = buffer![100i32, 200, 300].into_array();
523        values = values.cast(fill_value.dtype().clone()).unwrap();
524
525        SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
526            .unwrap()
527            .into_array()
528    }
529
530    #[test]
531    pub fn test_scalar_at() {
532        let array = sparse_array(nullable_fill());
533
534        assert_eq!(array.scalar_at(0).unwrap(), nullable_fill());
535        assert_eq!(array.scalar_at(2).unwrap(), Scalar::from(Some(100_i32)));
536        assert_eq!(array.scalar_at(5).unwrap(), Scalar::from(Some(200_i32)));
537    }
538
539    #[test]
540    #[should_panic(expected = "out of bounds")]
541    fn test_scalar_at_oob() {
542        let array = sparse_array(nullable_fill());
543        array.scalar_at(10).unwrap();
544    }
545
546    #[test]
547    pub fn test_scalar_at_again() {
548        let arr = SparseArray::try_new(
549            ConstantArray::new(10u32, 1).into_array(),
550            ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
551            100,
552            Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
553        )
554        .unwrap();
555
556        assert_eq!(
557            arr.scalar_at(10)
558                .unwrap()
559                .as_primitive()
560                .typed_value::<u32>(),
561            Some(1234)
562        );
563        assert!(arr.scalar_at(0).unwrap().is_null());
564        assert!(arr.scalar_at(99).unwrap().is_null());
565    }
566
567    #[test]
568    pub fn scalar_at_sliced() {
569        let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
570        assert_eq!(usize::try_from(&sliced.scalar_at(0).unwrap()).unwrap(), 100);
571    }
572
573    #[test]
574    pub fn validity_mask_sliced_null_fill() {
575        let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
576        assert_eq!(
577            sliced.validity_mask().unwrap(),
578            Mask::from_iter(vec![true, false, false, true, false])
579        );
580    }
581
582    #[test]
583    pub fn validity_mask_sliced_nonnull_fill() {
584        let sliced = SparseArray::try_new(
585            buffer![2u64, 5, 8].into_array(),
586            ConstantArray::new(
587                Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
588                3,
589            )
590            .into_array(),
591            10,
592            Scalar::primitive(1.0f32, Nullability::Nullable),
593        )
594        .unwrap()
595        .slice(2..7)
596        .unwrap();
597
598        assert_eq!(
599            sliced.validity_mask().unwrap(),
600            Mask::from_iter(vec![false, true, true, false, true])
601        );
602    }
603
604    #[test]
605    pub fn scalar_at_sliced_twice() {
606        let sliced_once = sparse_array(nullable_fill()).slice(1..8).unwrap();
607        assert_eq!(
608            usize::try_from(&sliced_once.scalar_at(1).unwrap()).unwrap(),
609            100
610        );
611
612        let sliced_twice = sliced_once.slice(1..6).unwrap();
613        assert_eq!(
614            usize::try_from(&sliced_twice.scalar_at(3).unwrap()).unwrap(),
615            200
616        );
617    }
618
619    #[test]
620    pub fn sparse_validity_mask() {
621        let array = sparse_array(nullable_fill());
622        assert_eq!(
623            array
624                .validity_mask()
625                .unwrap()
626                .to_bit_buffer()
627                .iter()
628                .collect_vec(),
629            [
630                false, false, true, false, false, true, false, false, true, false
631            ]
632        );
633    }
634
635    #[test]
636    fn sparse_validity_mask_non_null_fill() {
637        let array = sparse_array(non_nullable_fill());
638        assert!(array.validity_mask().unwrap().all_true());
639    }
640
641    #[test]
642    #[should_panic]
643    fn test_invalid_length() {
644        let values = buffer![15_u32, 135, 13531, 42].into_array();
645        let indices = buffer![10_u64, 11, 50, 100].into_array();
646
647        SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
648    }
649
650    #[test]
651    fn test_valid_length() {
652        let values = buffer![15_u32, 135, 13531, 42].into_array();
653        let indices = buffer![10_u64, 11, 50, 100].into_array();
654
655        SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
656    }
657
658    #[test]
659    fn encode_with_nulls() {
660        let original = PrimitiveArray::new(
661            buffer![0i32, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
662            Validity::from_iter(vec![
663                true, true, false, true, false, true, false, true, true, false, true, false,
664            ]),
665        );
666        let sparse = SparseArray::encode(&original.clone().into_array(), None)
667            .vortex_expect("SparseArray::encode should succeed for test data");
668        assert_eq!(
669            sparse.validity_mask().unwrap(),
670            Mask::from_iter(vec![
671                true, true, false, true, false, true, false, true, true, false, true, false,
672            ])
673        );
674        assert_arrays_eq!(sparse.to_primitive(), original);
675    }
676
677    #[test]
678    fn validity_mask_includes_null_values_when_fill_is_null() {
679        let indices = buffer![0u8, 2, 4, 6, 8].into_array();
680        let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
681            .into_array();
682        let array =
683            SparseArray::try_new(indices, values, 10, Scalar::null_native::<i16>()).unwrap();
684        let actual = array.validity_mask().unwrap();
685        let expected = Mask::from_iter([
686            true, false, true, false, false, false, false, false, true, false,
687        ]);
688
689        assert_eq!(actual, expected);
690    }
691}