1use std::sync::Arc;
5
6use vortex_buffer::{BitBuffer, Buffer, buffer};
7use vortex_dtype::{
8 DType, DecimalType, Nullability, match_each_decimal_value, match_each_decimal_value_type,
9 match_each_native_ptype,
10};
11use vortex_error::VortexExpect;
12use vortex_scalar::{
13 BinaryScalar, BoolScalar, DecimalValue, ExtScalar, ListScalar, Scalar, StructScalar, Utf8Scalar,
14};
15use vortex_vector::binaryview::BinaryView;
16
17use crate::arrays::constant::ConstantArray;
18use crate::arrays::primitive::PrimitiveArray;
19use crate::arrays::{
20 BoolArray, ConstantVTable, DecimalArray, ExtensionArray, FixedSizeListArray, ListViewArray,
21 NullArray, StructArray, VarBinViewArray,
22};
23use crate::builders::builder_with_capacity;
24use crate::validity::Validity;
25use crate::vtable::CanonicalVTable;
26use crate::{Canonical, IntoArray};
27
28impl CanonicalVTable<ConstantVTable> for ConstantVTable {
29 fn canonicalize(array: &ConstantArray) -> Canonical {
30 let scalar = array.scalar();
31
32 let validity = match array.dtype().nullability() {
33 Nullability::NonNullable => Validity::NonNullable,
34 Nullability::Nullable => match scalar.is_null() {
35 true => Validity::AllInvalid,
36 false => Validity::AllValid,
37 },
38 };
39
40 match array.dtype() {
41 DType::Null => Canonical::Null(NullArray::new(array.len())),
42 DType::Bool(..) => Canonical::Bool(BoolArray::from_bit_buffer(
43 if BoolScalar::try_from(scalar)
44 .vortex_expect("must be bool")
45 .value()
46 .unwrap_or_default()
47 {
48 BitBuffer::new_set(array.len())
49 } else {
50 BitBuffer::new_unset(array.len())
51 },
52 validity,
53 )),
54 DType::Primitive(ptype, ..) => {
55 match_each_native_ptype!(ptype, |P| {
56 Canonical::Primitive(PrimitiveArray::new(
57 if scalar.is_valid() {
58 Buffer::full(
59 P::try_from(scalar)
60 .vortex_expect("Couldn't unwrap scalar to primitive"),
61 array.len(),
62 )
63 } else {
64 Buffer::zeroed(array.len())
65 },
66 validity,
67 ))
68 })
69 }
70 DType::Decimal(decimal_type, ..) => {
71 let size = DecimalType::smallest_decimal_value_type(decimal_type);
72 let decimal = scalar.as_decimal();
73 let Some(value) = decimal.decimal_value() else {
74 let all_null = match_each_decimal_value_type!(size, |D| {
75 unsafe {
77 DecimalArray::new_unchecked(
78 Buffer::<D>::zeroed(array.len()),
79 *decimal_type,
80 validity,
81 )
82 }
83 });
84 return Canonical::Decimal(all_null);
85 };
86
87 let decimal_array = match_each_decimal_value!(value, |value| {
88 unsafe {
90 DecimalArray::new_unchecked(
91 Buffer::full(value, array.len()),
92 *decimal_type,
93 validity,
94 )
95 }
96 });
97 Canonical::Decimal(decimal_array)
98 }
99 DType::Utf8(_) => {
100 let value = Utf8Scalar::try_from(scalar)
101 .vortex_expect("Must be a utf8 scalar")
102 .value();
103 let const_value = value.as_ref().map(|v| v.as_bytes());
104 Canonical::VarBinView(constant_canonical_byte_view(
105 const_value,
106 array.dtype(),
107 array.len(),
108 ))
109 }
110 DType::Binary(_) => {
111 let value = BinaryScalar::try_from(scalar)
112 .vortex_expect("must be a binary scalar")
113 .value();
114 let const_value = value.as_ref().map(|v| v.as_slice());
115 Canonical::VarBinView(constant_canonical_byte_view(
116 const_value,
117 array.dtype(),
118 array.len(),
119 ))
120 }
121 DType::Struct(struct_dtype, _) => {
122 let value = StructScalar::try_from(scalar).vortex_expect("must be struct");
123 let fields: Vec<_> = match value.fields() {
124 Some(fields) => fields
125 .into_iter()
126 .map(|s| ConstantArray::new(s, array.len()).into_array())
127 .collect(),
128 None => {
129 assert!(validity.all_invalid(array.len()));
130 struct_dtype
131 .fields()
132 .map(|dt| {
133 let scalar = Scalar::default_value(dt);
134 ConstantArray::new(scalar, array.len()).into_array()
135 })
136 .collect()
137 }
138 };
139 Canonical::Struct(unsafe {
142 StructArray::new_unchecked(fields, struct_dtype.clone(), array.len(), validity)
143 })
144 }
145 DType::List(..) => Canonical::List(constant_canonical_list_array(scalar, array.len())),
146 DType::FixedSizeList(element_dtype, list_size, _) => {
147 let value = ListScalar::try_from(scalar).vortex_expect("must be list");
148
149 Canonical::FixedSizeList(constant_canonical_fixed_size_list_array(
150 value.elements(),
151 element_dtype,
152 *list_size,
153 value.dtype().nullability(),
154 array.len(),
155 ))
156 }
157 DType::Extension(ext_dtype) => {
158 let s = ExtScalar::try_from(scalar).vortex_expect("must be an extension scalar");
159
160 let storage_scalar = s.storage();
161 let storage_self = ConstantArray::new(storage_scalar, array.len()).into_array();
162 Canonical::Extension(ExtensionArray::new(ext_dtype.clone(), storage_self))
163 }
164 }
165 }
166}
167
168fn constant_canonical_byte_view(
169 scalar_bytes: Option<&[u8]>,
170 dtype: &DType,
171 len: usize,
172) -> VarBinViewArray {
173 match scalar_bytes {
174 None => {
175 let views = buffer![BinaryView::empty_view(); len];
176
177 unsafe {
179 VarBinViewArray::new_unchecked(
180 views,
181 Default::default(),
182 dtype.clone(),
183 Validity::AllInvalid,
184 )
185 }
186 }
187 Some(scalar_bytes) => {
188 let view = BinaryView::make_view(scalar_bytes, 0, 0);
191 let mut buffers = Vec::new();
192 if scalar_bytes.len() >= BinaryView::MAX_INLINED_SIZE {
193 buffers.push(Buffer::copy_from(scalar_bytes));
194 }
195
196 let views = buffer![view; len];
198
199 unsafe {
201 VarBinViewArray::new_unchecked(
202 views,
203 Arc::from(buffers),
204 dtype.clone(),
205 Validity::from(dtype.nullability()),
206 )
207 }
208 }
209 }
210}
211
212fn constant_canonical_list_array(scalar: &Scalar, len: usize) -> ListViewArray {
217 let list = ListScalar::try_from(scalar).vortex_expect("must be list");
218
219 let elements = if let Some(elements) = list.elements() {
222 let mut builder = builder_with_capacity(
224 list.dtype()
225 .as_list_element_opt()
226 .vortex_expect("list scalar somehow did not have a list DType"),
227 list.len(),
228 );
229 for scalar in &elements {
230 builder
231 .append_scalar(scalar)
232 .vortex_expect("list element scalar was invalid");
233 }
234 builder.finish()
235 } else {
236 Canonical::empty(list.element_dtype()).into_array()
238 };
239
240 let validity = if scalar.dtype().is_nullable() {
241 if list.is_null() {
242 Validity::AllInvalid
243 } else {
244 Validity::AllValid
245 }
246 } else {
247 debug_assert!(!list.is_null());
248 Validity::NonNullable
249 };
250
251 let offsets = ConstantArray::new::<u64>(0, len).into_array();
253 let sizes = ConstantArray::new::<u64>(list.len() as u64, len).into_array();
254
255 debug_assert!(!offsets.dtype().is_nullable());
256 debug_assert!(!sizes.dtype().is_nullable());
257
258 unsafe { ListViewArray::new_unchecked(elements, offsets, sizes, validity) }
262}
263
264fn constant_canonical_fixed_size_list_array(
265 values: Option<Vec<Scalar>>,
266 element_dtype: &DType,
267 list_size: u32,
268 list_nullability: Nullability,
269 len: usize,
270) -> FixedSizeListArray {
271 match values {
272 None => {
273 let elements_len = list_size as usize * len;
276 let mut element_builder = builder_with_capacity(element_dtype, elements_len);
277 element_builder.append_defaults(elements_len);
278 let elements = element_builder.finish();
279
280 unsafe {
283 FixedSizeListArray::new_unchecked(elements, list_size, Validity::AllInvalid, len)
284 }
285 }
286 Some(values) => {
287 let mut elements_builder = builder_with_capacity(element_dtype, len * values.len());
288
289 for _ in 0..len {
290 for v in &values {
291 elements_builder
292 .append_scalar(v)
293 .vortex_expect("must be a same dtype");
294 }
295 }
296
297 let elements = elements_builder.finish();
298 let validity = Validity::from(list_nullability);
299
300 unsafe { FixedSizeListArray::new_unchecked(elements, list_size, validity, len) }
303 }
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use std::sync::Arc;
310
311 use enum_iterator::all;
312 use itertools::Itertools;
313 use vortex_dtype::half::f16;
314 use vortex_dtype::{DType, Nullability, PType};
315 use vortex_scalar::Scalar;
316
317 use crate::arrays::{ConstantArray, ListViewRebuildMode};
318 use crate::canonical::ToCanonical;
319 use crate::stats::{Stat, StatsProvider};
320 use crate::validity::Validity;
321 use crate::vtable::ValidityHelper;
322 use crate::{Array, IntoArray};
323
324 #[test]
325 fn test_canonicalize_null() {
326 let const_null = ConstantArray::new(Scalar::null(DType::Null), 42);
327 let actual = const_null.to_null();
328 assert_eq!(actual.len(), 42);
329 assert_eq!(actual.scalar_at(33), Scalar::null(DType::Null));
330 }
331
332 #[test]
333 fn test_canonicalize_const_str() {
334 let const_array = ConstantArray::new("four".to_string(), 4);
335
336 let canonical = const_array.to_varbinview();
338
339 assert_eq!(canonical.len(), 4);
340
341 for i in 0..=3 {
342 assert_eq!(canonical.scalar_at(i), "four".into());
343 }
344 }
345
346 #[test]
347 fn test_canonicalize_propagates_stats() {
348 let scalar = Scalar::bool(true, Nullability::NonNullable);
349 let const_array = ConstantArray::new(scalar, 4).into_array();
350 let stats = const_array
351 .statistics()
352 .compute_all(&all::<Stat>().collect_vec())
353 .unwrap();
354 let canonical = const_array.to_canonical();
355 let canonical_stats = canonical.as_ref().statistics();
356
357 let stats_ref = stats.as_typed_ref(canonical.as_ref().dtype());
358
359 for stat in all::<Stat>() {
360 if stat.dtype(canonical.as_ref().dtype()).is_none() {
361 continue;
362 }
363 assert_eq!(
364 canonical_stats.get(stat),
365 stats_ref.get(stat),
366 "stat mismatch {stat}"
367 );
368 }
369 }
370
371 #[test]
372 fn test_canonicalize_scalar_values() {
373 let f16_value = f16::from_f32(5.722046e-6);
374 let f16_scalar = Scalar::primitive(f16_value, Nullability::NonNullable);
375
376 let const_array = ConstantArray::new(f16_scalar.clone(), 1).into_array();
378 let canonical_const = const_array.to_primitive();
379
380 assert_eq!(canonical_const.scalar_at(0), f16_scalar);
382 }
383
384 #[test]
385 fn test_canonicalize_lists() {
386 let list_scalar = Scalar::list(
387 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
388 vec![1u64.into(), 2u64.into()],
389 Nullability::NonNullable,
390 );
391 let const_array = ConstantArray::new(list_scalar, 2).into_array();
392 let canonical_const = const_array.to_listview();
393 let list_array = canonical_const.rebuild(ListViewRebuildMode::MakeZeroCopyToList);
394 assert_eq!(
395 list_array.elements().to_primitive().as_slice::<u64>(),
396 [1u64, 2, 1, 2]
397 );
398 assert_eq!(
399 list_array.offsets().to_primitive().as_slice::<u64>(),
400 [0u64, 2]
401 );
402 assert_eq!(
403 list_array.sizes().to_primitive().as_slice::<u64>(),
404 [2u64, 2]
405 );
406 }
407
408 #[test]
409 fn test_canonicalize_empty_list() {
410 let list_scalar = Scalar::list(
411 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
412 vec![],
413 Nullability::NonNullable,
414 );
415 let const_array = ConstantArray::new(list_scalar, 2).into_array();
416 let canonical_const = const_array.to_listview();
417 assert!(canonical_const.elements().to_primitive().is_empty());
418 assert_eq!(
419 canonical_const.offsets().to_primitive().as_slice::<u64>(),
420 [0u64, 0]
421 );
422 assert_eq!(
423 canonical_const.sizes().to_primitive().as_slice::<u64>(),
424 [0u64, 0]
425 );
426 }
427
428 #[test]
429 fn test_canonicalize_null_list() {
430 let list_scalar = Scalar::null(DType::List(
431 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
432 Nullability::Nullable,
433 ));
434 let const_array = ConstantArray::new(list_scalar, 2).into_array();
435 let canonical_const = const_array.to_listview();
436 assert!(canonical_const.elements().to_primitive().is_empty());
437 assert_eq!(
438 canonical_const.offsets().to_primitive().as_slice::<u64>(),
439 [0u64, 0]
440 );
441 assert_eq!(
442 canonical_const.sizes().to_primitive().as_slice::<u64>(),
443 [0u64, 0]
444 );
445 }
446
447 #[test]
448 fn test_canonicalize_nullable_struct() {
449 let array = ConstantArray::new(
450 Scalar::null(DType::struct_(
451 [(
452 "non_null_field",
453 DType::Primitive(PType::I8, Nullability::NonNullable),
454 )],
455 Nullability::Nullable,
456 )),
457 3,
458 );
459
460 let struct_array = array.to_struct();
461 assert_eq!(struct_array.len(), 3);
462 assert_eq!(struct_array.valid_count(), 0);
463
464 let field = struct_array.field_by_name("non_null_field").unwrap();
465
466 assert_eq!(
467 field.dtype(),
468 &DType::Primitive(PType::I8, Nullability::NonNullable)
469 );
470 }
471
472 #[test]
473 fn test_canonicalize_fixed_size_list_non_null() {
474 let fsl_scalar = Scalar::fixed_size_list(
476 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
477 vec![
478 Scalar::primitive(10i32, Nullability::NonNullable),
479 Scalar::primitive(20i32, Nullability::NonNullable),
480 Scalar::primitive(30i32, Nullability::NonNullable),
481 ],
482 Nullability::NonNullable,
483 );
484
485 let const_array = ConstantArray::new(fsl_scalar, 4).into_array();
486 let canonical = const_array.to_fixed_size_list();
487
488 assert_eq!(canonical.len(), 4);
489 assert_eq!(canonical.list_size(), 3);
490 assert_eq!(canonical.validity(), &Validity::NonNullable);
491
492 for i in 0..4 {
494 let list = canonical.fixed_size_list_elements_at(i);
495 let list_primitive = list.to_primitive();
496 assert_eq!(list_primitive.as_slice::<i32>(), [10, 20, 30]);
497 }
498 }
499
500 #[test]
501 fn test_canonicalize_fixed_size_list_nullable() {
502 let fsl_scalar = Scalar::fixed_size_list(
504 Arc::new(DType::Primitive(PType::F64, Nullability::NonNullable)),
505 vec![
506 Scalar::primitive(1.5f64, Nullability::NonNullable),
507 Scalar::primitive(2.5f64, Nullability::NonNullable),
508 ],
509 Nullability::Nullable,
510 );
511
512 let const_array = ConstantArray::new(fsl_scalar, 3).into_array();
513 let canonical = const_array.to_fixed_size_list();
514
515 assert_eq!(canonical.len(), 3);
516 assert_eq!(canonical.list_size(), 2);
517 assert_eq!(canonical.validity(), &Validity::AllValid);
518
519 let elements = canonical.elements().to_primitive();
521 assert_eq!(elements.as_slice::<f64>(), [1.5, 2.5, 1.5, 2.5, 1.5, 2.5]);
522 }
523
524 #[test]
525 fn test_canonicalize_fixed_size_list_null() {
526 let fsl_scalar = Scalar::null(DType::FixedSizeList(
528 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
529 4,
530 Nullability::Nullable,
531 ));
532
533 let const_array = ConstantArray::new(fsl_scalar, 5).into_array();
534 let canonical = const_array.to_fixed_size_list();
535
536 assert_eq!(canonical.len(), 5);
537 assert_eq!(canonical.list_size(), 4);
538 assert_eq!(canonical.validity(), &Validity::AllInvalid);
539
540 let elements = canonical.elements().to_primitive();
542 assert_eq!(elements.len(), 20); assert!(elements.as_slice::<u64>().iter().all(|&x| x == 0));
544 }
545
546 #[test]
547 fn test_canonicalize_fixed_size_list_empty() {
548 let fsl_scalar = Scalar::fixed_size_list(
550 Arc::new(DType::Primitive(PType::I8, Nullability::NonNullable)),
551 vec![],
552 Nullability::NonNullable,
553 );
554
555 let const_array = ConstantArray::new(fsl_scalar, 10).into_array();
556 let canonical = const_array.to_fixed_size_list();
557
558 assert_eq!(canonical.len(), 10);
559 assert_eq!(canonical.list_size(), 0);
560 assert_eq!(canonical.validity(), &Validity::NonNullable);
561
562 assert!(canonical.elements().is_empty());
564 }
565
566 #[test]
567 fn test_canonicalize_fixed_size_list_nested() {
568 let fsl_scalar = Scalar::fixed_size_list(
570 Arc::new(DType::Utf8(Nullability::NonNullable)),
571 vec![Scalar::from("hello"), Scalar::from("world")],
572 Nullability::NonNullable,
573 );
574
575 let const_array = ConstantArray::new(fsl_scalar, 2).into_array();
576 let canonical = const_array.to_fixed_size_list();
577
578 assert_eq!(canonical.len(), 2);
579 assert_eq!(canonical.list_size(), 2);
580
581 let elements = canonical.elements().to_varbinview();
583 assert_eq!(elements.scalar_at(0), "hello".into());
584 assert_eq!(elements.scalar_at(1), "world".into());
585 assert_eq!(elements.scalar_at(2), "hello".into());
586 assert_eq!(elements.scalar_at(3), "world".into());
587 }
588
589 #[test]
590 fn test_canonicalize_fixed_size_list_single_element() {
591 let fsl_scalar = Scalar::fixed_size_list(
593 Arc::new(DType::Primitive(PType::I16, Nullability::NonNullable)),
594 vec![Scalar::primitive(42i16, Nullability::NonNullable)],
595 Nullability::NonNullable,
596 );
597
598 let const_array = ConstantArray::new(fsl_scalar, 1).into_array();
599 let canonical = const_array.to_fixed_size_list();
600
601 assert_eq!(canonical.len(), 1);
602 assert_eq!(canonical.list_size(), 1);
603
604 let elements = canonical.elements().to_primitive();
605 assert_eq!(elements.as_slice::<i16>(), [42]);
606 }
607
608 #[test]
609 fn test_canonicalize_fixed_size_list_with_null_elements() {
610 let fsl_scalar = Scalar::fixed_size_list(
612 Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)),
613 vec![
614 Scalar::primitive(100i32, Nullability::Nullable),
615 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
616 Scalar::primitive(200i32, Nullability::Nullable),
617 ],
618 Nullability::NonNullable,
619 );
620
621 let const_array = ConstantArray::new(fsl_scalar, 3).into_array();
622 let canonical = const_array.to_fixed_size_list();
623
624 assert_eq!(canonical.len(), 3);
625 assert_eq!(canonical.list_size(), 3);
626 assert_eq!(canonical.validity(), &Validity::NonNullable);
627
628 let elements = canonical.elements().to_primitive();
630 assert_eq!(elements.as_slice::<i32>()[0], 100);
631 assert_eq!(elements.as_slice::<i32>()[1], 0); assert_eq!(elements.as_slice::<i32>()[2], 200);
633
634 let element_validity = elements.validity();
636 assert!(element_validity.is_valid(0));
637 assert!(!element_validity.is_valid(1));
638 assert!(element_validity.is_valid(2));
639
640 assert!(element_validity.is_valid(3));
642 assert!(!element_validity.is_valid(4));
643 assert!(element_validity.is_valid(5));
644 }
645
646 #[test]
647 fn test_canonicalize_fixed_size_list_large() {
648 let fsl_scalar = Scalar::fixed_size_list(
650 Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)),
651 vec![
652 Scalar::primitive(1u8, Nullability::NonNullable),
653 Scalar::primitive(2u8, Nullability::NonNullable),
654 Scalar::primitive(3u8, Nullability::NonNullable),
655 Scalar::primitive(4u8, Nullability::NonNullable),
656 Scalar::primitive(5u8, Nullability::NonNullable),
657 ],
658 Nullability::NonNullable,
659 );
660
661 let const_array = ConstantArray::new(fsl_scalar, 1000).into_array();
662 let canonical = const_array.to_fixed_size_list();
663
664 assert_eq!(canonical.len(), 1000);
665 assert_eq!(canonical.list_size(), 5);
666
667 let elements = canonical.elements().to_primitive();
668 assert_eq!(elements.len(), 5000);
669
670 for i in 0..1000 {
672 let base = i * 5;
673 assert_eq!(elements.as_slice::<u8>()[base], 1);
674 assert_eq!(elements.as_slice::<u8>()[base + 1], 2);
675 assert_eq!(elements.as_slice::<u8>()[base + 2], 3);
676 assert_eq!(elements.as_slice::<u8>()[base + 3], 4);
677 assert_eq!(elements.as_slice::<u8>()[base + 4], 5);
678 }
679 }
680}