vortex_sparse/
canonical.rs

1use vortex_array::arrays::{BoolArray, BooleanBuffer, ConstantArray, PrimitiveArray};
2use vortex_array::patches::Patches;
3use vortex_array::validity::Validity;
4use vortex_array::{Array, ArrayCanonicalImpl, Canonical};
5use vortex_buffer::buffer;
6use vortex_dtype::{DType, NativePType, Nullability, PType, match_each_native_ptype};
7use vortex_error::{VortexError, VortexResult};
8use vortex_scalar::Scalar;
9
10use crate::SparseArray;
11
12impl ArrayCanonicalImpl for SparseArray {
13    fn _to_canonical(&self) -> VortexResult<Canonical> {
14        let resolved_patches = self.resolved_patches()?;
15        if resolved_patches.num_patches() == 0 {
16            return ConstantArray::new(self.fill_scalar().clone(), self.len()).to_canonical();
17        }
18
19        if matches!(self.dtype(), DType::Bool(_)) {
20            canonicalize_sparse_bools(&resolved_patches, self.fill_scalar())
21        } else {
22            let ptype = PType::try_from(resolved_patches.values().dtype())?;
23            match_each_native_ptype!(ptype, |$P| {
24                canonicalize_sparse_primitives::<$P>(
25                    &resolved_patches,
26                    &self.fill_scalar(),
27                )
28            })
29        }
30    }
31}
32
33fn canonicalize_sparse_bools(patches: &Patches, fill_value: &Scalar) -> VortexResult<Canonical> {
34    let (fill_bool, validity) = if fill_value.is_null() {
35        (false, Validity::AllInvalid)
36    } else {
37        (
38            fill_value.try_into()?,
39            if patches.dtype().nullability() == Nullability::NonNullable {
40                Validity::NonNullable
41            } else {
42                Validity::AllValid
43            },
44        )
45    };
46
47    let bools = BoolArray::new(
48        if fill_bool {
49            BooleanBuffer::new_set(patches.array_len())
50        } else {
51            BooleanBuffer::new_unset(patches.array_len())
52        },
53        validity,
54    );
55
56    bools.patch(patches).map(Canonical::Bool)
57}
58
59fn canonicalize_sparse_primitives<
60    T: NativePType + for<'a> TryFrom<&'a Scalar, Error = VortexError>,
61>(
62    patches: &Patches,
63    fill_value: &Scalar,
64) -> VortexResult<Canonical> {
65    let (primitive_fill, validity) = if fill_value.is_null() {
66        (T::default(), Validity::AllInvalid)
67    } else {
68        (
69            fill_value.try_into()?,
70            if patches.dtype().nullability() == Nullability::NonNullable {
71                Validity::NonNullable
72            } else {
73                Validity::AllValid
74            },
75        )
76    };
77
78    let parray = PrimitiveArray::new(buffer![primitive_fill; patches.array_len()], validity);
79
80    parray.patch(patches).map(Canonical::Primitive)
81}
82
83#[cfg(test)]
84mod test {
85    use rstest::rstest;
86    use vortex_array::arrays::{BoolArray, BooleanBufferBuilder, PrimitiveArray};
87    use vortex_array::validity::Validity;
88    use vortex_array::{Array, IntoArray, ToCanonical};
89    use vortex_buffer::buffer;
90    use vortex_dtype::{DType, Nullability, PType};
91    use vortex_scalar::Scalar;
92
93    use crate::SparseArray;
94
95    #[rstest]
96    #[case(Some(true))]
97    #[case(Some(false))]
98    #[case(None)]
99    fn test_sparse_bool(#[case] fill_value: Option<bool>) {
100        let indices = buffer![0u64, 1, 7].into_array();
101        let values = bool_array_from_nullable_vec(vec![Some(true), None, Some(false)], fill_value)
102            .into_array();
103        let sparse_bools =
104            SparseArray::try_new(indices, values, 10, Scalar::from(fill_value)).unwrap();
105        assert_eq!(sparse_bools.dtype(), &DType::Bool(Nullability::Nullable));
106
107        let flat_bools = sparse_bools.to_bool().unwrap();
108        let expected = bool_array_from_nullable_vec(
109            vec![
110                Some(true),
111                None,
112                fill_value,
113                fill_value,
114                fill_value,
115                fill_value,
116                fill_value,
117                Some(false),
118                fill_value,
119                fill_value,
120            ],
121            fill_value,
122        );
123
124        assert_eq!(flat_bools.boolean_buffer(), expected.boolean_buffer());
125        assert_eq!(flat_bools.validity(), expected.validity());
126
127        assert!(flat_bools.boolean_buffer().value(0));
128        assert!(flat_bools.validity().is_valid(0).unwrap());
129        assert_eq!(
130            flat_bools.boolean_buffer().value(1),
131            fill_value.unwrap_or_default()
132        );
133        assert!(!flat_bools.validity().is_valid(1).unwrap());
134        assert_eq!(
135            flat_bools.validity().is_valid(2).unwrap(),
136            fill_value.is_some()
137        );
138        assert!(!flat_bools.boolean_buffer().value(7));
139        assert!(flat_bools.validity().is_valid(7).unwrap());
140    }
141
142    fn bool_array_from_nullable_vec(
143        bools: Vec<Option<bool>>,
144        fill_value: Option<bool>,
145    ) -> BoolArray {
146        let mut buffer = BooleanBufferBuilder::new(bools.len());
147        let mut validity = BooleanBufferBuilder::new(bools.len());
148        for maybe_bool in bools {
149            buffer.append(maybe_bool.unwrap_or_else(|| fill_value.unwrap_or_default()));
150            validity.append(maybe_bool.is_some());
151        }
152        BoolArray::new(buffer.finish(), Validity::from(validity.finish()))
153    }
154
155    #[rstest]
156    #[case(Some(0i32))]
157    #[case(Some(-1i32))]
158    #[case(None)]
159    fn test_sparse_primitive(#[case] fill_value: Option<i32>) {
160        use vortex_scalar::Scalar;
161
162        let indices = buffer![0u64, 1, 7].into_array();
163        let values = PrimitiveArray::from_option_iter([Some(0i32), None, Some(1)]).into_array();
164        let sparse_ints =
165            SparseArray::try_new(indices, values, 10, Scalar::from(fill_value)).unwrap();
166        assert_eq!(
167            *sparse_ints.dtype(),
168            DType::Primitive(PType::I32, Nullability::Nullable)
169        );
170
171        let flat_ints = sparse_ints.to_primitive().unwrap();
172        let expected = PrimitiveArray::from_option_iter([
173            Some(0i32),
174            None,
175            fill_value,
176            fill_value,
177            fill_value,
178            fill_value,
179            fill_value,
180            Some(1),
181            fill_value,
182            fill_value,
183        ]);
184
185        assert_eq!(flat_ints.byte_buffer(), expected.byte_buffer());
186        assert_eq!(flat_ints.validity(), expected.validity());
187
188        assert_eq!(flat_ints.as_slice::<i32>()[0], 0);
189        assert!(flat_ints.validity().is_valid(0).unwrap());
190        assert_eq!(flat_ints.as_slice::<i32>()[1], 0);
191        assert!(!flat_ints.validity().is_valid(1).unwrap());
192        assert_eq!(
193            flat_ints.as_slice::<i32>()[2],
194            fill_value.unwrap_or_default()
195        );
196        assert_eq!(
197            flat_ints.validity().is_valid(2).unwrap(),
198            fill_value.is_some()
199        );
200        assert_eq!(flat_ints.as_slice::<i32>()[7], 1);
201        assert!(flat_ints.validity().is_valid(7).unwrap());
202    }
203}