1use std::sync::Arc;
5
6use arrow_buffer::BooleanBuffer;
7use vortex_buffer::{Buffer, buffer};
8use vortex_dtype::{DType, Nullability, PType, match_each_native_ptype};
9use vortex_error::VortexExpect;
10use vortex_scalar::{
11 BinaryScalar, BoolScalar, DecimalValue, ExtScalar, ListScalar, Scalar, ScalarValue,
12 StructScalar, Utf8Scalar, match_each_decimal_value, match_each_decimal_value_type,
13};
14
15use crate::arrays::constant::ConstantArray;
16use crate::arrays::primitive::PrimitiveArray;
17use crate::arrays::{
18 BinaryView, BoolArray, ConstantVTable, DecimalArray, ExtensionArray, FixedSizeListArray,
19 ListArray, NullArray, StructArray, VarBinViewArray, smallest_storage_type,
20};
21use crate::builders::builder_with_capacity;
22use crate::validity::Validity;
23use crate::vtable::CanonicalVTable;
24use crate::{Canonical, IntoArray};
25
26impl CanonicalVTable<ConstantVTable> for ConstantVTable {
27 fn canonicalize(array: &ConstantArray) -> Canonical {
28 let scalar = array.scalar();
29
30 let validity = match array.dtype().nullability() {
31 Nullability::NonNullable => Validity::NonNullable,
32 Nullability::Nullable => match scalar.is_null() {
33 true => Validity::AllInvalid,
34 false => Validity::AllValid,
35 },
36 };
37
38 match array.dtype() {
39 DType::Null => Canonical::Null(NullArray::new(array.len())),
40 DType::Bool(..) => Canonical::Bool(BoolArray::from_bool_buffer(
41 if BoolScalar::try_from(scalar)
42 .vortex_expect("must be bool")
43 .value()
44 .unwrap_or_default()
45 {
46 BooleanBuffer::new_set(array.len())
47 } else {
48 BooleanBuffer::new_unset(array.len())
49 },
50 validity,
51 )),
52 DType::Primitive(ptype, ..) => {
53 match_each_native_ptype!(ptype, |P| {
54 Canonical::Primitive(PrimitiveArray::new(
55 if scalar.is_valid() {
56 Buffer::full(
57 P::try_from(scalar)
58 .vortex_expect("Couldn't unwrap scalar to primitive"),
59 array.len(),
60 )
61 } else {
62 Buffer::zeroed(array.len())
63 },
64 validity,
65 ))
66 })
67 }
68 DType::Decimal(decimal_type, ..) => {
69 let size = smallest_storage_type(decimal_type);
70 let decimal = scalar.as_decimal();
71 let Some(value) = decimal.decimal_value() else {
72 let all_null = match_each_decimal_value_type!(size, |D| {
73 unsafe {
75 DecimalArray::new_unchecked(
76 Buffer::<D>::zeroed(array.len()),
77 *decimal_type,
78 validity,
79 )
80 }
81 });
82 return Canonical::Decimal(all_null);
83 };
84
85 let decimal_array = match_each_decimal_value!(value, |value| {
86 unsafe {
88 DecimalArray::new_unchecked(
89 Buffer::full(value, array.len()),
90 *decimal_type,
91 validity,
92 )
93 }
94 });
95 Canonical::Decimal(decimal_array)
96 }
97 DType::Utf8(_) => {
98 let value = Utf8Scalar::try_from(scalar)
99 .vortex_expect("Must be a utf8 scalar")
100 .value();
101 let const_value = value.as_ref().map(|v| v.as_bytes());
102 Canonical::VarBinView(canonical_byte_view(const_value, array.dtype(), array.len()))
103 }
104 DType::Binary(_) => {
105 let value = BinaryScalar::try_from(scalar)
106 .vortex_expect("must be a binary scalar")
107 .value();
108 let const_value = value.as_ref().map(|v| v.as_slice());
109 Canonical::VarBinView(canonical_byte_view(const_value, array.dtype(), array.len()))
110 }
111 DType::Struct(struct_dtype, _) => {
112 let value = StructScalar::try_from(scalar).vortex_expect("must be struct");
113 let fields: Vec<_> = match value.fields() {
114 Some(fields) => fields
115 .into_iter()
116 .map(|s| ConstantArray::new(s, array.len()).into_array())
117 .collect(),
118 None => {
119 assert!(validity.all_invalid(array.len()));
120 struct_dtype
121 .fields()
122 .map(|dt| {
123 let scalar = Scalar::default_value(dt);
124 ConstantArray::new(scalar, array.len()).into_array()
125 })
126 .collect()
127 }
128 };
129 Canonical::Struct(unsafe {
132 StructArray::new_unchecked(fields, struct_dtype.clone(), array.len(), validity)
133 })
134 }
135 DType::List(..) => {
136 let value = ListScalar::try_from(scalar).vortex_expect("must be list");
137 Canonical::List(canonical_list_array(
138 value.elements(),
139 value.element_dtype(),
140 value.dtype().nullability(),
141 array.len(),
142 ))
143 }
144 DType::FixedSizeList(element_dtype, list_size, _) => {
145 let value = ListScalar::try_from(scalar).vortex_expect("must be list");
146
147 Canonical::FixedSizeList(canonical_fixed_size_list_array(
148 value.elements(),
149 element_dtype,
150 *list_size,
151 value.dtype().nullability(),
152 array.len(),
153 ))
154 }
155 DType::Extension(ext_dtype) => {
156 let s = ExtScalar::try_from(scalar).vortex_expect("must be an extension scalar");
157
158 let storage_scalar = s.storage();
159 let storage_self = ConstantArray::new(storage_scalar, array.len()).into_array();
160 Canonical::Extension(ExtensionArray::new(ext_dtype.clone(), storage_self))
161 }
162 }
163 }
164}
165
166fn canonical_byte_view(scalar_bytes: Option<&[u8]>, dtype: &DType, len: usize) -> VarBinViewArray {
167 match scalar_bytes {
168 None => {
169 let views = buffer![BinaryView::from(0_u128); len];
170
171 unsafe {
173 VarBinViewArray::new_unchecked(
174 views,
175 Default::default(),
176 dtype.clone(),
177 Validity::AllInvalid,
178 )
179 }
180 }
181 Some(scalar_bytes) => {
182 let view = BinaryView::make_view(scalar_bytes, 0, 0);
185 let mut buffers = Vec::new();
186 if scalar_bytes.len() >= BinaryView::MAX_INLINED_SIZE {
187 buffers.push(Buffer::copy_from(scalar_bytes));
188 }
189
190 let views = buffer![view; len];
192
193 unsafe {
195 VarBinViewArray::new_unchecked(
196 views,
197 Arc::from(buffers),
198 dtype.clone(),
199 Validity::from(dtype.nullability()),
200 )
201 }
202 }
203 }
204}
205
206fn canonical_list_array(
207 values: Option<Vec<Scalar>>,
208 element_dtype: &DType,
209 list_nullability: Nullability,
210 len: usize,
211) -> ListArray {
212 match values {
213 None => unsafe {
214 ListArray::new_unchecked(
215 Canonical::empty(element_dtype).into_array(),
216 ConstantArray::new(
217 Scalar::new(
218 DType::Primitive(PType::U64, Nullability::NonNullable),
219 ScalarValue::from(0),
220 ),
221 len + 1,
222 )
223 .into_array(),
224 Validity::AllInvalid,
225 )
226 },
227 Some(values) => {
228 let mut elements_builder = builder_with_capacity(element_dtype, len * values.len());
229 for _ in 0..len {
230 for v in &values {
231 elements_builder
232 .append_scalar(v)
233 .vortex_expect("must be a same dtype");
234 }
235 }
236 let offsets = if values.is_empty() {
237 Buffer::zeroed(len + 1)
238 } else {
239 Buffer::from_trusted_len_iter(
240 (0..=len * values.len())
241 .step_by(values.len())
242 .map(|i| i as u64),
243 )
244 };
245
246 unsafe {
247 ListArray::new_unchecked(
248 elements_builder.finish(),
249 offsets.into_array(),
250 Validity::from(list_nullability),
251 )
252 }
253 }
254 }
255}
256
257fn canonical_fixed_size_list_array(
258 values: Option<Vec<Scalar>>,
259 element_dtype: &DType,
260 list_size: u32,
261 list_nullability: Nullability,
262 len: usize,
263) -> FixedSizeListArray {
264 match values {
265 None => {
266 let elements_len = list_size as usize * len;
269 let mut element_builder = builder_with_capacity(element_dtype, elements_len);
270 element_builder.append_defaults(elements_len);
271 let elements = element_builder.finish();
272
273 unsafe {
276 FixedSizeListArray::new_unchecked(elements, list_size, Validity::AllInvalid, len)
277 }
278 }
279 Some(values) => {
280 let mut elements_builder = builder_with_capacity(element_dtype, len * values.len());
281
282 for _ in 0..len {
283 for v in &values {
284 elements_builder
285 .append_scalar(v)
286 .vortex_expect("must be a same dtype");
287 }
288 }
289
290 let elements = elements_builder.finish();
291 let validity = Validity::from(list_nullability);
292
293 unsafe { FixedSizeListArray::new_unchecked(elements, list_size, validity, len) }
296 }
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use std::sync::Arc;
303
304 use enum_iterator::all;
305 use itertools::Itertools;
306 use vortex_dtype::half::f16;
307 use vortex_dtype::{DType, Nullability, PType};
308 use vortex_scalar::Scalar;
309
310 use crate::arrays::ConstantArray;
311 use crate::canonical::ToCanonical;
312 use crate::stats::{Stat, StatsProvider};
313 use crate::validity::Validity;
314 use crate::vtable::ValidityHelper;
315 use crate::{Array, IntoArray};
316
317 #[test]
318 fn test_canonicalize_null() {
319 let const_null = ConstantArray::new(Scalar::null(DType::Null), 42);
320 let actual = const_null.to_null();
321 assert_eq!(actual.len(), 42);
322 assert_eq!(actual.scalar_at(33), Scalar::null(DType::Null));
323 }
324
325 #[test]
326 fn test_canonicalize_const_str() {
327 let const_array = ConstantArray::new("four".to_string(), 4);
328
329 let canonical = const_array.to_varbinview();
331
332 assert_eq!(canonical.len(), 4);
333
334 for i in 0..=3 {
335 assert_eq!(canonical.scalar_at(i), "four".into());
336 }
337 }
338
339 #[test]
340 fn test_canonicalize_propagates_stats() {
341 let scalar = Scalar::bool(true, Nullability::NonNullable);
342 let const_array = ConstantArray::new(scalar, 4).into_array();
343 let stats = const_array
344 .statistics()
345 .compute_all(&all::<Stat>().collect_vec())
346 .unwrap();
347 let canonical = const_array.to_canonical();
348 let canonical_stats = canonical.as_ref().statistics();
349
350 let stats_ref = stats.as_typed_ref(canonical.as_ref().dtype());
351
352 for stat in all::<Stat>() {
353 if stat.dtype(canonical.as_ref().dtype()).is_none() {
354 continue;
355 }
356 assert_eq!(
357 canonical_stats.get(stat),
358 stats_ref.get(stat),
359 "stat mismatch {stat}"
360 );
361 }
362 }
363
364 #[test]
365 fn test_canonicalize_scalar_values() {
366 let f16_value = f16::from_f32(5.722046e-6);
367 let f16_scalar = Scalar::primitive(f16_value, Nullability::NonNullable);
368
369 let const_array = ConstantArray::new(f16_scalar.clone(), 1).into_array();
371 let canonical_const = const_array.to_primitive();
372
373 assert_eq!(canonical_const.scalar_at(0), f16_scalar);
375 }
376
377 #[test]
378 fn test_canonicalize_lists() {
379 let list_scalar = Scalar::list(
380 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
381 vec![1u64.into(), 2u64.into()],
382 Nullability::NonNullable,
383 );
384 let const_array = ConstantArray::new(list_scalar, 2).into_array();
385 let canonical_const = const_array.to_list();
386 assert_eq!(
387 canonical_const.elements().to_primitive().as_slice::<u64>(),
388 [1u64, 2, 1, 2]
389 );
390 assert_eq!(
391 canonical_const.offsets().to_primitive().as_slice::<u64>(),
392 [0u64, 2, 4]
393 );
394 }
395
396 #[test]
397 fn test_canonicalize_empty_list() {
398 let list_scalar = Scalar::list(
399 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
400 vec![],
401 Nullability::NonNullable,
402 );
403 let const_array = ConstantArray::new(list_scalar, 2).into_array();
404 let canonical_const = const_array.to_list();
405 assert!(canonical_const.elements().to_primitive().is_empty());
406 assert_eq!(
407 canonical_const.offsets().to_primitive().as_slice::<u64>(),
408 [0u64, 0, 0]
409 );
410 }
411
412 #[test]
413 fn test_canonicalize_null_list() {
414 let list_scalar = Scalar::null(DType::List(
415 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
416 Nullability::Nullable,
417 ));
418 let const_array = ConstantArray::new(list_scalar, 2).into_array();
419 let canonical_const = const_array.to_list();
420 assert!(canonical_const.elements().to_primitive().is_empty());
421 assert_eq!(
422 canonical_const.offsets().to_primitive().as_slice::<u64>(),
423 [0u64, 0, 0]
424 );
425 }
426
427 #[test]
428 fn test_canonicalize_nullable_struct() {
429 let array = ConstantArray::new(
430 Scalar::null(DType::struct_(
431 [(
432 "non_null_field",
433 DType::Primitive(PType::I8, Nullability::NonNullable),
434 )],
435 Nullability::Nullable,
436 )),
437 3,
438 );
439
440 let struct_array = array.to_struct();
441 assert_eq!(struct_array.len(), 3);
442 assert_eq!(struct_array.valid_count(), 0);
443
444 let field = struct_array.field_by_name("non_null_field").unwrap();
445
446 assert_eq!(
447 field.dtype(),
448 &DType::Primitive(PType::I8, Nullability::NonNullable)
449 );
450 }
451
452 #[test]
453 fn test_canonicalize_fixed_size_list_non_null() {
454 let fsl_scalar = Scalar::fixed_size_list(
456 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
457 vec![
458 Scalar::primitive(10i32, Nullability::NonNullable),
459 Scalar::primitive(20i32, Nullability::NonNullable),
460 Scalar::primitive(30i32, Nullability::NonNullable),
461 ],
462 Nullability::NonNullable,
463 );
464
465 let const_array = ConstantArray::new(fsl_scalar, 4).into_array();
466 let canonical = const_array.to_fixed_size_list();
467
468 assert_eq!(canonical.len(), 4);
469 assert_eq!(canonical.list_size(), 3);
470 assert_eq!(canonical.validity(), &Validity::NonNullable);
471
472 for i in 0..4 {
474 let list = canonical.fixed_size_list_elements_at(i);
475 let list_primitive = list.to_primitive();
476 assert_eq!(list_primitive.as_slice::<i32>(), [10, 20, 30]);
477 }
478 }
479
480 #[test]
481 fn test_canonicalize_fixed_size_list_nullable() {
482 let fsl_scalar = Scalar::fixed_size_list(
484 Arc::new(DType::Primitive(PType::F64, Nullability::NonNullable)),
485 vec![
486 Scalar::primitive(1.5f64, Nullability::NonNullable),
487 Scalar::primitive(2.5f64, Nullability::NonNullable),
488 ],
489 Nullability::Nullable,
490 );
491
492 let const_array = ConstantArray::new(fsl_scalar, 3).into_array();
493 let canonical = const_array.to_fixed_size_list();
494
495 assert_eq!(canonical.len(), 3);
496 assert_eq!(canonical.list_size(), 2);
497 assert_eq!(canonical.validity(), &Validity::AllValid);
498
499 let elements = canonical.elements().to_primitive();
501 assert_eq!(elements.as_slice::<f64>(), [1.5, 2.5, 1.5, 2.5, 1.5, 2.5]);
502 }
503
504 #[test]
505 fn test_canonicalize_fixed_size_list_null() {
506 let fsl_scalar = Scalar::null(DType::FixedSizeList(
508 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
509 4,
510 Nullability::Nullable,
511 ));
512
513 let const_array = ConstantArray::new(fsl_scalar, 5).into_array();
514 let canonical = const_array.to_fixed_size_list();
515
516 assert_eq!(canonical.len(), 5);
517 assert_eq!(canonical.list_size(), 4);
518 assert_eq!(canonical.validity(), &Validity::AllInvalid);
519
520 let elements = canonical.elements().to_primitive();
522 assert_eq!(elements.len(), 20); assert!(elements.as_slice::<u64>().iter().all(|&x| x == 0));
524 }
525
526 #[test]
527 fn test_canonicalize_fixed_size_list_empty() {
528 let fsl_scalar = Scalar::fixed_size_list(
530 Arc::new(DType::Primitive(PType::I8, Nullability::NonNullable)),
531 vec![],
532 Nullability::NonNullable,
533 );
534
535 let const_array = ConstantArray::new(fsl_scalar, 10).into_array();
536 let canonical = const_array.to_fixed_size_list();
537
538 assert_eq!(canonical.len(), 10);
539 assert_eq!(canonical.list_size(), 0);
540 assert_eq!(canonical.validity(), &Validity::NonNullable);
541
542 assert!(canonical.elements().is_empty());
544 }
545
546 #[test]
547 fn test_canonicalize_fixed_size_list_nested() {
548 let fsl_scalar = Scalar::fixed_size_list(
550 Arc::new(DType::Utf8(Nullability::NonNullable)),
551 vec![Scalar::from("hello"), Scalar::from("world")],
552 Nullability::NonNullable,
553 );
554
555 let const_array = ConstantArray::new(fsl_scalar, 2).into_array();
556 let canonical = const_array.to_fixed_size_list();
557
558 assert_eq!(canonical.len(), 2);
559 assert_eq!(canonical.list_size(), 2);
560
561 let elements = canonical.elements().to_varbinview();
563 assert_eq!(elements.scalar_at(0), "hello".into());
564 assert_eq!(elements.scalar_at(1), "world".into());
565 assert_eq!(elements.scalar_at(2), "hello".into());
566 assert_eq!(elements.scalar_at(3), "world".into());
567 }
568
569 #[test]
570 fn test_canonicalize_fixed_size_list_single_element() {
571 let fsl_scalar = Scalar::fixed_size_list(
573 Arc::new(DType::Primitive(PType::I16, Nullability::NonNullable)),
574 vec![Scalar::primitive(42i16, Nullability::NonNullable)],
575 Nullability::NonNullable,
576 );
577
578 let const_array = ConstantArray::new(fsl_scalar, 1).into_array();
579 let canonical = const_array.to_fixed_size_list();
580
581 assert_eq!(canonical.len(), 1);
582 assert_eq!(canonical.list_size(), 1);
583
584 let elements = canonical.elements().to_primitive();
585 assert_eq!(elements.as_slice::<i16>(), [42]);
586 }
587
588 #[test]
589 fn test_canonicalize_fixed_size_list_with_null_elements() {
590 let fsl_scalar = Scalar::fixed_size_list(
592 Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)),
593 vec![
594 Scalar::primitive(100i32, Nullability::Nullable),
595 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
596 Scalar::primitive(200i32, Nullability::Nullable),
597 ],
598 Nullability::NonNullable,
599 );
600
601 let const_array = ConstantArray::new(fsl_scalar, 3).into_array();
602 let canonical = const_array.to_fixed_size_list();
603
604 assert_eq!(canonical.len(), 3);
605 assert_eq!(canonical.list_size(), 3);
606 assert_eq!(canonical.validity(), &Validity::NonNullable);
607
608 let elements = canonical.elements().to_primitive();
610 assert_eq!(elements.as_slice::<i32>()[0], 100);
611 assert_eq!(elements.as_slice::<i32>()[1], 0); assert_eq!(elements.as_slice::<i32>()[2], 200);
613
614 let element_validity = elements.validity();
616 assert!(element_validity.is_valid(0));
617 assert!(!element_validity.is_valid(1));
618 assert!(element_validity.is_valid(2));
619
620 assert!(element_validity.is_valid(3));
622 assert!(!element_validity.is_valid(4));
623 assert!(element_validity.is_valid(5));
624 }
625
626 #[test]
627 fn test_canonicalize_fixed_size_list_large() {
628 let fsl_scalar = Scalar::fixed_size_list(
630 Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)),
631 vec![
632 Scalar::primitive(1u8, Nullability::NonNullable),
633 Scalar::primitive(2u8, Nullability::NonNullable),
634 Scalar::primitive(3u8, Nullability::NonNullable),
635 Scalar::primitive(4u8, Nullability::NonNullable),
636 Scalar::primitive(5u8, Nullability::NonNullable),
637 ],
638 Nullability::NonNullable,
639 );
640
641 let const_array = ConstantArray::new(fsl_scalar, 1000).into_array();
642 let canonical = const_array.to_fixed_size_list();
643
644 assert_eq!(canonical.len(), 1000);
645 assert_eq!(canonical.list_size(), 5);
646
647 let elements = canonical.elements().to_primitive();
648 assert_eq!(elements.len(), 5000);
649
650 for i in 0..1000 {
652 let base = i * 5;
653 assert_eq!(elements.as_slice::<u8>()[base], 1);
654 assert_eq!(elements.as_slice::<u8>()[base + 1], 2);
655 assert_eq!(elements.as_slice::<u8>()[base + 2], 3);
656 assert_eq!(elements.as_slice::<u8>()[base + 3], 4);
657 assert_eq!(elements.as_slice::<u8>()[base + 4], 5);
658 }
659 }
660}