1use std::fmt::Debug;
2
3use vortex_array::arrays::{BooleanBufferBuilder, ConstantArray};
4use vortex_array::compute::{Operator, compare, fill_null, filter, scalar_at, sub_scalar};
5use vortex_array::patches::Patches;
6use vortex_array::stats::{ArrayStats, StatsSetRef};
7use vortex_array::variants::PrimitiveArrayTrait;
8use vortex_array::vtable::VTableRef;
9use vortex_array::{
10    Array, ArrayImpl, ArrayRef, ArrayStatisticsImpl, ArrayValidityImpl, Encoding, IntoArray,
11    RkyvMetadata, ToCanonical, try_from_array_ref,
12};
13use vortex_buffer::Buffer;
14use vortex_dtype::{DType, Nullability, match_each_integer_ptype};
15use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
16use vortex_mask::{AllOr, Mask};
17use vortex_scalar::Scalar;
18
19use crate::serde::SparseMetadata;
20
21mod canonical;
22mod compute;
23mod serde;
24mod variants;
25
26#[derive(Clone, Debug)]
27pub struct SparseArray {
28    patches: Patches,
29    fill_value: Scalar,
30    stats_set: ArrayStats,
31}
32
33try_from_array_ref!(SparseArray);
34
35pub struct SparseEncoding;
36impl Encoding for SparseEncoding {
37    type Array = SparseArray;
38    type Metadata = RkyvMetadata<SparseMetadata>;
39}
40
41impl SparseArray {
42    pub fn try_new(
43        indices: ArrayRef,
44        values: ArrayRef,
45        len: usize,
46        fill_value: Scalar,
47    ) -> VortexResult<Self> {
48        Self::try_new_with_offset(indices, values, len, 0, fill_value)
49    }
50
51    pub(crate) fn try_new_with_offset(
52        indices: ArrayRef,
53        values: ArrayRef,
54        len: usize,
55        indices_offset: usize,
56        fill_value: Scalar,
57    ) -> VortexResult<Self> {
58        if indices.len() != values.len() {
59            vortex_bail!(
60                "Mismatched indices {} and values {} length",
61                indices.len(),
62                values.len()
63            );
64        }
65
66        if !indices.is_empty() {
67            let last_index = usize::try_from(&scalar_at(&indices, indices.len() - 1)?)?;
68
69            if last_index - indices_offset >= len {
70                vortex_bail!("Array length was set to {len} but the last index is {last_index}");
71            }
72        }
73
74        let patches = Patches::new(len, indices_offset, indices, values);
75
76        Self::try_new_from_patches(patches, fill_value)
77    }
78
79    pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
80        if fill_value.dtype() != patches.values().dtype() {
81            vortex_bail!(
82                "fill value, {:?}, should be instance of values dtype, {}",
83                fill_value,
84                patches.values().dtype(),
85            );
86        }
87        Ok(Self {
88            patches,
89            fill_value,
90            stats_set: Default::default(),
91        })
92    }
93
94    #[inline]
95    pub fn patches(&self) -> &Patches {
96        &self.patches
97    }
98
99    #[inline]
100    pub fn resolved_patches(&self) -> VortexResult<Patches> {
101        let (len, offset, indices, values) = self.patches().clone().into_parts();
102        let indices_offset = Scalar::from(offset).cast(indices.dtype())?;
103        let indices = sub_scalar(&indices, indices_offset)?;
104        Ok(Patches::new(len, 0, indices, values))
105    }
106
107    #[inline]
108    pub fn fill_scalar(&self) -> &Scalar {
109        &self.fill_value
110    }
111
112    pub fn encode(array: &dyn Array, fill_value: Option<Scalar>) -> VortexResult<ArrayRef> {
116        if let Some(fill_value) = fill_value.as_ref() {
117            if array.dtype() != fill_value.dtype() {
118                vortex_bail!(
119                    "Array and fill value types must match. got {} and {}",
120                    array.dtype(),
121                    fill_value.dtype()
122                )
123            }
124        }
125        let mask = array.validity_mask()?;
126
127        if mask.all_false() {
128            return Ok(
130                ConstantArray::new(Scalar::null(array.dtype().clone()), array.len()).into_array(),
131            );
132        } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) {
133            let non_null_values = filter(array, &mask)?;
135            let non_null_indices = match mask.indices() {
136                AllOr::All => {
137                    unreachable!("Mask is mostly null")
139                }
140                AllOr::None => {
141                    unreachable!("Mask is mostly null but not all null")
143                }
144                AllOr::Some(values) => {
145                    let buffer: Buffer<u32> = values
146                        .iter()
147                        .map(|&v| v.try_into().vortex_expect("indices must fit in u32"))
148                        .collect();
149
150                    buffer.into_array()
151                }
152            };
153
154            return Ok(SparseArray::try_new(
155                non_null_indices,
156                non_null_values,
157                array.len(),
158                Scalar::null(array.dtype().clone()),
159            )?
160            .into_array());
161        }
162
163        let fill = if let Some(fill) = fill_value {
164            fill
165        } else {
166            let (top_pvalue, _) = array
168                .to_primitive()?
169                .top_value()?
170                .vortex_expect("Non empty or all null array");
171
172            Scalar::primitive_value(top_pvalue, top_pvalue.ptype(), array.dtype().nullability())
173        };
174
175        let fill_array = ConstantArray::new(fill.clone(), array.len()).into_array();
176        let non_top_mask = Mask::from_buffer(
177            fill_null(
178                &compare(array, &fill_array, Operator::NotEq)?,
179                Scalar::bool(true, Nullability::NonNullable),
180            )?
181            .to_bool()?
182            .boolean_buffer()
183            .clone(),
184        );
185
186        let non_top_values = filter(array, &non_top_mask)?;
187
188        let indices: Buffer<u64> = match non_top_mask {
189            Mask::AllTrue(count) => {
190                (0u64..count as u64).collect()
192            }
193            Mask::AllFalse(_) => {
194                return Ok(fill_array);
196            }
197            Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(),
198        };
199
200        SparseArray::try_new(
201            indices.into_array(),
202            non_top_values.into_array(),
203            array.len(),
204            fill,
205        )
206        .map(|a| a.into_array())
207    }
208}
209
210impl ArrayImpl for SparseArray {
211    type Encoding = SparseEncoding;
212
213    fn _len(&self) -> usize {
214        self.patches.array_len()
215    }
216
217    fn _dtype(&self) -> &DType {
218        self.fill_value.dtype()
219    }
220
221    fn _vtable(&self) -> VTableRef {
222        VTableRef::new_ref(&SparseEncoding)
223    }
224
225    fn _with_children(&self, children: &[ArrayRef]) -> VortexResult<Self> {
226        let patch_indices = children[0].clone();
227        let patch_values = children[1].clone();
228
229        let patches = Patches::new(
230            self.patches().array_len(),
231            self.patches().offset(),
232            patch_indices,
233            patch_values,
234        );
235
236        Self::try_new_from_patches(patches, self.fill_value.clone())
237    }
238}
239
240impl ArrayStatisticsImpl for SparseArray {
241    fn _stats_ref(&self) -> StatsSetRef<'_> {
242        self.stats_set.to_ref(self)
243    }
244}
245
246impl ArrayValidityImpl for SparseArray {
247    fn _is_valid(&self, index: usize) -> VortexResult<bool> {
248        Ok(match self.patches().get_patched(index)? {
249            None => self.fill_scalar().is_valid(),
250            Some(patch_value) => patch_value.is_valid(),
251        })
252    }
253
254    fn _all_valid(&self) -> VortexResult<bool> {
255        if self.fill_scalar().is_null() {
256            return Ok(self.patches().values().len() == self.len()
258                && self.patches().values().all_valid()?);
259        }
260
261        self.patches().values().all_valid()
262    }
263
264    fn _all_invalid(&self) -> VortexResult<bool> {
265        if !self.fill_scalar().is_null() {
266            return Ok(self.patches().values().len() == self.len()
268                && self.patches().values().all_invalid()?);
269        }
270
271        self.patches().values().all_invalid()
272    }
273
274    fn _validity_mask(&self) -> VortexResult<Mask> {
275        let indices = self.patches().indices().to_primitive()?;
276
277        if self.fill_scalar().is_null() {
278            let mut buffer = BooleanBufferBuilder::new(self.len());
280            buffer.append_n(self.len(), false);
282
283            match_each_integer_ptype!(indices.ptype(), |$I| {
284                indices.as_slice::<$I>().into_iter().for_each(|&index| {
285                    buffer.set_bit(usize::try_from(index).vortex_expect("Failed to cast to usize") - self.patches().offset(), true);
286                });
287            });
288
289            return Ok(Mask::from_buffer(buffer.finish()));
290        }
291
292        let mut buffer = BooleanBufferBuilder::new(self.len());
295        buffer.append_n(self.len(), true);
296
297        let values_validity = self.patches().values().validity_mask()?;
298        match_each_integer_ptype!(indices.ptype(), |$I| {
299            indices.as_slice::<$I>()
300                .into_iter()
301                .enumerate()
302                .for_each(|(patch_idx, &index)| {
303                    buffer.set_bit(usize::try_from(index).vortex_expect("Failed to cast to usize") - self.patches().offset(), values_validity.value(patch_idx));
304                })
305        });
306
307        Ok(Mask::from_buffer(buffer.finish()))
308    }
309}
310
311#[cfg(test)]
312mod test {
313    use itertools::Itertools;
314    use vortex_array::IntoArray;
315    use vortex_array::arrays::{ConstantArray, PrimitiveArray};
316    use vortex_array::compute::{slice, try_cast};
317    use vortex_array::validity::Validity;
318    use vortex_buffer::buffer;
319    use vortex_dtype::{DType, Nullability, PType};
320    use vortex_error::{VortexError, VortexUnwrap};
321    use vortex_scalar::{PrimitiveScalar, Scalar};
322
323    use super::*;
324
325    fn nullable_fill() -> Scalar {
326        Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
327    }
328
329    fn non_nullable_fill() -> Scalar {
330        Scalar::from(42i32)
331    }
332
333    fn sparse_array(fill_value: Scalar) -> ArrayRef {
334        let mut values = buffer![100i32, 200, 300].into_array();
336        values = try_cast(&values, fill_value.dtype()).unwrap();
337
338        SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
339            .unwrap()
340            .into_array()
341    }
342
343    #[test]
344    pub fn test_scalar_at() {
345        let array = sparse_array(nullable_fill());
346
347        assert_eq!(scalar_at(&array, 0).unwrap(), nullable_fill());
348        assert_eq!(scalar_at(&array, 2).unwrap(), Scalar::from(Some(100_i32)));
349        assert_eq!(scalar_at(&array, 5).unwrap(), Scalar::from(Some(200_i32)));
350
351        let error = scalar_at(&array, 10).err().unwrap();
352        let VortexError::OutOfBounds(i, start, stop, _) = error else {
353            unreachable!()
354        };
355        assert_eq!(i, 10);
356        assert_eq!(start, 0);
357        assert_eq!(stop, 10);
358    }
359
360    #[test]
361    pub fn test_scalar_at_again() {
362        let arr = SparseArray::try_new(
363            ConstantArray::new(10u32, 1).into_array(),
364            ConstantArray::new(Scalar::primitive(1234u32, Nullability::Nullable), 1).into_array(),
365            100,
366            Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
367        )
368        .unwrap();
369
370        assert_eq!(
371            PrimitiveScalar::try_from(&scalar_at(&arr, 10).unwrap())
372                .unwrap()
373                .typed_value::<u32>(),
374            Some(1234)
375        );
376        assert!(scalar_at(&arr, 0).unwrap().is_null());
377        assert!(scalar_at(&arr, 99).unwrap().is_null());
378    }
379
380    #[test]
381    pub fn scalar_at_sliced() {
382        let sliced = slice(&sparse_array(nullable_fill()), 2, 7).unwrap();
383        assert_eq!(
384            usize::try_from(&scalar_at(&sliced, 0).unwrap()).unwrap(),
385            100
386        );
387        let error = scalar_at(&sliced, 5).err().unwrap();
388        let VortexError::OutOfBounds(i, start, stop, _) = error else {
389            unreachable!()
390        };
391        assert_eq!(i, 5);
392        assert_eq!(start, 0);
393        assert_eq!(stop, 5);
394    }
395
396    #[test]
397    pub fn validity_mask_sliced_null_fill() {
398        let sliced = slice(&sparse_array(nullable_fill()), 2, 7).unwrap();
399        assert_eq!(
400            sliced.validity_mask().unwrap(),
401            Mask::from_iter(vec![true, false, false, true, false])
402        );
403    }
404
405    #[test]
406    pub fn validity_mask_sliced_nonnull_fill() {
407        let sliced = slice(
408            &SparseArray::try_new(
409                buffer![2u64, 5, 8].into_array(),
410                ConstantArray::new(
411                    Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable)),
412                    3,
413                )
414                .into_array(),
415                10,
416                Scalar::primitive(1.0f32, Nullability::Nullable),
417            )
418            .unwrap()
419            .into_array(),
420            2,
421            7,
422        )
423        .unwrap();
424
425        assert_eq!(
426            sliced.validity_mask().unwrap(),
427            Mask::from_iter(vec![false, true, true, false, true])
428        );
429    }
430
431    #[test]
432    pub fn scalar_at_sliced_twice() {
433        let sliced_once = slice(&sparse_array(nullable_fill()), 1, 8).unwrap();
434        assert_eq!(
435            usize::try_from(&scalar_at(&sliced_once, 1).unwrap()).unwrap(),
436            100
437        );
438        let error = scalar_at(&sliced_once, 7).err().unwrap();
439        let VortexError::OutOfBounds(i, start, stop, _) = error else {
440            unreachable!()
441        };
442        assert_eq!(i, 7);
443        assert_eq!(start, 0);
444        assert_eq!(stop, 7);
445
446        let sliced_twice = slice(&sliced_once, 1, 6).unwrap();
447        assert_eq!(
448            usize::try_from(&scalar_at(&sliced_twice, 3).unwrap()).unwrap(),
449            200
450        );
451        let error2 = scalar_at(&sliced_twice, 5).err().unwrap();
452        let VortexError::OutOfBounds(i, start, stop, _) = error2 else {
453            unreachable!()
454        };
455        assert_eq!(i, 5);
456        assert_eq!(start, 0);
457        assert_eq!(stop, 5);
458    }
459
460    #[test]
461    pub fn sparse_validity_mask() {
462        let array = sparse_array(nullable_fill());
463        assert_eq!(
464            array
465                .validity_mask()
466                .unwrap()
467                .to_boolean_buffer()
468                .iter()
469                .collect_vec(),
470            [
471                false, false, true, false, false, true, false, false, true, false
472            ]
473        );
474    }
475
476    #[test]
477    fn sparse_validity_mask_non_null_fill() {
478        let array = sparse_array(non_nullable_fill());
479        assert!(array.validity_mask().unwrap().all_true());
480    }
481
482    #[test]
483    #[should_panic]
484    fn test_invalid_length() {
485        let values = buffer![15_u32, 135, 13531, 42].into_array();
486        let indices = buffer![10_u64, 11, 50, 100].into_array();
487
488        SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
489    }
490
491    #[test]
492    fn test_valid_length() {
493        let values = buffer![15_u32, 135, 13531, 42].into_array();
494        let indices = buffer![10_u64, 11, 50, 100].into_array();
495
496        SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
497    }
498
499    #[test]
500    fn encode_with_nulls() {
501        let sparse = SparseArray::encode(
502            &PrimitiveArray::new(
503                buffer![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4],
504                Validity::from_iter(vec![
505                    true, true, false, true, false, true, false, true, true, false, true, false,
506                ]),
507            )
508            .into_array(),
509            None,
510        )
511        .vortex_unwrap();
512        let canonical = sparse.to_primitive().vortex_unwrap();
513        assert_eq!(
514            sparse.validity_mask().unwrap(),
515            Mask::from_iter(vec![
516                true, true, false, true, false, true, false, true, true, false, true, false,
517            ])
518        );
519        assert_eq!(
520            canonical.as_slice::<i32>(),
521            vec![0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4]
522        );
523    }
524}