vortex_array/arrays/constant/
canonical.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use arrow_buffer::BooleanBuffer;
7use vortex_buffer::{Buffer, buffer};
8use vortex_dtype::{DType, Nullability, PType, match_each_native_ptype};
9use vortex_error::VortexExpect;
10use vortex_scalar::{
11    BinaryScalar, BoolScalar, DecimalValue, ExtScalar, ListScalar, Scalar, ScalarValue,
12    StructScalar, Utf8Scalar, match_each_decimal_value, match_each_decimal_value_type,
13};
14
15use crate::arrays::constant::ConstantArray;
16use crate::arrays::primitive::PrimitiveArray;
17use crate::arrays::{
18    BinaryView, BoolArray, ConstantVTable, DecimalArray, ExtensionArray, FixedSizeListArray,
19    ListArray, NullArray, StructArray, VarBinViewArray, smallest_storage_type,
20};
21use crate::builders::builder_with_capacity;
22use crate::validity::Validity;
23use crate::vtable::CanonicalVTable;
24use crate::{Canonical, IntoArray};
25
26impl CanonicalVTable<ConstantVTable> for ConstantVTable {
27    fn canonicalize(array: &ConstantArray) -> Canonical {
28        let scalar = array.scalar();
29
30        let validity = match array.dtype().nullability() {
31            Nullability::NonNullable => Validity::NonNullable,
32            Nullability::Nullable => match scalar.is_null() {
33                true => Validity::AllInvalid,
34                false => Validity::AllValid,
35            },
36        };
37
38        match array.dtype() {
39            DType::Null => Canonical::Null(NullArray::new(array.len())),
40            DType::Bool(..) => Canonical::Bool(BoolArray::from_bool_buffer(
41                if BoolScalar::try_from(scalar)
42                    .vortex_expect("must be bool")
43                    .value()
44                    .unwrap_or_default()
45                {
46                    BooleanBuffer::new_set(array.len())
47                } else {
48                    BooleanBuffer::new_unset(array.len())
49                },
50                validity,
51            )),
52            DType::Primitive(ptype, ..) => {
53                match_each_native_ptype!(ptype, |P| {
54                    Canonical::Primitive(PrimitiveArray::new(
55                        if scalar.is_valid() {
56                            Buffer::full(
57                                P::try_from(scalar)
58                                    .vortex_expect("Couldn't unwrap scalar to primitive"),
59                                array.len(),
60                            )
61                        } else {
62                            Buffer::zeroed(array.len())
63                        },
64                        validity,
65                    ))
66                })
67            }
68            DType::Decimal(decimal_type, ..) => {
69                let size = smallest_storage_type(decimal_type);
70                let decimal = scalar.as_decimal();
71                let Some(value) = decimal.decimal_value() else {
72                    let all_null = match_each_decimal_value_type!(size, |D| {
73                        // SAFETY: All-null decimal arrays with zeroed buffers and matching validity.
74                        unsafe {
75                            DecimalArray::new_unchecked(
76                                Buffer::<D>::zeroed(array.len()),
77                                *decimal_type,
78                                validity,
79                            )
80                        }
81                    });
82                    return Canonical::Decimal(all_null);
83                };
84
85                let decimal_array = match_each_decimal_value!(value, |value| {
86                    // SAFETY: Constant decimal values with correct type and validity.
87                    unsafe {
88                        DecimalArray::new_unchecked(
89                            Buffer::full(value, array.len()),
90                            *decimal_type,
91                            validity,
92                        )
93                    }
94                });
95                Canonical::Decimal(decimal_array)
96            }
97            DType::Utf8(_) => {
98                let value = Utf8Scalar::try_from(scalar)
99                    .vortex_expect("Must be a utf8 scalar")
100                    .value();
101                let const_value = value.as_ref().map(|v| v.as_bytes());
102                Canonical::VarBinView(canonical_byte_view(const_value, array.dtype(), array.len()))
103            }
104            DType::Binary(_) => {
105                let value = BinaryScalar::try_from(scalar)
106                    .vortex_expect("must be a binary scalar")
107                    .value();
108                let const_value = value.as_ref().map(|v| v.as_slice());
109                Canonical::VarBinView(canonical_byte_view(const_value, array.dtype(), array.len()))
110            }
111            DType::Struct(struct_dtype, _) => {
112                let value = StructScalar::try_from(scalar).vortex_expect("must be struct");
113                let fields: Vec<_> = match value.fields() {
114                    Some(fields) => fields
115                        .into_iter()
116                        .map(|s| ConstantArray::new(s, array.len()).into_array())
117                        .collect(),
118                    None => {
119                        assert!(validity.all_invalid(array.len()));
120                        struct_dtype
121                            .fields()
122                            .map(|dt| {
123                                let scalar = Scalar::default_value(dt);
124                                ConstantArray::new(scalar, array.len()).into_array()
125                            })
126                            .collect()
127                    }
128                };
129                // SAFETY: Fields are constructed from the same struct scalar, all have same
130                // length, dtypes match by construction.
131                Canonical::Struct(unsafe {
132                    StructArray::new_unchecked(fields, struct_dtype.clone(), array.len(), validity)
133                })
134            }
135            DType::List(..) => {
136                let value = ListScalar::try_from(scalar).vortex_expect("must be list");
137                Canonical::List(canonical_list_array(
138                    value.elements(),
139                    value.element_dtype(),
140                    value.dtype().nullability(),
141                    array.len(),
142                ))
143            }
144            DType::FixedSizeList(element_dtype, list_size, _) => {
145                let value = ListScalar::try_from(scalar).vortex_expect("must be list");
146
147                Canonical::FixedSizeList(canonical_fixed_size_list_array(
148                    value.elements(),
149                    element_dtype,
150                    *list_size,
151                    value.dtype().nullability(),
152                    array.len(),
153                ))
154            }
155            DType::Extension(ext_dtype) => {
156                let s = ExtScalar::try_from(scalar).vortex_expect("must be an extension scalar");
157
158                let storage_scalar = s.storage();
159                let storage_self = ConstantArray::new(storage_scalar, array.len()).into_array();
160                Canonical::Extension(ExtensionArray::new(ext_dtype.clone(), storage_self))
161            }
162        }
163    }
164}
165
166fn canonical_byte_view(scalar_bytes: Option<&[u8]>, dtype: &DType, len: usize) -> VarBinViewArray {
167    match scalar_bytes {
168        None => {
169            let views = buffer![BinaryView::from(0_u128); len];
170
171            // SAFETY: for all-null the views and buffers are just zeroed, never accessed.
172            unsafe {
173                VarBinViewArray::new_unchecked(
174                    views,
175                    Default::default(),
176                    dtype.clone(),
177                    Validity::AllInvalid,
178                )
179            }
180        }
181        Some(scalar_bytes) => {
182            // Create a view to hold the scalar bytes.
183            // If the scalar cannot be inlined, allocate a single buffer large enough to hold it.
184            let view = BinaryView::make_view(scalar_bytes, 0, 0);
185            let mut buffers = Vec::new();
186            if scalar_bytes.len() >= BinaryView::MAX_INLINED_SIZE {
187                buffers.push(Buffer::copy_from(scalar_bytes));
188            }
189
190            // Clone our constant view `len` times.
191            let views = buffer![view; len];
192
193            // SAFETY: all the views are identical and point to a constant value.
194            unsafe {
195                VarBinViewArray::new_unchecked(
196                    views,
197                    Arc::from(buffers),
198                    dtype.clone(),
199                    Validity::from(dtype.nullability()),
200                )
201            }
202        }
203    }
204}
205
206fn canonical_list_array(
207    values: Option<Vec<Scalar>>,
208    element_dtype: &DType,
209    list_nullability: Nullability,
210    len: usize,
211) -> ListArray {
212    match values {
213        None => unsafe {
214            ListArray::new_unchecked(
215                Canonical::empty(element_dtype).into_array(),
216                ConstantArray::new(
217                    Scalar::new(
218                        DType::Primitive(PType::U64, Nullability::NonNullable),
219                        ScalarValue::from(0),
220                    ),
221                    len + 1,
222                )
223                .into_array(),
224                Validity::AllInvalid,
225            )
226        },
227        Some(values) => {
228            let mut elements_builder = builder_with_capacity(element_dtype, len * values.len());
229            for _ in 0..len {
230                for v in &values {
231                    elements_builder
232                        .append_scalar(v)
233                        .vortex_expect("must be a same dtype");
234                }
235            }
236            let offsets = if values.is_empty() {
237                Buffer::zeroed(len + 1)
238            } else {
239                Buffer::from_trusted_len_iter(
240                    (0..=len * values.len())
241                        .step_by(values.len())
242                        .map(|i| i as u64),
243                )
244            };
245
246            unsafe {
247                ListArray::new_unchecked(
248                    elements_builder.finish(),
249                    offsets.into_array(),
250                    Validity::from(list_nullability),
251                )
252            }
253        }
254    }
255}
256
257fn canonical_fixed_size_list_array(
258    values: Option<Vec<Scalar>>,
259    element_dtype: &DType,
260    list_size: u32,
261    list_nullability: Nullability,
262    len: usize,
263) -> FixedSizeListArray {
264    match values {
265        None => {
266            // Even though the scalar is null, we still have to allocate the correct amount of space
267            // for the given `DType`.
268            let elements_len = list_size as usize * len;
269            let mut element_builder = builder_with_capacity(element_dtype, elements_len);
270            element_builder.append_defaults(elements_len);
271            let elements = element_builder.finish();
272
273            // SAFETY: The elements array has a length that is a multiple of `list_size`, and the
274            // validity is `AllInvalid` so we don't care about the length.
275            unsafe {
276                FixedSizeListArray::new_unchecked(elements, list_size, Validity::AllInvalid, len)
277            }
278        }
279        Some(values) => {
280            let mut elements_builder = builder_with_capacity(element_dtype, len * values.len());
281
282            for _ in 0..len {
283                for v in &values {
284                    elements_builder
285                        .append_scalar(v)
286                        .vortex_expect("must be a same dtype");
287                }
288            }
289
290            let elements = elements_builder.finish();
291            let validity = Validity::from(list_nullability);
292
293            // SAFETY: The elements array has a length that is a multiple of `list_size`, and the
294            // validity is either `NonNullable` or `AllValid` so we don't care about the length.
295            unsafe { FixedSizeListArray::new_unchecked(elements, list_size, validity, len) }
296        }
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use std::sync::Arc;
303
304    use enum_iterator::all;
305    use itertools::Itertools;
306    use vortex_dtype::half::f16;
307    use vortex_dtype::{DType, Nullability, PType};
308    use vortex_scalar::Scalar;
309
310    use crate::arrays::ConstantArray;
311    use crate::canonical::ToCanonical;
312    use crate::stats::{Stat, StatsProvider};
313    use crate::validity::Validity;
314    use crate::vtable::ValidityHelper;
315    use crate::{Array, IntoArray};
316
317    #[test]
318    fn test_canonicalize_null() {
319        let const_null = ConstantArray::new(Scalar::null(DType::Null), 42);
320        let actual = const_null.to_null();
321        assert_eq!(actual.len(), 42);
322        assert_eq!(actual.scalar_at(33), Scalar::null(DType::Null));
323    }
324
325    #[test]
326    fn test_canonicalize_const_str() {
327        let const_array = ConstantArray::new("four".to_string(), 4);
328
329        // Check all values correct.
330        let canonical = const_array.to_varbinview();
331
332        assert_eq!(canonical.len(), 4);
333
334        for i in 0..=3 {
335            assert_eq!(canonical.scalar_at(i), "four".into());
336        }
337    }
338
339    #[test]
340    fn test_canonicalize_propagates_stats() {
341        let scalar = Scalar::bool(true, Nullability::NonNullable);
342        let const_array = ConstantArray::new(scalar, 4).into_array();
343        let stats = const_array
344            .statistics()
345            .compute_all(&all::<Stat>().collect_vec())
346            .unwrap();
347        let canonical = const_array.to_canonical();
348        let canonical_stats = canonical.as_ref().statistics();
349
350        let stats_ref = stats.as_typed_ref(canonical.as_ref().dtype());
351
352        for stat in all::<Stat>() {
353            if stat.dtype(canonical.as_ref().dtype()).is_none() {
354                continue;
355            }
356            assert_eq!(
357                canonical_stats.get(stat),
358                stats_ref.get(stat),
359                "stat mismatch {stat}"
360            );
361        }
362    }
363
364    #[test]
365    fn test_canonicalize_scalar_values() {
366        let f16_value = f16::from_f32(5.722046e-6);
367        let f16_scalar = Scalar::primitive(f16_value, Nullability::NonNullable);
368
369        // Create a ConstantArray with the f16 scalar
370        let const_array = ConstantArray::new(f16_scalar.clone(), 1).into_array();
371        let canonical_const = const_array.to_primitive();
372
373        // Verify the scalar value is preserved through canonicalization
374        assert_eq!(canonical_const.scalar_at(0), f16_scalar);
375    }
376
377    #[test]
378    fn test_canonicalize_lists() {
379        let list_scalar = Scalar::list(
380            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
381            vec![1u64.into(), 2u64.into()],
382            Nullability::NonNullable,
383        );
384        let const_array = ConstantArray::new(list_scalar, 2).into_array();
385        let canonical_const = const_array.to_list();
386        assert_eq!(
387            canonical_const.elements().to_primitive().as_slice::<u64>(),
388            [1u64, 2, 1, 2]
389        );
390        assert_eq!(
391            canonical_const.offsets().to_primitive().as_slice::<u64>(),
392            [0u64, 2, 4]
393        );
394    }
395
396    #[test]
397    fn test_canonicalize_empty_list() {
398        let list_scalar = Scalar::list(
399            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
400            vec![],
401            Nullability::NonNullable,
402        );
403        let const_array = ConstantArray::new(list_scalar, 2).into_array();
404        let canonical_const = const_array.to_list();
405        assert!(canonical_const.elements().to_primitive().is_empty());
406        assert_eq!(
407            canonical_const.offsets().to_primitive().as_slice::<u64>(),
408            [0u64, 0, 0]
409        );
410    }
411
412    #[test]
413    fn test_canonicalize_null_list() {
414        let list_scalar = Scalar::null(DType::List(
415            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
416            Nullability::Nullable,
417        ));
418        let const_array = ConstantArray::new(list_scalar, 2).into_array();
419        let canonical_const = const_array.to_list();
420        assert!(canonical_const.elements().to_primitive().is_empty());
421        assert_eq!(
422            canonical_const.offsets().to_primitive().as_slice::<u64>(),
423            [0u64, 0, 0]
424        );
425    }
426
427    #[test]
428    fn test_canonicalize_nullable_struct() {
429        let array = ConstantArray::new(
430            Scalar::null(DType::struct_(
431                [(
432                    "non_null_field",
433                    DType::Primitive(PType::I8, Nullability::NonNullable),
434                )],
435                Nullability::Nullable,
436            )),
437            3,
438        );
439
440        let struct_array = array.to_struct();
441        assert_eq!(struct_array.len(), 3);
442        assert_eq!(struct_array.valid_count(), 0);
443
444        let field = struct_array.field_by_name("non_null_field").unwrap();
445
446        assert_eq!(
447            field.dtype(),
448            &DType::Primitive(PType::I8, Nullability::NonNullable)
449        );
450    }
451
452    #[test]
453    fn test_canonicalize_fixed_size_list_non_null() {
454        // Test with a non-null fixed-size list constant.
455        let fsl_scalar = Scalar::fixed_size_list(
456            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
457            vec![
458                Scalar::primitive(10i32, Nullability::NonNullable),
459                Scalar::primitive(20i32, Nullability::NonNullable),
460                Scalar::primitive(30i32, Nullability::NonNullable),
461            ],
462            Nullability::NonNullable,
463        );
464
465        let const_array = ConstantArray::new(fsl_scalar, 4).into_array();
466        let canonical = const_array.to_fixed_size_list();
467
468        assert_eq!(canonical.len(), 4);
469        assert_eq!(canonical.list_size(), 3);
470        assert_eq!(canonical.validity(), &Validity::NonNullable);
471
472        // Check that each list is [10, 20, 30].
473        for i in 0..4 {
474            let list = canonical.fixed_size_list_elements_at(i);
475            let list_primitive = list.to_primitive();
476            assert_eq!(list_primitive.as_slice::<i32>(), [10, 20, 30]);
477        }
478    }
479
480    #[test]
481    fn test_canonicalize_fixed_size_list_nullable() {
482        // Test with a nullable but non-null fixed-size list constant.
483        let fsl_scalar = Scalar::fixed_size_list(
484            Arc::new(DType::Primitive(PType::F64, Nullability::NonNullable)),
485            vec![
486                Scalar::primitive(1.5f64, Nullability::NonNullable),
487                Scalar::primitive(2.5f64, Nullability::NonNullable),
488            ],
489            Nullability::Nullable,
490        );
491
492        let const_array = ConstantArray::new(fsl_scalar, 3).into_array();
493        let canonical = const_array.to_fixed_size_list();
494
495        assert_eq!(canonical.len(), 3);
496        assert_eq!(canonical.list_size(), 2);
497        assert_eq!(canonical.validity(), &Validity::AllValid);
498
499        // Check elements.
500        let elements = canonical.elements().to_primitive();
501        assert_eq!(elements.as_slice::<f64>(), [1.5, 2.5, 1.5, 2.5, 1.5, 2.5]);
502    }
503
504    #[test]
505    fn test_canonicalize_fixed_size_list_null() {
506        // Test with a null fixed-size list constant.
507        let fsl_scalar = Scalar::null(DType::FixedSizeList(
508            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
509            4,
510            Nullability::Nullable,
511        ));
512
513        let const_array = ConstantArray::new(fsl_scalar, 5).into_array();
514        let canonical = const_array.to_fixed_size_list();
515
516        assert_eq!(canonical.len(), 5);
517        assert_eq!(canonical.list_size(), 4);
518        assert_eq!(canonical.validity(), &Validity::AllInvalid);
519
520        // Elements should be defaults (zeros).
521        let elements = canonical.elements().to_primitive();
522        assert_eq!(elements.len(), 20); // 5 lists * 4 elements each
523        assert!(elements.as_slice::<u64>().iter().all(|&x| x == 0));
524    }
525
526    #[test]
527    fn test_canonicalize_fixed_size_list_empty() {
528        // Test with size-0 lists (edge case).
529        let fsl_scalar = Scalar::fixed_size_list(
530            Arc::new(DType::Primitive(PType::I8, Nullability::NonNullable)),
531            vec![],
532            Nullability::NonNullable,
533        );
534
535        let const_array = ConstantArray::new(fsl_scalar, 10).into_array();
536        let canonical = const_array.to_fixed_size_list();
537
538        assert_eq!(canonical.len(), 10);
539        assert_eq!(canonical.list_size(), 0);
540        assert_eq!(canonical.validity(), &Validity::NonNullable);
541
542        // Elements array should be empty.
543        assert!(canonical.elements().is_empty());
544    }
545
546    #[test]
547    fn test_canonicalize_fixed_size_list_nested() {
548        // Test with nested data types (list of strings).
549        let fsl_scalar = Scalar::fixed_size_list(
550            Arc::new(DType::Utf8(Nullability::NonNullable)),
551            vec![Scalar::from("hello"), Scalar::from("world")],
552            Nullability::NonNullable,
553        );
554
555        let const_array = ConstantArray::new(fsl_scalar, 2).into_array();
556        let canonical = const_array.to_fixed_size_list();
557
558        assert_eq!(canonical.len(), 2);
559        assert_eq!(canonical.list_size(), 2);
560
561        // Check elements are repeated correctly.
562        let elements = canonical.elements().to_varbinview();
563        assert_eq!(elements.scalar_at(0), "hello".into());
564        assert_eq!(elements.scalar_at(1), "world".into());
565        assert_eq!(elements.scalar_at(2), "hello".into());
566        assert_eq!(elements.scalar_at(3), "world".into());
567    }
568
569    #[test]
570    fn test_canonicalize_fixed_size_list_single_element() {
571        // Test with a single-element list.
572        let fsl_scalar = Scalar::fixed_size_list(
573            Arc::new(DType::Primitive(PType::I16, Nullability::NonNullable)),
574            vec![Scalar::primitive(42i16, Nullability::NonNullable)],
575            Nullability::NonNullable,
576        );
577
578        let const_array = ConstantArray::new(fsl_scalar, 1).into_array();
579        let canonical = const_array.to_fixed_size_list();
580
581        assert_eq!(canonical.len(), 1);
582        assert_eq!(canonical.list_size(), 1);
583
584        let elements = canonical.elements().to_primitive();
585        assert_eq!(elements.as_slice::<i16>(), [42]);
586    }
587
588    #[test]
589    fn test_canonicalize_fixed_size_list_with_null_elements() {
590        // Test FSL with nullable element type where some elements are null.
591        let fsl_scalar = Scalar::fixed_size_list(
592            Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)),
593            vec![
594                Scalar::primitive(100i32, Nullability::Nullable),
595                Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
596                Scalar::primitive(200i32, Nullability::Nullable),
597            ],
598            Nullability::NonNullable,
599        );
600
601        let const_array = ConstantArray::new(fsl_scalar, 3).into_array();
602        let canonical = const_array.to_fixed_size_list();
603
604        assert_eq!(canonical.len(), 3);
605        assert_eq!(canonical.list_size(), 3);
606        assert_eq!(canonical.validity(), &Validity::NonNullable);
607
608        // Check elements including nulls.
609        let elements = canonical.elements().to_primitive();
610        assert_eq!(elements.as_slice::<i32>()[0], 100);
611        assert_eq!(elements.as_slice::<i32>()[1], 0); // null becomes 0
612        assert_eq!(elements.as_slice::<i32>()[2], 200);
613
614        // Check element validity.
615        let element_validity = elements.validity();
616        assert!(element_validity.is_valid(0));
617        assert!(!element_validity.is_valid(1));
618        assert!(element_validity.is_valid(2));
619
620        // Pattern should repeat.
621        assert!(element_validity.is_valid(3));
622        assert!(!element_validity.is_valid(4));
623        assert!(element_validity.is_valid(5));
624    }
625
626    #[test]
627    fn test_canonicalize_fixed_size_list_large() {
628        // Test with a large constant array.
629        let fsl_scalar = Scalar::fixed_size_list(
630            Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)),
631            vec![
632                Scalar::primitive(1u8, Nullability::NonNullable),
633                Scalar::primitive(2u8, Nullability::NonNullable),
634                Scalar::primitive(3u8, Nullability::NonNullable),
635                Scalar::primitive(4u8, Nullability::NonNullable),
636                Scalar::primitive(5u8, Nullability::NonNullable),
637            ],
638            Nullability::NonNullable,
639        );
640
641        let const_array = ConstantArray::new(fsl_scalar, 1000).into_array();
642        let canonical = const_array.to_fixed_size_list();
643
644        assert_eq!(canonical.len(), 1000);
645        assert_eq!(canonical.list_size(), 5);
646
647        let elements = canonical.elements().to_primitive();
648        assert_eq!(elements.len(), 5000);
649
650        // Check pattern repeats correctly.
651        for i in 0..1000 {
652            let base = i * 5;
653            assert_eq!(elements.as_slice::<u8>()[base], 1);
654            assert_eq!(elements.as_slice::<u8>()[base + 1], 2);
655            assert_eq!(elements.as_slice::<u8>()[base + 2], 3);
656            assert_eq!(elements.as_slice::<u8>()[base + 3], 4);
657            assert_eq!(elements.as_slice::<u8>()[base + 4], 5);
658        }
659    }
660}