vortex_array/arrays/constant/
canonical.rs

1use arrow_array::builder::make_view;
2use arrow_buffer::BooleanBuffer;
3use vortex_buffer::{Buffer, BufferMut, buffer};
4use vortex_dtype::{DType, Nullability, PType, match_each_native_ptype};
5use vortex_error::{VortexExpect, VortexResult};
6use vortex_scalar::{
7    BinaryScalar, BoolScalar, ExtScalar, ListScalar, Scalar, ScalarValue, StructScalar, Utf8Scalar,
8};
9
10use crate::array::ArrayCanonicalImpl;
11use crate::arrays::constant::ConstantArray;
12use crate::arrays::primitive::PrimitiveArray;
13use crate::arrays::{
14    BinaryView, BoolArray, ExtensionArray, ListArray, NullArray, StructArray, VarBinViewArray,
15};
16use crate::builders::{ArrayBuilderExt, builder_with_capacity};
17use crate::validity::Validity;
18use crate::{Array, Canonical, IntoArray};
19
20impl ArrayCanonicalImpl for ConstantArray {
21    fn _to_canonical(&self) -> VortexResult<Canonical> {
22        let scalar = self.scalar();
23
24        let validity = match self.dtype().nullability() {
25            Nullability::NonNullable => Validity::NonNullable,
26            Nullability::Nullable => match scalar.is_null() {
27                true => Validity::AllInvalid,
28                false => Validity::AllValid,
29            },
30        };
31
32        Ok(match self.dtype() {
33            DType::Null => Canonical::Null(NullArray::new(self.len())),
34            DType::Bool(..) => Canonical::Bool(BoolArray::new(
35                if BoolScalar::try_from(scalar)?.value().unwrap_or_default() {
36                    BooleanBuffer::new_set(self.len())
37                } else {
38                    BooleanBuffer::new_unset(self.len())
39                },
40                validity,
41            )),
42            DType::Primitive(ptype, ..) => {
43                match_each_native_ptype!(ptype, |$P| {
44                    Canonical::Primitive(PrimitiveArray::new(
45                        if scalar.is_valid() {
46                            Buffer::full(
47                                $P::try_from(scalar)
48                                    .vortex_expect("Couldn't unwrap scalar to primitive"),
49                                self.len(),
50                            )
51                        } else {
52                            Buffer::zeroed(self.len())
53                        },
54                        validity,
55                    ))
56                })
57            }
58            DType::Utf8(_) => {
59                let value = Utf8Scalar::try_from(scalar)?.value();
60                let const_value = value.as_ref().map(|v| v.as_bytes());
61                Canonical::VarBinView(canonical_byte_view(const_value, self.dtype(), self.len())?)
62            }
63            DType::Binary(_) => {
64                let value = BinaryScalar::try_from(scalar)?.value();
65                let const_value = value.as_ref().map(|v| v.as_slice());
66                Canonical::VarBinView(canonical_byte_view(const_value, self.dtype(), self.len())?)
67            }
68            DType::Struct(..) => {
69                let value = StructScalar::try_from(scalar)?;
70                let fields = value.fields().map(|fields| {
71                    fields
72                        .into_iter()
73                        .map(|s| ConstantArray::new(s, self.len()).into_array())
74                        .collect::<Vec<_>>()
75                });
76                Canonical::Struct(StructArray::try_new(
77                    value.struct_dtype().names().clone(),
78                    fields.unwrap_or_default(),
79                    self.len(),
80                    validity,
81                )?)
82            }
83            DType::List(..) => {
84                let value = ListScalar::try_from(scalar)?;
85                Canonical::List(canonical_list_array(
86                    value.elements(),
87                    value.element_dtype(),
88                    value.dtype().nullability(),
89                    self.len(),
90                )?)
91            }
92            DType::Extension(ext_dtype) => {
93                let s = ExtScalar::try_from(scalar)?;
94
95                let storage_scalar = s.storage();
96                let storage_self = ConstantArray::new(storage_scalar, self.len()).into_array();
97                Canonical::Extension(ExtensionArray::new(ext_dtype.clone(), storage_self))
98            }
99        })
100    }
101}
102
103fn canonical_byte_view(
104    scalar_bytes: Option<&[u8]>,
105    dtype: &DType,
106    len: usize,
107) -> VortexResult<VarBinViewArray> {
108    match scalar_bytes {
109        None => {
110            let views = buffer![BinaryView::from(0_u128); len];
111
112            VarBinViewArray::try_new(views, Vec::new(), dtype.clone(), Validity::AllInvalid)
113        }
114        Some(scalar_bytes) => {
115            // Create a view to hold the scalar bytes.
116            // If the scalar cannot be inlined, allocate a single buffer large enough to hold it.
117            let view = BinaryView::from(make_view(scalar_bytes, 0, 0));
118            let mut buffers = Vec::new();
119            if scalar_bytes.len() >= BinaryView::MAX_INLINED_SIZE {
120                buffers.push(Buffer::copy_from(scalar_bytes));
121            }
122
123            // Clone our constant view `len` times.
124            // TODO(aduffy): switch this out for a ConstantArray once we
125            //   add u128 PType, see https://github.com/spiraldb/vortex/issues/1110
126            let mut views = BufferMut::with_capacity_aligned(len, align_of::<u128>().into());
127            for _ in 0..len {
128                views.push(view);
129            }
130
131            VarBinViewArray::try_new(
132                views.freeze(),
133                buffers,
134                dtype.clone(),
135                Validity::from(dtype.nullability()),
136            )
137        }
138    }
139}
140
141fn canonical_list_array(
142    values: Option<Vec<Scalar>>,
143    element_dtype: &DType,
144    list_nullability: Nullability,
145    len: usize,
146) -> VortexResult<ListArray> {
147    match values {
148        None => ListArray::try_new(
149            ConstantArray::new(Scalar::null(element_dtype.clone()), 1).into_array(),
150            ConstantArray::new(
151                Scalar::new(
152                    DType::Primitive(PType::U64, Nullability::NonNullable),
153                    ScalarValue::from(0),
154                ),
155                len,
156            )
157            .into_array(),
158            Validity::AllInvalid,
159        ),
160        Some(vs) => {
161            let mut elements_builder = builder_with_capacity(element_dtype, len * vs.len());
162            for _ in 0..len {
163                for v in &vs {
164                    elements_builder.append_scalar(v)?;
165                }
166            }
167            let offsets = (0..=len * vs.len())
168                .step_by(vs.len())
169                .map(|i| i as u64)
170                .collect::<Buffer<_>>();
171
172            ListArray::try_new(
173                elements_builder.finish(),
174                offsets.into_array(),
175                Validity::from(list_nullability),
176            )
177        }
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use std::sync::Arc;
184
185    use enum_iterator::all;
186    use vortex_dtype::half::f16;
187    use vortex_dtype::{DType, Nullability, PType};
188    use vortex_scalar::Scalar;
189
190    use crate::array::Array;
191    use crate::arrays::ConstantArray;
192    use crate::canonical::ToCanonical;
193    use crate::compute::scalar_at;
194    use crate::stats::{Stat, StatsProviderExt, StatsSet};
195
196    #[test]
197    fn test_canonicalize_null() {
198        let const_null = ConstantArray::new(Scalar::null(DType::Null), 42);
199        let actual = const_null.to_null().unwrap();
200        assert_eq!(actual.len(), 42);
201        assert_eq!(scalar_at(&actual, 33).unwrap(), Scalar::null(DType::Null));
202    }
203
204    #[test]
205    fn test_canonicalize_const_str() {
206        let const_array = ConstantArray::new("four".to_string(), 4);
207
208        // Check all values correct.
209        let canonical = const_array.to_varbinview().unwrap();
210
211        assert_eq!(canonical.len(), 4);
212
213        for i in 0..=3 {
214            assert_eq!(scalar_at(&canonical, i).unwrap(), "four".into());
215        }
216    }
217
218    #[test]
219    fn test_canonicalize_propagates_stats() {
220        let scalar = Scalar::bool(true, Nullability::NonNullable);
221        let const_array = ConstantArray::new(scalar.clone(), 4).into_array();
222        let stats = const_array.statistics().to_owned();
223
224        let canonical = const_array.to_canonical().unwrap();
225        let canonical_stats = canonical.as_ref().statistics().to_owned();
226
227        let reference = StatsSet::constant(scalar, 4);
228        for stat in all::<Stat>() {
229            let canonical_stat =
230                canonical_stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
231            let reference_stat =
232                reference.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
233            let original_stat =
234                stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
235            assert_eq!(canonical_stat, reference_stat);
236            assert_eq!(canonical_stat, original_stat);
237        }
238    }
239
240    #[test]
241    fn test_canonicalize_scalar_values() {
242        let f16_scalar = Scalar::primitive(f16::from_f32(5.722046e-6), Nullability::NonNullable);
243        let scalar = Scalar::new(
244            DType::Primitive(PType::F16, Nullability::NonNullable),
245            Scalar::primitive(96u8, Nullability::NonNullable).into_value(),
246        );
247        let const_array = ConstantArray::new(scalar.clone(), 1).into_array();
248        let canonical_const = const_array.to_primitive().unwrap();
249        assert_eq!(scalar_at(&canonical_const, 0).unwrap(), scalar);
250        assert_eq!(scalar_at(&canonical_const, 0).unwrap(), f16_scalar);
251    }
252
253    #[test]
254    fn test_canonicalize_lists() {
255        let list_scalar = Scalar::list(
256            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
257            vec![1u64.into(), 2u64.into()],
258            Nullability::NonNullable,
259        );
260        let const_array = ConstantArray::new(list_scalar, 2).into_array();
261        let canonical_const = const_array.to_list().unwrap();
262        assert_eq!(
263            canonical_const
264                .elements()
265                .to_primitive()
266                .unwrap()
267                .as_slice::<u64>(),
268            [1u64, 2, 1, 2]
269        );
270        assert_eq!(
271            canonical_const
272                .offsets()
273                .to_primitive()
274                .unwrap()
275                .as_slice::<u64>(),
276            [0u64, 2, 4]
277        );
278    }
279}