vortex_sparse/
canonical.rs

1use vortex_array::arrays::{BoolArray, BooleanBuffer, ConstantArray, PrimitiveArray};
2use vortex_array::patches::Patches;
3use vortex_array::validity::Validity;
4use vortex_array::vtable::CanonicalVTable;
5use vortex_array::{Array, Canonical};
6use vortex_buffer::buffer;
7use vortex_dtype::{DType, NativePType, Nullability, PType, match_each_native_ptype};
8use vortex_error::{VortexError, VortexResult};
9use vortex_scalar::Scalar;
10
11use crate::{SparseArray, SparseVTable};
12
13impl CanonicalVTable<SparseVTable> for SparseVTable {
14    fn canonicalize(array: &SparseArray) -> VortexResult<Canonical> {
15        let resolved_patches = array.resolved_patches()?;
16        if resolved_patches.num_patches() == 0 {
17            return ConstantArray::new(array.fill_scalar().clone(), array.len()).to_canonical();
18        }
19
20        if matches!(array.dtype(), DType::Bool(_)) {
21            canonicalize_sparse_bools(&resolved_patches, array.fill_scalar())
22        } else {
23            let ptype = PType::try_from(resolved_patches.values().dtype())?;
24            match_each_native_ptype!(ptype, |$P| {
25                canonicalize_sparse_primitives::<$P>(
26                    &resolved_patches,
27                    &array.fill_scalar(),
28                )
29            })
30        }
31    }
32}
33
34fn canonicalize_sparse_bools(patches: &Patches, fill_value: &Scalar) -> VortexResult<Canonical> {
35    let (fill_bool, validity) = if fill_value.is_null() {
36        (false, Validity::AllInvalid)
37    } else {
38        (
39            fill_value.try_into()?,
40            if patches.dtype().nullability() == Nullability::NonNullable {
41                Validity::NonNullable
42            } else {
43                Validity::AllValid
44            },
45        )
46    };
47
48    let bools = BoolArray::new(
49        if fill_bool {
50            BooleanBuffer::new_set(patches.array_len())
51        } else {
52            BooleanBuffer::new_unset(patches.array_len())
53        },
54        validity,
55    );
56
57    bools.patch(patches).map(Canonical::Bool)
58}
59
60fn canonicalize_sparse_primitives<
61    T: NativePType + for<'a> TryFrom<&'a Scalar, Error = VortexError>,
62>(
63    patches: &Patches,
64    fill_value: &Scalar,
65) -> VortexResult<Canonical> {
66    let (primitive_fill, validity) = if fill_value.is_null() {
67        (T::default(), Validity::AllInvalid)
68    } else {
69        (
70            fill_value.try_into()?,
71            if patches.dtype().nullability() == Nullability::NonNullable {
72                Validity::NonNullable
73            } else {
74                Validity::AllValid
75            },
76        )
77    };
78
79    let parray = PrimitiveArray::new(buffer![primitive_fill; patches.array_len()], validity);
80
81    parray.patch(patches).map(Canonical::Primitive)
82}
83
84#[cfg(test)]
85mod test {
86    use rstest::rstest;
87    use vortex_array::arrays::{BoolArray, BooleanBufferBuilder, PrimitiveArray};
88    use vortex_array::validity::Validity;
89    use vortex_array::vtable::ValidityHelper;
90    use vortex_array::{IntoArray, ToCanonical};
91    use vortex_buffer::buffer;
92    use vortex_dtype::{DType, Nullability, PType};
93    use vortex_scalar::Scalar;
94
95    use crate::SparseArray;
96
97    #[rstest]
98    #[case(Some(true))]
99    #[case(Some(false))]
100    #[case(None)]
101    fn test_sparse_bool(#[case] fill_value: Option<bool>) {
102        let indices = buffer![0u64, 1, 7].into_array();
103        let values = bool_array_from_nullable_vec(vec![Some(true), None, Some(false)], fill_value)
104            .into_array();
105        let sparse_bools =
106            SparseArray::try_new(indices, values, 10, Scalar::from(fill_value)).unwrap();
107        assert_eq!(sparse_bools.dtype(), &DType::Bool(Nullability::Nullable));
108
109        let flat_bools = sparse_bools.to_bool().unwrap();
110        let expected = bool_array_from_nullable_vec(
111            vec![
112                Some(true),
113                None,
114                fill_value,
115                fill_value,
116                fill_value,
117                fill_value,
118                fill_value,
119                Some(false),
120                fill_value,
121                fill_value,
122            ],
123            fill_value,
124        );
125
126        assert_eq!(flat_bools.boolean_buffer(), expected.boolean_buffer());
127        assert_eq!(flat_bools.validity(), expected.validity());
128
129        assert!(flat_bools.boolean_buffer().value(0));
130        assert!(flat_bools.validity().is_valid(0).unwrap());
131        assert_eq!(
132            flat_bools.boolean_buffer().value(1),
133            fill_value.unwrap_or_default()
134        );
135        assert!(!flat_bools.validity().is_valid(1).unwrap());
136        assert_eq!(
137            flat_bools.validity().is_valid(2).unwrap(),
138            fill_value.is_some()
139        );
140        assert!(!flat_bools.boolean_buffer().value(7));
141        assert!(flat_bools.validity().is_valid(7).unwrap());
142    }
143
144    fn bool_array_from_nullable_vec(
145        bools: Vec<Option<bool>>,
146        fill_value: Option<bool>,
147    ) -> BoolArray {
148        let mut buffer = BooleanBufferBuilder::new(bools.len());
149        let mut validity = BooleanBufferBuilder::new(bools.len());
150        for maybe_bool in bools {
151            buffer.append(maybe_bool.unwrap_or_else(|| fill_value.unwrap_or_default()));
152            validity.append(maybe_bool.is_some());
153        }
154        BoolArray::new(buffer.finish(), Validity::from(validity.finish()))
155    }
156
157    #[rstest]
158    #[case(Some(0i32))]
159    #[case(Some(-1i32))]
160    #[case(None)]
161    fn test_sparse_primitive(#[case] fill_value: Option<i32>) {
162        use vortex_scalar::Scalar;
163
164        let indices = buffer![0u64, 1, 7].into_array();
165        let values = PrimitiveArray::from_option_iter([Some(0i32), None, Some(1)]).into_array();
166        let sparse_ints =
167            SparseArray::try_new(indices, values, 10, Scalar::from(fill_value)).unwrap();
168        assert_eq!(
169            *sparse_ints.dtype(),
170            DType::Primitive(PType::I32, Nullability::Nullable)
171        );
172
173        let flat_ints = sparse_ints.to_primitive().unwrap();
174        let expected = PrimitiveArray::from_option_iter([
175            Some(0i32),
176            None,
177            fill_value,
178            fill_value,
179            fill_value,
180            fill_value,
181            fill_value,
182            Some(1),
183            fill_value,
184            fill_value,
185        ]);
186
187        assert_eq!(flat_ints.byte_buffer(), expected.byte_buffer());
188        assert_eq!(flat_ints.validity(), expected.validity());
189
190        assert_eq!(flat_ints.as_slice::<i32>()[0], 0);
191        assert!(flat_ints.validity().is_valid(0).unwrap());
192        assert_eq!(flat_ints.as_slice::<i32>()[1], 0);
193        assert!(!flat_ints.validity().is_valid(1).unwrap());
194        assert_eq!(
195            flat_ints.as_slice::<i32>()[2],
196            fill_value.unwrap_or_default()
197        );
198        assert_eq!(
199            flat_ints.validity().is_valid(2).unwrap(),
200            fill_value.is_some()
201        );
202        assert_eq!(flat_ints.as_slice::<i32>()[7], 1);
203        assert!(flat_ints.validity().is_valid(7).unwrap());
204    }
205}