1use std::sync::Arc;
5
6use arrow_buffer::BooleanBuffer;
7use vortex_buffer::{Buffer, buffer};
8use vortex_dtype::{DType, Nullability, match_each_native_ptype};
9use vortex_error::VortexExpect;
10use vortex_scalar::{
11 BinaryScalar, BoolScalar, DecimalValue, ExtScalar, ListScalar, Scalar, StructScalar,
12 Utf8Scalar, match_each_decimal_value, match_each_decimal_value_type,
13};
14
15use crate::arrays::binary_view::BinaryView;
16use crate::arrays::constant::ConstantArray;
17use crate::arrays::primitive::PrimitiveArray;
18use crate::arrays::{
19 BoolArray, ConstantVTable, DecimalArray, ExtensionArray, FixedSizeListArray, ListViewArray,
20 NullArray, StructArray, VarBinViewArray, smallest_decimal_value_type,
21};
22use crate::builders::builder_with_capacity;
23use crate::validity::Validity;
24use crate::vtable::CanonicalVTable;
25use crate::{Canonical, IntoArray};
26
27impl CanonicalVTable<ConstantVTable> for ConstantVTable {
28 fn canonicalize(array: &ConstantArray) -> Canonical {
29 let scalar = array.scalar();
30
31 let validity = match array.dtype().nullability() {
32 Nullability::NonNullable => Validity::NonNullable,
33 Nullability::Nullable => match scalar.is_null() {
34 true => Validity::AllInvalid,
35 false => Validity::AllValid,
36 },
37 };
38
39 match array.dtype() {
40 DType::Null => Canonical::Null(NullArray::new(array.len())),
41 DType::Bool(..) => Canonical::Bool(BoolArray::from_bool_buffer(
42 if BoolScalar::try_from(scalar)
43 .vortex_expect("must be bool")
44 .value()
45 .unwrap_or_default()
46 {
47 BooleanBuffer::new_set(array.len())
48 } else {
49 BooleanBuffer::new_unset(array.len())
50 },
51 validity,
52 )),
53 DType::Primitive(ptype, ..) => {
54 match_each_native_ptype!(ptype, |P| {
55 Canonical::Primitive(PrimitiveArray::new(
56 if scalar.is_valid() {
57 Buffer::full(
58 P::try_from(scalar)
59 .vortex_expect("Couldn't unwrap scalar to primitive"),
60 array.len(),
61 )
62 } else {
63 Buffer::zeroed(array.len())
64 },
65 validity,
66 ))
67 })
68 }
69 DType::Decimal(decimal_type, ..) => {
70 let size = smallest_decimal_value_type(decimal_type);
71 let decimal = scalar.as_decimal();
72 let Some(value) = decimal.decimal_value() else {
73 let all_null = match_each_decimal_value_type!(size, |D| {
74 unsafe {
76 DecimalArray::new_unchecked(
77 Buffer::<D>::zeroed(array.len()),
78 *decimal_type,
79 validity,
80 )
81 }
82 });
83 return Canonical::Decimal(all_null);
84 };
85
86 let decimal_array = match_each_decimal_value!(value, |value| {
87 unsafe {
89 DecimalArray::new_unchecked(
90 Buffer::full(value, array.len()),
91 *decimal_type,
92 validity,
93 )
94 }
95 });
96 Canonical::Decimal(decimal_array)
97 }
98 DType::Utf8(_) => {
99 let value = Utf8Scalar::try_from(scalar)
100 .vortex_expect("Must be a utf8 scalar")
101 .value();
102 let const_value = value.as_ref().map(|v| v.as_bytes());
103 Canonical::VarBinView(constant_canonical_byte_view(
104 const_value,
105 array.dtype(),
106 array.len(),
107 ))
108 }
109 DType::Binary(_) => {
110 let value = BinaryScalar::try_from(scalar)
111 .vortex_expect("must be a binary scalar")
112 .value();
113 let const_value = value.as_ref().map(|v| v.as_slice());
114 Canonical::VarBinView(constant_canonical_byte_view(
115 const_value,
116 array.dtype(),
117 array.len(),
118 ))
119 }
120 DType::Struct(struct_dtype, _) => {
121 let value = StructScalar::try_from(scalar).vortex_expect("must be struct");
122 let fields: Vec<_> = match value.fields() {
123 Some(fields) => fields
124 .into_iter()
125 .map(|s| ConstantArray::new(s, array.len()).into_array())
126 .collect(),
127 None => {
128 assert!(validity.all_invalid(array.len()));
129 struct_dtype
130 .fields()
131 .map(|dt| {
132 let scalar = Scalar::default_value(dt);
133 ConstantArray::new(scalar, array.len()).into_array()
134 })
135 .collect()
136 }
137 };
138 Canonical::Struct(unsafe {
141 StructArray::new_unchecked(fields, struct_dtype.clone(), array.len(), validity)
142 })
143 }
144 DType::List(..) => Canonical::List(constant_canonical_list_array(scalar, array.len())),
145 DType::FixedSizeList(element_dtype, list_size, _) => {
146 let value = ListScalar::try_from(scalar).vortex_expect("must be list");
147
148 Canonical::FixedSizeList(constant_canonical_fixed_size_list_array(
149 value.elements(),
150 element_dtype,
151 *list_size,
152 value.dtype().nullability(),
153 array.len(),
154 ))
155 }
156 DType::Extension(ext_dtype) => {
157 let s = ExtScalar::try_from(scalar).vortex_expect("must be an extension scalar");
158
159 let storage_scalar = s.storage();
160 let storage_self = ConstantArray::new(storage_scalar, array.len()).into_array();
161 Canonical::Extension(ExtensionArray::new(ext_dtype.clone(), storage_self))
162 }
163 }
164 }
165}
166
167fn constant_canonical_byte_view(
168 scalar_bytes: Option<&[u8]>,
169 dtype: &DType,
170 len: usize,
171) -> VarBinViewArray {
172 match scalar_bytes {
173 None => {
174 let views = buffer![BinaryView::from(0_u128); len];
175
176 unsafe {
178 VarBinViewArray::new_unchecked(
179 views,
180 Default::default(),
181 dtype.clone(),
182 Validity::AllInvalid,
183 )
184 }
185 }
186 Some(scalar_bytes) => {
187 let view = BinaryView::make_view(scalar_bytes, 0, 0);
190 let mut buffers = Vec::new();
191 if scalar_bytes.len() >= BinaryView::MAX_INLINED_SIZE {
192 buffers.push(Buffer::copy_from(scalar_bytes));
193 }
194
195 let views = buffer![view; len];
197
198 unsafe {
200 VarBinViewArray::new_unchecked(
201 views,
202 Arc::from(buffers),
203 dtype.clone(),
204 Validity::from(dtype.nullability()),
205 )
206 }
207 }
208 }
209}
210
211fn constant_canonical_list_array(scalar: &Scalar, len: usize) -> ListViewArray {
216 let list = ListScalar::try_from(scalar).vortex_expect("must be list");
217
218 let elements = if let Some(elements) = list.elements() {
221 let mut builder = builder_with_capacity(
223 list.dtype()
224 .as_list_element_opt()
225 .vortex_expect("list scalar somehow did not have a list DType"),
226 list.len(),
227 );
228 for scalar in &elements {
229 builder
230 .append_scalar(scalar)
231 .vortex_expect("list element scalar was invalid");
232 }
233 builder.finish()
234 } else {
235 Canonical::empty(list.element_dtype()).into_array()
237 };
238
239 let validity = if scalar.dtype().is_nullable() {
240 if list.is_null() {
241 Validity::AllInvalid
242 } else {
243 Validity::AllValid
244 }
245 } else {
246 debug_assert!(!list.is_null());
247 Validity::NonNullable
248 };
249
250 let offsets = ConstantArray::new::<u64>(0, len).into_array();
252 let sizes = ConstantArray::new::<u64>(list.len() as u64, len).into_array();
253
254 debug_assert!(!offsets.dtype().is_nullable());
255 debug_assert!(!sizes.dtype().is_nullable());
256
257 unsafe { ListViewArray::new_unchecked(elements, offsets, sizes, validity) }
261}
262
263fn constant_canonical_fixed_size_list_array(
264 values: Option<Vec<Scalar>>,
265 element_dtype: &DType,
266 list_size: u32,
267 list_nullability: Nullability,
268 len: usize,
269) -> FixedSizeListArray {
270 match values {
271 None => {
272 let elements_len = list_size as usize * len;
275 let mut element_builder = builder_with_capacity(element_dtype, elements_len);
276 element_builder.append_defaults(elements_len);
277 let elements = element_builder.finish();
278
279 unsafe {
282 FixedSizeListArray::new_unchecked(elements, list_size, Validity::AllInvalid, len)
283 }
284 }
285 Some(values) => {
286 let mut elements_builder = builder_with_capacity(element_dtype, len * values.len());
287
288 for _ in 0..len {
289 for v in &values {
290 elements_builder
291 .append_scalar(v)
292 .vortex_expect("must be a same dtype");
293 }
294 }
295
296 let elements = elements_builder.finish();
297 let validity = Validity::from(list_nullability);
298
299 unsafe { FixedSizeListArray::new_unchecked(elements, list_size, validity, len) }
302 }
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use std::sync::Arc;
309
310 use enum_iterator::all;
311 use itertools::Itertools;
312 use vortex_dtype::half::f16;
313 use vortex_dtype::{DType, Nullability, PType};
314 use vortex_scalar::Scalar;
315
316 use crate::arrays::{ConstantArray, ListViewRebuildMode};
317 use crate::canonical::ToCanonical;
318 use crate::stats::{Stat, StatsProvider};
319 use crate::validity::Validity;
320 use crate::vtable::ValidityHelper;
321 use crate::{Array, IntoArray};
322
323 #[test]
324 fn test_canonicalize_null() {
325 let const_null = ConstantArray::new(Scalar::null(DType::Null), 42);
326 let actual = const_null.to_null();
327 assert_eq!(actual.len(), 42);
328 assert_eq!(actual.scalar_at(33), Scalar::null(DType::Null));
329 }
330
331 #[test]
332 fn test_canonicalize_const_str() {
333 let const_array = ConstantArray::new("four".to_string(), 4);
334
335 let canonical = const_array.to_varbinview();
337
338 assert_eq!(canonical.len(), 4);
339
340 for i in 0..=3 {
341 assert_eq!(canonical.scalar_at(i), "four".into());
342 }
343 }
344
345 #[test]
346 fn test_canonicalize_propagates_stats() {
347 let scalar = Scalar::bool(true, Nullability::NonNullable);
348 let const_array = ConstantArray::new(scalar, 4).into_array();
349 let stats = const_array
350 .statistics()
351 .compute_all(&all::<Stat>().collect_vec())
352 .unwrap();
353 let canonical = const_array.to_canonical();
354 let canonical_stats = canonical.as_ref().statistics();
355
356 let stats_ref = stats.as_typed_ref(canonical.as_ref().dtype());
357
358 for stat in all::<Stat>() {
359 if stat.dtype(canonical.as_ref().dtype()).is_none() {
360 continue;
361 }
362 assert_eq!(
363 canonical_stats.get(stat),
364 stats_ref.get(stat),
365 "stat mismatch {stat}"
366 );
367 }
368 }
369
370 #[test]
371 fn test_canonicalize_scalar_values() {
372 let f16_value = f16::from_f32(5.722046e-6);
373 let f16_scalar = Scalar::primitive(f16_value, Nullability::NonNullable);
374
375 let const_array = ConstantArray::new(f16_scalar.clone(), 1).into_array();
377 let canonical_const = const_array.to_primitive();
378
379 assert_eq!(canonical_const.scalar_at(0), f16_scalar);
381 }
382
383 #[test]
384 fn test_canonicalize_lists() {
385 let list_scalar = Scalar::list(
386 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
387 vec![1u64.into(), 2u64.into()],
388 Nullability::NonNullable,
389 );
390 let const_array = ConstantArray::new(list_scalar, 2).into_array();
391 let canonical_const = const_array.to_listview();
392 let list_array = canonical_const.rebuild(ListViewRebuildMode::MakeZeroCopyToList);
393 assert_eq!(
394 list_array.elements().to_primitive().as_slice::<u64>(),
395 [1u64, 2, 1, 2]
396 );
397 assert_eq!(
398 list_array.offsets().to_primitive().as_slice::<u64>(),
399 [0u64, 2]
400 );
401 assert_eq!(
402 list_array.sizes().to_primitive().as_slice::<u64>(),
403 [2u64, 2]
404 );
405 }
406
407 #[test]
408 fn test_canonicalize_empty_list() {
409 let list_scalar = Scalar::list(
410 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
411 vec![],
412 Nullability::NonNullable,
413 );
414 let const_array = ConstantArray::new(list_scalar, 2).into_array();
415 let canonical_const = const_array.to_listview();
416 assert!(canonical_const.elements().to_primitive().is_empty());
417 assert_eq!(
418 canonical_const.offsets().to_primitive().as_slice::<u64>(),
419 [0u64, 0]
420 );
421 assert_eq!(
422 canonical_const.sizes().to_primitive().as_slice::<u64>(),
423 [0u64, 0]
424 );
425 }
426
427 #[test]
428 fn test_canonicalize_null_list() {
429 let list_scalar = Scalar::null(DType::List(
430 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
431 Nullability::Nullable,
432 ));
433 let const_array = ConstantArray::new(list_scalar, 2).into_array();
434 let canonical_const = const_array.to_listview();
435 assert!(canonical_const.elements().to_primitive().is_empty());
436 assert_eq!(
437 canonical_const.offsets().to_primitive().as_slice::<u64>(),
438 [0u64, 0]
439 );
440 assert_eq!(
441 canonical_const.sizes().to_primitive().as_slice::<u64>(),
442 [0u64, 0]
443 );
444 }
445
446 #[test]
447 fn test_canonicalize_nullable_struct() {
448 let array = ConstantArray::new(
449 Scalar::null(DType::struct_(
450 [(
451 "non_null_field",
452 DType::Primitive(PType::I8, Nullability::NonNullable),
453 )],
454 Nullability::Nullable,
455 )),
456 3,
457 );
458
459 let struct_array = array.to_struct();
460 assert_eq!(struct_array.len(), 3);
461 assert_eq!(struct_array.valid_count(), 0);
462
463 let field = struct_array.field_by_name("non_null_field").unwrap();
464
465 assert_eq!(
466 field.dtype(),
467 &DType::Primitive(PType::I8, Nullability::NonNullable)
468 );
469 }
470
471 #[test]
472 fn test_canonicalize_fixed_size_list_non_null() {
473 let fsl_scalar = Scalar::fixed_size_list(
475 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
476 vec![
477 Scalar::primitive(10i32, Nullability::NonNullable),
478 Scalar::primitive(20i32, Nullability::NonNullable),
479 Scalar::primitive(30i32, Nullability::NonNullable),
480 ],
481 Nullability::NonNullable,
482 );
483
484 let const_array = ConstantArray::new(fsl_scalar, 4).into_array();
485 let canonical = const_array.to_fixed_size_list();
486
487 assert_eq!(canonical.len(), 4);
488 assert_eq!(canonical.list_size(), 3);
489 assert_eq!(canonical.validity(), &Validity::NonNullable);
490
491 for i in 0..4 {
493 let list = canonical.fixed_size_list_elements_at(i);
494 let list_primitive = list.to_primitive();
495 assert_eq!(list_primitive.as_slice::<i32>(), [10, 20, 30]);
496 }
497 }
498
499 #[test]
500 fn test_canonicalize_fixed_size_list_nullable() {
501 let fsl_scalar = Scalar::fixed_size_list(
503 Arc::new(DType::Primitive(PType::F64, Nullability::NonNullable)),
504 vec![
505 Scalar::primitive(1.5f64, Nullability::NonNullable),
506 Scalar::primitive(2.5f64, Nullability::NonNullable),
507 ],
508 Nullability::Nullable,
509 );
510
511 let const_array = ConstantArray::new(fsl_scalar, 3).into_array();
512 let canonical = const_array.to_fixed_size_list();
513
514 assert_eq!(canonical.len(), 3);
515 assert_eq!(canonical.list_size(), 2);
516 assert_eq!(canonical.validity(), &Validity::AllValid);
517
518 let elements = canonical.elements().to_primitive();
520 assert_eq!(elements.as_slice::<f64>(), [1.5, 2.5, 1.5, 2.5, 1.5, 2.5]);
521 }
522
523 #[test]
524 fn test_canonicalize_fixed_size_list_null() {
525 let fsl_scalar = Scalar::null(DType::FixedSizeList(
527 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
528 4,
529 Nullability::Nullable,
530 ));
531
532 let const_array = ConstantArray::new(fsl_scalar, 5).into_array();
533 let canonical = const_array.to_fixed_size_list();
534
535 assert_eq!(canonical.len(), 5);
536 assert_eq!(canonical.list_size(), 4);
537 assert_eq!(canonical.validity(), &Validity::AllInvalid);
538
539 let elements = canonical.elements().to_primitive();
541 assert_eq!(elements.len(), 20); assert!(elements.as_slice::<u64>().iter().all(|&x| x == 0));
543 }
544
545 #[test]
546 fn test_canonicalize_fixed_size_list_empty() {
547 let fsl_scalar = Scalar::fixed_size_list(
549 Arc::new(DType::Primitive(PType::I8, Nullability::NonNullable)),
550 vec![],
551 Nullability::NonNullable,
552 );
553
554 let const_array = ConstantArray::new(fsl_scalar, 10).into_array();
555 let canonical = const_array.to_fixed_size_list();
556
557 assert_eq!(canonical.len(), 10);
558 assert_eq!(canonical.list_size(), 0);
559 assert_eq!(canonical.validity(), &Validity::NonNullable);
560
561 assert!(canonical.elements().is_empty());
563 }
564
565 #[test]
566 fn test_canonicalize_fixed_size_list_nested() {
567 let fsl_scalar = Scalar::fixed_size_list(
569 Arc::new(DType::Utf8(Nullability::NonNullable)),
570 vec![Scalar::from("hello"), Scalar::from("world")],
571 Nullability::NonNullable,
572 );
573
574 let const_array = ConstantArray::new(fsl_scalar, 2).into_array();
575 let canonical = const_array.to_fixed_size_list();
576
577 assert_eq!(canonical.len(), 2);
578 assert_eq!(canonical.list_size(), 2);
579
580 let elements = canonical.elements().to_varbinview();
582 assert_eq!(elements.scalar_at(0), "hello".into());
583 assert_eq!(elements.scalar_at(1), "world".into());
584 assert_eq!(elements.scalar_at(2), "hello".into());
585 assert_eq!(elements.scalar_at(3), "world".into());
586 }
587
588 #[test]
589 fn test_canonicalize_fixed_size_list_single_element() {
590 let fsl_scalar = Scalar::fixed_size_list(
592 Arc::new(DType::Primitive(PType::I16, Nullability::NonNullable)),
593 vec![Scalar::primitive(42i16, Nullability::NonNullable)],
594 Nullability::NonNullable,
595 );
596
597 let const_array = ConstantArray::new(fsl_scalar, 1).into_array();
598 let canonical = const_array.to_fixed_size_list();
599
600 assert_eq!(canonical.len(), 1);
601 assert_eq!(canonical.list_size(), 1);
602
603 let elements = canonical.elements().to_primitive();
604 assert_eq!(elements.as_slice::<i16>(), [42]);
605 }
606
607 #[test]
608 fn test_canonicalize_fixed_size_list_with_null_elements() {
609 let fsl_scalar = Scalar::fixed_size_list(
611 Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)),
612 vec![
613 Scalar::primitive(100i32, Nullability::Nullable),
614 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
615 Scalar::primitive(200i32, Nullability::Nullable),
616 ],
617 Nullability::NonNullable,
618 );
619
620 let const_array = ConstantArray::new(fsl_scalar, 3).into_array();
621 let canonical = const_array.to_fixed_size_list();
622
623 assert_eq!(canonical.len(), 3);
624 assert_eq!(canonical.list_size(), 3);
625 assert_eq!(canonical.validity(), &Validity::NonNullable);
626
627 let elements = canonical.elements().to_primitive();
629 assert_eq!(elements.as_slice::<i32>()[0], 100);
630 assert_eq!(elements.as_slice::<i32>()[1], 0); assert_eq!(elements.as_slice::<i32>()[2], 200);
632
633 let element_validity = elements.validity();
635 assert!(element_validity.is_valid(0));
636 assert!(!element_validity.is_valid(1));
637 assert!(element_validity.is_valid(2));
638
639 assert!(element_validity.is_valid(3));
641 assert!(!element_validity.is_valid(4));
642 assert!(element_validity.is_valid(5));
643 }
644
645 #[test]
646 fn test_canonicalize_fixed_size_list_large() {
647 let fsl_scalar = Scalar::fixed_size_list(
649 Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)),
650 vec![
651 Scalar::primitive(1u8, Nullability::NonNullable),
652 Scalar::primitive(2u8, Nullability::NonNullable),
653 Scalar::primitive(3u8, Nullability::NonNullable),
654 Scalar::primitive(4u8, Nullability::NonNullable),
655 Scalar::primitive(5u8, Nullability::NonNullable),
656 ],
657 Nullability::NonNullable,
658 );
659
660 let const_array = ConstantArray::new(fsl_scalar, 1000).into_array();
661 let canonical = const_array.to_fixed_size_list();
662
663 assert_eq!(canonical.len(), 1000);
664 assert_eq!(canonical.list_size(), 5);
665
666 let elements = canonical.elements().to_primitive();
667 assert_eq!(elements.len(), 5000);
668
669 for i in 0..1000 {
671 let base = i * 5;
672 assert_eq!(elements.as_slice::<u8>()[base], 1);
673 assert_eq!(elements.as_slice::<u8>()[base + 1], 2);
674 assert_eq!(elements.as_slice::<u8>()[base + 2], 3);
675 assert_eq!(elements.as_slice::<u8>()[base + 3], 4);
676 assert_eq!(elements.as_slice::<u8>()[base + 4], 5);
677 }
678 }
679}