vortex_array/arrays/constant/
canonical.rs

1use arrow_buffer::BooleanBuffer;
2use vortex_buffer::{Buffer, BufferMut, buffer};
3use vortex_dtype::{DType, Nullability, PType, match_each_native_ptype};
4use vortex_error::{VortexExpect, VortexResult};
5use vortex_scalar::{
6    BinaryScalar, BoolScalar, DecimalValue, ExtScalar, ListScalar, Scalar, ScalarValue,
7    StructScalar, Utf8Scalar,
8};
9
10use crate::array::ArrayCanonicalImpl;
11use crate::arrays::constant::ConstantArray;
12use crate::arrays::primitive::PrimitiveArray;
13use crate::arrays::{
14    BinaryView, BoolArray, DecimalArray, ExtensionArray, ListArray, NullArray, StructArray,
15    VarBinViewArray, precision_to_storage_size,
16};
17use crate::builders::{ArrayBuilderExt, builder_with_capacity};
18use crate::validity::Validity;
19use crate::{Array, Canonical, IntoArray, match_each_decimal_value, match_each_decimal_value_type};
20
21impl ArrayCanonicalImpl for ConstantArray {
22    fn _to_canonical(&self) -> VortexResult<Canonical> {
23        let scalar = self.scalar();
24
25        let validity = match self.dtype().nullability() {
26            Nullability::NonNullable => Validity::NonNullable,
27            Nullability::Nullable => match scalar.is_null() {
28                true => Validity::AllInvalid,
29                false => Validity::AllValid,
30            },
31        };
32
33        Ok(match self.dtype() {
34            DType::Null => Canonical::Null(NullArray::new(self.len())),
35            DType::Bool(..) => Canonical::Bool(BoolArray::new(
36                if BoolScalar::try_from(scalar)?.value().unwrap_or_default() {
37                    BooleanBuffer::new_set(self.len())
38                } else {
39                    BooleanBuffer::new_unset(self.len())
40                },
41                validity,
42            )),
43            DType::Primitive(ptype, ..) => {
44                match_each_native_ptype!(ptype, |$P| {
45                    Canonical::Primitive(PrimitiveArray::new(
46                        if scalar.is_valid() {
47                            Buffer::full(
48                                $P::try_from(scalar)
49                                    .vortex_expect("Couldn't unwrap scalar to primitive"),
50                                self.len(),
51                            )
52                        } else {
53                            Buffer::zeroed(self.len())
54                        },
55                        validity,
56                    ))
57                })
58            }
59            DType::Decimal(decimal_type, ..) => {
60                let size = precision_to_storage_size(decimal_type);
61                let decimal = scalar.as_decimal();
62                let Some(value) = decimal.decimal_value() else {
63                    let all_null = match_each_decimal_value_type!(size, |$D| {
64                       DecimalArray::new(
65                                Buffer::<$D>::zeroed(self.len()),
66                                *decimal_type,
67                                Validity::AllInvalid,
68                            )
69                    });
70                    return Ok(Canonical::Decimal(all_null));
71                };
72
73                let decimal_array = match_each_decimal_value!(value, |$V| {
74                   DecimalArray::new(
75                        Buffer::full(*$V, self.len()),
76                        *decimal_type,
77                        Validity::AllValid,
78                    )
79                });
80                Canonical::Decimal(decimal_array)
81            }
82            DType::Utf8(_) => {
83                let value = Utf8Scalar::try_from(scalar)?.value();
84                let const_value = value.as_ref().map(|v| v.as_bytes());
85                Canonical::VarBinView(canonical_byte_view(const_value, self.dtype(), self.len())?)
86            }
87            DType::Binary(_) => {
88                let value = BinaryScalar::try_from(scalar)?.value();
89                let const_value = value.as_ref().map(|v| v.as_slice());
90                Canonical::VarBinView(canonical_byte_view(const_value, self.dtype(), self.len())?)
91            }
92            DType::Struct(struct_dtype, _) => {
93                let value = StructScalar::try_from(scalar)?;
94                let fields = value.fields().map(|fields| {
95                    fields
96                        .into_iter()
97                        .map(|s| ConstantArray::new(s, self.len()).into_array())
98                        .collect::<Vec<_>>()
99                });
100                Canonical::Struct(StructArray::try_new_with_dtype(
101                    fields.unwrap_or_default(),
102                    struct_dtype.clone(),
103                    self.len(),
104                    validity,
105                )?)
106            }
107            DType::List(..) => {
108                let value = ListScalar::try_from(scalar)?;
109                Canonical::List(canonical_list_array(
110                    value.elements(),
111                    value.element_dtype(),
112                    value.dtype().nullability(),
113                    self.len(),
114                )?)
115            }
116            DType::Extension(ext_dtype) => {
117                let s = ExtScalar::try_from(scalar)?;
118
119                let storage_scalar = s.storage();
120                let storage_self = ConstantArray::new(storage_scalar, self.len()).into_array();
121                Canonical::Extension(ExtensionArray::new(ext_dtype.clone(), storage_self))
122            }
123        })
124    }
125}
126
127fn canonical_byte_view(
128    scalar_bytes: Option<&[u8]>,
129    dtype: &DType,
130    len: usize,
131) -> VortexResult<VarBinViewArray> {
132    match scalar_bytes {
133        None => {
134            let views = buffer![BinaryView::from(0_u128); len];
135
136            VarBinViewArray::try_new(views, Vec::new(), dtype.clone(), Validity::AllInvalid)
137        }
138        Some(scalar_bytes) => {
139            // Create a view to hold the scalar bytes.
140            // If the scalar cannot be inlined, allocate a single buffer large enough to hold it.
141            let view = BinaryView::make_view(scalar_bytes, 0, 0);
142            let mut buffers = Vec::new();
143            if scalar_bytes.len() >= BinaryView::MAX_INLINED_SIZE {
144                buffers.push(Buffer::copy_from(scalar_bytes));
145            }
146
147            // Clone our constant view `len` times.
148            // TODO(aduffy): switch this out for a ConstantArray once we
149            //   add u128 PType, see https://github.com/spiraldb/vortex/issues/1110
150            let mut views = BufferMut::with_capacity_aligned(len, align_of::<u128>().into());
151            for _ in 0..len {
152                views.push(view);
153            }
154
155            VarBinViewArray::try_new(
156                views.freeze(),
157                buffers,
158                dtype.clone(),
159                Validity::from(dtype.nullability()),
160            )
161        }
162    }
163}
164
165fn canonical_list_array(
166    values: Option<Vec<Scalar>>,
167    element_dtype: &DType,
168    list_nullability: Nullability,
169    len: usize,
170) -> VortexResult<ListArray> {
171    match values {
172        None => ListArray::try_new(
173            Canonical::empty(element_dtype).into_array(),
174            ConstantArray::new(
175                Scalar::new(
176                    DType::Primitive(PType::U64, Nullability::NonNullable),
177                    ScalarValue::from(0),
178                ),
179                len + 1,
180            )
181            .into_array(),
182            Validity::AllInvalid,
183        ),
184        Some(vs) => {
185            let mut elements_builder = builder_with_capacity(element_dtype, len * vs.len());
186            for _ in 0..len {
187                for v in &vs {
188                    elements_builder.append_scalar(v)?;
189                }
190            }
191            let offsets = if vs.is_empty() {
192                Buffer::zeroed(len + 1)
193            } else {
194                (0..=len * vs.len())
195                    .step_by(vs.len())
196                    .map(|i| i as u64)
197                    .collect::<Buffer<_>>()
198            };
199
200            ListArray::try_new(
201                elements_builder.finish(),
202                offsets.into_array(),
203                Validity::from(list_nullability),
204            )
205        }
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use std::sync::Arc;
212
213    use enum_iterator::all;
214    use vortex_dtype::half::f16;
215    use vortex_dtype::{DType, Nullability, PType};
216    use vortex_scalar::Scalar;
217
218    use crate::array::Array;
219    use crate::arrays::ConstantArray;
220    use crate::canonical::ToCanonical;
221    use crate::compute::scalar_at;
222    use crate::stats::{Stat, StatsProviderExt, StatsSet};
223
224    #[test]
225    fn test_canonicalize_null() {
226        let const_null = ConstantArray::new(Scalar::null(DType::Null), 42);
227        let actual = const_null.to_null().unwrap();
228        assert_eq!(actual.len(), 42);
229        assert_eq!(scalar_at(&actual, 33).unwrap(), Scalar::null(DType::Null));
230    }
231
232    #[test]
233    fn test_canonicalize_const_str() {
234        let const_array = ConstantArray::new("four".to_string(), 4);
235
236        // Check all values correct.
237        let canonical = const_array.to_varbinview().unwrap();
238
239        assert_eq!(canonical.len(), 4);
240
241        for i in 0..=3 {
242            assert_eq!(scalar_at(&canonical, i).unwrap(), "four".into());
243        }
244    }
245
246    #[test]
247    fn test_canonicalize_propagates_stats() {
248        let scalar = Scalar::bool(true, Nullability::NonNullable);
249        let const_array = ConstantArray::new(scalar.clone(), 4).into_array();
250        let stats = const_array.statistics().to_owned();
251
252        let canonical = const_array.to_canonical().unwrap();
253        let canonical_stats = canonical.as_ref().statistics().to_owned();
254
255        let reference = StatsSet::constant(scalar, 4);
256        for stat in all::<Stat>() {
257            let canonical_stat =
258                canonical_stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
259            let reference_stat =
260                reference.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
261            let original_stat =
262                stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
263            assert_eq!(canonical_stat, reference_stat);
264            assert_eq!(canonical_stat, original_stat);
265        }
266    }
267
268    #[test]
269    fn test_canonicalize_scalar_values() {
270        let f16_scalar = Scalar::primitive(f16::from_f32(5.722046e-6), Nullability::NonNullable);
271        let scalar = Scalar::new(
272            DType::Primitive(PType::F16, Nullability::NonNullable),
273            Scalar::primitive(96u8, Nullability::NonNullable).into_value(),
274        );
275        let const_array = ConstantArray::new(scalar.clone(), 1).into_array();
276        let canonical_const = const_array.to_primitive().unwrap();
277        assert_eq!(scalar_at(&canonical_const, 0).unwrap(), scalar);
278        assert_eq!(scalar_at(&canonical_const, 0).unwrap(), f16_scalar);
279    }
280
281    #[test]
282    fn test_canonicalize_lists() {
283        let list_scalar = Scalar::list(
284            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
285            vec![1u64.into(), 2u64.into()],
286            Nullability::NonNullable,
287        );
288        let const_array = ConstantArray::new(list_scalar, 2).into_array();
289        let canonical_const = const_array.to_list().unwrap();
290        assert_eq!(
291            canonical_const
292                .elements()
293                .to_primitive()
294                .unwrap()
295                .as_slice::<u64>(),
296            [1u64, 2, 1, 2]
297        );
298        assert_eq!(
299            canonical_const
300                .offsets()
301                .to_primitive()
302                .unwrap()
303                .as_slice::<u64>(),
304            [0u64, 2, 4]
305        );
306    }
307
308    #[test]
309    fn test_canonicalize_empty_list() {
310        let list_scalar = Scalar::list(
311            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
312            vec![],
313            Nullability::NonNullable,
314        );
315        let const_array = ConstantArray::new(list_scalar, 2).into_array();
316        let canonical_const = const_array.to_list().unwrap();
317        assert!(
318            canonical_const
319                .elements()
320                .to_primitive()
321                .unwrap()
322                .is_empty()
323        );
324        assert_eq!(
325            canonical_const
326                .offsets()
327                .to_primitive()
328                .unwrap()
329                .as_slice::<u64>(),
330            [0u64, 0, 0]
331        );
332    }
333
334    #[test]
335    fn test_canonicalize_null_list() {
336        let list_scalar = Scalar::null(DType::List(
337            Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
338            Nullability::Nullable,
339        ));
340        let const_array = ConstantArray::new(list_scalar, 2).into_array();
341        let canonical_const = const_array.to_list().unwrap();
342        assert!(
343            canonical_const
344                .elements()
345                .to_primitive()
346                .unwrap()
347                .is_empty()
348        );
349        assert_eq!(
350            canonical_const
351                .offsets()
352                .to_primitive()
353                .unwrap()
354                .as_slice::<u64>(),
355            [0u64, 0, 0]
356        );
357    }
358}