vortex_array/arrays/constant/
canonical.rs1use arrow_buffer::BooleanBuffer;
2use vortex_buffer::{Buffer, BufferMut, buffer};
3use vortex_dtype::{DType, Nullability, PType, match_each_native_ptype};
4use vortex_error::{VortexExpect, VortexResult};
5use vortex_scalar::{
6 BinaryScalar, BoolScalar, DecimalValue, ExtScalar, ListScalar, Scalar, ScalarValue,
7 StructScalar, Utf8Scalar, match_each_decimal_value, match_each_decimal_value_type,
8};
9
10use crate::arrays::constant::ConstantArray;
11use crate::arrays::primitive::PrimitiveArray;
12use crate::arrays::{
13 BinaryView, BoolArray, ConstantVTable, DecimalArray, ExtensionArray, ListArray, NullArray,
14 StructArray, VarBinViewArray, smallest_storage_type,
15};
16use crate::builders::{ArrayBuilderExt, builder_with_capacity};
17use crate::validity::Validity;
18use crate::vtable::CanonicalVTable;
19use crate::{Canonical, IntoArray};
20
21impl CanonicalVTable<ConstantVTable> for ConstantVTable {
22 fn canonicalize(array: &ConstantArray) -> VortexResult<Canonical> {
23 let scalar = array.scalar();
24
25 let validity = match array.dtype().nullability() {
26 Nullability::NonNullable => Validity::NonNullable,
27 Nullability::Nullable => match scalar.is_null() {
28 true => Validity::AllInvalid,
29 false => Validity::AllValid,
30 },
31 };
32
33 Ok(match array.dtype() {
34 DType::Null => Canonical::Null(NullArray::new(array.len())),
35 DType::Bool(..) => Canonical::Bool(BoolArray::new(
36 if BoolScalar::try_from(scalar)?.value().unwrap_or_default() {
37 BooleanBuffer::new_set(array.len())
38 } else {
39 BooleanBuffer::new_unset(array.len())
40 },
41 validity,
42 )),
43 DType::Primitive(ptype, ..) => {
44 match_each_native_ptype!(ptype, |P| {
45 Canonical::Primitive(PrimitiveArray::new(
46 if scalar.is_valid() {
47 Buffer::full(
48 P::try_from(scalar)
49 .vortex_expect("Couldn't unwrap scalar to primitive"),
50 array.len(),
51 )
52 } else {
53 Buffer::zeroed(array.len())
54 },
55 validity,
56 ))
57 })
58 }
59 DType::Decimal(decimal_type, ..) => {
60 let size = smallest_storage_type(decimal_type);
61 let decimal = scalar.as_decimal();
62 let Some(value) = decimal.decimal_value() else {
63 let all_null = match_each_decimal_value_type!(size, |D| {
64 DecimalArray::new(Buffer::<D>::zeroed(array.len()), *decimal_type, validity)
65 });
66 return Ok(Canonical::Decimal(all_null));
67 };
68
69 let decimal_array = match_each_decimal_value!(value, |value| {
70 DecimalArray::new(Buffer::full(*value, array.len()), *decimal_type, validity)
71 });
72 Canonical::Decimal(decimal_array)
73 }
74 DType::Utf8(_) => {
75 let value = Utf8Scalar::try_from(scalar)?.value();
76 let const_value = value.as_ref().map(|v| v.as_bytes());
77 Canonical::VarBinView(canonical_byte_view(
78 const_value,
79 array.dtype(),
80 array.len(),
81 )?)
82 }
83 DType::Binary(_) => {
84 let value = BinaryScalar::try_from(scalar)?.value();
85 let const_value = value.as_ref().map(|v| v.as_slice());
86 Canonical::VarBinView(canonical_byte_view(
87 const_value,
88 array.dtype(),
89 array.len(),
90 )?)
91 }
92 DType::Struct(struct_dtype, _) => {
93 let value = StructScalar::try_from(scalar)?;
94 let fields = value.fields().map(|fields| {
95 fields
96 .into_iter()
97 .map(|s| ConstantArray::new(s, array.len()).into_array())
98 .collect::<Vec<_>>()
99 });
100 Canonical::Struct(StructArray::try_new_with_dtype(
101 fields.unwrap_or_default(),
102 struct_dtype.clone(),
103 array.len(),
104 validity,
105 )?)
106 }
107 DType::List(..) => {
108 let value = ListScalar::try_from(scalar)?;
109 Canonical::List(canonical_list_array(
110 value.elements(),
111 value.element_dtype(),
112 value.dtype().nullability(),
113 array.len(),
114 )?)
115 }
116 DType::Extension(ext_dtype) => {
117 let s = ExtScalar::try_from(scalar)?;
118
119 let storage_scalar = s.storage();
120 let storage_self = ConstantArray::new(storage_scalar, array.len()).into_array();
121 Canonical::Extension(ExtensionArray::new(ext_dtype.clone(), storage_self))
122 }
123 })
124 }
125}
126
127fn canonical_byte_view(
128 scalar_bytes: Option<&[u8]>,
129 dtype: &DType,
130 len: usize,
131) -> VortexResult<VarBinViewArray> {
132 match scalar_bytes {
133 None => {
134 let views = buffer![BinaryView::from(0_u128); len];
135
136 VarBinViewArray::try_new(views, Vec::new(), dtype.clone(), Validity::AllInvalid)
137 }
138 Some(scalar_bytes) => {
139 let view = BinaryView::make_view(scalar_bytes, 0, 0);
142 let mut buffers = Vec::new();
143 if scalar_bytes.len() >= BinaryView::MAX_INLINED_SIZE {
144 buffers.push(Buffer::copy_from(scalar_bytes));
145 }
146
147 let mut views = BufferMut::with_capacity_aligned(len, align_of::<u128>().into());
151 for _ in 0..len {
152 views.push(view);
153 }
154
155 VarBinViewArray::try_new(
156 views.freeze(),
157 buffers,
158 dtype.clone(),
159 Validity::from(dtype.nullability()),
160 )
161 }
162 }
163}
164
165fn canonical_list_array(
166 values: Option<Vec<Scalar>>,
167 element_dtype: &DType,
168 list_nullability: Nullability,
169 len: usize,
170) -> VortexResult<ListArray> {
171 match values {
172 None => ListArray::try_new(
173 Canonical::empty(element_dtype).into_array(),
174 ConstantArray::new(
175 Scalar::new(
176 DType::Primitive(PType::U64, Nullability::NonNullable),
177 ScalarValue::from(0),
178 ),
179 len + 1,
180 )
181 .into_array(),
182 Validity::AllInvalid,
183 ),
184 Some(vs) => {
185 let mut elements_builder = builder_with_capacity(element_dtype, len * vs.len());
186 for _ in 0..len {
187 for v in &vs {
188 elements_builder.append_scalar(v)?;
189 }
190 }
191 let offsets = if vs.is_empty() {
192 Buffer::zeroed(len + 1)
193 } else {
194 (0..=len * vs.len())
195 .step_by(vs.len())
196 .map(|i| i as u64)
197 .collect::<Buffer<_>>()
198 };
199
200 ListArray::try_new(
201 elements_builder.finish(),
202 offsets.into_array(),
203 Validity::from(list_nullability),
204 )
205 }
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use std::sync::Arc;
212
213 use enum_iterator::all;
214 use vortex_dtype::half::f16;
215 use vortex_dtype::{DType, Nullability, PType};
216 use vortex_scalar::Scalar;
217
218 use crate::arrays::ConstantArray;
219 use crate::canonical::ToCanonical;
220 use crate::stats::{Stat, StatsProviderExt, StatsSet};
221 use crate::{Array, IntoArray};
222
223 #[test]
224 fn test_canonicalize_null() {
225 let const_null = ConstantArray::new(Scalar::null(DType::Null), 42);
226 let actual = const_null.to_null().unwrap();
227 assert_eq!(actual.len(), 42);
228 assert_eq!(actual.scalar_at(33).unwrap(), Scalar::null(DType::Null));
229 }
230
231 #[test]
232 fn test_canonicalize_const_str() {
233 let const_array = ConstantArray::new("four".to_string(), 4);
234
235 let canonical = const_array.to_varbinview().unwrap();
237
238 assert_eq!(canonical.len(), 4);
239
240 for i in 0..=3 {
241 assert_eq!(canonical.scalar_at(i).unwrap(), "four".into());
242 }
243 }
244
245 #[test]
246 fn test_canonicalize_propagates_stats() {
247 let scalar = Scalar::bool(true, Nullability::NonNullable);
248 let const_array = ConstantArray::new(scalar.clone(), 4).into_array();
249 let stats = const_array.statistics().to_owned();
250
251 let canonical = const_array.to_canonical().unwrap();
252 let canonical_stats = canonical.as_ref().statistics().to_owned();
253
254 let reference = StatsSet::constant(scalar, 4);
255 for stat in all::<Stat>() {
256 if stat.dtype(canonical.as_ref().dtype()).is_none() {
257 continue;
258 }
259
260 let canonical_stat =
261 canonical_stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
262 let reference_stat =
263 reference.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
264 let original_stat =
265 stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
266 assert_eq!(canonical_stat, reference_stat);
267 assert_eq!(canonical_stat, original_stat);
268 }
269 }
270
271 #[test]
272 fn test_canonicalize_scalar_values() {
273 let f16_scalar = Scalar::primitive(f16::from_f32(5.722046e-6), Nullability::NonNullable);
274 let scalar = Scalar::new(
275 DType::Primitive(PType::F16, Nullability::NonNullable),
276 Scalar::primitive(96u8, Nullability::NonNullable).into_value(),
277 );
278 let const_array = ConstantArray::new(scalar.clone(), 1).into_array();
279 let canonical_const = const_array.to_primitive().unwrap();
280 assert_eq!(canonical_const.scalar_at(0).unwrap(), scalar);
281 assert_eq!(canonical_const.scalar_at(0).unwrap(), f16_scalar);
282 }
283
284 #[test]
285 fn test_canonicalize_lists() {
286 let list_scalar = Scalar::list(
287 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
288 vec![1u64.into(), 2u64.into()],
289 Nullability::NonNullable,
290 );
291 let const_array = ConstantArray::new(list_scalar, 2).into_array();
292 let canonical_const = const_array.to_list().unwrap();
293 assert_eq!(
294 canonical_const
295 .elements()
296 .to_primitive()
297 .unwrap()
298 .as_slice::<u64>(),
299 [1u64, 2, 1, 2]
300 );
301 assert_eq!(
302 canonical_const
303 .offsets()
304 .to_primitive()
305 .unwrap()
306 .as_slice::<u64>(),
307 [0u64, 2, 4]
308 );
309 }
310
311 #[test]
312 fn test_canonicalize_empty_list() {
313 let list_scalar = Scalar::list(
314 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
315 vec![],
316 Nullability::NonNullable,
317 );
318 let const_array = ConstantArray::new(list_scalar, 2).into_array();
319 let canonical_const = const_array.to_list().unwrap();
320 assert!(
321 canonical_const
322 .elements()
323 .to_primitive()
324 .unwrap()
325 .is_empty()
326 );
327 assert_eq!(
328 canonical_const
329 .offsets()
330 .to_primitive()
331 .unwrap()
332 .as_slice::<u64>(),
333 [0u64, 0, 0]
334 );
335 }
336
337 #[test]
338 fn test_canonicalize_null_list() {
339 let list_scalar = Scalar::null(DType::List(
340 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
341 Nullability::Nullable,
342 ));
343 let const_array = ConstantArray::new(list_scalar, 2).into_array();
344 let canonical_const = const_array.to_list().unwrap();
345 assert!(
346 canonical_const
347 .elements()
348 .to_primitive()
349 .unwrap()
350 .is_empty()
351 );
352 assert_eq!(
353 canonical_const
354 .offsets()
355 .to_primitive()
356 .unwrap()
357 .as_slice::<u64>(),
358 [0u64, 0, 0]
359 );
360 }
361}