Skip to main content

vortex_sparse/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use kernel::PARENT_KERNELS;
9use prost::Message as _;
10use vortex_array::ArrayEq;
11use vortex_array::ArrayHash;
12use vortex_array::ArrayRef;
13use vortex_array::DynArray;
14use vortex_array::ExecutionCtx;
15use vortex_array::ExecutionResult;
16use vortex_array::IntoArray;
17use vortex_array::Precision;
18use vortex_array::ToCanonical;
19use vortex_array::arrays::ConstantArray;
20use vortex_array::buffer::BufferHandle;
21use vortex_array::builtins::ArrayBuiltins;
22use vortex_array::dtype::DType;
23use vortex_array::dtype::Nullability;
24use vortex_array::patches::Patches;
25use vortex_array::patches::PatchesMetadata;
26use vortex_array::scalar::Scalar;
27use vortex_array::scalar::ScalarValue;
28use vortex_array::scalar_fn::fns::operators::Operator;
29use vortex_array::serde::ArrayChildren;
30use vortex_array::stats::ArrayStats;
31use vortex_array::stats::StatsSetRef;
32use vortex_array::validity::Validity;
33use vortex_array::vtable;
34use vortex_array::vtable::ArrayId;
35use vortex_array::vtable::VTable;
36use vortex_array::vtable::ValidityVTable;
37use vortex_array::vtable::patches_child;
38use vortex_array::vtable::patches_child_name;
39use vortex_array::vtable::patches_nchildren;
40use vortex_buffer::Buffer;
41use vortex_buffer::ByteBufferMut;
42use vortex_error::VortexExpect as _;
43use vortex_error::VortexResult;
44use vortex_error::vortex_bail;
45use vortex_error::vortex_ensure;
46use vortex_error::vortex_ensure_eq;
47use vortex_error::vortex_panic;
48use vortex_mask::AllOr;
49use vortex_mask::Mask;
50use vortex_session::VortexSession;
51
52use crate::canonical::execute_sparse;
53use crate::rules::RULES;
54
55mod canonical;
56mod compute;
57mod kernel;
58mod ops;
59mod rules;
60mod slice;
61
62vtable!(Sparse);
63
64#[derive(Debug)]
65pub struct SparseMetadata {
66    patches: PatchesMetadata,
67    fill_value: Scalar,
68}
69
70#[derive(Clone, prost::Message)]
71#[repr(C)]
72pub struct ProstPatchesMetadata {
73    #[prost(message, required, tag = "1")]
74    patches: PatchesMetadata,
75}
76
77impl VTable for Sparse {
78    type Array = SparseArray;
79
80    type Metadata = SparseMetadata;
81    type OperationsVTable = Self;
82    type ValidityVTable = Self;
83
84    fn vtable(_array: &Self::Array) -> &Self {
85        &Sparse
86    }
87
88    fn id(&self) -> ArrayId {
89        Self::ID
90    }
91
92    fn len(array: &SparseArray) -> usize {
93        array.patches.array_len()
94    }
95
96    fn dtype(array: &SparseArray) -> &DType {
97        array.fill_scalar().dtype()
98    }
99
100    fn stats(array: &SparseArray) -> StatsSetRef<'_> {
101        array.stats_set.to_ref(array.as_ref())
102    }
103
104    fn array_hash<H: std::hash::Hasher>(array: &SparseArray, state: &mut H, precision: Precision) {
105        array.patches.array_hash(state, precision);
106        array.fill_value.hash(state);
107    }
108
109    fn array_eq(array: &SparseArray, other: &SparseArray, precision: Precision) -> bool {
110        array.patches.array_eq(&other.patches, precision) && array.fill_value == other.fill_value
111    }
112
113    fn nbuffers(_array: &SparseArray) -> usize {
114        1
115    }
116
117    fn buffer(array: &SparseArray, idx: usize) -> BufferHandle {
118        match idx {
119            0 => {
120                let fill_value_buffer =
121                    ScalarValue::to_proto_bytes::<ByteBufferMut>(array.fill_value.value()).freeze();
122                BufferHandle::new_host(fill_value_buffer)
123            }
124            _ => vortex_panic!("SparseArray buffer index {idx} out of bounds"),
125        }
126    }
127
128    fn buffer_name(_array: &SparseArray, idx: usize) -> Option<String> {
129        match idx {
130            0 => Some("fill_value".to_string()),
131            _ => vortex_panic!("SparseArray buffer_name index {idx} out of bounds"),
132        }
133    }
134
135    fn nchildren(array: &SparseArray) -> usize {
136        patches_nchildren(array.patches())
137    }
138
139    fn child(array: &SparseArray, idx: usize) -> ArrayRef {
140        patches_child(array.patches(), idx)
141    }
142
143    fn child_name(_array: &SparseArray, idx: usize) -> String {
144        patches_child_name(idx).to_string()
145    }
146
147    fn metadata(array: &SparseArray) -> VortexResult<Self::Metadata> {
148        let patches = array.patches().to_metadata(array.len(), array.dtype())?;
149
150        Ok(SparseMetadata {
151            patches,
152            fill_value: array.fill_value.clone(),
153        })
154    }
155
156    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
157        let prost_patches = ProstPatchesMetadata {
158            patches: metadata.patches,
159        };
160
161        // Note that we DO NOT serialize the fill value since that is stored in the buffers.
162        Ok(Some(prost_patches.encode_to_vec()))
163    }
164
165    fn deserialize(
166        bytes: &[u8],
167        dtype: &DType,
168        _len: usize,
169        buffers: &[BufferHandle],
170        session: &VortexSession,
171    ) -> VortexResult<Self::Metadata> {
172        let prost_patches = ProstPatchesMetadata::decode(bytes)?;
173
174        // Once we have the patches metadata, we need to get the fill value from the buffers.
175
176        if buffers.len() != 1 {
177            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
178        }
179        let scalar_bytes: &[u8] = &buffers[0].clone().try_to_host_sync()?;
180
181        let scalar_value = ScalarValue::from_proto_bytes(scalar_bytes, dtype, session)?;
182        let fill_value = Scalar::try_new(dtype.clone(), scalar_value)?;
183
184        Ok(SparseMetadata {
185            patches: prost_patches.patches,
186            fill_value,
187        })
188    }
189
190    fn build(
191        dtype: &DType,
192        len: usize,
193        metadata: &Self::Metadata,
194        _buffers: &[BufferHandle],
195        children: &dyn ArrayChildren,
196    ) -> VortexResult<SparseArray> {
197        vortex_ensure_eq!(
198            children.len(),
199            2,
200            "SparseArray expects 2 children for sparse encoding, found {}",
201            children.len()
202        );
203
204        let patch_indices = children.get(
205            0,
206            &metadata.patches.indices_dtype()?,
207            metadata.patches.len()?,
208        )?;
209        let patch_values = children.get(1, dtype, metadata.patches.len()?)?;
210
211        SparseArray::try_new_from_patches(
212            Patches::new(
213                len,
214                metadata.patches.offset()?,
215                patch_indices,
216                patch_values,
217                None,
218            )?,
219            metadata.fill_value.clone(),
220        )
221    }
222
223    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
224        vortex_ensure_eq!(
225            children.len(),
226            2,
227            "SparseArray expects 2 children, got {}",
228            children.len()
229        );
230
231        let mut children_iter = children.into_iter();
232        let patch_indices = children_iter.next().vortex_expect("patch_indices child");
233        let patch_values = children_iter.next().vortex_expect("patch_values child");
234
235        array.patches = Patches::new(
236            array.patches.array_len(),
237            array.patches.offset(),
238            patch_indices,
239            patch_values,
240            array.patches.chunk_offsets().clone(),
241        )?;
242
243        Ok(())
244    }
245
246    fn reduce_parent(
247        array: &Self::Array,
248        parent: &ArrayRef,
249        child_idx: usize,
250    ) -> VortexResult<Option<ArrayRef>> {
251        RULES.evaluate(array, parent, child_idx)
252    }
253
254    fn execute_parent(
255        array: &Self::Array,
256        parent: &ArrayRef,
257        child_idx: usize,
258        ctx: &mut ExecutionCtx,
259    ) -> VortexResult<Option<ArrayRef>> {
260        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
261    }
262
263    fn execute(array: Arc<Self::Array>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
264        execute_sparse(&array, ctx).map(ExecutionResult::done)
265    }
266}
267
268#[derive(Clone, Debug)]
269pub struct SparseArray {
270    patches: Patches,
271    fill_value: Scalar,
272    stats_set: ArrayStats,
273}
274
275#[derive(Clone, Debug)]
276pub struct Sparse;
277
278impl Sparse {
279    pub const ID: ArrayId = ArrayId::new_ref("vortex.sparse");
280}
281
282impl SparseArray {
283    pub fn try_new(
284        indices: ArrayRef,
285        values: ArrayRef,
286        len: usize,
287        fill_value: Scalar,
288    ) -> VortexResult<Self> {
289        vortex_ensure!(
290            indices.len() == values.len(),
291            "Mismatched indices {} and values {} length",
292            indices.len(),
293            values.len()
294        );
295
296        if indices.is_host() {
297            debug_assert_eq!(
298                indices.statistics().compute_is_strict_sorted(),
299                Some(true),
300                "SparseArray: indices must be strict-sorted"
301            );
302
303            // Verify the indices are all in the valid range
304            if !indices.is_empty() {
305                let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1)?)?;
306
307                vortex_ensure!(
308                    last_index < len,
309                    "Array length was {len} but the last index is {last_index}"
310                );
311            }
312        }
313
314        Ok(Self {
315            // TODO(0ax1): handle chunk offsets
316            patches: Patches::new(len, 0, indices, values, None)?,
317            fill_value,
318            stats_set: Default::default(),
319        })
320    }
321
322    /// Build a new SparseArray from an existing set of patches.
323    pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
324        vortex_ensure!(
325            fill_value.dtype() == patches.values().dtype(),
326            "fill value, {:?}, should be instance of values dtype, {} but was {}.",
327            fill_value,
328            patches.values().dtype(),
329            fill_value.dtype(),
330        );
331
332        Ok(Self {
333            patches,
334            fill_value,
335            stats_set: Default::default(),
336        })
337    }
338
339    pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
340        Self {
341            patches,
342            fill_value,
343            stats_set: Default::default(),
344        }
345    }
346
347    #[inline]
348    pub fn patches(&self) -> &Patches {
349        &self.patches
350    }
351
352    #[inline]
353    pub fn resolved_patches(&self) -> VortexResult<Patches> {
354        let patches = self.patches();
355        let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?;
356        let indices = patches.indices().to_array().binary(
357            ConstantArray::new(indices_offset, patches.indices().len()).into_array(),
358            Operator::Sub,
359        )?;
360
361        Patches::new(
362            patches.array_len(),
363            0,
364            indices,
365            patches.values().clone(),
366            // TODO(0ax1): handle chunk offsets
367            None,
368        )
369    }
370
371    #[inline]
372    pub fn fill_scalar(&self) -> &Scalar {
373        &self.fill_value
374    }
375
376    /// Encode given array as a SparseArray.
377    ///
378    /// Optionally provided fill value will be respected if the array is less than 90% null.
379    pub fn encode(array: &ArrayRef, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
380        if let Some(fill_value) = fill_value.as_ref()
381            && array.dtype() != fill_value.dtype()
382        {
383            vortex_bail!(
384                "Array and fill value types must match. got {} and {}",
385                array.dtype(),
386                fill_value.dtype()
387            )
388        }
389        let mask = array.validity_mask()?;
390
391        if mask.all_false() {
392            // Array is constant NULL
393            return Ok(
394                ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
395            );
396        } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
397            // Array is dominated by NULL but has non-NULL values
398            // TODO(joe): use exe ctx?
399            let non_null_values = array.filter(mask.clone())?.to_canonical()?.into_array();
400            let non_null_indices = match mask.indices() {
401                AllOr::All => {
402                    // We already know that the mask is 90%+ false
403                    unreachable!("Mask is mostly null")
404                }
405                AllOr::None => {
406                    // we know there are some non-NULL values
407                    unreachable!("Mask is mostly null but not all null")
408                }
409                AllOr::Some(values) => {
410                    let buffer: Buffer<u32> = values
411                        .iter()
412                        .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
413                        .collect();
414
415                    buffer.into_array()
416                }
417            };
418
419            return Ok(SparseArray::try_new(
420                non_null_indices,
421                non_null_values,
422                array.len(),
423                Scalar::null(array.dtype().clone()),
424            )?
425            .into_array());
426        }
427
428        let fill = if let Some(fill) = fill_value {
429            fill
430        } else {
431            // TODO(robert): Support other dtypes, only thing missing is getting most common value out of the array
432            let (top_pvalue, _) = array
433                .to_primitive()
434                .top_value()?
435                .vortex_expect("Non empty or all null array");
436
437            Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
438        };
439
440        let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
441        let non_top_mask = Mask::from_buffer(
442            array
443                .to_array()
444                .binary(fill_array.clone(), Operator::NotEq)?
445                .fill_null(Scalar::bool(true, Nullability::NonNullable))?
446                .to_bool()
447                .to_bit_buffer(),
448        );
449
450        let non_top_values = array
451            .filter(non_top_mask.clone())?
452            .to_canonical()?
453            .into_array();
454
455        let indices: Buffer<u64> = match non_top_mask {
456            Mask::AllTrue(count) => {
457                // all true -> complete slice
458                (0u64..count as u64).collect()
459            }
460            Mask::AllFalse(_) => {
461                // All values are equal to the top value
462                return Ok(fill_array);
463            }
464            Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
465        };
466
467        SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
468            .map(|a| a.into_array())
469    }
470}
471
472impl ValidityVTable<Sparse> for Sparse {
473    fn validity(array: &SparseArray) -> VortexResult<Validity> {
474        let patches = unsafe {
475            Patches::new_unchecked(
476                array.patches.array_len(),
477                array.patches.offset(),
478                array.patches.indices().clone(),
479                array
480                    .patches
481                    .values()
482                    .validity()?
483                    .to_array(array.patches.values().len()),
484                array.patches.chunk_offsets().clone(),
485                array.patches.offset_within_chunk(),
486            )
487        };
488
489        Ok(Validity::Array(
490            unsafe { SparseArray::new_unchecked(patches, array.fill_value.is_valid().into()) }
491                .into_array(),
492        ))
493    }
494}
495
496#[cfg(test)]
497mod test {
498    use itertools::Itertools;
499    use vortex_array::DynArray;
500    use vortex_array::IntoArray;
501    use vortex_array::arrays::ConstantArray;
502    use vortex_array::arrays::PrimitiveArray;
503    use vortex_array::assert_arrays_eq;
504    use vortex_array::builtins::ArrayBuiltins;
505    use vortex_array::dtype::DType;
506    use vortex_array::dtype::Nullability;
507    use vortex_array::dtype::PType;
508    use vortex_array::scalar::Scalar;
509    use vortex_array::validity::Validity;
510    use vortex_buffer::buffer;
511    use vortex_error::VortexExpect;
512
513    use super::*;
514
515    fn nullable_fill() -> Scalar {
516        Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
517    }
518
519    fn non_nullable_fill() -> Scalar {
520        Scalar::from(42i32)
521    }
522
523    fn sparse_array(fill_value: Scalar) -> ArrayRef {
524        // merged array: [null, null, 100, null, null, 200, null, null, 300, null]
525        let mut values = buffer![100i32, 200, 300].into_array();
526        values = values.cast(fill_value.dtype().clone()).unwrap();
527
528        SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
529            .unwrap()
530            .into_array()
531    }
532
533    #[test]
534    pub fn test_scalar_at() {
535        let array = sparse_array(nullable_fill());
536
537        assert_eq!(array.scalar_at(0).unwrap(), nullable_fill());
538        assert_eq!(array.scalar_at(2).unwrap(), Scalar::from(Some(100_i32)));
539        assert_eq!(array.scalar_at(5).unwrap(), Scalar::from(Some(200_i32)));
540    }
541
542    #[test]
543    #[should_panic(expected = "out of bounds")]
544    fn test_scalar_at_oob() {
545        let array = sparse_array(nullable_fill());
546        array.scalar_at(10).unwrap();
547    }
548
549    #[test]
550    pub fn test_scalar_at_again() {
551        let arr = SparseArray::try_new(
552            ConstantArray::new(10u32, 1).into_array(),
553            ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
554            100,
555            Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
556        )
557        .unwrap();
558
559        assert_eq!(
560            arr.scalar_at(10)
561                .unwrap()
562                .as_primitive()
563                .typed_value::<u32>(),
564            Some(1234)
565        );
566        assert!(arr.scalar_at(0).unwrap().is_null());
567        assert!(arr.scalar_at(99).unwrap().is_null());
568    }
569
570    #[test]
571    pub fn scalar_at_sliced() {
572        let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
573        assert_eq!(usize::try_from(&sliced.scalar_at(0).unwrap()).unwrap(), 100);
574    }
575
576    #[test]
577    pub fn validity_mask_sliced_null_fill() {
578        let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
579        assert_eq!(
580            sliced.validity_mask().unwrap(),
581            Mask::from_iter(vec![true, false, false, true, false])
582        );
583    }
584
585    #[test]
586    pub fn validity_mask_sliced_nonnull_fill() {
587        let sliced = SparseArray::try_new(
588            buffer![2u64, 5, 8].into_array(),
589            ConstantArray::new(
590                Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
591                3,
592            )
593            .into_array(),
594            10,
595            Scalar::primitive(1.0f32, Nullability::Nullable),
596        )
597        .unwrap()
598        .slice(2..7)
599        .unwrap();
600
601        assert_eq!(
602            sliced.validity_mask().unwrap(),
603            Mask::from_iter(vec![false, true, true, false, true])
604        );
605    }
606
607    #[test]
608    pub fn scalar_at_sliced_twice() {
609        let sliced_once = sparse_array(nullable_fill()).slice(1..8).unwrap();
610        assert_eq!(
611            usize::try_from(&sliced_once.scalar_at(1).unwrap()).unwrap(),
612            100
613        );
614
615        let sliced_twice = sliced_once.slice(1..6).unwrap();
616        assert_eq!(
617            usize::try_from(&sliced_twice.scalar_at(3).unwrap()).unwrap(),
618            200
619        );
620    }
621
622    #[test]
623    pub fn sparse_validity_mask() {
624        let array = sparse_array(nullable_fill());
625        assert_eq!(
626            array
627                .validity_mask()
628                .unwrap()
629                .to_bit_buffer()
630                .iter()
631                .collect_vec(),
632            [
633                false, false, true, false, false, true, false, false, true, false
634            ]
635        );
636    }
637
638    #[test]
639    fn sparse_validity_mask_non_null_fill() {
640        let array = sparse_array(non_nullable_fill());
641        assert!(array.validity_mask().unwrap().all_true());
642    }
643
644    #[test]
645    #[should_panic]
646    fn test_invalid_length() {
647        let values = buffer![15_u32, 135, 13531, 42].into_array();
648        let indices = buffer![10_u64, 11, 50, 100].into_array();
649
650        SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
651    }
652
653    #[test]
654    fn test_valid_length() {
655        let values = buffer![15_u32, 135, 13531, 42].into_array();
656        let indices = buffer![10_u64, 11, 50, 100].into_array();
657
658        SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
659    }
660
661    #[test]
662    fn encode_with_nulls() {
663        let original = PrimitiveArray::new(
664            buffer![0i32, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
665            Validity::from_iter(vec![
666                true, true, false, true, false, true, false, true, true, false, true, false,
667            ]),
668        );
669        let sparse = SparseArray::encode(&original.clone().into_array(), None)
670            .vortex_expect("SparseArray::encode should succeed for test data");
671        assert_eq!(
672            sparse.validity_mask().unwrap(),
673            Mask::from_iter(vec![
674                true, true, false, true, false, true, false, true, true, false, true, false,
675            ])
676        );
677        assert_arrays_eq!(sparse.to_primitive(), original);
678    }
679
680    #[test]
681    fn validity_mask_includes_null_values_when_fill_is_null() {
682        let indices = buffer![0u8, 2, 4, 6, 8].into_array();
683        let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
684            .into_array();
685        let array =
686            SparseArray::try_new(indices, values, 10, Scalar::null_native::<i16>()).unwrap();
687        let actual = array.validity_mask().unwrap();
688        let expected = Mask::from_iter([
689            true, false, true, false, false, false, false, false, true, false,
690        ]);
691
692        assert_eq!(actual, expected);
693    }
694}