Skip to main content

vortex_sparse/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6
7use kernel::PARENT_KERNELS;
8use prost::Message as _;
9use vortex_array::Array;
10use vortex_array::ArrayBufferVisitor;
11use vortex_array::ArrayChildVisitor;
12use vortex_array::ArrayEq;
13use vortex_array::ArrayHash;
14use vortex_array::ArrayRef;
15use vortex_array::ExecutionCtx;
16use vortex_array::IntoArray;
17use vortex_array::Precision;
18use vortex_array::ProstMetadata;
19use vortex_array::ToCanonical;
20use vortex_array::arrays::ConstantArray;
21use vortex_array::buffer::BufferHandle;
22use vortex_array::builtins::ArrayBuiltins;
23use vortex_array::compute::Operator;
24use vortex_array::compute::compare;
25use vortex_array::compute::filter;
26use vortex_array::compute::sub_scalar;
27use vortex_array::patches::Patches;
28use vortex_array::patches::PatchesMetadata;
29use vortex_array::scalar::Scalar;
30use vortex_array::scalar::ScalarValue;
31use vortex_array::serde::ArrayChildren;
32use vortex_array::stats::ArrayStats;
33use vortex_array::stats::StatsSetRef;
34use vortex_array::validity::Validity;
35use vortex_array::vtable;
36use vortex_array::vtable::ArrayId;
37use vortex_array::vtable::BaseArrayVTable;
38use vortex_array::vtable::VTable;
39use vortex_array::vtable::ValidityVTable;
40use vortex_array::vtable::VisitorVTable;
41use vortex_buffer::Buffer;
42use vortex_buffer::ByteBufferMut;
43use vortex_dtype::DType;
44use vortex_dtype::Nullability;
45use vortex_error::VortexExpect as _;
46use vortex_error::VortexResult;
47use vortex_error::vortex_bail;
48use vortex_error::vortex_ensure;
49use vortex_mask::AllOr;
50use vortex_mask::Mask;
51use vortex_session::VortexSession;
52
53use crate::canonical::execute_sparse;
54use crate::rules::RULES;
55
56mod canonical;
57mod compute;
58mod kernel;
59mod ops;
60mod rules;
61mod slice;
62
63vtable!(Sparse);
64
65#[derive(Clone, prost::Message)]
66#[repr(C)]
67pub struct SparseMetadata {
68    #[prost(message, required, tag = "1")]
69    patches: PatchesMetadata,
70}
71
72impl VTable for SparseVTable {
73    type Array = SparseArray;
74
75    type Metadata = ProstMetadata<SparseMetadata>;
76
77    type ArrayVTable = Self;
78    type OperationsVTable = Self;
79    type ValidityVTable = Self;
80    type VisitorVTable = Self;
81
82    fn id(_array: &Self::Array) -> ArrayId {
83        Self::ID
84    }
85
86    fn metadata(array: &SparseArray) -> VortexResult<Self::Metadata> {
87        Ok(ProstMetadata(SparseMetadata {
88            patches: array.patches().to_metadata(array.len(), array.dtype())?,
89        }))
90    }
91
92    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
93        Ok(Some(metadata.0.encode_to_vec()))
94    }
95
96    fn deserialize(
97        bytes: &[u8],
98        _dtype: &DType,
99        _len: usize,
100        _buffers: &[BufferHandle],
101        _session: &VortexSession,
102    ) -> VortexResult<Self::Metadata> {
103        Ok(ProstMetadata(SparseMetadata::decode(bytes)?))
104    }
105
106    fn build(
107        dtype: &DType,
108        len: usize,
109        metadata: &Self::Metadata,
110        buffers: &[BufferHandle],
111        children: &dyn ArrayChildren,
112    ) -> VortexResult<SparseArray> {
113        if children.len() != 2 {
114            vortex_bail!(
115                "Expected 2 children for sparse encoding, found {}",
116                children.len()
117            )
118        }
119        vortex_ensure!(
120            metadata.0.patches.offset()? == 0,
121            "Patches must start at offset 0"
122        );
123
124        let patch_indices = children.get(
125            0,
126            &metadata.0.patches.indices_dtype()?,
127            metadata.0.patches.len()?,
128        )?;
129        let patch_values = children.get(1, dtype, metadata.0.patches.len()?)?;
130
131        if buffers.len() != 1 {
132            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
133        }
134
135        let bytes: &[u8] = &buffers[0].clone().try_to_host_sync()?;
136        let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype)?;
137
138        let fill_value = Scalar::try_new(dtype.clone(), scalar_value)?;
139
140        SparseArray::try_new(patch_indices, patch_values, len, fill_value)
141    }
142
143    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
144        vortex_ensure!(
145            children.len() == 2,
146            "SparseArray expects 2 children, got {}",
147            children.len()
148        );
149
150        let mut children_iter = children.into_iter();
151        let patch_indices = children_iter.next().vortex_expect("patch_indices child");
152        let patch_values = children_iter.next().vortex_expect("patch_values child");
153
154        array.patches = Patches::new(
155            array.patches.array_len(),
156            array.patches.offset(),
157            patch_indices,
158            patch_values,
159            array.patches.chunk_offsets().clone(),
160        )?;
161
162        Ok(())
163    }
164
165    fn reduce_parent(
166        array: &Self::Array,
167        parent: &ArrayRef,
168        child_idx: usize,
169    ) -> VortexResult<Option<ArrayRef>> {
170        RULES.evaluate(array, parent, child_idx)
171    }
172
173    fn execute_parent(
174        array: &Self::Array,
175        parent: &ArrayRef,
176        child_idx: usize,
177        ctx: &mut ExecutionCtx,
178    ) -> VortexResult<Option<ArrayRef>> {
179        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
180    }
181
182    fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
183        execute_sparse(array)
184    }
185}
186
187#[derive(Clone, Debug)]
188pub struct SparseArray {
189    patches: Patches,
190    fill_value: Scalar,
191    stats_set: ArrayStats,
192}
193
194#[derive(Debug)]
195pub struct SparseVTable;
196
197impl SparseVTable {
198    pub const ID: ArrayId = ArrayId::new_ref("vortex.sparse");
199}
200
201impl SparseArray {
202    pub fn try_new(
203        indices: ArrayRef,
204        values: ArrayRef,
205        len: usize,
206        fill_value: Scalar,
207    ) -> VortexResult<Self> {
208        vortex_ensure!(
209            indices.len() == values.len(),
210            "Mismatched indices {} and values {} length",
211            indices.len(),
212            values.len()
213        );
214
215        if indices.is_host() {
216            debug_assert_eq!(
217                indices.statistics().compute_is_strict_sorted(),
218                Some(true),
219                "SparseArray: indices must be strict-sorted"
220            );
221
222            // Verify the indices are all in the valid range
223            if !indices.is_empty() {
224                let last_index = usize::try_from(&indices.scalar_at(indices.len() - 1)?)?;
225
226                vortex_ensure!(
227                    last_index < len,
228                    "Array length was {len} but the last index is {last_index}"
229                );
230            }
231        }
232
233        Ok(Self {
234            // TODO(0ax1): handle chunk offsets
235            patches: Patches::new(len, 0, indices, values, None)?,
236            fill_value,
237            stats_set: Default::default(),
238        })
239    }
240
241    /// Build a new SparseArray from an existing set of patches.
242    pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
243        vortex_ensure!(
244            fill_value.dtype() == patches.values().dtype(),
245            "fill value, {:?}, should be instance of values dtype, {} but was {}.",
246            fill_value,
247            patches.values().dtype(),
248            fill_value.dtype(),
249        );
250
251        Ok(Self {
252            patches,
253            fill_value,
254            stats_set: Default::default(),
255        })
256    }
257
258    pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> Self {
259        Self {
260            patches,
261            fill_value,
262            stats_set: Default::default(),
263        }
264    }
265
266    #[inline]
267    pub fn patches(&self) -> &Patches {
268        &self.patches
269    }
270
271    #[inline]
272    pub fn resolved_patches(&self) -> VortexResult<Patches> {
273        let patches = self.patches();
274        let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?;
275        let indices = sub_scalar(patches.indices(), indices_offset)?;
276
277        Patches::new(
278            patches.array_len(),
279            0,
280            indices,
281            patches.values().clone(),
282            // TODO(0ax1): handle chunk offsets
283            None,
284        )
285    }
286
287    #[inline]
288    pub fn fill_scalar(&self) -> &Scalar {
289        &self.fill_value
290    }
291
292    /// Encode given array as a SparseArray.
293    ///
294    /// Optionally provided fill value will be respected if the array is less than 90% null.
295    pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
296        if let Some(fill_value) = fill_value.as_ref()
297            && array.dtype() != fill_value.dtype()
298        {
299            vortex_bail!(
300                "Array and fill value types must match. got {} and {}",
301                array.dtype(),
302                fill_value.dtype()
303            )
304        }
305        let mask = array.validity_mask()?;
306
307        if mask.all_false() {
308            // Array is constant NULL
309            return Ok(
310                ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
311            );
312        } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
313            // Array is dominated by NULL but has non-NULL values
314            let non_null_values = filter(array, &mask)?;
315            let non_null_indices = match mask.indices() {
316                AllOr::All => {
317                    // We already know that the mask is 90%+ false
318                    unreachable!("Mask is mostly null")
319                }
320                AllOr::None => {
321                    // we know there are some non-NULL values
322                    unreachable!("Mask is mostly null but not all null")
323                }
324                AllOr::Some(values) => {
325                    let buffer: Buffer<u32> = values
326                        .iter()
327                        .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
328                        .collect();
329
330                    buffer.into_array()
331                }
332            };
333
334            return Ok(SparseArray::try_new(
335                non_null_indices,
336                non_null_values,
337                array.len(),
338                Scalar::null(array.dtype().clone()),
339            )?
340            .into_array());
341        }
342
343        let fill = if let Some(fill) = fill_value {
344            fill
345        } else {
346            // TODO(robert): Support other dtypes, only thing missing is getting most common value out of the array
347            let (top_pvalue, _) = array
348                .to_primitive()
349                .top_value()?
350                .vortex_expect("Non empty or all null array");
351
352            Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
353        };
354
355        let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
356        let non_top_mask = Mask::from_buffer(
357            compare(array, &fill_array, Operator::NotEq)?
358                .fill_null(Scalar::bool(true, Nullability::NonNullable))?
359                .to_bool()
360                .to_bit_buffer(),
361        );
362
363        let non_top_values = filter(array, &non_top_mask)?;
364
365        let indices: Buffer<u64> = match non_top_mask {
366            Mask::AllTrue(count) => {
367                // all true -> complete slice
368                (0u64..count as u64).collect()
369            }
370            Mask::AllFalse(_) => {
371                // All values are equal to the top value
372                return Ok(fill_array);
373            }
374            Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
375        };
376
377        SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill)
378            .map(|a| a.into_array())
379    }
380}
381
382impl BaseArrayVTable<SparseVTable> for SparseVTable {
383    fn len(array: &SparseArray) -> usize {
384        array.patches.array_len()
385    }
386
387    fn dtype(array: &SparseArray) -> &DType {
388        array.fill_scalar().dtype()
389    }
390
391    fn stats(array: &SparseArray) -> StatsSetRef<'_> {
392        array.stats_set.to_ref(array.as_ref())
393    }
394
395    fn array_hash<H: std::hash::Hasher>(array: &SparseArray, state: &mut H, precision: Precision) {
396        array.patches.array_hash(state, precision);
397        array.fill_value.hash(state);
398    }
399
400    fn array_eq(array: &SparseArray, other: &SparseArray, precision: Precision) -> bool {
401        array.patches.array_eq(&other.patches, precision) && array.fill_value == other.fill_value
402    }
403}
404
405impl ValidityVTable<SparseVTable> for SparseVTable {
406    fn validity(array: &SparseArray) -> VortexResult<Validity> {
407        let patches = unsafe {
408            Patches::new_unchecked(
409                array.patches.array_len(),
410                array.patches.offset(),
411                array.patches.indices().clone(),
412                array
413                    .patches
414                    .values()
415                    .validity()?
416                    .to_array(array.patches.values().len()),
417                array.patches.chunk_offsets().clone(),
418                array.patches.offset_within_chunk(),
419            )
420        };
421
422        Ok(Validity::Array(
423            unsafe { SparseArray::new_unchecked(patches, array.fill_value.is_valid().into()) }
424                .into_array(),
425        ))
426    }
427}
428
429impl VisitorVTable<SparseVTable> for SparseVTable {
430    fn visit_buffers(array: &SparseArray, visitor: &mut dyn ArrayBufferVisitor) {
431        let fill_value_buffer =
432            ScalarValue::to_proto_bytes::<ByteBufferMut>(array.fill_value.value()).freeze();
433        visitor.visit_buffer_handle("fill_value", &BufferHandle::new_host(fill_value_buffer));
434    }
435
436    fn nbuffers(_array: &SparseArray) -> usize {
437        1
438    }
439
440    fn visit_children(array: &SparseArray, visitor: &mut dyn ArrayChildVisitor) {
441        visitor.visit_patches(array.patches())
442    }
443
444    fn nchildren(array: &SparseArray) -> usize {
445        // patches have indices + values + optional chunk_offsets
446        2 + array.patches().chunk_offsets().is_some() as usize
447    }
448}
449
450#[cfg(test)]
451mod test {
452    use itertools::Itertools;
453    use vortex_array::IntoArray;
454    use vortex_array::arrays::ConstantArray;
455    use vortex_array::arrays::PrimitiveArray;
456    use vortex_array::assert_arrays_eq;
457    use vortex_array::builtins::ArrayBuiltins;
458    use vortex_array::scalar::Scalar;
459    use vortex_array::validity::Validity;
460    use vortex_buffer::buffer;
461    use vortex_dtype::DType;
462    use vortex_dtype::Nullability;
463    use vortex_dtype::PType;
464    use vortex_error::VortexExpect;
465
466    use super::*;
467
468    fn nullable_fill() -> Scalar {
469        Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
470    }
471
472    fn non_nullable_fill() -> Scalar {
473        Scalar::from(42i32)
474    }
475
476    fn sparse_array(fill_value: Scalar) -> ArrayRef {
477        // merged array: [null, null, 100, null, null, 200, null, null, 300, null]
478        let mut values = buffer![100i32, 200, 300].into_array();
479        values = values.cast(fill_value.dtype().clone()).unwrap();
480
481        SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
482            .unwrap()
483            .into_array()
484    }
485
486    #[test]
487    pub fn test_scalar_at() {
488        let array = sparse_array(nullable_fill());
489
490        assert_eq!(array.scalar_at(0).unwrap(), nullable_fill());
491        assert_eq!(array.scalar_at(2).unwrap(), Scalar::from(Some(100_i32)));
492        assert_eq!(array.scalar_at(5).unwrap(), Scalar::from(Some(200_i32)));
493    }
494
495    #[test]
496    #[should_panic(expected = "out of bounds")]
497    fn test_scalar_at_oob() {
498        let array = sparse_array(nullable_fill());
499        array.scalar_at(10).unwrap();
500    }
501
502    #[test]
503    pub fn test_scalar_at_again() {
504        let arr = SparseArray::try_new(
505            ConstantArray::new(10u32, 1).into_array(),
506            ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
507            100,
508            Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
509        )
510        .unwrap();
511
512        assert_eq!(
513            arr.scalar_at(10)
514                .unwrap()
515                .as_primitive()
516                .typed_value::<u32>(),
517            Some(1234)
518        );
519        assert!(arr.scalar_at(0).unwrap().is_null());
520        assert!(arr.scalar_at(99).unwrap().is_null());
521    }
522
523    #[test]
524    pub fn scalar_at_sliced() {
525        let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
526        assert_eq!(usize::try_from(&sliced.scalar_at(0).unwrap()).unwrap(), 100);
527    }
528
529    #[test]
530    pub fn validity_mask_sliced_null_fill() {
531        let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
532        assert_eq!(
533            sliced.validity_mask().unwrap(),
534            Mask::from_iter(vec![true, false, false, true, false])
535        );
536    }
537
538    #[test]
539    pub fn validity_mask_sliced_nonnull_fill() {
540        let sliced = SparseArray::try_new(
541            buffer![2u64, 5, 8].into_array(),
542            ConstantArray::new(
543                Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
544                3,
545            )
546            .into_array(),
547            10,
548            Scalar::primitive(1.0f32, Nullability::Nullable),
549        )
550        .unwrap()
551        .slice(2..7)
552        .unwrap();
553
554        assert_eq!(
555            sliced.validity_mask().unwrap(),
556            Mask::from_iter(vec![false, true, true, false, true])
557        );
558    }
559
560    #[test]
561    pub fn scalar_at_sliced_twice() {
562        let sliced_once = sparse_array(nullable_fill()).slice(1..8).unwrap();
563        assert_eq!(
564            usize::try_from(&sliced_once.scalar_at(1).unwrap()).unwrap(),
565            100
566        );
567
568        let sliced_twice = sliced_once.slice(1..6).unwrap();
569        assert_eq!(
570            usize::try_from(&sliced_twice.scalar_at(3).unwrap()).unwrap(),
571            200
572        );
573    }
574
575    #[test]
576    pub fn sparse_validity_mask() {
577        let array = sparse_array(nullable_fill());
578        assert_eq!(
579            array
580                .validity_mask()
581                .unwrap()
582                .to_bit_buffer()
583                .iter()
584                .collect_vec(),
585            [
586                false, false, true, false, false, true, false, false, true, false
587            ]
588        );
589    }
590
591    #[test]
592    fn sparse_validity_mask_non_null_fill() {
593        let array = sparse_array(non_nullable_fill());
594        assert!(array.validity_mask().unwrap().all_true());
595    }
596
597    #[test]
598    #[should_panic]
599    fn test_invalid_length() {
600        let values = buffer![15_u32, 135, 13531, 42].into_array();
601        let indices = buffer![10_u64, 11, 50, 100].into_array();
602
603        SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
604    }
605
606    #[test]
607    fn test_valid_length() {
608        let values = buffer![15_u32, 135, 13531, 42].into_array();
609        let indices = buffer![10_u64, 11, 50, 100].into_array();
610
611        SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
612    }
613
614    #[test]
615    fn encode_with_nulls() {
616        let original = PrimitiveArray::new(
617            buffer![0i32, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
618            Validity::from_iter(vec![
619                true, true, false, true, false, true, false, true, true, false, true, false,
620            ]),
621        );
622        let sparse = SparseArray::encode(&original.clone().into_array(), None)
623            .vortex_expect("SparseArray::encode should succeed for test data");
624        assert_eq!(
625            sparse.validity_mask().unwrap(),
626            Mask::from_iter(vec![
627                true, true, false, true, false, true, false, true, true, false, true, false,
628            ])
629        );
630        assert_arrays_eq!(sparse.to_primitive(), original);
631    }
632
633    #[test]
634    fn validity_mask_includes_null_values_when_fill_is_null() {
635        let indices = buffer![0u8, 2, 4, 6, 8].into_array();
636        let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
637            .into_array();
638        let array =
639            SparseArray::try_new(indices, values, 10, Scalar::null_native::<i16>()).unwrap();
640        let actual = array.validity_mask().unwrap();
641        let expected = Mask::from_iter([
642            true, false, true, false, false, false, false, false, true, false,
643        ]);
644
645        assert_eq!(actual, expected);
646    }
647}