vortex_array/arrays/constant/
canonical.rs

1use arrow_buffer::BooleanBuffer;
2use vortex_buffer::{Buffer, BufferMut, buffer};
3use vortex_dtype::{DType, Nullability, PType, match_each_native_ptype};
4use vortex_error::{VortexExpect, VortexResult};
5use vortex_scalar::{
6    BinaryScalar, BoolScalar, DecimalValue, ExtScalar, ListScalar, Scalar, ScalarValue,
7    StructScalar, Utf8Scalar, match_each_decimal_value, match_each_decimal_value_type,
8};
9
10use crate::arrays::constant::ConstantArray;
11use crate::arrays::primitive::PrimitiveArray;
12use crate::arrays::{
13    BinaryView, BoolArray, ConstantVTable, DecimalArray, ExtensionArray, ListArray, NullArray,
14    StructArray, VarBinViewArray, smallest_storage_type,
15};
16use crate::builders::{ArrayBuilderExt, builder_with_capacity};
17use crate::validity::Validity;
18use crate::vtable::CanonicalVTable;
19use crate::{Canonical, IntoArray};
20
21impl CanonicalVTable<ConstantVTable> for ConstantVTable {
22    fn canonicalize(array: &ConstantArray) -> VortexResult<Canonical> {
23        let scalar = array.scalar();
24
25        let validity = match array.dtype().nullability() {
26            Nullability::NonNullable => Validity::NonNullable,
27            Nullability::Nullable => match scalar.is_null() {
28                true => Validity::AllInvalid,
29                false => Validity::AllValid,
30            },
31        };
32
33        Ok(match array.dtype() {
34            DType::Null => Canonical::Null(NullArray::new(array.len())),
35            DType::Bool(..) => Canonical::Bool(BoolArray::new(
36                if BoolScalar::try_from(scalar)?.value().unwrap_or_default() {
37                    BooleanBuffer::new_set(array.len())
38                } else {
39                    BooleanBuffer::new_unset(array.len())
40                },
41                validity,
42            )),
43            DType::Primitive(ptype, ..) => {
44                match_each_native_ptype!(ptype, |P| {
45                    Canonical::Primitive(PrimitiveArray::new(
46                        if scalar.is_valid() {
47                            Buffer::full(
48                                P::try_from(scalar)
49                                    .vortex_expect("Couldn't unwrap scalar to primitive"),
50                                array.len(),
51                            )
52                        } else {
53                            Buffer::zeroed(array.len())
54                        },
55                        validity,
56                    ))
57                })
58            }
59            DType::Decimal(decimal_type, ..) => {
60                let size = smallest_storage_type(decimal_type);
61                let decimal = scalar.as_decimal();
62                let Some(value) = decimal.decimal_value() else {
63                    let all_null = match_each_decimal_value_type!(size, |D| {
64                        DecimalArray::new(Buffer::<D>::zeroed(array.len()), *decimal_type, validity)
65                    });
66                    return Ok(Canonical::Decimal(all_null));
67                };
68
69                let decimal_array = match_each_decimal_value!(value, |value| {
70                    DecimalArray::new(Buffer::full(*value, array.len()), *decimal_type, validity)
71                });
72                Canonical::Decimal(decimal_array)
73            }
74            DType::Utf8(_) => {
75                let value = Utf8Scalar::try_from(scalar)?.value();
76                let const_value = value.as_ref().map(|v| v.as_bytes());
77                Canonical::VarBinView(canonical_byte_view(
78                    const_value,
79                    array.dtype(),
80                    array.len(),
81                )?)
82            }
83            DType::Binary(_) => {
84                let value = BinaryScalar::try_from(scalar)?.value();
85                let const_value = value.as_ref().map(|v| v.as_slice());
86                Canonical::VarBinView(canonical_byte_view(
87                    const_value,
88                    array.dtype(),
89                    array.len(),
90                )?)
91            }
92            DType::Struct(struct_dtype, _) => {
93                let value = StructScalar::try_from(scalar)?;
94                let fields: Vec<_> = match value.fields() {
95                    Some(fields) => fields
96                        .into_iter()
97                        .map(|s| ConstantArray::new(s, array.len()).into_array())
98                        .collect(),
99                    None => {
100                        assert!(validity.all_invalid()?);
101                        struct_dtype
102                            .fields()
103                            .map(|dt| {
104                                let scalar = Scalar::default_value(dt);
105                                ConstantArray::new(scalar, array.len()).into_array()
106                            })
107                            .collect()
108                    }
109                };
110                Canonical::Struct(StructArray::try_new_with_dtype(
111                    fields,
112                    struct_dtype.clone(),
113                    array.len(),
114                    validity,
115                )?)
116            }
117            DType::List(..) => {
118                let value = ListScalar::try_from(scalar)?;
119                Canonical::List(canonical_list_array(
120                    value.elements(),
121                    value.element_dtype(),
122                    value.dtype().nullability(),
123                    array.len(),
124                )?)
125            }
126            DType::Extension(ext_dtype) => {
127                let s = ExtScalar::try_from(scalar)?;
128
129                let storage_scalar = s.storage();
130                let storage_self = ConstantArray::new(storage_scalar, array.len()).into_array();
131                Canonical::Extension(ExtensionArray::new(ext_dtype.clone(), storage_self))
132            }
133        })
134    }
135}
136
137fn canonical_byte_view(
138    scalar_bytes: Option<&[u8]>,
139    dtype: &DType,
140    len: usize,
141) -> VortexResult<VarBinViewArray> {
142    match scalar_bytes {
143        None => {
144            let views = buffer![BinaryView::from(0_u128); len];
145
146            VarBinViewArray::try_new(views, Vec::new(), dtype.clone(), Validity::AllInvalid)
147        }
148        Some(scalar_bytes) => {
149            // Create a view to hold the scalar bytes.
150            // If the scalar cannot be inlined, allocate a single buffer large enough to hold it.
151            let view = BinaryView::make_view(scalar_bytes, 0, 0);
152            let mut buffers = Vec::new();
153            if scalar_bytes.len() >= BinaryView::MAX_INLINED_SIZE {
154                buffers.push(Buffer::copy_from(scalar_bytes));
155            }
156
157            // Clone our constant view `len` times.
158            // TODO(aduffy): switch this out for a ConstantArray once we
159            //   add u128 PType, see https://github.com/vortex-data/vortex/issues/1110
160            let mut views = BufferMut::with_capacity_aligned(len, align_of::<u128>().into());
161            for _ in 0..len {
162                views.push(view);
163            }
164
165            VarBinViewArray::try_new(
166                views.freeze(),
167                buffers,
168                dtype.clone(),
169                Validity::from(dtype.nullability()),
170            )
171        }
172    }
173}
174
175fn canonical_list_array(
176    values: Option<Vec<Scalar>>,
177    element_dtype: &DType,
178    list_nullability: Nullability,
179    len: usize,
180) -> VortexResult<ListArray> {
181    match values {
182        None => ListArray::try_new(
183            Canonical::empty(element_dtype).into_array(),
184            ConstantArray::new(
185                Scalar::new(
186                    DType::Primitive(PType::U64, Nullability::NonNullable),
187                    ScalarValue::from(0),
188                ),
189                len + 1,
190            )
191            .into_array(),
192            Validity::AllInvalid,
193        ),
194        Some(vs) => {
195            let mut elements_builder = builder_with_capacity(element_dtype, len * vs.len());
196            for _ in 0..len {
197                for v in &vs {
198                    elements_builder.append_scalar(v)?;
199                }
200            }
201            let offsets = if vs.is_empty() {
202                Buffer::zeroed(len + 1)
203            } else {
204                (0..=len * vs.len())
205                    .step_by(vs.len())
206                    .map(|i| i as u64)
207                    .collect::<Buffer<_>>()
208            };
209
210            ListArray::try_new(
211                elements_builder.finish(),
212                offsets.into_array(),
213                Validity::from(list_nullability),
214            )
215        }
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use std::sync::Arc;
222
223    use enum_iterator::all;
224    use vortex_dtype::half::f16;
225    use vortex_dtype::{DType, Nullability, PType};
226    use vortex_scalar::Scalar;
227
228    use crate::arrays::ConstantArray;
229    use crate::canonical::ToCanonical;
230    use crate::stats::{Stat, StatsProviderExt, StatsSet};
231    use crate::{Array, IntoArray};
232
233    #[test]
234    fn test_canonicalize_null() {
235        let const_null = ConstantArray::new(Scalar::null(DType::Null), 42);
236        let actual = const_null.to_null().unwrap();
237        assert_eq!(actual.len(), 42);
238        assert_eq!(actual.scalar_at(33).unwrap(), Scalar::null(DType::Null));
239    }
240
241    #[test]
242    fn test_canonicalize_const_str() {
243        let const_array = ConstantArray::new("four".to_string(), 4);
244
245        // Check all values correct.
246        let canonical = const_array.to_varbinview().unwrap();
247
248        assert_eq!(canonical.len(), 4);
249
250        for i in 0..=3 {
251            assert_eq!(canonical.scalar_at(i).unwrap(), "four".into());
252        }
253    }
254
255    #[test]
256    fn test_canonicalize_propagates_stats() {
257        let scalar = Scalar::bool(true, Nullability::NonNullable);
258        let const_array = ConstantArray::new(scalar.clone(), 4).into_array();
259        let stats = const_array.statistics().to_owned();
260
261        let canonical = const_array.to_canonical().unwrap();
262        let canonical_stats = canonical.as_ref().statistics().to_owned();
263
264        let reference = StatsSet::constant(scalar, 4);
265        for stat in all::<Stat>() {
266            if stat.dtype(canonical.as_ref().dtype()).is_none() {
267                continue;
268            }
269
270            let canonical_stat =
271                canonical_stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
272            let reference_stat =
273                reference.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
274            let original_stat =
275                stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
276            assert_eq!(canonical_stat, reference_stat);
277            assert_eq!(canonical_stat, original_stat);
278        }
279    }
280
281    #[test]
282    fn test_canonicalize_scalar_values() {
283        let f16_scalar = Scalar::primitive(f16::from_f32(5.722046e-6), Nullability::NonNullable);
284        let scalar = Scalar::new(
285            DType::Primitive(PType::F16, Nullability::NonNullable),
286            Scalar::primitive(96u8, Nullability::NonNullable).into_value(),
287        );
288        let const_array = ConstantArray::new(scalar.clone(), 1).into_array();
289        let canonical_const = const_array.to_primitive().unwrap();
290        assert_eq!(canonical_const.scalar_at(0).unwrap(), scalar);
291        assert_eq!(canonical_const.scalar_at(0).unwrap(), f16_scalar);
292    }
293
294    #[test]
295    fn test_canonicalize_lists() {
296        let list_scalar = Scalar::list(
297            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
298            vec![1u64.into(), 2u64.into()],
299            Nullability::NonNullable,
300        );
301        let const_array = ConstantArray::new(list_scalar, 2).into_array();
302        let canonical_const = const_array.to_list().unwrap();
303        assert_eq!(
304            canonical_const
305                .elements()
306                .to_primitive()
307                .unwrap()
308                .as_slice::<u64>(),
309            [1u64, 2, 1, 2]
310        );
311        assert_eq!(
312            canonical_const
313                .offsets()
314                .to_primitive()
315                .unwrap()
316                .as_slice::<u64>(),
317            [0u64, 2, 4]
318        );
319    }
320
321    #[test]
322    fn test_canonicalize_empty_list() {
323        let list_scalar = Scalar::list(
324            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
325            vec![],
326            Nullability::NonNullable,
327        );
328        let const_array = ConstantArray::new(list_scalar, 2).into_array();
329        let canonical_const = const_array.to_list().unwrap();
330        assert!(
331            canonical_const
332                .elements()
333                .to_primitive()
334                .unwrap()
335                .is_empty()
336        );
337        assert_eq!(
338            canonical_const
339                .offsets()
340                .to_primitive()
341                .unwrap()
342                .as_slice::<u64>(),
343            [0u64, 0, 0]
344        );
345    }
346
347    #[test]
348    fn test_canonicalize_null_list() {
349        let list_scalar = Scalar::null(DType::List(
350            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
351            Nullability::Nullable,
352        ));
353        let const_array = ConstantArray::new(list_scalar, 2).into_array();
354        let canonical_const = const_array.to_list().unwrap();
355        assert!(
356            canonical_const
357                .elements()
358                .to_primitive()
359                .unwrap()
360                .is_empty()
361        );
362        assert_eq!(
363            canonical_const
364                .offsets()
365                .to_primitive()
366                .unwrap()
367                .as_slice::<u64>(),
368            [0u64, 0, 0]
369        );
370    }
371
372    #[test]
373    fn test_canonicalize_nullable_struct() {
374        let array = ConstantArray::new(
375            Scalar::null(DType::struct_(
376                [(
377                    "non_null_field",
378                    DType::Primitive(PType::I8, Nullability::NonNullable),
379                )],
380                Nullability::Nullable,
381            )),
382            3,
383        );
384
385        let struct_array = array.to_struct().unwrap();
386        assert_eq!(struct_array.len(), 3);
387        assert_eq!(struct_array.valid_count().unwrap(), 0);
388
389        let field = struct_array.field_by_name("non_null_field").unwrap();
390
391        assert_eq!(
392            field.dtype(),
393            &DType::Primitive(PType::I8, Nullability::NonNullable)
394        );
395    }
396}