1use vortex_array::arrays::{BoolArray, BooleanBuffer, ConstantArray, PrimitiveArray};
2use vortex_array::patches::Patches;
3use vortex_array::validity::Validity;
4use vortex_array::{Array, ArrayCanonicalImpl, Canonical};
5use vortex_buffer::buffer;
6use vortex_dtype::{DType, NativePType, Nullability, PType, match_each_native_ptype};
7use vortex_error::{VortexError, VortexResult};
8use vortex_scalar::Scalar;
9
10use crate::SparseArray;
11
12impl ArrayCanonicalImpl for SparseArray {
13 fn _to_canonical(&self) -> VortexResult<Canonical> {
14 let resolved_patches = self.resolved_patches()?;
15 if resolved_patches.num_patches() == 0 {
16 return ConstantArray::new(self.fill_scalar().clone(), self.len()).to_canonical();
17 }
18
19 if matches!(self.dtype(), DType::Bool(_)) {
20 canonicalize_sparse_bools(&resolved_patches, self.fill_scalar())
21 } else {
22 let ptype = PType::try_from(resolved_patches.values().dtype())?;
23 match_each_native_ptype!(ptype, |$P| {
24 canonicalize_sparse_primitives::<$P>(
25 &resolved_patches,
26 &self.fill_scalar(),
27 )
28 })
29 }
30 }
31}
32
33fn canonicalize_sparse_bools(patches: &Patches, fill_value: &Scalar) -> VortexResult<Canonical> {
34 let (fill_bool, validity) = if fill_value.is_null() {
35 (false, Validity::AllInvalid)
36 } else {
37 (
38 fill_value.try_into()?,
39 if patches.dtype().nullability() == Nullability::NonNullable {
40 Validity::NonNullable
41 } else {
42 Validity::AllValid
43 },
44 )
45 };
46
47 let bools = BoolArray::new(
48 if fill_bool {
49 BooleanBuffer::new_set(patches.array_len())
50 } else {
51 BooleanBuffer::new_unset(patches.array_len())
52 },
53 validity,
54 );
55
56 bools.patch(patches).map(Canonical::Bool)
57}
58
59fn canonicalize_sparse_primitives<
60 T: NativePType + for<'a> TryFrom<&'a Scalar, Error = VortexError>,
61>(
62 patches: &Patches,
63 fill_value: &Scalar,
64) -> VortexResult<Canonical> {
65 let (primitive_fill, validity) = if fill_value.is_null() {
66 (T::default(), Validity::AllInvalid)
67 } else {
68 (
69 fill_value.try_into()?,
70 if patches.dtype().nullability() == Nullability::NonNullable {
71 Validity::NonNullable
72 } else {
73 Validity::AllValid
74 },
75 )
76 };
77
78 let parray = PrimitiveArray::new(buffer![primitive_fill; patches.array_len()], validity);
79
80 parray.patch(patches).map(Canonical::Primitive)
81}
82
83#[cfg(test)]
84mod test {
85 use rstest::rstest;
86 use vortex_array::arrays::{BoolArray, BooleanBufferBuilder, PrimitiveArray};
87 use vortex_array::validity::Validity;
88 use vortex_array::{Array, IntoArray, ToCanonical};
89 use vortex_buffer::buffer;
90 use vortex_dtype::{DType, Nullability, PType};
91 use vortex_scalar::Scalar;
92
93 use crate::SparseArray;
94
95 #[rstest]
96 #[case(Some(true))]
97 #[case(Some(false))]
98 #[case(None)]
99 fn test_sparse_bool(#[case] fill_value: Option<bool>) {
100 let indices = buffer![0u64, 1, 7].into_array();
101 let values = bool_array_from_nullable_vec(vec![Some(true), None, Some(false)], fill_value)
102 .into_array();
103 let sparse_bools =
104 SparseArray::try_new(indices, values, 10, Scalar::from(fill_value)).unwrap();
105 assert_eq!(sparse_bools.dtype(), &DType::Bool(Nullability::Nullable));
106
107 let flat_bools = sparse_bools.to_bool().unwrap();
108 let expected = bool_array_from_nullable_vec(
109 vec![
110 Some(true),
111 None,
112 fill_value,
113 fill_value,
114 fill_value,
115 fill_value,
116 fill_value,
117 Some(false),
118 fill_value,
119 fill_value,
120 ],
121 fill_value,
122 );
123
124 assert_eq!(flat_bools.boolean_buffer(), expected.boolean_buffer());
125 assert_eq!(flat_bools.validity(), expected.validity());
126
127 assert!(flat_bools.boolean_buffer().value(0));
128 assert!(flat_bools.validity().is_valid(0).unwrap());
129 assert_eq!(
130 flat_bools.boolean_buffer().value(1),
131 fill_value.unwrap_or_default()
132 );
133 assert!(!flat_bools.validity().is_valid(1).unwrap());
134 assert_eq!(
135 flat_bools.validity().is_valid(2).unwrap(),
136 fill_value.is_some()
137 );
138 assert!(!flat_bools.boolean_buffer().value(7));
139 assert!(flat_bools.validity().is_valid(7).unwrap());
140 }
141
142 fn bool_array_from_nullable_vec(
143 bools: Vec<Option<bool>>,
144 fill_value: Option<bool>,
145 ) -> BoolArray {
146 let mut buffer = BooleanBufferBuilder::new(bools.len());
147 let mut validity = BooleanBufferBuilder::new(bools.len());
148 for maybe_bool in bools {
149 buffer.append(maybe_bool.unwrap_or_else(|| fill_value.unwrap_or_default()));
150 validity.append(maybe_bool.is_some());
151 }
152 BoolArray::new(buffer.finish(), Validity::from(validity.finish()))
153 }
154
155 #[rstest]
156 #[case(Some(0i32))]
157 #[case(Some(-1i32))]
158 #[case(None)]
159 fn test_sparse_primitive(#[case] fill_value: Option<i32>) {
160 use vortex_scalar::Scalar;
161
162 let indices = buffer![0u64, 1, 7].into_array();
163 let values = PrimitiveArray::from_option_iter([Some(0i32), None, Some(1)]).into_array();
164 let sparse_ints =
165 SparseArray::try_new(indices, values, 10, Scalar::from(fill_value)).unwrap();
166 assert_eq!(
167 *sparse_ints.dtype(),
168 DType::Primitive(PType::I32, Nullability::Nullable)
169 );
170
171 let flat_ints = sparse_ints.to_primitive().unwrap();
172 let expected = PrimitiveArray::from_option_iter([
173 Some(0i32),
174 None,
175 fill_value,
176 fill_value,
177 fill_value,
178 fill_value,
179 fill_value,
180 Some(1),
181 fill_value,
182 fill_value,
183 ]);
184
185 assert_eq!(flat_ints.byte_buffer(), expected.byte_buffer());
186 assert_eq!(flat_ints.validity(), expected.validity());
187
188 assert_eq!(flat_ints.as_slice::<i32>()[0], 0);
189 assert!(flat_ints.validity().is_valid(0).unwrap());
190 assert_eq!(flat_ints.as_slice::<i32>()[1], 0);
191 assert!(!flat_ints.validity().is_valid(1).unwrap());
192 assert_eq!(
193 flat_ints.as_slice::<i32>()[2],
194 fill_value.unwrap_or_default()
195 );
196 assert_eq!(
197 flat_ints.validity().is_valid(2).unwrap(),
198 fill_value.is_some()
199 );
200 assert_eq!(flat_ints.as_slice::<i32>()[7], 1);
201 assert!(flat_ints.validity().is_valid(7).unwrap());
202 }
203}