vortex_array/arrays/constant/
canonical.rs

1use arrow_buffer::BooleanBuffer;
2use vortex_buffer::{Buffer, BufferMut, buffer};
3use vortex_dtype::{DType, Nullability, PType, match_each_native_ptype};
4use vortex_error::{VortexExpect, VortexResult};
5use vortex_scalar::{
6    BinaryScalar, BoolScalar, DecimalValue, ExtScalar, ListScalar, Scalar, ScalarValue,
7    StructScalar, Utf8Scalar, match_each_decimal_value, match_each_decimal_value_type,
8};
9
10use crate::arrays::constant::ConstantArray;
11use crate::arrays::primitive::PrimitiveArray;
12use crate::arrays::{
13    BinaryView, BoolArray, ConstantVTable, DecimalArray, ExtensionArray, ListArray, NullArray,
14    StructArray, VarBinViewArray, precision_to_storage_size,
15};
16use crate::builders::{ArrayBuilderExt, builder_with_capacity};
17use crate::validity::Validity;
18use crate::vtable::CanonicalVTable;
19use crate::{Canonical, IntoArray};
20
21impl CanonicalVTable<ConstantVTable> for ConstantVTable {
22    fn canonicalize(array: &ConstantArray) -> VortexResult<Canonical> {
23        let scalar = array.scalar();
24
25        let validity = match array.dtype().nullability() {
26            Nullability::NonNullable => Validity::NonNullable,
27            Nullability::Nullable => match scalar.is_null() {
28                true => Validity::AllInvalid,
29                false => Validity::AllValid,
30            },
31        };
32
33        Ok(match array.dtype() {
34            DType::Null => Canonical::Null(NullArray::new(array.len())),
35            DType::Bool(..) => Canonical::Bool(BoolArray::new(
36                if BoolScalar::try_from(scalar)?.value().unwrap_or_default() {
37                    BooleanBuffer::new_set(array.len())
38                } else {
39                    BooleanBuffer::new_unset(array.len())
40                },
41                validity,
42            )),
43            DType::Primitive(ptype, ..) => {
44                match_each_native_ptype!(ptype, |$P| {
45                    Canonical::Primitive(PrimitiveArray::new(
46                        if scalar.is_valid() {
47                            Buffer::full(
48                                $P::try_from(scalar)
49                                    .vortex_expect("Couldn't unwrap scalar to primitive"),
50                                array.len(),
51                            )
52                        } else {
53                            Buffer::zeroed(array.len())
54                        },
55                        validity,
56                    ))
57                })
58            }
59            DType::Decimal(decimal_type, ..) => {
60                let size = precision_to_storage_size(decimal_type);
61                let decimal = scalar.as_decimal();
62                let Some(value) = decimal.decimal_value() else {
63                    let all_null = match_each_decimal_value_type!(size, |$D| {
64                       DecimalArray::new(
65                                Buffer::<$D>::zeroed(array.len()),
66                                *decimal_type,
67                                validity,
68                            )
69                    });
70                    return Ok(Canonical::Decimal(all_null));
71                };
72
73                let decimal_array = match_each_decimal_value!(value, |$V| {
74                   DecimalArray::new(
75                        Buffer::full(*$V, array.len()),
76                        *decimal_type,
77                        validity,
78                    )
79                });
80                Canonical::Decimal(decimal_array)
81            }
82            DType::Utf8(_) => {
83                let value = Utf8Scalar::try_from(scalar)?.value();
84                let const_value = value.as_ref().map(|v| v.as_bytes());
85                Canonical::VarBinView(canonical_byte_view(
86                    const_value,
87                    array.dtype(),
88                    array.len(),
89                )?)
90            }
91            DType::Binary(_) => {
92                let value = BinaryScalar::try_from(scalar)?.value();
93                let const_value = value.as_ref().map(|v| v.as_slice());
94                Canonical::VarBinView(canonical_byte_view(
95                    const_value,
96                    array.dtype(),
97                    array.len(),
98                )?)
99            }
100            DType::Struct(struct_dtype, _) => {
101                let value = StructScalar::try_from(scalar)?;
102                let fields = value.fields().map(|fields| {
103                    fields
104                        .into_iter()
105                        .map(|s| ConstantArray::new(s, array.len()).into_array())
106                        .collect::<Vec<_>>()
107                });
108                Canonical::Struct(StructArray::try_new_with_dtype(
109                    fields.unwrap_or_default(),
110                    struct_dtype.clone(),
111                    array.len(),
112                    validity,
113                )?)
114            }
115            DType::List(..) => {
116                let value = ListScalar::try_from(scalar)?;
117                Canonical::List(canonical_list_array(
118                    value.elements(),
119                    value.element_dtype(),
120                    value.dtype().nullability(),
121                    array.len(),
122                )?)
123            }
124            DType::Extension(ext_dtype) => {
125                let s = ExtScalar::try_from(scalar)?;
126
127                let storage_scalar = s.storage();
128                let storage_self = ConstantArray::new(storage_scalar, array.len()).into_array();
129                Canonical::Extension(ExtensionArray::new(ext_dtype.clone(), storage_self))
130            }
131        })
132    }
133}
134
135fn canonical_byte_view(
136    scalar_bytes: Option<&[u8]>,
137    dtype: &DType,
138    len: usize,
139) -> VortexResult<VarBinViewArray> {
140    match scalar_bytes {
141        None => {
142            let views = buffer![BinaryView::from(0_u128); len];
143
144            VarBinViewArray::try_new(views, Vec::new(), dtype.clone(), Validity::AllInvalid)
145        }
146        Some(scalar_bytes) => {
147            // Create a view to hold the scalar bytes.
148            // If the scalar cannot be inlined, allocate a single buffer large enough to hold it.
149            let view = BinaryView::make_view(scalar_bytes, 0, 0);
150            let mut buffers = Vec::new();
151            if scalar_bytes.len() >= BinaryView::MAX_INLINED_SIZE {
152                buffers.push(Buffer::copy_from(scalar_bytes));
153            }
154
155            // Clone our constant view `len` times.
156            // TODO(aduffy): switch this out for a ConstantArray once we
157            //   add u128 PType, see https://github.com/vortex-data/vortex/issues/1110
158            let mut views = BufferMut::with_capacity_aligned(len, align_of::<u128>().into());
159            for _ in 0..len {
160                views.push(view);
161            }
162
163            VarBinViewArray::try_new(
164                views.freeze(),
165                buffers,
166                dtype.clone(),
167                Validity::from(dtype.nullability()),
168            )
169        }
170    }
171}
172
173fn canonical_list_array(
174    values: Option<Vec<Scalar>>,
175    element_dtype: &DType,
176    list_nullability: Nullability,
177    len: usize,
178) -> VortexResult<ListArray> {
179    match values {
180        None => ListArray::try_new(
181            Canonical::empty(element_dtype).into_array(),
182            ConstantArray::new(
183                Scalar::new(
184                    DType::Primitive(PType::U64, Nullability::NonNullable),
185                    ScalarValue::from(0),
186                ),
187                len + 1,
188            )
189            .into_array(),
190            Validity::AllInvalid,
191        ),
192        Some(vs) => {
193            let mut elements_builder = builder_with_capacity(element_dtype, len * vs.len());
194            for _ in 0..len {
195                for v in &vs {
196                    elements_builder.append_scalar(v)?;
197                }
198            }
199            let offsets = if vs.is_empty() {
200                Buffer::zeroed(len + 1)
201            } else {
202                (0..=len * vs.len())
203                    .step_by(vs.len())
204                    .map(|i| i as u64)
205                    .collect::<Buffer<_>>()
206            };
207
208            ListArray::try_new(
209                elements_builder.finish(),
210                offsets.into_array(),
211                Validity::from(list_nullability),
212            )
213        }
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use std::sync::Arc;
220
221    use enum_iterator::all;
222    use vortex_dtype::half::f16;
223    use vortex_dtype::{DType, Nullability, PType};
224    use vortex_scalar::Scalar;
225
226    use crate::arrays::ConstantArray;
227    use crate::canonical::ToCanonical;
228    use crate::stats::{Stat, StatsProviderExt, StatsSet};
229    use crate::{Array, IntoArray};
230
231    #[test]
232    fn test_canonicalize_null() {
233        let const_null = ConstantArray::new(Scalar::null(DType::Null), 42);
234        let actual = const_null.to_null().unwrap();
235        assert_eq!(actual.len(), 42);
236        assert_eq!(actual.scalar_at(33).unwrap(), Scalar::null(DType::Null));
237    }
238
239    #[test]
240    fn test_canonicalize_const_str() {
241        let const_array = ConstantArray::new("four".to_string(), 4);
242
243        // Check all values correct.
244        let canonical = const_array.to_varbinview().unwrap();
245
246        assert_eq!(canonical.len(), 4);
247
248        for i in 0..=3 {
249            assert_eq!(canonical.scalar_at(i).unwrap(), "four".into());
250        }
251    }
252
253    #[test]
254    fn test_canonicalize_propagates_stats() {
255        let scalar = Scalar::bool(true, Nullability::NonNullable);
256        let const_array = ConstantArray::new(scalar.clone(), 4).into_array();
257        let stats = const_array.statistics().to_owned();
258
259        let canonical = const_array.to_canonical().unwrap();
260        let canonical_stats = canonical.as_ref().statistics().to_owned();
261
262        let reference = StatsSet::constant(scalar, 4);
263        for stat in all::<Stat>() {
264            if stat.dtype(canonical.as_ref().dtype()).is_none() {
265                continue;
266            }
267
268            let canonical_stat =
269                canonical_stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
270            let reference_stat =
271                reference.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
272            let original_stat =
273                stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
274            assert_eq!(canonical_stat, reference_stat);
275            assert_eq!(canonical_stat, original_stat);
276        }
277    }
278
279    #[test]
280    fn test_canonicalize_scalar_values() {
281        let f16_scalar = Scalar::primitive(f16::from_f32(5.722046e-6), Nullability::NonNullable);
282        let scalar = Scalar::new(
283            DType::Primitive(PType::F16, Nullability::NonNullable),
284            Scalar::primitive(96u8, Nullability::NonNullable).into_value(),
285        );
286        let const_array = ConstantArray::new(scalar.clone(), 1).into_array();
287        let canonical_const = const_array.to_primitive().unwrap();
288        assert_eq!(canonical_const.scalar_at(0).unwrap(), scalar);
289        assert_eq!(canonical_const.scalar_at(0).unwrap(), f16_scalar);
290    }
291
292    #[test]
293    fn test_canonicalize_lists() {
294        let list_scalar = Scalar::list(
295            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
296            vec![1u64.into(), 2u64.into()],
297            Nullability::NonNullable,
298        );
299        let const_array = ConstantArray::new(list_scalar, 2).into_array();
300        let canonical_const = const_array.to_list().unwrap();
301        assert_eq!(
302            canonical_const
303                .elements()
304                .to_primitive()
305                .unwrap()
306                .as_slice::<u64>(),
307            [1u64, 2, 1, 2]
308        );
309        assert_eq!(
310            canonical_const
311                .offsets()
312                .to_primitive()
313                .unwrap()
314                .as_slice::<u64>(),
315            [0u64, 2, 4]
316        );
317    }
318
319    #[test]
320    fn test_canonicalize_empty_list() {
321        let list_scalar = Scalar::list(
322            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
323            vec![],
324            Nullability::NonNullable,
325        );
326        let const_array = ConstantArray::new(list_scalar, 2).into_array();
327        let canonical_const = const_array.to_list().unwrap();
328        assert!(
329            canonical_const
330                .elements()
331                .to_primitive()
332                .unwrap()
333                .is_empty()
334        );
335        assert_eq!(
336            canonical_const
337                .offsets()
338                .to_primitive()
339                .unwrap()
340                .as_slice::<u64>(),
341            [0u64, 0, 0]
342        );
343    }
344
345    #[test]
346    fn test_canonicalize_null_list() {
347        let list_scalar = Scalar::null(DType::List(
348            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
349            Nullability::Nullable,
350        ));
351        let const_array = ConstantArray::new(list_scalar, 2).into_array();
352        let canonical_const = const_array.to_list().unwrap();
353        assert!(
354            canonical_const
355                .elements()
356                .to_primitive()
357                .unwrap()
358                .is_empty()
359        );
360        assert_eq!(
361            canonical_const
362                .offsets()
363                .to_primitive()
364                .unwrap()
365                .as_slice::<u64>(),
366            [0u64, 0, 0]
367        );
368    }
369}