Skip to main content

vortex_sparse/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9
10use kernel::PARENT_KERNELS;
11use prost::Message as _;
12use vortex_array::Array;
13use vortex_array::ArrayEq;
14use vortex_array::ArrayHash;
15use vortex_array::ArrayId;
16use vortex_array::ArrayParts;
17use vortex_array::ArrayRef;
18use vortex_array::ArrayView;
19use vortex_array::ExecutionCtx;
20use vortex_array::ExecutionResult;
21use vortex_array::IntoArray;
22use vortex_array::Precision;
23use vortex_array::ToCanonical;
24use vortex_array::arrays::ConstantArray;
25use vortex_array::arrays::bool::BoolArrayExt;
26use vortex_array::buffer::BufferHandle;
27use vortex_array::builtins::ArrayBuiltins;
28use vortex_array::dtype::DType;
29use vortex_array::dtype::Nullability;
30use vortex_array::patches::Patches;
31use vortex_array::patches::PatchesMetadata;
32use vortex_array::scalar::Scalar;
33use vortex_array::scalar::ScalarValue;
34use vortex_array::scalar_fn::fns::operators::Operator;
35use vortex_array::serde::ArrayChildren;
36use vortex_array::validity::Validity;
37use vortex_array::vtable::VTable;
38use vortex_array::vtable::ValidityVTable;
39use vortex_buffer::Buffer;
40use vortex_buffer::ByteBufferMut;
41use vortex_error::VortexExpect as _;
42use vortex_error::VortexResult;
43use vortex_error::vortex_bail;
44use vortex_error::vortex_ensure;
45use vortex_error::vortex_ensure_eq;
46use vortex_error::vortex_panic;
47use vortex_mask::AllOr;
48use vortex_mask::Mask;
49use vortex_session::VortexSession;
50
51use crate::canonical::execute_sparse;
52use crate::rules::RULES;
53
54mod canonical;
55mod compute;
56mod kernel;
57mod ops;
58mod rules;
59mod slice;
60
61/// A [`Sparse`]-encoded Vortex array.
62pub type SparseArray = Array<Sparse>;
63
64#[derive(Clone, prost::Message)]
65#[repr(C)]
66pub struct SparseMetadata {
67    #[prost(message, required, tag = "1")]
68    patches: PatchesMetadata,
69}
70
71impl ArrayHash for SparseData {
72    fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
73        self.patches.array_hash(state, precision);
74        self.fill_value.hash(state);
75    }
76}
77
78impl ArrayEq for SparseData {
79    fn array_eq(&self, other: &Self, precision: Precision) -> bool {
80        self.patches.array_eq(&other.patches, precision) && self.fill_value == other.fill_value
81    }
82}
83
84impl VTable for Sparse {
85    type ArrayData = SparseData;
86
87    type OperationsVTable = Self;
88    type ValidityVTable = Self;
89
90    fn id(&self) -> ArrayId {
91        Self::ID
92    }
93
94    fn validate(
95        &self,
96        data: &Self::ArrayData,
97        dtype: &DType,
98        len: usize,
99        _slots: &[Option<ArrayRef>],
100    ) -> VortexResult<()> {
101        SparseData::validate(data.patches(), data.fill_scalar(), dtype, len)
102    }
103
104    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
105        1
106    }
107
108    fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
109        match idx {
110            0 => {
111                let fill_value_buffer =
112                    ScalarValue::to_proto_bytes::<ByteBufferMut>(array.fill_value.value()).freeze();
113                BufferHandle::new_host(fill_value_buffer)
114            }
115            _ => vortex_panic!("SparseArray buffer index {idx} out of bounds"),
116        }
117    }
118
119    fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
120        match idx {
121            0 => Some("fill_value".to_string()),
122            _ => vortex_panic!("SparseArray buffer_name index {idx} out of bounds"),
123        }
124    }
125
126    fn serialize(
127        array: ArrayView<'_, Self>,
128        _session: &VortexSession,
129    ) -> VortexResult<Option<Vec<u8>>> {
130        let patches = array.patches().to_metadata(array.len(), array.dtype())?;
131        let metadata = SparseMetadata { patches };
132
133        // Note that we DO NOT serialize the fill value since that is stored in the buffers.
134        Ok(Some(metadata.encode_to_vec()))
135    }
136
137    fn deserialize(
138        &self,
139        dtype: &DType,
140        len: usize,
141        metadata: &[u8],
142        buffers: &[BufferHandle],
143        children: &dyn ArrayChildren,
144        session: &VortexSession,
145    ) -> VortexResult<ArrayParts<Self>> {
146        let metadata = SparseMetadata::decode(metadata)?;
147
148        // Once we have the patches metadata, we need to get the fill value from the buffers.
149
150        if buffers.len() != 1 {
151            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
152        }
153        let scalar_bytes: &[u8] = &buffers[0].clone().try_to_host_sync()?;
154
155        let scalar_value = ScalarValue::from_proto_bytes(scalar_bytes, dtype, session)?;
156        let fill_value = Scalar::try_new(dtype.clone(), scalar_value)?;
157
158        vortex_ensure_eq!(
159            children.len(),
160            2,
161            "SparseArray expects 2 children for sparse encoding, found {}",
162            children.len()
163        );
164
165        let patch_indices = children.get(
166            0,
167            &metadata.patches.indices_dtype()?,
168            metadata.patches.len()?,
169        )?;
170        let patch_values = children.get(1, dtype, metadata.patches.len()?)?;
171
172        let patches = Patches::new(
173            len,
174            metadata.patches.offset()?,
175            patch_indices,
176            patch_values,
177            None,
178        )?;
179        let slots = SparseData::make_slots(&patches);
180        let data = SparseData::try_new_from_patches(patches, fill_value)?;
181        Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
182    }
183
184    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
185        SLOT_NAMES[idx].to_string()
186    }
187
188    fn reduce_parent(
189        array: ArrayView<'_, Self>,
190        parent: &ArrayRef,
191        child_idx: usize,
192    ) -> VortexResult<Option<ArrayRef>> {
193        RULES.evaluate(array, parent, child_idx)
194    }
195
196    fn execute_parent(
197        array: ArrayView<'_, Self>,
198        parent: &ArrayRef,
199        child_idx: usize,
200        ctx: &mut ExecutionCtx,
201    ) -> VortexResult<Option<ArrayRef>> {
202        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
203    }
204
205    fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
206        execute_sparse(&array, ctx).map(ExecutionResult::done)
207    }
208}
209
210pub(crate) const NUM_SLOTS: usize = 3;
211pub(crate) const SLOT_NAMES: [&str; NUM_SLOTS] =
212    ["patch_indices", "patch_values", "patch_chunk_offsets"];
213
214#[derive(Clone, Debug)]
215pub struct SparseData {
216    patches: Patches,
217    fill_value: Scalar,
218}
219
220impl Display for SparseData {
221    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
222        write!(f, "fill_value: {}", self.fill_value)
223    }
224}
225
226#[derive(Clone, Debug)]
227pub struct Sparse;
228
229impl Sparse {
230    pub const ID: ArrayId = ArrayId::new_ref("vortex.sparse");
231
232    /// Construct a new [`SparseArray`] from indices, values, length, and fill value.
233    pub fn try_new(
234        indices: ArrayRef,
235        values: ArrayRef,
236        len: usize,
237        fill_value: Scalar,
238    ) -> VortexResult<SparseArray> {
239        let dtype = fill_value.dtype().clone();
240        let patches = Patches::new(len, 0, indices, values, None)?;
241        let slots = SparseData::make_slots(&patches);
242        let data = SparseData::try_new_from_patches(patches, fill_value)?;
243        Ok(unsafe {
244            Array::from_parts_unchecked(ArrayParts::new(Sparse, dtype, len, data).with_slots(slots))
245        })
246    }
247
248    pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<SparseArray> {
249        let dtype = fill_value.dtype().clone();
250        let len = patches.array_len();
251        let slots = SparseData::make_slots(&patches);
252        let data = SparseData::try_new_from_patches(patches, fill_value)?;
253        Ok(unsafe {
254            Array::from_parts_unchecked(ArrayParts::new(Sparse, dtype, len, data).with_slots(slots))
255        })
256    }
257
258    pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> SparseArray {
259        let dtype = fill_value.dtype().clone();
260        let len = patches.array_len();
261        let slots = SparseData::make_slots(&patches);
262        let data = unsafe { SparseData::new_unchecked(patches, fill_value) };
263        unsafe {
264            Array::from_parts_unchecked(ArrayParts::new(Sparse, dtype, len, data).with_slots(slots))
265        }
266    }
267
268    /// Encode the given array as a [`SparseArray`].
269    pub fn encode(array: &ArrayRef, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
270        SparseData::encode(array, fill_value)
271    }
272}
273
274impl SparseData {
275    fn normalize_patches_dtype(patches: Patches, fill_value: &Scalar) -> VortexResult<Patches> {
276        let fill_dtype = fill_value.dtype();
277        let values_dtype = patches.values().dtype();
278
279        vortex_ensure!(
280            values_dtype.eq_ignore_nullability(fill_dtype),
281            "fill value, {:?}, should be instance of values dtype, {} but was {}.",
282            fill_value,
283            values_dtype,
284            fill_dtype,
285        );
286
287        if values_dtype == fill_dtype {
288            Ok(patches)
289        } else {
290            patches.cast_values(fill_dtype)
291        }
292    }
293
294    pub fn validate(
295        patches: &Patches,
296        fill_value: &Scalar,
297        dtype: &DType,
298        len: usize,
299    ) -> VortexResult<()> {
300        vortex_ensure!(
301            fill_value.dtype() == dtype,
302            "fill value dtype {} does not match array dtype {}",
303            fill_value.dtype(),
304            dtype,
305        );
306        vortex_ensure!(
307            patches.array_len() == len,
308            "patches length {} does not match array length {}",
309            patches.array_len(),
310            len
311        );
312        vortex_ensure!(
313            patches.values().dtype() == dtype,
314            "patch values dtype {} does not match array dtype {}",
315            patches.values().dtype(),
316            dtype,
317        );
318        Ok(())
319    }
320
321    fn make_slots(patches: &Patches) -> Vec<Option<ArrayRef>> {
322        vec![
323            Some(patches.indices().clone()),
324            Some(patches.values().clone()),
325            patches.chunk_offsets().clone(),
326        ]
327    }
328
329    /// Build a new SparseArray from an existing set of patches.
330    pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
331        let patches = Self::normalize_patches_dtype(patches, &fill_value)?;
332        Ok(Self {
333            patches,
334            fill_value,
335        })
336    }
337
338    pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
339        Self {
340            patches,
341            fill_value,
342        }
343    }
344
345    /// Returns the length of the array.
346    #[inline]
347    pub fn len(&self) -> usize {
348        self.patches.array_len()
349    }
350
351    /// Returns whether the array is empty.
352    #[inline]
353    pub fn is_empty(&self) -> bool {
354        self.patches.array_len() == 0
355    }
356
357    /// Returns the logical data type of the array.
358    #[inline]
359    pub fn dtype(&self) -> &DType {
360        self.fill_scalar().dtype()
361    }
362
363    #[inline]
364    pub fn patches(&self) -> &Patches {
365        &self.patches
366    }
367
368    #[inline]
369    pub fn resolved_patches(&self) -> VortexResult<Patches> {
370        let patches = self.patches();
371        let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?;
372        let indices = patches.indices().binary(
373            ConstantArray::new(indices_offset, patches.indices().len()).into_array(),
374            Operator::Sub,
375        )?;
376
377        Patches::new(
378            patches.array_len(),
379            0,
380            indices,
381            patches.values().clone(),
382            // TODO(0ax1): handle chunk offsets
383            None,
384        )
385    }
386
387    #[inline]
388    pub fn fill_scalar(&self) -> &Scalar {
389        &self.fill_value
390    }
391
392    /// Encode given array as a SparseArray.
393    ///
394    /// Optionally provided fill value will be respected if the array is less than 90% null.
395    pub fn encode(array: &ArrayRef, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
396        if let Some(fill_value) = fill_value.as_ref()
397            && !array.dtype().eq_ignore_nullability(fill_value.dtype())
398        {
399            vortex_bail!(
400                "Array and fill value types must have the same base type. got {} and {}",
401                array.dtype(),
402                fill_value.dtype()
403            )
404        }
405        let mask = array.validity_mask()?;
406
407        if mask.all_false() {
408            // Array is constant NULL
409            return Ok(
410                ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
411            );
412        } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
413            // Array is dominated by NULL but has non-NULL values
414            // TODO(joe): use exe ctx?
415            let non_null_values = array.filter(mask.clone())?.to_canonical()?.into_array();
416            let non_null_indices = match mask.indices() {
417                AllOr::All => {
418                    // We already know that the mask is 90%+ false
419                    unreachable!("Mask is mostly null")
420                }
421                AllOr::None => {
422                    // we know there are some non-NULL values
423                    unreachable!("Mask is mostly null but not all null")
424                }
425                AllOr::Some(values) => {
426                    let buffer: Buffer<u32> = values
427                        .iter()
428                        .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
429                        .collect();
430
431                    buffer.into_array()
432                }
433            };
434
435            return Sparse::try_new(
436                non_null_indices,
437                non_null_values,
438                array.len(),
439                Scalar::null(array.dtype().clone()),
440            )
441            .map(IntoArray::into_array);
442        }
443
444        let fill = if let Some(fill) = fill_value {
445            fill.cast(array.dtype())?
446        } else {
447            // TODO(robert): Support other dtypes, only thing missing is getting most common value out of the array
448            let (top_pvalue, _) = array
449                .to_primitive()
450                .top_value()?
451                .vortex_expect("Non empty or all null array");
452
453            Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
454        };
455
456        let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
457        let non_top_mask = Mask::from_buffer(
458            array
459                .binary(fill_array.clone(), Operator::NotEq)?
460                .fill_null(Scalar::bool(true, Nullability::NonNullable))?
461                .to_bool()
462                .to_bit_buffer(),
463        );
464
465        let non_top_values = array
466            .filter(non_top_mask.clone())?
467            .to_canonical()?
468            .into_array();
469
470        let indices: Buffer<u64> = match non_top_mask {
471            Mask::AllTrue(count) => {
472                // all true -> complete slice
473                (0u64..count as u64).collect()
474            }
475            Mask::AllFalse(_) => {
476                // All values are equal to the top value
477                return Ok(fill_array);
478            }
479            Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
480        };
481
482        Sparse::try_new(indices.into_array(), non_top_values, array.len(), fill)
483            .map(IntoArray::into_array)
484    }
485}
486
487impl ValidityVTable<Sparse> for Sparse {
488    fn validity(array: ArrayView<'_, Sparse>) -> VortexResult<Validity> {
489        let patches = unsafe {
490            Patches::new_unchecked(
491                array.patches.array_len(),
492                array.patches.offset(),
493                array.patches.indices().clone(),
494                array
495                    .patches
496                    .values()
497                    .validity()?
498                    .to_array(array.patches.values().len()),
499                array.patches.chunk_offsets().clone(),
500                array.patches.offset_within_chunk(),
501            )
502        };
503
504        Ok(Validity::Array(
505            unsafe { Sparse::new_unchecked(patches, array.fill_value.is_valid().into()) }
506                .into_array(),
507        ))
508    }
509}
510
511#[cfg(test)]
512mod test {
513    use itertools::Itertools;
514    use vortex_array::IntoArray;
515    use vortex_array::arrays::ConstantArray;
516    use vortex_array::arrays::PrimitiveArray;
517    use vortex_array::assert_arrays_eq;
518    use vortex_array::builtins::ArrayBuiltins;
519    use vortex_array::dtype::DType;
520    use vortex_array::dtype::Nullability;
521    use vortex_array::dtype::PType;
522    use vortex_array::scalar::Scalar;
523    use vortex_array::validity::Validity;
524    use vortex_buffer::buffer;
525    use vortex_error::VortexExpect;
526
527    use super::*;
528    use crate::Sparse;
529
530    fn nullable_fill() -> Scalar {
531        Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
532    }
533
534    fn non_nullable_fill() -> Scalar {
535        Scalar::from(42i32)
536    }
537
538    fn sparse_array(fill_value: Scalar) -> ArrayRef {
539        // merged array: [null, null, 100, null, null, 200, null, null, 300, null]
540        let mut values = buffer![100i32, 200, 300].into_array();
541        values = values.cast(fill_value.dtype().clone()).unwrap();
542
543        Sparse::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
544            .unwrap()
545            .into_array()
546    }
547
548    #[test]
549    pub fn test_scalar_at() {
550        let array = sparse_array(nullable_fill());
551
552        assert_eq!(array.scalar_at(0).unwrap(), nullable_fill());
553        assert_eq!(array.scalar_at(2).unwrap(), Scalar::from(Some(100_i32)));
554        assert_eq!(array.scalar_at(5).unwrap(), Scalar::from(Some(200_i32)));
555    }
556
557    #[test]
558    #[should_panic(expected = "out of bounds")]
559    fn test_scalar_at_oob() {
560        let array = sparse_array(nullable_fill());
561        array.scalar_at(10).unwrap();
562    }
563
564    #[test]
565    pub fn test_scalar_at_again() {
566        let arr = Sparse::try_new(
567            ConstantArray::new(10u32, 1).into_array(),
568            ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
569            100,
570            Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
571        )
572        .unwrap();
573
574        assert_eq!(
575            arr.scalar_at(10)
576                .unwrap()
577                .as_primitive()
578                .typed_value::<u32>(),
579            Some(1234)
580        );
581        assert!(arr.scalar_at(0).unwrap().is_null());
582        assert!(arr.scalar_at(99).unwrap().is_null());
583    }
584
585    #[test]
586    pub fn scalar_at_sliced() {
587        let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
588        assert_eq!(usize::try_from(&sliced.scalar_at(0).unwrap()).unwrap(), 100);
589    }
590
591    #[test]
592    pub fn validity_mask_sliced_null_fill() {
593        let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
594        assert_eq!(
595            sliced.validity_mask().unwrap(),
596            Mask::from_iter(vec![true, false, false, true, false])
597        );
598    }
599
600    #[test]
601    pub fn validity_mask_sliced_nonnull_fill() {
602        let sliced = Sparse::try_new(
603            buffer![2u64, 5, 8].into_array(),
604            ConstantArray::new(
605                Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
606                3,
607            )
608            .into_array(),
609            10,
610            Scalar::primitive(1.0f32, Nullability::Nullable),
611        )
612        .unwrap()
613        .slice(2..7)
614        .unwrap();
615
616        assert_eq!(
617            sliced.validity_mask().unwrap(),
618            Mask::from_iter(vec![false, true, true, false, true])
619        );
620    }
621
622    #[test]
623    pub fn scalar_at_sliced_twice() {
624        let sliced_once = sparse_array(nullable_fill()).slice(1..8).unwrap();
625        assert_eq!(
626            usize::try_from(&sliced_once.scalar_at(1).unwrap()).unwrap(),
627            100
628        );
629
630        let sliced_twice = sliced_once.slice(1..6).unwrap();
631        assert_eq!(
632            usize::try_from(&sliced_twice.scalar_at(3).unwrap()).unwrap(),
633            200
634        );
635    }
636
637    #[test]
638    pub fn sparse_validity_mask() {
639        let array = sparse_array(nullable_fill());
640        assert_eq!(
641            array
642                .validity_mask()
643                .unwrap()
644                .to_bit_buffer()
645                .iter()
646                .collect_vec(),
647            [
648                false, false, true, false, false, true, false, false, true, false
649            ]
650        );
651    }
652
653    #[test]
654    fn sparse_validity_mask_non_null_fill() {
655        let array = sparse_array(non_nullable_fill());
656        assert!(array.validity_mask().unwrap().all_true());
657    }
658
659    #[test]
660    #[should_panic]
661    fn test_invalid_length() {
662        let values = buffer![15_u32, 135, 13531, 42].into_array();
663        let indices = buffer![10_u64, 11, 50, 100].into_array();
664
665        Sparse::try_new(indices, values, 100, 0_u32.into()).unwrap();
666    }
667
668    #[test]
669    fn test_valid_length() {
670        let values = buffer![15_u32, 135, 13531, 42].into_array();
671        let indices = buffer![10_u64, 11, 50, 100].into_array();
672
673        Sparse::try_new(indices, values, 101, 0_u32.into()).unwrap();
674    }
675
676    #[test]
677    fn encode_with_nulls() {
678        let original = PrimitiveArray::new(
679            buffer![0i32, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
680            Validity::from_iter(vec![
681                true, true, false, true, false, true, false, true, true, false, true, false,
682            ]),
683        );
684        let sparse = Sparse::encode(&original.clone().into_array(), None)
685            .vortex_expect("Sparse::encode should succeed for test data");
686        assert_eq!(
687            sparse.validity_mask().unwrap(),
688            Mask::from_iter(vec![
689                true, true, false, true, false, true, false, true, true, false, true, false,
690            ])
691        );
692        assert_arrays_eq!(sparse.to_primitive(), original);
693    }
694
695    #[test]
696    fn validity_mask_includes_null_values_when_fill_is_null() {
697        let indices = buffer![0u8, 2, 4, 6, 8].into_array();
698        let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
699            .into_array();
700        let array = Sparse::try_new(indices, values, 10, Scalar::null_native::<i16>()).unwrap();
701        let actual = array.validity_mask().unwrap();
702        let expected = Mask::from_iter([
703            true, false, true, false, false, false, false, false, true, false,
704        ]);
705
706        assert_eq!(actual, expected);
707    }
708}