vortex_array/arrays/constant/
canonical.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use arrow_buffer::BooleanBuffer;
7use vortex_buffer::{Buffer, BufferMut, buffer};
8use vortex_dtype::{DType, Nullability, PType, match_each_native_ptype};
9use vortex_error::{VortexExpect, VortexResult};
10use vortex_scalar::{
11    BinaryScalar, BoolScalar, DecimalValue, ExtScalar, ListScalar, Scalar, ScalarValue,
12    StructScalar, Utf8Scalar, match_each_decimal_value, match_each_decimal_value_type,
13};
14
15use crate::arrays::constant::ConstantArray;
16use crate::arrays::primitive::PrimitiveArray;
17use crate::arrays::{
18    BinaryView, BoolArray, ConstantVTable, DecimalArray, ExtensionArray, ListArray, NullArray,
19    StructArray, VarBinViewArray, smallest_storage_type,
20};
21use crate::builders::{ArrayBuilderExt, builder_with_capacity};
22use crate::validity::Validity;
23use crate::vtable::CanonicalVTable;
24use crate::{Canonical, IntoArray};
25
26impl CanonicalVTable<ConstantVTable> for ConstantVTable {
27    fn canonicalize(array: &ConstantArray) -> VortexResult<Canonical> {
28        let scalar = array.scalar();
29
30        let validity = match array.dtype().nullability() {
31            Nullability::NonNullable => Validity::NonNullable,
32            Nullability::Nullable => match scalar.is_null() {
33                true => Validity::AllInvalid,
34                false => Validity::AllValid,
35            },
36        };
37
38        Ok(match array.dtype() {
39            DType::Null => Canonical::Null(NullArray::new(array.len())),
40            DType::Bool(..) => Canonical::Bool(BoolArray::new(
41                if BoolScalar::try_from(scalar)?.value().unwrap_or_default() {
42                    BooleanBuffer::new_set(array.len())
43                } else {
44                    BooleanBuffer::new_unset(array.len())
45                },
46                validity,
47            )),
48            DType::Primitive(ptype, ..) => {
49                match_each_native_ptype!(ptype, |P| {
50                    Canonical::Primitive(PrimitiveArray::new(
51                        if scalar.is_valid() {
52                            Buffer::full(
53                                P::try_from(scalar)
54                                    .vortex_expect("Couldn't unwrap scalar to primitive"),
55                                array.len(),
56                            )
57                        } else {
58                            Buffer::zeroed(array.len())
59                        },
60                        validity,
61                    ))
62                })
63            }
64            DType::Decimal(decimal_type, ..) => {
65                let size = smallest_storage_type(decimal_type);
66                let decimal = scalar.as_decimal();
67                let Some(value) = decimal.decimal_value() else {
68                    let all_null = match_each_decimal_value_type!(size, |D| {
69                        DecimalArray::new(Buffer::<D>::zeroed(array.len()), *decimal_type, validity)
70                    });
71                    return Ok(Canonical::Decimal(all_null));
72                };
73
74                let decimal_array = match_each_decimal_value!(value, |value| {
75                    DecimalArray::new(Buffer::full(*value, array.len()), *decimal_type, validity)
76                });
77                Canonical::Decimal(decimal_array)
78            }
79            DType::Utf8(_) => {
80                let value = Utf8Scalar::try_from(scalar)?.value();
81                let const_value = value.as_ref().map(|v| v.as_bytes());
82                Canonical::VarBinView(canonical_byte_view(
83                    const_value,
84                    array.dtype(),
85                    array.len(),
86                )?)
87            }
88            DType::Binary(_) => {
89                let value = BinaryScalar::try_from(scalar)?.value();
90                let const_value = value.as_ref().map(|v| v.as_slice());
91                Canonical::VarBinView(canonical_byte_view(
92                    const_value,
93                    array.dtype(),
94                    array.len(),
95                )?)
96            }
97            DType::Struct(struct_dtype, _) => {
98                let value = StructScalar::try_from(scalar)?;
99                let fields: Vec<_> = match value.fields() {
100                    Some(fields) => fields
101                        .into_iter()
102                        .map(|s| ConstantArray::new(s, array.len()).into_array())
103                        .collect(),
104                    None => {
105                        assert!(validity.all_invalid()?);
106                        struct_dtype
107                            .fields()
108                            .map(|dt| {
109                                let scalar = Scalar::default_value(dt);
110                                ConstantArray::new(scalar, array.len()).into_array()
111                            })
112                            .collect()
113                    }
114                };
115                Canonical::Struct(StructArray::try_new_with_dtype(
116                    fields,
117                    struct_dtype.clone(),
118                    array.len(),
119                    validity,
120                )?)
121            }
122            DType::List(..) => {
123                let value = ListScalar::try_from(scalar)?;
124                Canonical::List(canonical_list_array(
125                    value.elements(),
126                    value.element_dtype(),
127                    value.dtype().nullability(),
128                    array.len(),
129                )?)
130            }
131            DType::Extension(ext_dtype) => {
132                let s = ExtScalar::try_from(scalar)?;
133
134                let storage_scalar = s.storage();
135                let storage_self = ConstantArray::new(storage_scalar, array.len()).into_array();
136                Canonical::Extension(ExtensionArray::new(ext_dtype.clone(), storage_self))
137            }
138        })
139    }
140}
141
142fn canonical_byte_view(
143    scalar_bytes: Option<&[u8]>,
144    dtype: &DType,
145    len: usize,
146) -> VortexResult<VarBinViewArray> {
147    match scalar_bytes {
148        None => {
149            let views = buffer![BinaryView::from(0_u128); len];
150
151            VarBinViewArray::try_new(
152                views,
153                Default::default(),
154                dtype.clone(),
155                Validity::AllInvalid,
156            )
157        }
158        Some(scalar_bytes) => {
159            // Create a view to hold the scalar bytes.
160            // If the scalar cannot be inlined, allocate a single buffer large enough to hold it.
161            let view = BinaryView::make_view(scalar_bytes, 0, 0);
162            let mut buffers = Vec::new();
163            if scalar_bytes.len() >= BinaryView::MAX_INLINED_SIZE {
164                buffers.push(Buffer::copy_from(scalar_bytes));
165            }
166
167            // Clone our constant view `len` times.
168            // TODO(aduffy): switch this out for a ConstantArray once we
169            //   add u128 PType, see https://github.com/vortex-data/vortex/issues/1110
170            let mut views = BufferMut::with_capacity_aligned(len, align_of::<u128>().into());
171            for _ in 0..len {
172                views.push(view);
173            }
174
175            VarBinViewArray::try_new(
176                views.freeze(),
177                Arc::from(buffers),
178                dtype.clone(),
179                Validity::from(dtype.nullability()),
180            )
181        }
182    }
183}
184
185fn canonical_list_array(
186    values: Option<Vec<Scalar>>,
187    element_dtype: &DType,
188    list_nullability: Nullability,
189    len: usize,
190) -> VortexResult<ListArray> {
191    match values {
192        None => ListArray::try_new(
193            Canonical::empty(element_dtype).into_array(),
194            ConstantArray::new(
195                Scalar::new(
196                    DType::Primitive(PType::U64, Nullability::NonNullable),
197                    ScalarValue::from(0),
198                ),
199                len + 1,
200            )
201            .into_array(),
202            Validity::AllInvalid,
203        ),
204        Some(vs) => {
205            let mut elements_builder = builder_with_capacity(element_dtype, len * vs.len());
206            for _ in 0..len {
207                for v in &vs {
208                    elements_builder.append_scalar(v)?;
209                }
210            }
211            let offsets = if vs.is_empty() {
212                Buffer::zeroed(len + 1)
213            } else {
214                Buffer::from_trusted_len_iter(
215                    (0..=len * vs.len()).step_by(vs.len()).map(|i| i as u64),
216                )
217            };
218
219            ListArray::try_new(
220                elements_builder.finish(),
221                offsets.into_array(),
222                Validity::from(list_nullability),
223            )
224        }
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use std::sync::Arc;
231
232    use enum_iterator::all;
233    use itertools::Itertools;
234    use vortex_dtype::half::f16;
235    use vortex_dtype::{DType, Nullability, PType};
236    use vortex_scalar::Scalar;
237
238    use crate::arrays::ConstantArray;
239    use crate::canonical::ToCanonical;
240    use crate::stats::{Stat, StatsProviderExt};
241    use crate::{Array, IntoArray};
242
243    #[test]
244    fn test_canonicalize_null() {
245        let const_null = ConstantArray::new(Scalar::null(DType::Null), 42);
246        let actual = const_null.to_null().unwrap();
247        assert_eq!(actual.len(), 42);
248        assert_eq!(actual.scalar_at(33).unwrap(), Scalar::null(DType::Null));
249    }
250
251    #[test]
252    fn test_canonicalize_const_str() {
253        let const_array = ConstantArray::new("four".to_string(), 4);
254
255        // Check all values correct.
256        let canonical = const_array.to_varbinview().unwrap();
257
258        assert_eq!(canonical.len(), 4);
259
260        for i in 0..=3 {
261            assert_eq!(canonical.scalar_at(i).unwrap(), "four".into());
262        }
263    }
264
265    #[test]
266    fn test_canonicalize_propagates_stats() {
267        let scalar = Scalar::bool(true, Nullability::NonNullable);
268        let const_array = ConstantArray::new(scalar, 4).into_array();
269        let stats = const_array
270            .statistics()
271            .compute_all(&all::<Stat>().collect_vec())
272            .unwrap();
273        let canonical = const_array.to_canonical().unwrap();
274        let canonical_stats = canonical.as_ref().statistics().to_owned();
275
276        for stat in all::<Stat>() {
277            if stat.dtype(canonical.as_ref().dtype()).is_none() {
278                continue;
279            }
280            let canonical_stat =
281                canonical_stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
282            let original_stat =
283                stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
284            assert_eq!(canonical_stat, original_stat, "stat mismatch {stat}");
285        }
286    }
287
288    #[test]
289    fn test_canonicalize_scalar_values() {
290        let f16_scalar = Scalar::primitive(f16::from_f32(5.722046e-6), Nullability::NonNullable);
291        let scalar = Scalar::new(
292            DType::Primitive(PType::F16, Nullability::NonNullable),
293            Scalar::primitive(96u8, Nullability::NonNullable).into_value(),
294        );
295        let const_array = ConstantArray::new(scalar.clone(), 1).into_array();
296        let canonical_const = const_array.to_primitive().unwrap();
297        assert_eq!(canonical_const.scalar_at(0).unwrap(), scalar);
298        assert_eq!(canonical_const.scalar_at(0).unwrap(), f16_scalar);
299    }
300
301    #[test]
302    fn test_canonicalize_lists() {
303        let list_scalar = Scalar::list(
304            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
305            vec![1u64.into(), 2u64.into()],
306            Nullability::NonNullable,
307        );
308        let const_array = ConstantArray::new(list_scalar, 2).into_array();
309        let canonical_const = const_array.to_list().unwrap();
310        assert_eq!(
311            canonical_const
312                .elements()
313                .to_primitive()
314                .unwrap()
315                .as_slice::<u64>(),
316            [1u64, 2, 1, 2]
317        );
318        assert_eq!(
319            canonical_const
320                .offsets()
321                .to_primitive()
322                .unwrap()
323                .as_slice::<u64>(),
324            [0u64, 2, 4]
325        );
326    }
327
328    #[test]
329    fn test_canonicalize_empty_list() {
330        let list_scalar = Scalar::list(
331            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
332            vec![],
333            Nullability::NonNullable,
334        );
335        let const_array = ConstantArray::new(list_scalar, 2).into_array();
336        let canonical_const = const_array.to_list().unwrap();
337        assert!(
338            canonical_const
339                .elements()
340                .to_primitive()
341                .unwrap()
342                .is_empty()
343        );
344        assert_eq!(
345            canonical_const
346                .offsets()
347                .to_primitive()
348                .unwrap()
349                .as_slice::<u64>(),
350            [0u64, 0, 0]
351        );
352    }
353
354    #[test]
355    fn test_canonicalize_null_list() {
356        let list_scalar = Scalar::null(DType::List(
357            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
358            Nullability::Nullable,
359        ));
360        let const_array = ConstantArray::new(list_scalar, 2).into_array();
361        let canonical_const = const_array.to_list().unwrap();
362        assert!(
363            canonical_const
364                .elements()
365                .to_primitive()
366                .unwrap()
367                .is_empty()
368        );
369        assert_eq!(
370            canonical_const
371                .offsets()
372                .to_primitive()
373                .unwrap()
374                .as_slice::<u64>(),
375            [0u64, 0, 0]
376        );
377    }
378
379    #[test]
380    fn test_canonicalize_nullable_struct() {
381        let array = ConstantArray::new(
382            Scalar::null(DType::struct_(
383                [(
384                    "non_null_field",
385                    DType::Primitive(PType::I8, Nullability::NonNullable),
386                )],
387                Nullability::Nullable,
388            )),
389            3,
390        );
391
392        let struct_array = array.to_struct().unwrap();
393        assert_eq!(struct_array.len(), 3);
394        assert_eq!(struct_array.valid_count().unwrap(), 0);
395
396        let field = struct_array.field_by_name("non_null_field").unwrap();
397
398        assert_eq!(
399            field.dtype(),
400            &DType::Primitive(PType::I8, Nullability::NonNullable)
401        );
402    }
403}