vortex_sparse/
lib.rs

1use std::fmt::Debug;
2
3use vortex_array::arrays::BooleanBufferBuilder;
4use vortex_array::compute::{scalar_at, sub_scalar};
5use vortex_array::patches::Patches;
6use vortex_array::stats::{ArrayStats, Stat, StatsSet, StatsSetRef};
7use vortex_array::variants::PrimitiveArrayTrait;
8use vortex_array::vtable::{EncodingVTable, StatisticsVTable, VTableRef};
9use vortex_array::{
10    Array, ArrayImpl, ArrayRef, ArrayStatisticsImpl, ArrayValidityImpl, Encoding, EncodingId,
11    RkyvMetadata, ToCanonical, try_from_array_ref,
12};
13use vortex_dtype::{DType, match_each_integer_ptype};
14use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
15use vortex_mask::Mask;
16use vortex_scalar::Scalar;
17
18use crate::serde::SparseMetadata;
19
20mod canonical;
21mod compute;
22mod serde;
23mod variants;
24
25#[derive(Clone, Debug)]
26pub struct SparseArray {
27    patches: Patches,
28    fill_value: Scalar,
29    stats_set: ArrayStats,
30}
31
32try_from_array_ref!(SparseArray);
33
34pub struct SparseEncoding;
35impl Encoding for SparseEncoding {
36    type Array = SparseArray;
37    type Metadata = RkyvMetadata<SparseMetadata>;
38}
39
40impl EncodingVTable for SparseEncoding {
41    fn id(&self) -> EncodingId {
42        EncodingId::new_ref("vortex.sparse")
43    }
44}
45
46impl SparseArray {
47    pub fn try_new(
48        indices: ArrayRef,
49        values: ArrayRef,
50        len: usize,
51        fill_value: Scalar,
52    ) -> VortexResult<Self> {
53        Self::try_new_with_offset(indices, values, len, 0, fill_value)
54    }
55
56    pub(crate) fn try_new_with_offset(
57        indices: ArrayRef,
58        values: ArrayRef,
59        len: usize,
60        indices_offset: usize,
61        fill_value: Scalar,
62    ) -> VortexResult<Self> {
63        if indices.len() != values.len() {
64            vortex_bail!(
65                "Mismatched indices {} and values {} length",
66                indices.len(),
67                values.len()
68            );
69        }
70
71        if !indices.is_empty() {
72            let last_index = usize::try_from(&scalar_at(&indices, indices.len() - 1)?)?;
73
74            if last_index - indices_offset >= len {
75                vortex_bail!("Array length was set to {len} but the last index is {last_index}");
76            }
77        }
78
79        let patches = Patches::new(len, indices_offset, indices, values);
80
81        Self::try_new_from_patches(patches, fill_value)
82    }
83
84    pub fn try_new_from_patches(patches: Patches, fill_value: Scalar) -> VortexResult<Self> {
85        if fill_value.dtype() != patches.values().dtype() {
86            vortex_bail!(
87                "fill value, {:?}, should be instance of values dtype, {}",
88                fill_value,
89                patches.values().dtype(),
90            );
91        }
92        Ok(Self {
93            patches,
94            fill_value,
95            stats_set: Default::default(),
96        })
97    }
98
99    #[inline]
100    pub fn patches(&self) -> &Patches {
101        &self.patches
102    }
103
104    #[inline]
105    pub fn resolved_patches(&self) -> VortexResult<Patches> {
106        let (len, offset, indices, values) = self.patches().clone().into_parts();
107        let indices_offset = Scalar::from(offset).cast(indices.dtype())?;
108        let indices = sub_scalar(&indices, indices_offset)?;
109        Ok(Patches::new(len, 0, indices, values))
110    }
111
112    #[inline]
113    pub fn fill_scalar(&self) -> &Scalar {
114        &self.fill_value
115    }
116}
117
118impl ArrayImpl for SparseArray {
119    type Encoding = SparseEncoding;
120
121    fn _len(&self) -> usize {
122        self.patches.array_len()
123    }
124
125    fn _dtype(&self) -> &DType {
126        self.fill_value.dtype()
127    }
128
129    fn _vtable(&self) -> VTableRef {
130        VTableRef::new_ref(&SparseEncoding)
131    }
132}
133
134impl ArrayStatisticsImpl for SparseArray {
135    fn _stats_ref(&self) -> StatsSetRef<'_> {
136        self.stats_set.to_ref(self)
137    }
138}
139
140impl ArrayValidityImpl for SparseArray {
141    fn _is_valid(&self, index: usize) -> VortexResult<bool> {
142        Ok(match self.patches().get_patched(index)? {
143            None => self.fill_scalar().is_valid(),
144            Some(patch_value) => patch_value.is_valid(),
145        })
146    }
147
148    fn _all_valid(&self) -> VortexResult<bool> {
149        if self.fill_scalar().is_null() {
150            // We need _all_ values to be patched, and all patches to be valid
151            return Ok(self.patches().values().len() == self.len()
152                && self.patches().values().all_valid()?);
153        }
154
155        self.patches().values().all_valid()
156    }
157
158    fn _all_invalid(&self) -> VortexResult<bool> {
159        if !self.fill_scalar().is_null() {
160            // We need _all_ values to be patched, and all patches to be invalid
161            return Ok(self.patches().values().len() == self.len()
162                && self.patches().values().all_invalid()?);
163        }
164
165        self.patches().values().all_invalid()
166    }
167
168    fn _validity_mask(&self) -> VortexResult<Mask> {
169        let indices = self.patches().indices().to_primitive()?;
170
171        if self.fill_scalar().is_null() {
172            // If we have a null fill value, then we set each patch value to true.
173            let mut buffer = BooleanBufferBuilder::new(self.len());
174            // TODO(ngates): use vortex-buffer::BitBufferMut when it exists.
175            buffer.append_n(self.len(), false);
176
177            match_each_integer_ptype!(indices.ptype(), |$I| {
178                indices.as_slice::<$I>().into_iter().for_each(|&index| {
179                    buffer.set_bit(index.try_into().vortex_expect("Failed to cast to usize"), true);
180                });
181            });
182
183            return Ok(Mask::from_buffer(buffer.finish()));
184        }
185
186        // If the fill_value is non-null, then the validity is based on the validity of the
187        // patch values.
188        let mut buffer = BooleanBufferBuilder::new(self.len());
189        buffer.append_n(self.len(), true);
190
191        let values_validity = self.patches().values().validity_mask()?;
192        match_each_integer_ptype!(indices.ptype(), |$I| {
193            indices.as_slice::<$I>()
194                .into_iter()
195                .enumerate()
196                .for_each(|(patch_idx, &index)| {
197                    buffer.set_bit(index.try_into().vortex_expect("failed to cast to usize"), values_validity.value(patch_idx));
198                })
199        });
200
201        Ok(Mask::from_buffer(buffer.finish()))
202    }
203}
204
205impl StatisticsVTable<&SparseArray> for SparseEncoding {
206    fn compute_statistics(&self, array: &SparseArray, stat: Stat) -> VortexResult<StatsSet> {
207        let values = array.patches().clone().into_values();
208        let stats = values.statistics().compute_all(&[stat])?;
209        if array.len() == values.len() {
210            return Ok(stats);
211        }
212
213        let fill_len = array.len() - values.len();
214        let fill_stats = if array.fill_scalar().is_null() {
215            StatsSet::nulls(fill_len)
216        } else {
217            StatsSet::constant(array.fill_scalar().clone(), fill_len)
218        };
219
220        if values.is_empty() {
221            return Ok(fill_stats);
222        }
223
224        Ok(stats.merge_unordered(&fill_stats, array.dtype()))
225    }
226}
227
228#[cfg(test)]
229mod test {
230    use itertools::Itertools;
231    use vortex_array::IntoArray;
232    use vortex_array::arrays::ConstantArray;
233    use vortex_array::compute::{slice, try_cast};
234    use vortex_buffer::buffer;
235    use vortex_dtype::Nullability::Nullable;
236    use vortex_dtype::{DType, PType};
237    use vortex_error::VortexError;
238    use vortex_scalar::{PrimitiveScalar, Scalar};
239
240    use super::*;
241
242    fn nullable_fill() -> Scalar {
243        Scalar::null(DType::Primitive(PType::I32, Nullable))
244    }
245
246    fn non_nullable_fill() -> Scalar {
247        Scalar::from(42i32)
248    }
249
250    fn sparse_array(fill_value: Scalar) -> ArrayRef {
251        // merged array: [null, null, 100, null, null, 200, null, null, 300, null]
252        let mut values = buffer![100i32, 200, 300].into_array();
253        values = try_cast(&values, fill_value.dtype()).unwrap();
254
255        SparseArray::try_new(buffer![2u64, 5, 8].into_array(), values, 10, fill_value)
256            .unwrap()
257            .into_array()
258    }
259
260    #[test]
261    pub fn test_scalar_at() {
262        let array = sparse_array(nullable_fill());
263
264        assert_eq!(scalar_at(&array, 0).unwrap(), nullable_fill());
265        assert_eq!(scalar_at(&array, 2).unwrap(), Scalar::from(Some(100_i32)));
266        assert_eq!(scalar_at(&array, 5).unwrap(), Scalar::from(Some(200_i32)));
267
268        let error = scalar_at(&array, 10).err().unwrap();
269        let VortexError::OutOfBounds(i, start, stop, _) = error else {
270            unreachable!()
271        };
272        assert_eq!(i, 10);
273        assert_eq!(start, 0);
274        assert_eq!(stop, 10);
275    }
276
277    #[test]
278    pub fn test_scalar_at_again() {
279        let arr = SparseArray::try_new(
280            ConstantArray::new(10u32, 1).into_array(),
281            ConstantArray::new(Scalar::primitive(1234u32, Nullable), 1).into_array(),
282            100,
283            Scalar::null(DType::Primitive(PType::U32, Nullable)),
284        )
285        .unwrap();
286
287        assert_eq!(
288            PrimitiveScalar::try_from(&scalar_at(&arr, 10).unwrap())
289                .unwrap()
290                .typed_value::<u32>(),
291            Some(1234)
292        );
293        assert!(scalar_at(&arr, 0).unwrap().is_null());
294        assert!(scalar_at(&arr, 99).unwrap().is_null());
295    }
296
297    #[test]
298    pub fn scalar_at_sliced() {
299        let sliced = slice(&sparse_array(nullable_fill()), 2, 7).unwrap();
300        assert_eq!(
301            usize::try_from(&scalar_at(&sliced, 0).unwrap()).unwrap(),
302            100
303        );
304        let error = scalar_at(&sliced, 5).err().unwrap();
305        let VortexError::OutOfBounds(i, start, stop, _) = error else {
306            unreachable!()
307        };
308        assert_eq!(i, 5);
309        assert_eq!(start, 0);
310        assert_eq!(stop, 5);
311    }
312
313    #[test]
314    pub fn scalar_at_sliced_twice() {
315        let sliced_once = slice(&sparse_array(nullable_fill()), 1, 8).unwrap();
316        assert_eq!(
317            usize::try_from(&scalar_at(&sliced_once, 1).unwrap()).unwrap(),
318            100
319        );
320        let error = scalar_at(&sliced_once, 7).err().unwrap();
321        let VortexError::OutOfBounds(i, start, stop, _) = error else {
322            unreachable!()
323        };
324        assert_eq!(i, 7);
325        assert_eq!(start, 0);
326        assert_eq!(stop, 7);
327
328        let sliced_twice = slice(&sliced_once, 1, 6).unwrap();
329        assert_eq!(
330            usize::try_from(&scalar_at(&sliced_twice, 3).unwrap()).unwrap(),
331            200
332        );
333        let error2 = scalar_at(&sliced_twice, 5).err().unwrap();
334        let VortexError::OutOfBounds(i, start, stop, _) = error2 else {
335            unreachable!()
336        };
337        assert_eq!(i, 5);
338        assert_eq!(start, 0);
339        assert_eq!(stop, 5);
340    }
341
342    #[test]
343    pub fn sparse_validity_mask() {
344        let array = sparse_array(nullable_fill());
345        assert_eq!(
346            array
347                .validity_mask()
348                .unwrap()
349                .to_boolean_buffer()
350                .iter()
351                .collect_vec(),
352            [
353                false, false, true, false, false, true, false, false, true, false
354            ]
355        );
356    }
357
358    #[test]
359    fn sparse_validity_mask_non_null_fill() {
360        let array = sparse_array(non_nullable_fill());
361        assert!(array.validity_mask().unwrap().all_true());
362    }
363
364    #[test]
365    #[should_panic]
366    fn test_invalid_length() {
367        let values = buffer![15_u32, 135, 13531, 42].into_array();
368        let indices = buffer![10_u64, 11, 50, 100].into_array();
369
370        SparseArray::try_new(indices, values, 100, 0_u32.into()).unwrap();
371    }
372
373    #[test]
374    fn test_valid_length() {
375        let values = buffer![15_u32, 135, 13531, 42].into_array();
376        let indices = buffer![10_u64, 11, 50, 100].into_array();
377
378        SparseArray::try_new(indices, values, 101, 0_u32.into()).unwrap();
379    }
380}