vortex_array/arrays/constant/
canonical.rs1use arrow_buffer::BooleanBuffer;
2use vortex_buffer::{Buffer, BufferMut, buffer};
3use vortex_dtype::{DType, Nullability, PType, match_each_native_ptype};
4use vortex_error::{VortexExpect, VortexResult};
5use vortex_scalar::{
6 BinaryScalar, BoolScalar, DecimalValue, ExtScalar, ListScalar, Scalar, ScalarValue,
7 StructScalar, Utf8Scalar, match_each_decimal_value, match_each_decimal_value_type,
8};
9
10use crate::arrays::constant::ConstantArray;
11use crate::arrays::primitive::PrimitiveArray;
12use crate::arrays::{
13 BinaryView, BoolArray, ConstantVTable, DecimalArray, ExtensionArray, ListArray, NullArray,
14 StructArray, VarBinViewArray, precision_to_storage_size,
15};
16use crate::builders::{ArrayBuilderExt, builder_with_capacity};
17use crate::validity::Validity;
18use crate::vtable::CanonicalVTable;
19use crate::{Canonical, IntoArray};
20
21impl CanonicalVTable<ConstantVTable> for ConstantVTable {
22 fn canonicalize(array: &ConstantArray) -> VortexResult<Canonical> {
23 let scalar = array.scalar();
24
25 let validity = match array.dtype().nullability() {
26 Nullability::NonNullable => Validity::NonNullable,
27 Nullability::Nullable => match scalar.is_null() {
28 true => Validity::AllInvalid,
29 false => Validity::AllValid,
30 },
31 };
32
33 Ok(match array.dtype() {
34 DType::Null => Canonical::Null(NullArray::new(array.len())),
35 DType::Bool(..) => Canonical::Bool(BoolArray::new(
36 if BoolScalar::try_from(scalar)?.value().unwrap_or_default() {
37 BooleanBuffer::new_set(array.len())
38 } else {
39 BooleanBuffer::new_unset(array.len())
40 },
41 validity,
42 )),
43 DType::Primitive(ptype, ..) => {
44 match_each_native_ptype!(ptype, |$P| {
45 Canonical::Primitive(PrimitiveArray::new(
46 if scalar.is_valid() {
47 Buffer::full(
48 $P::try_from(scalar)
49 .vortex_expect("Couldn't unwrap scalar to primitive"),
50 array.len(),
51 )
52 } else {
53 Buffer::zeroed(array.len())
54 },
55 validity,
56 ))
57 })
58 }
59 DType::Decimal(decimal_type, ..) => {
60 let size = precision_to_storage_size(decimal_type);
61 let decimal = scalar.as_decimal();
62 let Some(value) = decimal.decimal_value() else {
63 let all_null = match_each_decimal_value_type!(size, |$D| {
64 DecimalArray::new(
65 Buffer::<$D>::zeroed(array.len()),
66 *decimal_type,
67 validity,
68 )
69 });
70 return Ok(Canonical::Decimal(all_null));
71 };
72
73 let decimal_array = match_each_decimal_value!(value, |$V| {
74 DecimalArray::new(
75 Buffer::full(*$V, array.len()),
76 *decimal_type,
77 validity,
78 )
79 });
80 Canonical::Decimal(decimal_array)
81 }
82 DType::Utf8(_) => {
83 let value = Utf8Scalar::try_from(scalar)?.value();
84 let const_value = value.as_ref().map(|v| v.as_bytes());
85 Canonical::VarBinView(canonical_byte_view(
86 const_value,
87 array.dtype(),
88 array.len(),
89 )?)
90 }
91 DType::Binary(_) => {
92 let value = BinaryScalar::try_from(scalar)?.value();
93 let const_value = value.as_ref().map(|v| v.as_slice());
94 Canonical::VarBinView(canonical_byte_view(
95 const_value,
96 array.dtype(),
97 array.len(),
98 )?)
99 }
100 DType::Struct(struct_dtype, _) => {
101 let value = StructScalar::try_from(scalar)?;
102 let fields = value.fields().map(|fields| {
103 fields
104 .into_iter()
105 .map(|s| ConstantArray::new(s, array.len()).into_array())
106 .collect::<Vec<_>>()
107 });
108 Canonical::Struct(StructArray::try_new_with_dtype(
109 fields.unwrap_or_default(),
110 struct_dtype.clone(),
111 array.len(),
112 validity,
113 )?)
114 }
115 DType::List(..) => {
116 let value = ListScalar::try_from(scalar)?;
117 Canonical::List(canonical_list_array(
118 value.elements(),
119 value.element_dtype(),
120 value.dtype().nullability(),
121 array.len(),
122 )?)
123 }
124 DType::Extension(ext_dtype) => {
125 let s = ExtScalar::try_from(scalar)?;
126
127 let storage_scalar = s.storage();
128 let storage_self = ConstantArray::new(storage_scalar, array.len()).into_array();
129 Canonical::Extension(ExtensionArray::new(ext_dtype.clone(), storage_self))
130 }
131 })
132 }
133}
134
135fn canonical_byte_view(
136 scalar_bytes: Option<&[u8]>,
137 dtype: &DType,
138 len: usize,
139) -> VortexResult<VarBinViewArray> {
140 match scalar_bytes {
141 None => {
142 let views = buffer![BinaryView::from(0_u128); len];
143
144 VarBinViewArray::try_new(views, Vec::new(), dtype.clone(), Validity::AllInvalid)
145 }
146 Some(scalar_bytes) => {
147 let view = BinaryView::make_view(scalar_bytes, 0, 0);
150 let mut buffers = Vec::new();
151 if scalar_bytes.len() >= BinaryView::MAX_INLINED_SIZE {
152 buffers.push(Buffer::copy_from(scalar_bytes));
153 }
154
155 let mut views = BufferMut::with_capacity_aligned(len, align_of::<u128>().into());
159 for _ in 0..len {
160 views.push(view);
161 }
162
163 VarBinViewArray::try_new(
164 views.freeze(),
165 buffers,
166 dtype.clone(),
167 Validity::from(dtype.nullability()),
168 )
169 }
170 }
171}
172
173fn canonical_list_array(
174 values: Option<Vec<Scalar>>,
175 element_dtype: &DType,
176 list_nullability: Nullability,
177 len: usize,
178) -> VortexResult<ListArray> {
179 match values {
180 None => ListArray::try_new(
181 Canonical::empty(element_dtype).into_array(),
182 ConstantArray::new(
183 Scalar::new(
184 DType::Primitive(PType::U64, Nullability::NonNullable),
185 ScalarValue::from(0),
186 ),
187 len + 1,
188 )
189 .into_array(),
190 Validity::AllInvalid,
191 ),
192 Some(vs) => {
193 let mut elements_builder = builder_with_capacity(element_dtype, len * vs.len());
194 for _ in 0..len {
195 for v in &vs {
196 elements_builder.append_scalar(v)?;
197 }
198 }
199 let offsets = if vs.is_empty() {
200 Buffer::zeroed(len + 1)
201 } else {
202 (0..=len * vs.len())
203 .step_by(vs.len())
204 .map(|i| i as u64)
205 .collect::<Buffer<_>>()
206 };
207
208 ListArray::try_new(
209 elements_builder.finish(),
210 offsets.into_array(),
211 Validity::from(list_nullability),
212 )
213 }
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use std::sync::Arc;
220
221 use enum_iterator::all;
222 use vortex_dtype::half::f16;
223 use vortex_dtype::{DType, Nullability, PType};
224 use vortex_scalar::Scalar;
225
226 use crate::arrays::ConstantArray;
227 use crate::canonical::ToCanonical;
228 use crate::stats::{Stat, StatsProviderExt, StatsSet};
229 use crate::{Array, IntoArray};
230
231 #[test]
232 fn test_canonicalize_null() {
233 let const_null = ConstantArray::new(Scalar::null(DType::Null), 42);
234 let actual = const_null.to_null().unwrap();
235 assert_eq!(actual.len(), 42);
236 assert_eq!(actual.scalar_at(33).unwrap(), Scalar::null(DType::Null));
237 }
238
239 #[test]
240 fn test_canonicalize_const_str() {
241 let const_array = ConstantArray::new("four".to_string(), 4);
242
243 let canonical = const_array.to_varbinview().unwrap();
245
246 assert_eq!(canonical.len(), 4);
247
248 for i in 0..=3 {
249 assert_eq!(canonical.scalar_at(i).unwrap(), "four".into());
250 }
251 }
252
253 #[test]
254 fn test_canonicalize_propagates_stats() {
255 let scalar = Scalar::bool(true, Nullability::NonNullable);
256 let const_array = ConstantArray::new(scalar.clone(), 4).into_array();
257 let stats = const_array.statistics().to_owned();
258
259 let canonical = const_array.to_canonical().unwrap();
260 let canonical_stats = canonical.as_ref().statistics().to_owned();
261
262 let reference = StatsSet::constant(scalar, 4);
263 for stat in all::<Stat>() {
264 if stat.dtype(canonical.as_ref().dtype()).is_none() {
265 continue;
266 }
267
268 let canonical_stat =
269 canonical_stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
270 let reference_stat =
271 reference.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
272 let original_stat =
273 stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
274 assert_eq!(canonical_stat, reference_stat);
275 assert_eq!(canonical_stat, original_stat);
276 }
277 }
278
279 #[test]
280 fn test_canonicalize_scalar_values() {
281 let f16_scalar = Scalar::primitive(f16::from_f32(5.722046e-6), Nullability::NonNullable);
282 let scalar = Scalar::new(
283 DType::Primitive(PType::F16, Nullability::NonNullable),
284 Scalar::primitive(96u8, Nullability::NonNullable).into_value(),
285 );
286 let const_array = ConstantArray::new(scalar.clone(), 1).into_array();
287 let canonical_const = const_array.to_primitive().unwrap();
288 assert_eq!(canonical_const.scalar_at(0).unwrap(), scalar);
289 assert_eq!(canonical_const.scalar_at(0).unwrap(), f16_scalar);
290 }
291
292 #[test]
293 fn test_canonicalize_lists() {
294 let list_scalar = Scalar::list(
295 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
296 vec![1u64.into(), 2u64.into()],
297 Nullability::NonNullable,
298 );
299 let const_array = ConstantArray::new(list_scalar, 2).into_array();
300 let canonical_const = const_array.to_list().unwrap();
301 assert_eq!(
302 canonical_const
303 .elements()
304 .to_primitive()
305 .unwrap()
306 .as_slice::<u64>(),
307 [1u64, 2, 1, 2]
308 );
309 assert_eq!(
310 canonical_const
311 .offsets()
312 .to_primitive()
313 .unwrap()
314 .as_slice::<u64>(),
315 [0u64, 2, 4]
316 );
317 }
318
319 #[test]
320 fn test_canonicalize_empty_list() {
321 let list_scalar = Scalar::list(
322 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
323 vec![],
324 Nullability::NonNullable,
325 );
326 let const_array = ConstantArray::new(list_scalar, 2).into_array();
327 let canonical_const = const_array.to_list().unwrap();
328 assert!(
329 canonical_const
330 .elements()
331 .to_primitive()
332 .unwrap()
333 .is_empty()
334 );
335 assert_eq!(
336 canonical_const
337 .offsets()
338 .to_primitive()
339 .unwrap()
340 .as_slice::<u64>(),
341 [0u64, 0, 0]
342 );
343 }
344
345 #[test]
346 fn test_canonicalize_null_list() {
347 let list_scalar = Scalar::null(DType::List(
348 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
349 Nullability::Nullable,
350 ));
351 let const_array = ConstantArray::new(list_scalar, 2).into_array();
352 let canonical_const = const_array.to_list().unwrap();
353 assert!(
354 canonical_const
355 .elements()
356 .to_primitive()
357 .unwrap()
358 .is_empty()
359 );
360 assert_eq!(
361 canonical_const
362 .offsets()
363 .to_primitive()
364 .unwrap()
365 .as_slice::<u64>(),
366 [0u64, 0, 0]
367 );
368 }
369}