vortex_array/arrays/constant/
canonical.rs

1use arrow_buffer::BooleanBuffer;
2use vortex_buffer::{Buffer, BufferMut, buffer};
3use vortex_dtype::{DType, Nullability, PType, match_each_native_ptype};
4use vortex_error::{VortexExpect, VortexResult};
5use vortex_scalar::{
6    BinaryScalar, BoolScalar, ExtScalar, ListScalar, Scalar, ScalarValue, StructScalar, Utf8Scalar,
7};
8
9use crate::array::ArrayCanonicalImpl;
10use crate::arrays::constant::ConstantArray;
11use crate::arrays::primitive::PrimitiveArray;
12use crate::arrays::{
13    BinaryView, BoolArray, ExtensionArray, ListArray, NullArray, StructArray, VarBinViewArray,
14};
15use crate::builders::{ArrayBuilderExt, builder_with_capacity};
16use crate::validity::Validity;
17use crate::{Array, Canonical, IntoArray};
18
19impl ArrayCanonicalImpl for ConstantArray {
20    fn _to_canonical(&self) -> VortexResult<Canonical> {
21        let scalar = self.scalar();
22
23        let validity = match self.dtype().nullability() {
24            Nullability::NonNullable => Validity::NonNullable,
25            Nullability::Nullable => match scalar.is_null() {
26                true => Validity::AllInvalid,
27                false => Validity::AllValid,
28            },
29        };
30
31        Ok(match self.dtype() {
32            DType::Null => Canonical::Null(NullArray::new(self.len())),
33            DType::Bool(..) => Canonical::Bool(BoolArray::new(
34                if BoolScalar::try_from(scalar)?.value().unwrap_or_default() {
35                    BooleanBuffer::new_set(self.len())
36                } else {
37                    BooleanBuffer::new_unset(self.len())
38                },
39                validity,
40            )),
41            DType::Primitive(ptype, ..) => {
42                match_each_native_ptype!(ptype, |$P| {
43                    Canonical::Primitive(PrimitiveArray::new(
44                        if scalar.is_valid() {
45                            Buffer::full(
46                                $P::try_from(scalar)
47                                    .vortex_expect("Couldn't unwrap scalar to primitive"),
48                                self.len(),
49                            )
50                        } else {
51                            Buffer::zeroed(self.len())
52                        },
53                        validity,
54                    ))
55                })
56            }
57            DType::Utf8(_) => {
58                let value = Utf8Scalar::try_from(scalar)?.value();
59                let const_value = value.as_ref().map(|v| v.as_bytes());
60                Canonical::VarBinView(canonical_byte_view(const_value, self.dtype(), self.len())?)
61            }
62            DType::Binary(_) => {
63                let value = BinaryScalar::try_from(scalar)?.value();
64                let const_value = value.as_ref().map(|v| v.as_slice());
65                Canonical::VarBinView(canonical_byte_view(const_value, self.dtype(), self.len())?)
66            }
67            DType::Struct(struct_dtype, _) => {
68                let value = StructScalar::try_from(scalar)?;
69                let fields = value.fields().map(|fields| {
70                    fields
71                        .into_iter()
72                        .map(|s| ConstantArray::new(s, self.len()).into_array())
73                        .collect::<Vec<_>>()
74                });
75                Canonical::Struct(StructArray::try_new_with_dtype(
76                    fields.unwrap_or_default(),
77                    struct_dtype.clone(),
78                    self.len(),
79                    validity,
80                )?)
81            }
82            DType::List(..) => {
83                let value = ListScalar::try_from(scalar)?;
84                Canonical::List(canonical_list_array(
85                    value.elements(),
86                    value.element_dtype(),
87                    value.dtype().nullability(),
88                    self.len(),
89                )?)
90            }
91            DType::Extension(ext_dtype) => {
92                let s = ExtScalar::try_from(scalar)?;
93
94                let storage_scalar = s.storage();
95                let storage_self = ConstantArray::new(storage_scalar, self.len()).into_array();
96                Canonical::Extension(ExtensionArray::new(ext_dtype.clone(), storage_self))
97            }
98        })
99    }
100}
101
102fn canonical_byte_view(
103    scalar_bytes: Option<&[u8]>,
104    dtype: &DType,
105    len: usize,
106) -> VortexResult<VarBinViewArray> {
107    match scalar_bytes {
108        None => {
109            let views = buffer![BinaryView::from(0_u128); len];
110
111            VarBinViewArray::try_new(views, Vec::new(), dtype.clone(), Validity::AllInvalid)
112        }
113        Some(scalar_bytes) => {
114            // Create a view to hold the scalar bytes.
115            // If the scalar cannot be inlined, allocate a single buffer large enough to hold it.
116            let view = BinaryView::make_view(scalar_bytes, 0, 0);
117            let mut buffers = Vec::new();
118            if scalar_bytes.len() >= BinaryView::MAX_INLINED_SIZE {
119                buffers.push(Buffer::copy_from(scalar_bytes));
120            }
121
122            // Clone our constant view `len` times.
123            // TODO(aduffy): switch this out for a ConstantArray once we
124            //   add u128 PType, see https://github.com/spiraldb/vortex/issues/1110
125            let mut views = BufferMut::with_capacity_aligned(len, align_of::<u128>().into());
126            for _ in 0..len {
127                views.push(view);
128            }
129
130            VarBinViewArray::try_new(
131                views.freeze(),
132                buffers,
133                dtype.clone(),
134                Validity::from(dtype.nullability()),
135            )
136        }
137    }
138}
139
140fn canonical_list_array(
141    values: Option<Vec<Scalar>>,
142    element_dtype: &DType,
143    list_nullability: Nullability,
144    len: usize,
145) -> VortexResult<ListArray> {
146    match values {
147        None => ListArray::try_new(
148            Canonical::empty(element_dtype).into_array(),
149            ConstantArray::new(
150                Scalar::new(
151                    DType::Primitive(PType::U64, Nullability::NonNullable),
152                    ScalarValue::from(0),
153                ),
154                len + 1,
155            )
156            .into_array(),
157            Validity::AllInvalid,
158        ),
159        Some(vs) => {
160            let mut elements_builder = builder_with_capacity(element_dtype, len * vs.len());
161            for _ in 0..len {
162                for v in &vs {
163                    elements_builder.append_scalar(v)?;
164                }
165            }
166            let offsets = if vs.is_empty() {
167                Buffer::zeroed(len + 1)
168            } else {
169                (0..=len * vs.len())
170                    .step_by(vs.len())
171                    .map(|i| i as u64)
172                    .collect::<Buffer<_>>()
173            };
174
175            ListArray::try_new(
176                elements_builder.finish(),
177                offsets.into_array(),
178                Validity::from(list_nullability),
179            )
180        }
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use std::sync::Arc;
187
188    use enum_iterator::all;
189    use vortex_dtype::half::f16;
190    use vortex_dtype::{DType, Nullability, PType};
191    use vortex_scalar::Scalar;
192
193    use crate::array::Array;
194    use crate::arrays::ConstantArray;
195    use crate::canonical::ToCanonical;
196    use crate::compute::scalar_at;
197    use crate::stats::{Stat, StatsProviderExt, StatsSet};
198
199    #[test]
200    fn test_canonicalize_null() {
201        let const_null = ConstantArray::new(Scalar::null(DType::Null), 42);
202        let actual = const_null.to_null().unwrap();
203        assert_eq!(actual.len(), 42);
204        assert_eq!(scalar_at(&actual, 33).unwrap(), Scalar::null(DType::Null));
205    }
206
207    #[test]
208    fn test_canonicalize_const_str() {
209        let const_array = ConstantArray::new("four".to_string(), 4);
210
211        // Check all values correct.
212        let canonical = const_array.to_varbinview().unwrap();
213
214        assert_eq!(canonical.len(), 4);
215
216        for i in 0..=3 {
217            assert_eq!(scalar_at(&canonical, i).unwrap(), "four".into());
218        }
219    }
220
221    #[test]
222    fn test_canonicalize_propagates_stats() {
223        let scalar = Scalar::bool(true, Nullability::NonNullable);
224        let const_array = ConstantArray::new(scalar.clone(), 4).into_array();
225        let stats = const_array.statistics().to_owned();
226
227        let canonical = const_array.to_canonical().unwrap();
228        let canonical_stats = canonical.as_ref().statistics().to_owned();
229
230        let reference = StatsSet::constant(scalar, 4);
231        for stat in all::<Stat>() {
232            let canonical_stat =
233                canonical_stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
234            let reference_stat =
235                reference.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
236            let original_stat =
237                stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
238            assert_eq!(canonical_stat, reference_stat);
239            assert_eq!(canonical_stat, original_stat);
240        }
241    }
242
243    #[test]
244    fn test_canonicalize_scalar_values() {
245        let f16_scalar = Scalar::primitive(f16::from_f32(5.722046e-6), Nullability::NonNullable);
246        let scalar = Scalar::new(
247            DType::Primitive(PType::F16, Nullability::NonNullable),
248            Scalar::primitive(96u8, Nullability::NonNullable).into_value(),
249        );
250        let const_array = ConstantArray::new(scalar.clone(), 1).into_array();
251        let canonical_const = const_array.to_primitive().unwrap();
252        assert_eq!(scalar_at(&canonical_const, 0).unwrap(), scalar);
253        assert_eq!(scalar_at(&canonical_const, 0).unwrap(), f16_scalar);
254    }
255
256    #[test]
257    fn test_canonicalize_lists() {
258        let list_scalar = Scalar::list(
259            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
260            vec![1u64.into(), 2u64.into()],
261            Nullability::NonNullable,
262        );
263        let const_array = ConstantArray::new(list_scalar, 2).into_array();
264        let canonical_const = const_array.to_list().unwrap();
265        assert_eq!(
266            canonical_const
267                .elements()
268                .to_primitive()
269                .unwrap()
270                .as_slice::<u64>(),
271            [1u64, 2, 1, 2]
272        );
273        assert_eq!(
274            canonical_const
275                .offsets()
276                .to_primitive()
277                .unwrap()
278                .as_slice::<u64>(),
279            [0u64, 2, 4]
280        );
281    }
282
283    #[test]
284    fn test_canonicalize_empty_list() {
285        let list_scalar = Scalar::list(
286            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
287            vec![],
288            Nullability::NonNullable,
289        );
290        let const_array = ConstantArray::new(list_scalar, 2).into_array();
291        let canonical_const = const_array.to_list().unwrap();
292        assert!(
293            canonical_const
294                .elements()
295                .to_primitive()
296                .unwrap()
297                .is_empty()
298        );
299        assert_eq!(
300            canonical_const
301                .offsets()
302                .to_primitive()
303                .unwrap()
304                .as_slice::<u64>(),
305            [0u64, 0, 0]
306        );
307    }
308
309    #[test]
310    fn test_canonicalize_null_list() {
311        let list_scalar = Scalar::null(DType::List(
312            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
313            Nullability::Nullable,
314        ));
315        let const_array = ConstantArray::new(list_scalar, 2).into_array();
316        let canonical_const = const_array.to_list().unwrap();
317        assert!(
318            canonical_const
319                .elements()
320                .to_primitive()
321                .unwrap()
322                .is_empty()
323        );
324        assert_eq!(
325            canonical_const
326                .offsets()
327                .to_primitive()
328                .unwrap()
329                .as_slice::<u64>(),
330            [0u64, 0, 0]
331        );
332    }
333}