vortex_array/arrays/constant/
canonical.rs1use std::sync::Arc;
5
6use arrow_buffer::BooleanBuffer;
7use vortex_buffer::{Buffer, buffer};
8use vortex_dtype::{DType, Nullability, PType, match_each_native_ptype};
9use vortex_error::{VortexExpect, VortexResult};
10use vortex_scalar::{
11 BinaryScalar, BoolScalar, DecimalValue, ExtScalar, ListScalar, Scalar, ScalarValue,
12 StructScalar, Utf8Scalar, match_each_decimal_value, match_each_decimal_value_type,
13};
14
15use crate::arrays::constant::ConstantArray;
16use crate::arrays::primitive::PrimitiveArray;
17use crate::arrays::{
18 BinaryView, BoolArray, ConstantVTable, DecimalArray, ExtensionArray, ListArray, NullArray,
19 StructArray, VarBinViewArray, smallest_storage_type,
20};
21use crate::builders::{ArrayBuilderExt, builder_with_capacity};
22use crate::validity::Validity;
23use crate::vtable::CanonicalVTable;
24use crate::{Canonical, IntoArray};
25
26impl CanonicalVTable<ConstantVTable> for ConstantVTable {
27 fn canonicalize(array: &ConstantArray) -> VortexResult<Canonical> {
28 let scalar = array.scalar();
29
30 let validity = match array.dtype().nullability() {
31 Nullability::NonNullable => Validity::NonNullable,
32 Nullability::Nullable => match scalar.is_null() {
33 true => Validity::AllInvalid,
34 false => Validity::AllValid,
35 },
36 };
37
38 Ok(match array.dtype() {
39 DType::Null => Canonical::Null(NullArray::new(array.len())),
40 DType::Bool(..) => Canonical::Bool(BoolArray::new(
41 if BoolScalar::try_from(scalar)?.value().unwrap_or_default() {
42 BooleanBuffer::new_set(array.len())
43 } else {
44 BooleanBuffer::new_unset(array.len())
45 },
46 validity,
47 )),
48 DType::Primitive(ptype, ..) => {
49 match_each_native_ptype!(ptype, |P| {
50 Canonical::Primitive(PrimitiveArray::new(
51 if scalar.is_valid() {
52 Buffer::full(
53 P::try_from(scalar)
54 .vortex_expect("Couldn't unwrap scalar to primitive"),
55 array.len(),
56 )
57 } else {
58 Buffer::zeroed(array.len())
59 },
60 validity,
61 ))
62 })
63 }
64 DType::Decimal(decimal_type, ..) => {
65 let size = smallest_storage_type(decimal_type);
66 let decimal = scalar.as_decimal();
67 let Some(value) = decimal.decimal_value() else {
68 let all_null = match_each_decimal_value_type!(size, |D| {
69 DecimalArray::new(Buffer::<D>::zeroed(array.len()), *decimal_type, validity)
70 });
71 return Ok(Canonical::Decimal(all_null));
72 };
73
74 let decimal_array = match_each_decimal_value!(value, |value| {
75 DecimalArray::new(Buffer::full(value, array.len()), *decimal_type, validity)
76 });
77 Canonical::Decimal(decimal_array)
78 }
79 DType::Utf8(_) => {
80 let value = Utf8Scalar::try_from(scalar)?.value();
81 let const_value = value.as_ref().map(|v| v.as_bytes());
82 Canonical::VarBinView(canonical_byte_view(
83 const_value,
84 array.dtype(),
85 array.len(),
86 )?)
87 }
88 DType::Binary(_) => {
89 let value = BinaryScalar::try_from(scalar)?.value();
90 let const_value = value.as_ref().map(|v| v.as_slice());
91 Canonical::VarBinView(canonical_byte_view(
92 const_value,
93 array.dtype(),
94 array.len(),
95 )?)
96 }
97 DType::Struct(struct_dtype, _) => {
98 let value = StructScalar::try_from(scalar)?;
99 let fields: Vec<_> = match value.fields() {
100 Some(fields) => fields
101 .into_iter()
102 .map(|s| ConstantArray::new(s, array.len()).into_array())
103 .collect(),
104 None => {
105 assert!(validity.all_invalid()?);
106 struct_dtype
107 .fields()
108 .map(|dt| {
109 let scalar = Scalar::default_value(dt);
110 ConstantArray::new(scalar, array.len()).into_array()
111 })
112 .collect()
113 }
114 };
115 Canonical::Struct(StructArray::try_new_with_dtype(
116 fields,
117 struct_dtype.clone(),
118 array.len(),
119 validity,
120 )?)
121 }
122 DType::List(..) => {
123 let value = ListScalar::try_from(scalar)?;
124 Canonical::List(canonical_list_array(
125 value.elements(),
126 value.element_dtype(),
127 value.dtype().nullability(),
128 array.len(),
129 )?)
130 }
131 DType::Extension(ext_dtype) => {
132 let s = ExtScalar::try_from(scalar)?;
133
134 let storage_scalar = s.storage();
135 let storage_self = ConstantArray::new(storage_scalar, array.len()).into_array();
136 Canonical::Extension(ExtensionArray::new(ext_dtype.clone(), storage_self))
137 }
138 })
139 }
140}
141
142fn canonical_byte_view(
143 scalar_bytes: Option<&[u8]>,
144 dtype: &DType,
145 len: usize,
146) -> VortexResult<VarBinViewArray> {
147 match scalar_bytes {
148 None => {
149 let views = buffer![BinaryView::from(0_u128); len];
150
151 unsafe {
153 Ok(VarBinViewArray::new_unchecked(
154 views,
155 Default::default(),
156 dtype.clone(),
157 Validity::AllInvalid,
158 ))
159 }
160 }
161 Some(scalar_bytes) => {
162 let view = BinaryView::make_view(scalar_bytes, 0, 0);
165 let mut buffers = Vec::new();
166 if scalar_bytes.len() >= BinaryView::MAX_INLINED_SIZE {
167 buffers.push(Buffer::copy_from(scalar_bytes));
168 }
169
170 let views = buffer![view; len];
172
173 unsafe {
175 Ok(VarBinViewArray::new_unchecked(
176 views,
177 Arc::from(buffers),
178 dtype.clone(),
179 Validity::from(dtype.nullability()),
180 ))
181 }
182 }
183 }
184}
185
186fn canonical_list_array(
187 values: Option<Vec<Scalar>>,
188 element_dtype: &DType,
189 list_nullability: Nullability,
190 len: usize,
191) -> VortexResult<ListArray> {
192 match values {
193 None => ListArray::try_new(
194 Canonical::empty(element_dtype).into_array(),
195 ConstantArray::new(
196 Scalar::new(
197 DType::Primitive(PType::U64, Nullability::NonNullable),
198 ScalarValue::from(0),
199 ),
200 len + 1,
201 )
202 .into_array(),
203 Validity::AllInvalid,
204 ),
205 Some(vs) => {
206 let mut elements_builder = builder_with_capacity(element_dtype, len * vs.len());
207 for _ in 0..len {
208 for v in &vs {
209 elements_builder.append_scalar(v)?;
210 }
211 }
212 let offsets = if vs.is_empty() {
213 Buffer::zeroed(len + 1)
214 } else {
215 Buffer::from_trusted_len_iter(
216 (0..=len * vs.len()).step_by(vs.len()).map(|i| i as u64),
217 )
218 };
219
220 ListArray::try_new(
221 elements_builder.finish(),
222 offsets.into_array(),
223 Validity::from(list_nullability),
224 )
225 }
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use std::sync::Arc;
232
233 use enum_iterator::all;
234 use itertools::Itertools;
235 use vortex_dtype::half::f16;
236 use vortex_dtype::{DType, Nullability, PType};
237 use vortex_scalar::Scalar;
238
239 use crate::arrays::ConstantArray;
240 use crate::canonical::ToCanonical;
241 use crate::stats::{Stat, StatsProvider};
242 use crate::{Array, IntoArray};
243
244 #[test]
245 fn test_canonicalize_null() {
246 let const_null = ConstantArray::new(Scalar::null(DType::Null), 42);
247 let actual = const_null.to_null().unwrap();
248 assert_eq!(actual.len(), 42);
249 assert_eq!(actual.scalar_at(33), Scalar::null(DType::Null));
250 }
251
252 #[test]
253 fn test_canonicalize_const_str() {
254 let const_array = ConstantArray::new("four".to_string(), 4);
255
256 let canonical = const_array.to_varbinview().unwrap();
258
259 assert_eq!(canonical.len(), 4);
260
261 for i in 0..=3 {
262 assert_eq!(canonical.scalar_at(i), "four".into());
263 }
264 }
265
266 #[test]
267 fn test_canonicalize_propagates_stats() {
268 let scalar = Scalar::bool(true, Nullability::NonNullable);
269 let const_array = ConstantArray::new(scalar, 4).into_array();
270 let stats = const_array
271 .statistics()
272 .compute_all(&all::<Stat>().collect_vec())
273 .unwrap();
274 let canonical = const_array.to_canonical().unwrap();
275 let canonical_stats = canonical.as_ref().statistics();
276
277 let stats_ref = stats.as_typed_ref(canonical.as_ref().dtype());
278
279 for stat in all::<Stat>() {
280 if stat.dtype(canonical.as_ref().dtype()).is_none() {
281 continue;
282 }
283 assert_eq!(
284 canonical_stats.get(stat),
285 stats_ref.get(stat),
286 "stat mismatch {stat}"
287 );
288 }
289 }
290
291 #[test]
292 fn test_canonicalize_scalar_values() {
293 let f16_value = f16::from_f32(5.722046e-6);
294 let f16_scalar = Scalar::primitive(f16_value, Nullability::NonNullable);
295
296 let const_array = ConstantArray::new(f16_scalar.clone(), 1).into_array();
298 let canonical_const = const_array.to_primitive().unwrap();
299
300 assert_eq!(canonical_const.scalar_at(0), f16_scalar);
302 }
303
304 #[test]
305 fn test_canonicalize_lists() {
306 let list_scalar = Scalar::list(
307 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
308 vec![1u64.into(), 2u64.into()],
309 Nullability::NonNullable,
310 );
311 let const_array = ConstantArray::new(list_scalar, 2).into_array();
312 let canonical_const = const_array.to_list().unwrap();
313 assert_eq!(
314 canonical_const
315 .elements()
316 .to_primitive()
317 .unwrap()
318 .as_slice::<u64>(),
319 [1u64, 2, 1, 2]
320 );
321 assert_eq!(
322 canonical_const
323 .offsets()
324 .to_primitive()
325 .unwrap()
326 .as_slice::<u64>(),
327 [0u64, 2, 4]
328 );
329 }
330
331 #[test]
332 fn test_canonicalize_empty_list() {
333 let list_scalar = Scalar::list(
334 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
335 vec![],
336 Nullability::NonNullable,
337 );
338 let const_array = ConstantArray::new(list_scalar, 2).into_array();
339 let canonical_const = const_array.to_list().unwrap();
340 assert!(
341 canonical_const
342 .elements()
343 .to_primitive()
344 .unwrap()
345 .is_empty()
346 );
347 assert_eq!(
348 canonical_const
349 .offsets()
350 .to_primitive()
351 .unwrap()
352 .as_slice::<u64>(),
353 [0u64, 0, 0]
354 );
355 }
356
357 #[test]
358 fn test_canonicalize_null_list() {
359 let list_scalar = Scalar::null(DType::List(
360 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
361 Nullability::Nullable,
362 ));
363 let const_array = ConstantArray::new(list_scalar, 2).into_array();
364 let canonical_const = const_array.to_list().unwrap();
365 assert!(
366 canonical_const
367 .elements()
368 .to_primitive()
369 .unwrap()
370 .is_empty()
371 );
372 assert_eq!(
373 canonical_const
374 .offsets()
375 .to_primitive()
376 .unwrap()
377 .as_slice::<u64>(),
378 [0u64, 0, 0]
379 );
380 }
381
382 #[test]
383 fn test_canonicalize_nullable_struct() {
384 let array = ConstantArray::new(
385 Scalar::null(DType::struct_(
386 [(
387 "non_null_field",
388 DType::Primitive(PType::I8, Nullability::NonNullable),
389 )],
390 Nullability::Nullable,
391 )),
392 3,
393 );
394
395 let struct_array = array.to_struct().unwrap();
396 assert_eq!(struct_array.len(), 3);
397 assert_eq!(struct_array.valid_count().unwrap(), 0);
398
399 let field = struct_array.field_by_name("non_null_field").unwrap();
400
401 assert_eq!(
402 field.dtype(),
403 &DType::Primitive(PType::I8, Nullability::NonNullable)
404 );
405 }
406}