Skip to main content

vortex_sparse/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9
10use kernel::PARENT_KERNELS;
11use prost::Message as _;
12use vortex_array::AnyCanonical;
13use vortex_array::Array;
14use vortex_array::ArrayEq;
15use vortex_array::ArrayHash;
16use vortex_array::ArrayId;
17use vortex_array::ArrayParts;
18use vortex_array::ArrayRef;
19use vortex_array::ArraySlots;
20use vortex_array::ArrayView;
21use vortex_array::Canonical;
22use vortex_array::ExecutionCtx;
23use vortex_array::ExecutionResult;
24use vortex_array::IntoArray;
25use vortex_array::Precision;
26use vortex_array::arrays::BoolArray;
27use vortex_array::arrays::ConstantArray;
28use vortex_array::arrays::Primitive;
29use vortex_array::arrays::PrimitiveArray;
30use vortex_array::arrays::bool::BoolArrayExt;
31use vortex_array::buffer::BufferHandle;
32use vortex_array::builtins::ArrayBuiltins;
33use vortex_array::dtype::DType;
34use vortex_array::dtype::Nullability;
35use vortex_array::patches::PatchSlotIndices;
36use vortex_array::patches::Patches;
37use vortex_array::patches::PatchesData;
38use vortex_array::patches::PatchesMetadata;
39use vortex_array::require_child;
40use vortex_array::require_opt_child;
41use vortex_array::scalar::Scalar;
42use vortex_array::scalar::ScalarValue;
43use vortex_array::scalar_fn::fns::operators::Operator;
44use vortex_array::serde::ArrayChildren;
45use vortex_array::validity::Validity;
46use vortex_array::vtable::VTable;
47use vortex_array::vtable::ValidityVTable;
48use vortex_buffer::Buffer;
49use vortex_buffer::ByteBufferMut;
50use vortex_error::VortexExpect as _;
51use vortex_error::VortexResult;
52use vortex_error::vortex_bail;
53use vortex_error::vortex_ensure;
54use vortex_error::vortex_ensure_eq;
55use vortex_error::vortex_panic;
56use vortex_mask::AllOr;
57use vortex_mask::Mask;
58use vortex_session::VortexSession;
59use vortex_session::registry::CachedId;
60
61use crate::canonical::execute_sparse;
62use crate::rules::RULES;
63
64mod canonical;
65mod compute;
66mod kernel;
67mod ops;
68mod rules;
69mod slice;
70
71use vortex_array::aggregate_fn::AggregateFnVTable as _;
72use vortex_array::aggregate_fn::fns::is_constant::IsConstant;
73use vortex_array::aggregate_fn::fns::min_max::MinMax;
74use vortex_array::aggregate_fn::fns::nan_count::NanCount;
75use vortex_array::aggregate_fn::fns::null_count::NullCount;
76use vortex_array::aggregate_fn::fns::sum::Sum;
77use vortex_array::aggregate_fn::session::AggregateFnSessionExt;
78use vortex_array::session::ArraySessionExt;
79
80/// Initialize Sparse encoding in the given session.
81///
82/// Registers the Sparse array vtable and its aggregate kernels (`IsConstant`, `Sum`,
83/// `MinMax`, `NullCount`, `NanCount`). Compare/between/fill_null pushdown is wired
84/// through `PARENT_KERNELS` (see `kernel.rs`) and does not require registration here.
85pub fn initialize(session: &VortexSession) {
86    session.arrays().register(Sparse);
87
88    let aggregate_fns = session.aggregate_fns();
89    aggregate_fns.register_aggregate_kernel(
90        Sparse.id(),
91        Some(IsConstant.id()),
92        &compute::is_constant::SparseIsConstantKernel,
93    );
94    aggregate_fns.register_aggregate_kernel(
95        Sparse.id(),
96        Some(Sum.id()),
97        &compute::sum::SparseSumKernel,
98    );
99    aggregate_fns.register_aggregate_kernel(
100        Sparse.id(),
101        Some(MinMax.id()),
102        &compute::min_max::SparseMinMaxKernel,
103    );
104    aggregate_fns.register_aggregate_kernel(
105        Sparse.id(),
106        Some(NullCount.id()),
107        &compute::null_count::SparseNullCountKernel,
108    );
109    aggregate_fns.register_aggregate_kernel(
110        Sparse.id(),
111        Some(NanCount.id()),
112        &compute::nan_count::SparseNanCountKernel,
113    );
114}
115
116/// A [`Sparse`]-encoded Vortex array.
117pub type SparseArray = Array<Sparse>;
118
119#[vortex_array::array_slots(Sparse)]
120pub struct SparseSlots {
121    pub patch_indices: ArrayRef,
122    pub patch_values: ArrayRef,
123    pub patch_chunk_offsets: Option<ArrayRef>,
124}
125
126/// Concrete parts of a [`SparseArray`] after iterative execution.
127pub(crate) struct SparseParts {
128    pub patches: Patches,
129    pub fill_value: Scalar,
130    pub dtype: DType,
131    pub len: usize,
132}
133
134pub(crate) trait SparseOwnedExt {
135    fn into_parts(self) -> VortexResult<SparseParts>;
136}
137
138impl SparseOwnedExt for Array<Sparse> {
139    fn into_parts(self) -> VortexResult<SparseParts> {
140        let patches = Patches::new(
141            self.len(),
142            self.patches().offset(),
143            self.as_ref().slots()[SparseSlots::PATCH_INDICES]
144                .clone()
145                .vortex_expect("indices"),
146            self.as_ref().slots()[SparseSlots::PATCH_VALUES]
147                .clone()
148                .vortex_expect("values"),
149            self.as_ref().slots()[SparseSlots::PATCH_CHUNK_OFFSETS].clone(),
150        )?;
151        Ok(SparseParts {
152            patches,
153            fill_value: self.fill_scalar().clone(),
154            dtype: self.dtype().clone(),
155            len: self.len(),
156        })
157    }
158}
159
160#[derive(Clone, prost::Message)]
161#[repr(C)]
162pub struct SparseMetadata {
163    #[prost(message, required, tag = "1")]
164    patches: PatchesMetadata,
165}
166
167impl ArrayHash for SparseData {
168    fn array_hash<H: Hasher>(&self, state: &mut H, _precision: Precision) {
169        self.array_len.hash(state);
170        self.patches_data.hash(state);
171        self.fill_value.hash(state);
172    }
173}
174
175impl ArrayEq for SparseData {
176    fn array_eq(&self, other: &Self, _precision: Precision) -> bool {
177        self.array_len == other.array_len
178            && self.patches_data == other.patches_data
179            && self.fill_value == other.fill_value
180    }
181}
182
183impl VTable for Sparse {
184    type TypedArrayData = SparseData;
185
186    type OperationsVTable = Self;
187    type ValidityVTable = Self;
188
189    fn id(&self) -> ArrayId {
190        static ID: CachedId = CachedId::new("vortex.sparse");
191        *ID
192    }
193
194    fn validate(
195        &self,
196        data: &Self::TypedArrayData,
197        dtype: &DType,
198        len: usize,
199        slots: &[Option<ArrayRef>],
200    ) -> VortexResult<()> {
201        let patches = SparseData::patches_from_slots(data, len, slots);
202        SparseData::validate(&patches, data.fill_scalar(), dtype, len)
203    }
204
205    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
206        1
207    }
208
209    fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
210        match idx {
211            0 => {
212                let fill_value_buffer =
213                    ScalarValue::to_proto_bytes::<ByteBufferMut>(array.fill_value.value()).freeze();
214                BufferHandle::new_host(fill_value_buffer)
215            }
216            _ => vortex_panic!("SparseArray buffer index {idx} out of bounds"),
217        }
218    }
219
220    fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
221        match idx {
222            0 => Some("fill_value".to_string()),
223            _ => vortex_panic!("SparseArray buffer_name index {idx} out of bounds"),
224        }
225    }
226
227    fn serialize(
228        array: ArrayView<'_, Self>,
229        _session: &VortexSession,
230    ) -> VortexResult<Option<Vec<u8>>> {
231        let patches = array.patches().to_metadata(array.len(), array.dtype())?;
232        let metadata = SparseMetadata { patches };
233
234        // Note that we DO NOT serialize the fill value since that is stored in the buffers.
235        Ok(Some(metadata.encode_to_vec()))
236    }
237
238    fn deserialize(
239        &self,
240        dtype: &DType,
241        len: usize,
242        metadata: &[u8],
243        buffers: &[BufferHandle],
244        children: &dyn ArrayChildren,
245        session: &VortexSession,
246    ) -> VortexResult<ArrayParts<Self>> {
247        let metadata = SparseMetadata::decode(metadata)?;
248
249        // Once we have the patches metadata, we need to get the fill value from the buffers.
250
251        if buffers.len() != 1 {
252            vortex_bail!("Expected 1 buffer, got {}", buffers.len());
253        }
254        let scalar_bytes: &[u8] = &buffers[0].clone().try_to_host_sync()?;
255
256        let scalar_value = ScalarValue::from_proto_bytes(scalar_bytes, dtype, session)?;
257        let fill_value = Scalar::try_new(dtype.clone(), scalar_value)?;
258
259        vortex_ensure_eq!(
260            children.len(),
261            2,
262            "SparseArray expects 2 children for sparse encoding, found {}",
263            children.len()
264        );
265
266        let patch_indices = children.get(
267            0,
268            &metadata.patches.indices_dtype()?,
269            metadata.patches.len()?,
270        )?;
271        let patch_values = children.get(1, dtype, metadata.patches.len()?)?;
272
273        let patches = Patches::new(
274            len,
275            metadata.patches.offset()?,
276            patch_indices,
277            patch_values,
278            None,
279        )?;
280        let slots = SparseData::make_slots(&patches);
281        let data = SparseData::from_patches(&patches, fill_value)?;
282        Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
283    }
284
285    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
286        SparseSlots::NAMES[idx].to_string()
287    }
288
289    fn reduce_parent(
290        array: ArrayView<'_, Self>,
291        parent: &ArrayRef,
292        child_idx: usize,
293    ) -> VortexResult<Option<ArrayRef>> {
294        RULES.evaluate(array, parent, child_idx)
295    }
296
297    fn execute_parent(
298        array: ArrayView<'_, Self>,
299        parent: &ArrayRef,
300        child_idx: usize,
301        ctx: &mut ExecutionCtx,
302    ) -> VortexResult<Option<ArrayRef>> {
303        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
304    }
305
306    fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
307        // Resolve offset first: wrap indices in Binary(indices, offset, Sub) and
308        // reassemble with offset=0. Uses slot children (not data) since the executor
309        // may have updated slots via reduce_parent/execute_parent.
310        let array = if array.patches().offset() != 0 {
311            let offset = array.patches().offset();
312            let indices = array.patch_indices();
313            let values = array.patch_values().clone();
314            let len = array.len();
315            let offset_scalar = Scalar::from(offset).cast(indices.dtype())?;
316            let resolved_indices = indices.binary(
317                ConstantArray::new(offset_scalar, indices.len()).into_array(),
318                Operator::Sub,
319            )?;
320            let patches = Patches::new(len, 0, resolved_indices.clone(), values, None)?;
321            // Decompose, update in place, and reassemble without re-validation.
322            match array.try_into_parts() {
323                Ok(mut parts) => {
324                    parts.data.patches_data = PatchesData::from_patches(&patches);
325                    parts.slots[SparseSlots::PATCH_INDICES] = Some(resolved_indices);
326                    parts.slots[SparseSlots::PATCH_CHUNK_OFFSETS] = None;
327                    unsafe { Array::from_parts_unchecked(parts) }
328                }
329                Err(array) => unsafe {
330                    Sparse::new_unchecked(patches, array.fill_scalar().clone())
331                },
332            }
333        } else {
334            array
335        };
336
337        // Require children to be executed through the scheduler,
338        // enabling cross-step optimization via reduce_parent rules.
339        let array = require_child!(
340            array, array.patch_indices(), SparseSlots::PATCH_INDICES => Primitive
341        );
342        let array = require_child!(
343            array, array.patch_values(), SparseSlots::PATCH_VALUES => AnyCanonical
344        );
345        require_opt_child!(
346            array,
347            array.patch_chunk_offsets(),
348            SparseSlots::PATCH_CHUNK_OFFSETS => Primitive
349        );
350
351        let parts = array.into_parts()?;
352        // TODO(joe): remove ctx from execute_sparse since all slots should be canonical.
353        execute_sparse(parts, ctx).map(ExecutionResult::done)
354    }
355}
356
357const PATCH_SLOTS: PatchSlotIndices = PatchSlotIndices {
358    indices: SparseSlots::PATCH_INDICES,
359    values: SparseSlots::PATCH_VALUES,
360    chunk_offsets: SparseSlots::PATCH_CHUNK_OFFSETS,
361};
362
363#[derive(Clone, Debug)]
364pub struct SparseData {
365    /// The total length of the sparse array.
366    array_len: usize,
367    /// Patch metadata (offset, offset_within_chunk) for reconstructing Patches from slots.
368    patches_data: PatchesData,
369    fill_value: Scalar,
370}
371
372impl Display for SparseData {
373    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
374        write!(f, "fill_value: {}", self.fill_value)
375    }
376}
377
378#[derive(Clone, Debug)]
379pub struct Sparse;
380
381impl Sparse {
382    /// Construct a new [`SparseArray`] from indices, values, length, and fill value.
383    pub fn try_new(
384        indices: ArrayRef,
385        values: ArrayRef,
386        len: usize,
387        fill_value: Scalar,
388    ) -> VortexResult<SparseArray> {
389        let dtype = fill_value.dtype().clone();
390        let patches = Patches::new(len, 0, indices, values, None)?;
391        let slots = SparseData::make_slots(&patches);
392        let data = SparseData::from_patches(&patches, fill_value)?;
393        Ok(unsafe {
394            Array::from_parts_unchecked(ArrayParts::new(Sparse, dtype, len, data).with_slots(slots))
395        })
396    }
397
398    pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<SparseArray> {
399        let dtype = fill_value.dtype().clone();
400        let len = patches.array_len();
401        let slots = SparseData::make_slots(&patches);
402        let data = SparseData::from_patches(&patches, fill_value)?;
403        Ok(unsafe {
404            Array::from_parts_unchecked(ArrayParts::new(Sparse, dtype, len, data).with_slots(slots))
405        })
406    }
407
408    pub(crate) unsafe fn new_unchecked(patches: Patches, fill_value: Scalar) -> SparseArray {
409        let dtype = fill_value.dtype().clone();
410        let len = patches.array_len();
411        let slots = SparseData::make_slots(&patches);
412        let data = SparseData::from_patches_unchecked(&patches, fill_value);
413        unsafe {
414            Array::from_parts_unchecked(ArrayParts::new(Sparse, dtype, len, data).with_slots(slots))
415        }
416    }
417
418    /// Encode the given array as a [`SparseArray`].
419    pub fn encode(
420        array: &ArrayRef,
421        fill_value: Option<Scalar>,
422        ctx: &mut ExecutionCtx,
423    ) -> VortexResult<ArrayRef> {
424        SparseData::encode(array, fill_value, ctx)
425    }
426}
427
428impl SparseData {
429    fn normalize_patches_dtype(patches: Patches, fill_value: &Scalar) -> VortexResult<Patches> {
430        let fill_dtype = fill_value.dtype();
431        let values_dtype = patches.values().dtype();
432
433        vortex_ensure!(
434            values_dtype.eq_ignore_nullability(fill_dtype),
435            "fill value, {:?}, should be instance of values dtype, {} but was {}.",
436            fill_value,
437            values_dtype,
438            fill_dtype,
439        );
440
441        if values_dtype == fill_dtype {
442            Ok(patches)
443        } else {
444            patches.cast_values(fill_dtype)
445        }
446    }
447
448    pub fn validate(
449        patches: &Patches,
450        fill_value: &Scalar,
451        dtype: &DType,
452        len: usize,
453    ) -> VortexResult<()> {
454        vortex_ensure!(
455            fill_value.dtype() == dtype,
456            "fill value dtype {} does not match array dtype {}",
457            fill_value.dtype(),
458            dtype,
459        );
460        vortex_ensure!(
461            patches.array_len() == len,
462            "patches length {} does not match array length {}",
463            patches.array_len(),
464            len
465        );
466        vortex_ensure!(
467            patches.values().dtype() == dtype,
468            "patch values dtype {} does not match array dtype {}",
469            patches.values().dtype(),
470            dtype,
471        );
472        Ok(())
473    }
474
475    fn make_slots(patches: &Patches) -> ArraySlots {
476        let mut slots = ArraySlots::with_capacity(SparseSlots::COUNT);
477        PatchesData::push_slots(&mut slots, Some(patches));
478        slots
479    }
480
481    /// Reconstruct a [`Patches`] from the stored metadata and the array's slots.
482    fn patches_from_slots(data: &SparseData, len: usize, slots: &[Option<ArrayRef>]) -> Patches {
483        PatchesData::patches_from_slots(Some(&data.patches_data), len, slots, PATCH_SLOTS)
484            .vortex_expect("SparseArray patch slots must be present")
485    }
486
487    /// Build a new SparseData from an existing set of patches, normalizing dtypes.
488    pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
489        let patches = Self::normalize_patches_dtype(patches, &fill_value)?;
490        Ok(Self::from_patches_unchecked(&patches, fill_value))
491    }
492
493    /// Extract metadata from patches to create SparseData, with dtype normalization.
494    fn from_patches(patches: &Patches, fill_value: Scalar) -> VortexResult<Self> {
495        let patches = Self::normalize_patches_dtype(patches.clone(), &fill_value)?;
496        Ok(Self::from_patches_unchecked(&patches, fill_value))
497    }
498
499    /// Extract metadata from patches to create SparseData, without validation.
500    fn from_patches_unchecked(patches: &Patches, fill_value: Scalar) -> Self {
501        Self {
502            array_len: patches.array_len(),
503            patches_data: PatchesData::from_patches(patches),
504            fill_value,
505        }
506    }
507
508    /// Returns the length of the array.
509    #[inline]
510    pub fn len(&self) -> usize {
511        self.array_len
512    }
513
514    /// Returns whether the array is empty.
515    #[inline]
516    pub fn is_empty(&self) -> bool {
517        self.array_len == 0
518    }
519
520    /// Returns the logical data type of the array.
521    #[inline]
522    pub fn dtype(&self) -> &DType {
523        self.fill_scalar().dtype()
524    }
525
526    /// Returns the offset of the patches within the parent array.
527    #[inline]
528    pub fn offset(&self) -> usize {
529        self.patches_data.offset()
530    }
531
532    #[inline]
533    pub fn fill_scalar(&self) -> &Scalar {
534        &self.fill_value
535    }
536
537    /// Encode given array as a SparseArray.
538    ///
539    /// Optionally provided fill value will be respected if the array is less than 90% null.
540    pub fn encode(
541        array: &ArrayRef,
542        fill_value: Option<Scalar>,
543        ctx: &mut ExecutionCtx,
544    ) -> VortexResult<ArrayRef> {
545        if let Some(fill_value) = fill_value.as_ref()
546            && !array.dtype().eq_ignore_nullability(fill_value.dtype())
547        {
548            vortex_bail!(
549                "Array and fill value types must have the same base type. got {} and {}",
550                array.dtype(),
551                fill_value.dtype()
552            )
553        }
554        let mask = array.validity()?.execute_mask(array.len(), ctx)?;
555
556        if mask.all_false() {
557            // Array is constant NULL
558            return Ok(
559                ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
560            );
561        } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
562            // Array is dominated by NULL but has non-NULL values
563            let non_null_values = array
564                .filter(mask.clone())?
565                .execute::<Canonical>(ctx)?
566                .into_array();
567            let non_null_indices = match mask.indices() {
568                AllOr::All => {
569                    // We already know that the mask is 90%+ false
570                    unreachable!("Mask is mostly null")
571                }
572                AllOr::None => {
573                    // we know there are some non-NULL values
574                    unreachable!("Mask is mostly null but not all null")
575                }
576                AllOr::Some(values) => {
577                    let buffer: Buffer<u32> = values
578                        .iter()
579                        .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
580                        .collect();
581
582                    buffer.into_array()
583                }
584            };
585
586            return Sparse::try_new(
587                non_null_indices,
588                non_null_values,
589                array.len(),
590                Scalar::null(array.dtype().clone()),
591            )
592            .map(IntoArray::into_array);
593        }
594
595        let fill = if let Some(fill) = fill_value {
596            fill.cast(array.dtype())?
597        } else {
598            // TODO(robert): Support other dtypes, only thing missing is getting most common value out of the array
599            let primitive = array.clone().execute::<PrimitiveArray>(ctx)?;
600            let (top_pvalue, _) = primitive
601                .top_value()?
602                .vortex_expect("Non empty or all null array");
603
604            Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
605        };
606
607        let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
608        let non_top_bool = array
609            .binary(fill_array.clone(), Operator::NotEq)?
610            .fill_null(Scalar::bool(true, Nullability::NonNullable))?
611            .execute::<BoolArray>(ctx)?;
612        let non_top_mask = Mask::from_buffer(non_top_bool.to_bit_buffer());
613
614        let non_top_values = array
615            .filter(non_top_mask.clone())?
616            .execute::<Canonical>(ctx)?
617            .into_array();
618
619        let indices: Buffer<u64> = match non_top_mask {
620            Mask::AllTrue(count) => {
621                // all true -> complete slice
622                (0u64..count as u64).collect()
623            }
624            Mask::AllFalse(_) => {
625                // All values are equal to the top value
626                return Ok(fill_array);
627            }
628            Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
629        };
630
631        Sparse::try_new(indices.into_array(), non_top_values, array.len(), fill)
632            .map(IntoArray::into_array)
633    }
634}
635
636/// Extension trait for accessing patches on [`SparseArray`] and [`ArrayView<'_, Sparse>`].
637///
638/// Patches are reconstructed from the array's slots and stored metadata on each call.
639pub trait SparseExt {
640    /// Reconstruct patches from the array's slots and metadata.
641    fn patches(&self) -> Patches;
642
643    /// Return patches with offset-resolved indices (offset subtracted from each index).
644    fn resolved_patches(&self) -> VortexResult<Patches> {
645        let patches = self.patches();
646        let indices_offset = Scalar::from(patches.offset()).cast(patches.indices().dtype())?;
647        let indices = patches.indices().binary(
648            ConstantArray::new(indices_offset, patches.indices().len()).into_array(),
649            Operator::Sub,
650        )?;
651
652        Patches::new(
653            patches.array_len(),
654            0,
655            indices,
656            patches.values().clone(),
657            // TODO(0ax1): handle chunk offsets
658            None,
659        )
660    }
661}
662
663impl SparseExt for ArrayView<'_, Sparse> {
664    fn patches(&self) -> Patches {
665        SparseData::patches_from_slots(self.data(), self.len(), self.slots())
666    }
667}
668
669impl SparseExt for Array<Sparse> {
670    fn patches(&self) -> Patches {
671        SparseData::patches_from_slots(self.data(), self.as_array().len(), self.slots())
672    }
673}
674
675impl ValidityVTable<Sparse> for Sparse {
676    fn validity(array: ArrayView<'_, Sparse>) -> VortexResult<Validity> {
677        let orig_patches = array.patches();
678        let validity_patches = unsafe {
679            Patches::new_unchecked(
680                orig_patches.array_len(),
681                orig_patches.offset(),
682                orig_patches.indices().clone(),
683                orig_patches
684                    .values()
685                    .validity()?
686                    .to_array(orig_patches.values().len()),
687                orig_patches.chunk_offsets().clone(),
688                orig_patches.offset_within_chunk(),
689            )
690        };
691
692        Ok(Validity::Array(
693            unsafe { Sparse::new_unchecked(validity_patches, array.fill_value.is_valid().into()) }
694                .into_array(),
695        ))
696    }
697}
698
699#[cfg(test)]
700mod test {
701    use itertools::Itertools;
702    use vortex_array::IntoArray;
703    use vortex_array::LEGACY_SESSION;
704    use vortex_array::VortexSessionExecute;
705    use vortex_array::arrays::ConstantArray;
706    use vortex_array::arrays::PrimitiveArray;
707    use vortex_array::assert_arrays_eq;
708    use vortex_array::builtins::ArrayBuiltins;
709    use vortex_array::dtype::DType;
710    use vortex_array::dtype::Nullability;
711    use vortex_array::dtype::PType;
712    use vortex_array::scalar::Scalar;
713    use vortex_array::validity::Validity;
714    use vortex_buffer::buffer;
715    use vortex_error::VortexExpect;
716
717    use super::*;
718    use crate::Sparse;
719
720    fn nullable_fill() -> Scalar {
721        Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
722    }
723
724    fn non_nullable_fill() -> Scalar {
725        Scalar::from(42i32)
726    }
727
728    fn sparse_array(fill_value: Scalar) -> ArrayRef {
729        // merged array: [null, null, 100, null, null, 200, null, null, 300, null]
730        let mut values = buffer![100i32, 200, 300].into_array();
731        values = values.cast(fill_value.dtype().clone()).unwrap();
732
733        Sparse::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
734            .unwrap()
735            .into_array()
736    }
737
738    #[test]
739    pub fn test_scalar_at() {
740        let array = sparse_array(nullable_fill());
741
742        assert_eq!(
743            array
744                .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
745                .unwrap(),
746            nullable_fill()
747        );
748        assert_eq!(
749            array
750                .execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())
751                .unwrap(),
752            Scalar::from(Some(100_i32))
753        );
754        assert_eq!(
755            array
756                .execute_scalar(5, &mut LEGACY_SESSION.create_execution_ctx())
757                .unwrap(),
758            Scalar::from(Some(200_i32))
759        );
760    }
761
762    #[test]
763    #[should_panic(expected = "out of bounds")]
764    fn test_scalar_at_oob() {
765        let array = sparse_array(nullable_fill());
766        array
767            .execute_scalar(10, &mut LEGACY_SESSION.create_execution_ctx())
768            .unwrap();
769    }
770
771    #[test]
772    pub fn test_scalar_at_again() {
773        let arr = Sparse::try_new(
774            ConstantArray::new(10u32, 1).into_array(),
775            ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
776            100,
777            Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
778        )
779        .unwrap();
780
781        assert_eq!(
782            arr.execute_scalar(10, &mut LEGACY_SESSION.create_execution_ctx())
783                .unwrap()
784                .as_primitive()
785                .typed_value::<u32>(),
786            Some(1234)
787        );
788        assert!(
789            arr.execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
790                .unwrap()
791                .is_null()
792        );
793        assert!(
794            arr.execute_scalar(99, &mut LEGACY_SESSION.create_execution_ctx())
795                .unwrap()
796                .is_null()
797        );
798    }
799
800    #[test]
801    pub fn scalar_at_sliced() {
802        let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
803        assert_eq!(
804            usize::try_from(
805                &sliced
806                    .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
807                    .unwrap()
808            )
809            .unwrap(),
810            100
811        );
812    }
813
814    #[test]
815    pub fn validity_mask_sliced_null_fill() {
816        let sliced = sparse_array(nullable_fill()).slice(2..7).unwrap();
817        assert_eq!(
818            sliced
819                .validity()
820                .unwrap()
821                .execute_mask(sliced.len(), &mut LEGACY_SESSION.create_execution_ctx())
822                .unwrap(),
823            Mask::from_iter(vec![true, false, false, true, false])
824        );
825    }
826
827    #[test]
828    pub fn validity_mask_sliced_nonnull_fill() {
829        let sliced = Sparse::try_new(
830            buffer![2u64, 5, 8].into_array(),
831            ConstantArray::new(
832                Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
833                3,
834            )
835            .into_array(),
836            10,
837            Scalar::primitive(1.0f32, Nullability::Nullable),
838        )
839        .unwrap()
840        .slice(2..7)
841        .unwrap();
842
843        assert_eq!(
844            sliced
845                .validity()
846                .unwrap()
847                .execute_mask(sliced.len(), &mut LEGACY_SESSION.create_execution_ctx())
848                .unwrap(),
849            Mask::from_iter(vec![false, true, true, false, true])
850        );
851    }
852
853    #[test]
854    pub fn scalar_at_sliced_twice() {
855        let sliced_once = sparse_array(nullable_fill()).slice(1..8).unwrap();
856        assert_eq!(
857            usize::try_from(
858                &sliced_once
859                    .execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())
860                    .unwrap()
861            )
862            .unwrap(),
863            100
864        );
865
866        let sliced_twice = sliced_once.slice(1..6).unwrap();
867        assert_eq!(
868            usize::try_from(
869                &sliced_twice
870                    .execute_scalar(3, &mut LEGACY_SESSION.create_execution_ctx())
871                    .unwrap()
872            )
873            .unwrap(),
874            200
875        );
876    }
877
878    #[test]
879    pub fn sparse_validity_mask() {
880        let array = sparse_array(nullable_fill());
881        assert_eq!(
882            array
883                .validity()
884                .unwrap()
885                .execute_mask(array.len(), &mut LEGACY_SESSION.create_execution_ctx())
886                .unwrap()
887                .to_bit_buffer()
888                .iter()
889                .collect_vec(),
890            [
891                false, false, true, false, false, true, false, false, true, false
892            ]
893        );
894    }
895
896    #[test]
897    fn sparse_validity_mask_non_null_fill() {
898        let array = sparse_array(non_nullable_fill());
899        assert!(
900            array
901                .validity()
902                .unwrap()
903                .execute_mask(array.len(), &mut LEGACY_SESSION.create_execution_ctx())
904                .unwrap()
905                .all_true()
906        );
907    }
908
909    #[test]
910    #[should_panic]
911    fn test_invalid_length() {
912        let values = buffer![15_u32, 135, 13531, 42].into_array();
913        let indices = buffer![10_u64, 11, 50, 100].into_array();
914
915        Sparse::try_new(indices, values, 100, 0_u32.into()).unwrap();
916    }
917
918    #[test]
919    fn test_valid_length() {
920        let values = buffer![15_u32, 135, 13531, 42].into_array();
921        let indices = buffer![10_u64, 11, 50, 100].into_array();
922
923        Sparse::try_new(indices, values, 101, 0_u32.into()).unwrap();
924    }
925
926    #[test]
927    fn encode_with_nulls() {
928        let mut ctx = LEGACY_SESSION.create_execution_ctx();
929        let original = PrimitiveArray::new(
930            buffer![0i32, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
931            Validity::from_iter(vec![
932                true, true, false, true, false, true, false, true, true, false, true, false,
933            ]),
934        );
935        let sparse = Sparse::encode(&original.clone().into_array(), None, &mut ctx)
936            .vortex_expect("Sparse::encode should succeed for test data");
937        assert_eq!(
938            sparse
939                .validity()
940                .unwrap()
941                .execute_mask(sparse.len(), &mut ctx)
942                .unwrap(),
943            Mask::from_iter(vec![
944                true, true, false, true, false, true, false, true, true, false, true, false,
945            ])
946        );
947        let sparse_primitive = sparse.execute::<PrimitiveArray>(&mut ctx).unwrap();
948        assert_arrays_eq!(sparse_primitive, original);
949    }
950
951    #[test]
952    fn validity_mask_includes_null_values_when_fill_is_null() {
953        let indices = buffer![0u8, 2, 4, 6, 8].into_array();
954        let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)])
955            .into_array();
956        let array = Sparse::try_new(indices, values, 10, Scalar::null_native::<i16>()).unwrap();
957        let actual = array
958            .validity()
959            .unwrap()
960            .execute_mask(array.len(), &mut LEGACY_SESSION.create_execution_ctx())
961            .unwrap();
962        let expected = Mask::from_iter([
963            true, false, true, false, false, false, false, false, true, false,
964        ]);
965
966        assert_eq!(actual, expected);
967    }
968}