vortex_array/arrays/constant/
canonical.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use arrow_buffer::BooleanBuffer;
7use vortex_buffer::{Buffer, buffer};
8use vortex_dtype::{DType, Nullability, PType, match_each_native_ptype};
9use vortex_error::VortexExpect;
10use vortex_scalar::{
11    BinaryScalar, BoolScalar, DecimalValue, ExtScalar, ListScalar, Scalar, ScalarValue,
12    StructScalar, Utf8Scalar, match_each_decimal_value, match_each_decimal_value_type,
13};
14
15use crate::arrays::constant::ConstantArray;
16use crate::arrays::primitive::PrimitiveArray;
17use crate::arrays::{
18    BinaryView, BoolArray, ConstantVTable, DecimalArray, ExtensionArray, ListArray, NullArray,
19    StructArray, VarBinViewArray, smallest_storage_type,
20};
21use crate::builders::builder_with_capacity;
22use crate::validity::Validity;
23use crate::vtable::CanonicalVTable;
24use crate::{Canonical, IntoArray};
25
26impl CanonicalVTable<ConstantVTable> for ConstantVTable {
27    fn canonicalize(array: &ConstantArray) -> Canonical {
28        let scalar = array.scalar();
29
30        let validity = match array.dtype().nullability() {
31            Nullability::NonNullable => Validity::NonNullable,
32            Nullability::Nullable => match scalar.is_null() {
33                true => Validity::AllInvalid,
34                false => Validity::AllValid,
35            },
36        };
37
38        match array.dtype() {
39            DType::Null => Canonical::Null(NullArray::new(array.len())),
40            DType::Bool(..) => Canonical::Bool(BoolArray::new(
41                if BoolScalar::try_from(scalar)
42                    .vortex_expect("must be bool")
43                    .value()
44                    .unwrap_or_default()
45                {
46                    BooleanBuffer::new_set(array.len())
47                } else {
48                    BooleanBuffer::new_unset(array.len())
49                },
50                validity,
51            )),
52            DType::Primitive(ptype, ..) => {
53                match_each_native_ptype!(ptype, |P| {
54                    Canonical::Primitive(PrimitiveArray::new(
55                        if scalar.is_valid() {
56                            Buffer::full(
57                                P::try_from(scalar)
58                                    .vortex_expect("Couldn't unwrap scalar to primitive"),
59                                array.len(),
60                            )
61                        } else {
62                            Buffer::zeroed(array.len())
63                        },
64                        validity,
65                    ))
66                })
67            }
68            DType::Decimal(decimal_type, ..) => {
69                let size = smallest_storage_type(decimal_type);
70                let decimal = scalar.as_decimal();
71                let Some(value) = decimal.decimal_value() else {
72                    let all_null = match_each_decimal_value_type!(size, |D| {
73                        DecimalArray::new(Buffer::<D>::zeroed(array.len()), *decimal_type, validity)
74                    });
75                    return Canonical::Decimal(all_null);
76                };
77
78                let decimal_array = match_each_decimal_value!(value, |value| {
79                    DecimalArray::new(Buffer::full(value, array.len()), *decimal_type, validity)
80                });
81                Canonical::Decimal(decimal_array)
82            }
83            DType::Utf8(_) => {
84                let value = Utf8Scalar::try_from(scalar)
85                    .vortex_expect("Must be a utf8 scalar")
86                    .value();
87                let const_value = value.as_ref().map(|v| v.as_bytes());
88                Canonical::VarBinView(canonical_byte_view(const_value, array.dtype(), array.len()))
89            }
90            DType::Binary(_) => {
91                let value = BinaryScalar::try_from(scalar)
92                    .vortex_expect("must be a binary scalar")
93                    .value();
94                let const_value = value.as_ref().map(|v| v.as_slice());
95                Canonical::VarBinView(canonical_byte_view(const_value, array.dtype(), array.len()))
96            }
97            DType::Struct(struct_dtype, _) => {
98                let value = StructScalar::try_from(scalar).vortex_expect("must be struct");
99                let fields: Vec<_> = match value.fields() {
100                    Some(fields) => fields
101                        .into_iter()
102                        .map(|s| ConstantArray::new(s, array.len()).into_array())
103                        .collect(),
104                    None => {
105                        assert!(validity.all_invalid());
106                        struct_dtype
107                            .fields()
108                            .map(|dt| {
109                                let scalar = Scalar::default_value(dt);
110                                ConstantArray::new(scalar, array.len()).into_array()
111                            })
112                            .collect()
113                    }
114                };
115                Canonical::Struct(StructArray::new_unchecked(
116                    fields,
117                    struct_dtype.clone(),
118                    array.len(),
119                    validity,
120                ))
121            }
122            DType::List(..) => {
123                let value = ListScalar::try_from(scalar).vortex_expect("must be list");
124                Canonical::List(canonical_list_array(
125                    value.elements(),
126                    value.element_dtype(),
127                    value.dtype().nullability(),
128                    array.len(),
129                ))
130            }
131            DType::FixedSizeList(..) => {
132                unimplemented!("TODO(connor)[FixedSizeList]")
133            }
134            DType::Extension(ext_dtype) => {
135                let s = ExtScalar::try_from(scalar).vortex_expect("must be an extension scalar");
136
137                let storage_scalar = s.storage();
138                let storage_self = ConstantArray::new(storage_scalar, array.len()).into_array();
139                Canonical::Extension(ExtensionArray::new(ext_dtype.clone(), storage_self))
140            }
141        }
142    }
143}
144
145fn canonical_byte_view(scalar_bytes: Option<&[u8]>, dtype: &DType, len: usize) -> VarBinViewArray {
146    match scalar_bytes {
147        None => {
148            let views = buffer![BinaryView::from(0_u128); len];
149
150            // SAFETY: for all-null the views and buffers are just zeroed, never accessed.
151            unsafe {
152                VarBinViewArray::new_unchecked(
153                    views,
154                    Default::default(),
155                    dtype.clone(),
156                    Validity::AllInvalid,
157                )
158            }
159        }
160        Some(scalar_bytes) => {
161            // Create a view to hold the scalar bytes.
162            // If the scalar cannot be inlined, allocate a single buffer large enough to hold it.
163            let view = BinaryView::make_view(scalar_bytes, 0, 0);
164            let mut buffers = Vec::new();
165            if scalar_bytes.len() >= BinaryView::MAX_INLINED_SIZE {
166                buffers.push(Buffer::copy_from(scalar_bytes));
167            }
168
169            // Clone our constant view `len` times.
170            let views = buffer![view; len];
171
172            // SAFETY: all the views are identical and point to a constant value.
173            unsafe {
174                VarBinViewArray::new_unchecked(
175                    views,
176                    Arc::from(buffers),
177                    dtype.clone(),
178                    Validity::from(dtype.nullability()),
179                )
180            }
181        }
182    }
183}
184
185fn canonical_list_array(
186    values: Option<Vec<Scalar>>,
187    element_dtype: &DType,
188    list_nullability: Nullability,
189    len: usize,
190) -> ListArray {
191    match values {
192        None => unsafe {
193            ListArray::new_unchecked(
194                Canonical::empty(element_dtype).into_array(),
195                ConstantArray::new(
196                    Scalar::new(
197                        DType::Primitive(PType::U64, Nullability::NonNullable),
198                        ScalarValue::from(0),
199                    ),
200                    len + 1,
201                )
202                .into_array(),
203                Validity::AllInvalid,
204            )
205        },
206        Some(vs) => {
207            let mut elements_builder = builder_with_capacity(element_dtype, len * vs.len());
208            for _ in 0..len {
209                for v in &vs {
210                    elements_builder
211                        .append_scalar(v)
212                        .vortex_expect("must be a same dtype");
213                }
214            }
215            let offsets = if vs.is_empty() {
216                Buffer::zeroed(len + 1)
217            } else {
218                Buffer::from_trusted_len_iter(
219                    (0..=len * vs.len()).step_by(vs.len()).map(|i| i as u64),
220                )
221            };
222
223            unsafe {
224                ListArray::new_unchecked(
225                    elements_builder.finish(),
226                    offsets.into_array(),
227                    Validity::from(list_nullability),
228                )
229            }
230        }
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use std::sync::Arc;
237
238    use enum_iterator::all;
239    use itertools::Itertools;
240    use vortex_dtype::half::f16;
241    use vortex_dtype::{DType, Nullability, PType};
242    use vortex_scalar::Scalar;
243
244    use crate::arrays::ConstantArray;
245    use crate::canonical::ToCanonical;
246    use crate::stats::{Stat, StatsProvider};
247    use crate::{Array, IntoArray};
248
249    #[test]
250    fn test_canonicalize_null() {
251        let const_null = ConstantArray::new(Scalar::null(DType::Null), 42);
252        let actual = const_null.to_null();
253        assert_eq!(actual.len(), 42);
254        assert_eq!(actual.scalar_at(33), Scalar::null(DType::Null));
255    }
256
257    #[test]
258    fn test_canonicalize_const_str() {
259        let const_array = ConstantArray::new("four".to_string(), 4);
260
261        // Check all values correct.
262        let canonical = const_array.to_varbinview();
263
264        assert_eq!(canonical.len(), 4);
265
266        for i in 0..=3 {
267            assert_eq!(canonical.scalar_at(i), "four".into());
268        }
269    }
270
271    #[test]
272    fn test_canonicalize_propagates_stats() {
273        let scalar = Scalar::bool(true, Nullability::NonNullable);
274        let const_array = ConstantArray::new(scalar, 4).into_array();
275        let stats = const_array
276            .statistics()
277            .compute_all(&all::<Stat>().collect_vec())
278            .unwrap();
279        let canonical = const_array.to_canonical();
280        let canonical_stats = canonical.as_ref().statistics();
281
282        let stats_ref = stats.as_typed_ref(canonical.as_ref().dtype());
283
284        for stat in all::<Stat>() {
285            if stat.dtype(canonical.as_ref().dtype()).is_none() {
286                continue;
287            }
288            assert_eq!(
289                canonical_stats.get(stat),
290                stats_ref.get(stat),
291                "stat mismatch {stat}"
292            );
293        }
294    }
295
296    #[test]
297    fn test_canonicalize_scalar_values() {
298        let f16_value = f16::from_f32(5.722046e-6);
299        let f16_scalar = Scalar::primitive(f16_value, Nullability::NonNullable);
300
301        // Create a ConstantArray with the f16 scalar
302        let const_array = ConstantArray::new(f16_scalar.clone(), 1).into_array();
303        let canonical_const = const_array.to_primitive();
304
305        // Verify the scalar value is preserved through canonicalization
306        assert_eq!(canonical_const.scalar_at(0), f16_scalar);
307    }
308
309    #[test]
310    fn test_canonicalize_lists() {
311        let list_scalar = Scalar::list(
312            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
313            vec![1u64.into(), 2u64.into()],
314            Nullability::NonNullable,
315        );
316        let const_array = ConstantArray::new(list_scalar, 2).into_array();
317        let canonical_const = const_array.to_list();
318        assert_eq!(
319            canonical_const.elements().to_primitive().as_slice::<u64>(),
320            [1u64, 2, 1, 2]
321        );
322        assert_eq!(
323            canonical_const.offsets().to_primitive().as_slice::<u64>(),
324            [0u64, 2, 4]
325        );
326    }
327
328    #[test]
329    fn test_canonicalize_empty_list() {
330        let list_scalar = Scalar::list(
331            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
332            vec![],
333            Nullability::NonNullable,
334        );
335        let const_array = ConstantArray::new(list_scalar, 2).into_array();
336        let canonical_const = const_array.to_list();
337        assert!(canonical_const.elements().to_primitive().is_empty());
338        assert_eq!(
339            canonical_const.offsets().to_primitive().as_slice::<u64>(),
340            [0u64, 0, 0]
341        );
342    }
343
344    #[test]
345    fn test_canonicalize_null_list() {
346        let list_scalar = Scalar::null(DType::List(
347            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
348            Nullability::Nullable,
349        ));
350        let const_array = ConstantArray::new(list_scalar, 2).into_array();
351        let canonical_const = const_array.to_list();
352        assert!(canonical_const.elements().to_primitive().is_empty());
353        assert_eq!(
354            canonical_const.offsets().to_primitive().as_slice::<u64>(),
355            [0u64, 0, 0]
356        );
357    }
358
359    #[test]
360    fn test_canonicalize_nullable_struct() {
361        let array = ConstantArray::new(
362            Scalar::null(DType::struct_(
363                [(
364                    "non_null_field",
365                    DType::Primitive(PType::I8, Nullability::NonNullable),
366                )],
367                Nullability::Nullable,
368            )),
369            3,
370        );
371
372        let struct_array = array.to_struct();
373        assert_eq!(struct_array.len(), 3);
374        assert_eq!(struct_array.valid_count(), 0);
375
376        let field = struct_array.field_by_name("non_null_field").unwrap();
377
378        assert_eq!(
379            field.dtype(),
380            &DType::Primitive(PType::I8, Nullability::NonNullable)
381        );
382    }
383}