Skip to main content

vortex_sparse/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use kernel::PARENT_KERNELS;
9use prost::Message as _;
10use vortex_array::ArrayEq;
11use vortex_array::ArrayHash;
12use vortex_array::ArrayRef;
13use vortex_array::DynArray;
14use vortex_array::ExecutionCtx;
15use vortex_array::ExecutionResult;
16use vortex_array::IntoArray;
17use vortex_array::Precision;
18use vortex_array::ToCanonical;
19use vortex_array::arrays::ConstantArray;
20use vortex_array::buffer::BufferHandle;
21use vortex_array::builtins::ArrayBuiltins;
22use vortex_array::dtype::DType;
23use vortex_array::dtype::Nullability;
24use vortex_array::patches::Patches;
25use vortex_array::patches::PatchesMetadata;
26use vortex_array::scalar::Scalar;
27use vortex_array::scalar::ScalarValue;
28use vortex_array::scalar_fn::fns::operators::Operator;
29use vortex_array::serde::ArrayChildren;
30use vortex_array::stats::ArrayStats;
31use vortex_array::stats::StatsSetRef;
32use vortex_array::validity::Validity;
33use vortex_array::vtable;
34use vortex_array::vtable::Array;
35use vortex_array::vtable::ArrayId;
36use vortex_array::vtable::VTable;
37use vortex_array::vtable::ValidityVTable;
38use vortex_array::vtable::patches_child;
39use vortex_array::vtable::patches_child_name;
40use vortex_array::vtable::patches_nchildren;
41use vortex_buffer::Buffer;
42use vortex_buffer::ByteBufferMut;
43use vortex_error::VortexExpect as _;
44use vortex_error::VortexResult;
45use vortex_error::vortex_bail;
46use vortex_error::vortex_ensure;
47use vortex_error::vortex_ensure_eq;
48use vortex_error::vortex_panic;
49use vortex_mask::AllOr;
50use vortex_mask::Mask;
51use vortex_session::VortexSession;
52
53use crate::canonical::execute_sparse;
54use crate::rules::RULES;
55
56mod canonical;
57mod compute;
58mod kernel;
59mod ops;
60mod rules;
61mod slice;
62
63vtable!(Sparse);
64
65#[derive(Debug)]
66pub struct SparseMetadata {
67    patches: PatchesMetadata,
68    fill_value: Scalar,
69}
70
71#[derive(Clone, prost::Message)]
72#[repr(C)]
73pub struct ProstPatchesMetadata {
74    #[prost(message, required, tag = "1")]
75    patches: PatchesMetadata,
76}
77
78impl VTable for Sparse {
79    type Array = SparseArray;
80
81    type Metadata = SparseMetadata;
82    type OperationsVTable = Self;
83    type ValidityVTable = Self;
84
85    fn vtable(_array: &Self::Array) -> &Self {
86        &Sparse
87    }
88
89    fn id(&self) -> ArrayId {
90        Self::ID
91    }
92
93    fn len(array: &SparseArray) -> usize {
94        array.patches.array_len()
95    }
96
97    fn dtype(array: &SparseArray) -> &DType {
98        array.fill_scalar().dtype()
99    }
100
101    fn stats(array: &SparseArray) -> StatsSetRef<'_> {
102        array.stats_set.to_ref(array.as_ref())
103    }
104
105    fn array_hash<H: std::hash::Hasher>(array: &SparseArray, state: &mut H, precision: Precision) {
106        array.patches.array_hash(state, precision);
107        array.fill_value.hash(state);
108    }
109
110    fn array_eq(array: &SparseArray, other: &SparseArray, precision: Precision) -> bool {
111        array.patches.array_eq(&other.patches, precision) && array.fill_value == other.fill_value
112    }
113
114    fn nbuffers(_array: &SparseArray) -> usize {
115        1
116    }
117
118    fn buffer(array: &SparseArray, idx: usize) -> BufferHandle {
119        match idx {
120            0 => {
121                let fill_value_buffer =
122                    ScalarValue::to_proto_bytes::<ByteBufferMut>(array.fill_value.value()).freeze();
123                BufferHandle::new_host(fill_value_buffer)
124            }
125            _ => vortex_panic!("SparseArray buffer index {idx} out of bounds"),
126        }
127    }
128
129    fn buffer_name(_array: &SparseArray, idx: usize) -> Option<String> {
130        match idx {
131            0 => Some("fill_value".to_string()),
132            _ => vortex_panic!("SparseArray buffer_name index {idx} out of bounds"),
133        }
134    }
135
136    fn nchildren(array: &SparseArray) -> usize {
137        patches_nchildren(array.patches())
138    }
139
140    fn child(array: &SparseArray, idx: usize) -> ArrayRef {
141        patches_child(array.patches(), idx)
142    }
143
144    fn child_name(_array: &SparseArray, idx: usize) -> String {
145        patches_child_name(idx).to_string()
146    }
147
148    fn metadata(array: &SparseArray) -> VortexResult<Self::Metadata> {
149        let patches = array.patches().to_metadata(array.len(), array.dtype())?;
150
151        Ok(SparseMetadata {
152            patches,
153            fill_value: array.fill_value.clone(),
154        })
155    }
156
157    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
158        let prost_patches = ProstPatchesMetadata {
159            patches: metadata.patches,
160        };
161
162        // Note that we DO NOT serialize the fill value since that is stored in the buffers.
163        Ok(Some(prost_patches.encode_to_vec()))
164    }
165
166    fn deserialize(
167        bytes: &[u8],
168        dtype: &DType,
169        _len: usize,
170        buffers: &[BufferHandle],
171        session: &VortexSession,
172    ) -> VortexResult<Self::Metadata> {
173        let prost_patches = ProstPatchesMetadata::decode(bytes)?;
174
175        // Once we have the patches metadata, we need to get the fill value from the buffers.
176
177        if buffers.len() != 1 {
178            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
179        }
180        let scalar_bytes: &[u8] = &buffers[0].clone().try_to_host_sync()?;
181
182        let scalar_value = ScalarValue::from_proto_bytes(scalar_bytes, dtype, session)?;
183        let fill_value = Scalar::try_new(dtype.clone(), scalar_value)?;
184
185        Ok(SparseMetadata {
186            patches: prost_patches.patches,
187            fill_value,
188        })
189    }
190
191    fn build(
192        dtype: &DType,
193        len: usize,
194        metadata: &Self::Metadata,
195        _buffers: &[BufferHandle],
196        children: &dyn ArrayChildren,
197    ) -> VortexResult<SparseArray> {
198        vortex_ensure_eq!(
199            children.len(),
200            2,
201            "SparseArray expects 2 children for sparse encoding, found {}",
202            children.len()
203        );
204
205        let patch_indices = children.get(
206            0,
207            &metadata.patches.indices_dtype()?,
208            metadata.patches.len()?,
209        )?;
210        let patch_values = children.get(1, dtype, metadata.patches.len()?)?;
211
212        SparseArray::try_new_from_patches(
213            Patches::new(
214                len,
215                metadata.patches.offset()?,
216                patch_indices,
217                patch_values,
218                None,
219            )?,
220            metadata.fill_value.clone(),
221        )
222    }
223
224    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
225        vortex_ensure_eq!(
226            children.len(),
227            2,
228            "SparseArray expects 2 children, got {}",
229            children.len()
230        );
231
232        let mut children_iter = children.into_iter();
233        let patch_indices = children_iter.next().vortex_expect("patch_indices child");
234        let patch_values = children_iter.next().vortex_expect("patch_values child");
235
236        array.patches = Patches::new(
237            array.patches.array_len(),
238            array.patches.offset(),
239            patch_indices,
240            patch_values,
241            array.patches.chunk_offsets().clone(),
242        )?;
243
244        Ok(())
245    }
246
247    fn reduce_parent(
248        array: &Array<Self>,
249        parent: &ArrayRef,
250        child_idx: usize,
251    ) -> VortexResult<Option<ArrayRef>> {
252        RULES.evaluate(array, parent, child_idx)
253    }
254
255    fn execute_parent(
256        array: &Array<Self>,
257        parent: &ArrayRef,
258        child_idx: usize,
259        ctx: &mut ExecutionCtx,
260    ) -> VortexResult<Option<ArrayRef>> {
261        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
262    }
263
264    fn execute(array: Arc<Array<Self>>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
265        execute_sparse(&array, ctx).map(ExecutionResult::done)
266    }
267}
268
269#[derive(Clone, Debug)]
270pub struct SparseArray {
271    patches: Patches,
272    fill_value: Scalar,
273    stats_set: ArrayStats,
274}
275
276#[derive(Clone, Debug)]
277pub struct Sparse;
278
279impl Sparse {
280    pub const ID: ArrayId = ArrayId::new_ref("vortex.sparse");
281}
282
283impl SparseArray {
284    pub fn try_new(
285        indices: ArrayRef,
286        values: ArrayRef,
287        len: usize,
288        fill_value: Scalar,
289    ) -> VortexResult<Self> {
290        vortex_ensure!(
291            indices.len() == values.len(),
292            "Mismatched indices {} and values {} length",
293            indices.len(),
294            values.len()
295        );
296
297        if indices.is_host() {
298            debug_assert_eq!(
299                indices.statistics().compute_is_strict_sorted(),
300                Some(true),
301                "SparseArray: indices must be strict-sorted"
302            );
303
304            // Verify the indices are all in the valid range
305            if !indices.is_empty() {
306                let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1)?)?;
307
308                vortex_ensure!(
309                    last_index < len,
310                    "Array length was {len} but the last index is {last_index}"
311                );
312            }
313        }
314
315        Ok(Self {
316            // TODO(0ax1): handle chunk offsets
317            patches: Patches::new(len, 0, indices, values, None)?,
318            fill_value,
319            stats_set: Default::default(),
320        })
321    }
322
323    /// Build a new SparseArray from an existing set of patches.
324    pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
325        vortex_ensure!(
326            fill_value.dtype() == patches.values().dtype(),
327            "fill value, {:?}, should be instance of values dtype, {} but was {}.",
328            fill_value,
329            patches.values().dtype(),
330            fill_value.dtype(),
331        );
332
333        Ok(Self {
334            patches,
335            fill_value,
336            stats_set: Default::default(),
337        })
338    }
339
340    pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
341        Self {
342            patches,
343            fill_value,
344            stats_set: Default::default(),
345        }
346    }
347
348    #[inline]
349    pub fn patches(&self) -> &Patches {
350        &self.patches
351    }
352
353    #[inline]
354    pub fn resolved_patches(&self) -> VortexResult<Patches> {
355        let patches = self.patches();
356        let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?;
357        let indices = patches.indices().to_array().binary(
358            ConstantArray::new(indices_offset, patches.indices().len()).into_array(),
359            Operator::Sub,
360        )?;
361
362        Patches::new(
363            patches.array_len(),
364            0,
365            indices,
366            patches.values().clone(),
367            // TODO(0ax1): handle chunk offsets
368            None,
369        )
370    }
371
372    #[inline]
373    pub fn fill_scalar(&self) -> &Scalar {
374        &self.fill_value
375    }
376
377    /// Encode given array as a SparseArray.
378    ///
379    /// Optionally provided fill value will be respected if the array is less than 90% null.
380    pub fn encode(array: &ArrayRef, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
381        if let Some(fill_value) = fill_value.as_ref()
382            && array.dtype() != fill_value.dtype()
383        {
384            vortex_bail!(
385                "Array and fill value types must match. got {} and {}",
386                array.dtype(),
387                fill_value.dtype()
388            )
389        }
390        let mask = array.validity_mask()?;
391
392        if mask.all_false() {
393            // Array is constant NULL
394            return Ok(
395                ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
396            );
397        } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
398            // Array is dominated by NULL but has non-NULL values
399            // TODO(joe): use exe ctx?
400            let non_null_values = array.filter(mask.clone())?.to_canonical()?.into_array();
401            let non_null_indices = match mask.indices() {
402                AllOr::All => {
403                    // We already know that the mask is 90%+ false
404                    unreachable!("Mask is mostly null")
405                }
406                AllOr::None => {
407                    // we know there are some non-NULL values
408                    unreachable!("Mask is mostly null but not all null")
409                }
410                AllOr::Some(values) => {
411                    let buffer: Buffer<u32> = values
412                        .iter()
413                        .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
414                        .collect();
415
416                    buffer.into_array()
417                }
418            };
419
420            return Ok(SparseArray::try_new(
421                non_null_indices,
422                non_null_values,
423                array.len(),
424                Scalar::null(array.dtype().clone()),
425            )?
426            .into_array());
427        }
428
429        let fill = if let Some(fill) = fill_value {
430            fill
431        } else {
432            // TODO(robert): Support other dtypes, only thing missing is getting most common value out of the array
433            let (top_pvalue, _) = array
434                .to_primitive()
435                .top_value()?
436                .vortex_expect("Non empty or all null array");
437
438            Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
439        };
440
441        let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
442        let non_top_mask = Mask::from_buffer(
443            array
444                .to_array()
445                .binary(fill_array.clone(), Operator::NotEq)?
446                .fill_null(Scalar::bool(true, Nullability::NonNullable))?
447                .to_bool()
448                .to_bit_buffer(),
449        );
450
451        let non_top_values = array
452            .filter(non_top_mask.clone())?
453            .to_canonical()?
454            .into_array();
455
456        let indices: Buffer<u64> = match non_top_mask {
457            Mask::AllTrue(count) => {
458                // all true -> complete slice
459                (0u64..count as u64).collect()
460            }
461            Mask::AllFalse(_) => {
462                // All values are equal to the top value
463                return Ok(fill_array);
464            }
465            Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
466        };
467
468        SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
469            .map(|a| a.into_array())
470    }
471}
472
473impl ValidityVTable<Sparse> for Sparse {
474    fn validity(array: &SparseArray) -> VortexResult<Validity> {
475        let patches = unsafe {
476            Patches::new_unchecked(
477                array.patches.array_len(),
478                array.patches.offset(),
479                array.patches.indices().clone(),
480                array
481                    .patches
482                    .values()
483                    .validity()?
484                    .to_array(array.patches.values().len()),
485                array.patches.chunk_offsets().clone(),
486                array.patches.offset_within_chunk(),
487            )
488        };
489
490        Ok(Validity::Array(
491            unsafe { SparseArray::new_unchecked(patches, array.fill_value.is_valid().into()) }
492                .into_array(),
493        ))
494    }
495}
496
497#[cfg(test)]
498mod test {
499    use itertools::Itertools;
500    use vortex_array::DynArray;
501    use vortex_array::IntoArray;
502    use vortex_array::arrays::ConstantArray;
503    use vortex_array::arrays::PrimitiveArray;
504    use vortex_array::assert_arrays_eq;
505    use vortex_array::builtins::ArrayBuiltins;
506    use vortex_array::dtype::DType;
507    use vortex_array::dtype::Nullability;
508    use vortex_array::dtype::PType;
509    use vortex_array::scalar::Scalar;
510    use vortex_array::validity::Validity;
511    use vortex_buffer::buffer;
512    use vortex_error::VortexExpect;
513
514    use super::*;
515
516    fn nullable_fill() -> Scalar {
517        Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
518    }
519
520    fn non_nullable_fill() -> Scalar {
521        Scalar::from(42i32)
522    }
523
524    fn sparse_array(fill_value: Scalar) -> ArrayRef {
525        // merged array: [null, null, 100, null, null, 200, null, null, 300, null]
526        let mut values = buffer![100i32, 200, 300].into_array();
527        values = values.cast(fill_value.dtype().clone()).unwrap();
528
529        SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
530            .unwrap()
531            .into_array()
532    }
533
534    #[test]
535    pub fn test_scalar_at() {
536        let array = sparse_array(nullable_fill());
537
538        assert_eq!(array.scalar_at(0).unwrap(), nullable_fill());
539        assert_eq!(array.scalar_at(2).unwrap(), Scalar::from(Some(100_i32)));
540        assert_eq!(array.scalar_at(5).unwrap(), Scalar::from(Some(200_i32)));
541    }
542
543    #[test]
544    #[should_panic(expected = "out of bounds")]
545    fn test_scalar_at_oob() {
546        let array = sparse_array(nullable_fill());
547        array.scalar_at(10).unwrap();
548    }
549
550    #[test]
551    pub fn test_scalar_at_again() {
552        let arr = SparseArray::try_new(
553            ConstantArray::new(10u32, 1).into_array(),
554            ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
555            100,
556            Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
557        )
558        .unwrap();
559
560        assert_eq!(
561            arr.scalar_at(10)
562                .unwrap()
563                .as_primitive()
564                .typed_value::<u32>(),
565            Some(1234)
566        );
567        assert!(arr.scalar_at(0).unwrap().is_null());
568        assert!(arr.scalar_at(99).unwrap().is_null());
569    }
570
571    #[test]
572    pub fn scalar_at_sliced() {
573        let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
574        assert_eq!(usize::try_from(&sliced.scalar_at(0).unwrap()).unwrap(), 100);
575    }
576
577    #[test]
578    pub fn validity_mask_sliced_null_fill() {
579        let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
580        assert_eq!(
581            sliced.validity_mask().unwrap(),
582            Mask::from_iter(vec![true, false, false, true, false])
583        );
584    }
585
586    #[test]
587    pub fn validity_mask_sliced_nonnull_fill() {
588        let sliced = SparseArray::try_new(
589            buffer![2u64, 5, 8].into_array(),
590            ConstantArray::new(
591                Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
592                3,
593            )
594            .into_array(),
595            10,
596            Scalar::primitive(1.0f32, Nullability::Nullable),
597        )
598        .unwrap()
599        .slice(2..7)
600        .unwrap();
601
602        assert_eq!(
603            sliced.validity_mask().unwrap(),
604            Mask::from_iter(vec![false, true, true, false, true])
605        );
606    }
607
608    #[test]
609    pub fn scalar_at_sliced_twice() {
610        let sliced_once = sparse_array(nullable_fill()).slice(1..8).unwrap();
611        assert_eq!(
612            usize::try_from(&sliced_once.scalar_at(1).unwrap()).unwrap(),
613            100
614        );
615
616        let sliced_twice = sliced_once.slice(1..6).unwrap();
617        assert_eq!(
618            usize::try_from(&sliced_twice.scalar_at(3).unwrap()).unwrap(),
619            200
620        );
621    }
622
623    #[test]
624    pub fn sparse_validity_mask() {
625        let array = sparse_array(nullable_fill());
626        assert_eq!(
627            array
628                .validity_mask()
629                .unwrap()
630                .to_bit_buffer()
631                .iter()
632                .collect_vec(),
633            [
634                false, false, true, false, false, true, false, false, true, false
635            ]
636        );
637    }
638
639    #[test]
640    fn sparse_validity_mask_non_null_fill() {
641        let array = sparse_array(non_nullable_fill());
642        assert!(array.validity_mask().unwrap().all_true());
643    }
644
645    #[test]
646    #[should_panic]
647    fn test_invalid_length() {
648        let values = buffer![15_u32, 135, 13531, 42].into_array();
649        let indices = buffer![10_u64, 11, 50, 100].into_array();
650
651        SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
652    }
653
654    #[test]
655    fn test_valid_length() {
656        let values = buffer![15_u32, 135, 13531, 42].into_array();
657        let indices = buffer![10_u64, 11, 50, 100].into_array();
658
659        SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
660    }
661
662    #[test]
663    fn encode_with_nulls() {
664        let original = PrimitiveArray::new(
665            buffer![0i32, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
666            Validity::from_iter(vec![
667                true, true, false, true, false, true, false, true, true, false, true, false,
668            ]),
669        );
670        let sparse = SparseArray::encode(&original.clone().into_array(), None)
671            .vortex_expect("SparseArray::encode should succeed for test data");
672        assert_eq!(
673            sparse.validity_mask().unwrap(),
674            Mask::from_iter(vec![
675                true, true, false, true, false, true, false, true, true, false, true, false,
676            ])
677        );
678        assert_arrays_eq!(sparse.to_primitive(), original);
679    }
680
681    #[test]
682    fn validity_mask_includes_null_values_when_fill_is_null() {
683        let indices = buffer![0u8, 2, 4, 6, 8].into_array();
684        let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
685            .into_array();
686        let array =
687            SparseArray::try_new(indices, values, 10, Scalar::null_native::<i16>()).unwrap();
688        let actual = array.validity_mask().unwrap();
689        let expected = Mask::from_iter([
690            true, false, true, false, false, false, false, false, true, false,
691        ]);
692
693        assert_eq!(actual, expected);
694    }
695}