vortex_array/arrays/constant/
canonical.rs1use std::sync::Arc;
5
6use arrow_buffer::BooleanBuffer;
7use vortex_buffer::{Buffer, buffer};
8use vortex_dtype::{DType, Nullability, PType, match_each_native_ptype};
9use vortex_error::{VortexExpect, VortexResult};
10use vortex_scalar::{
11 BinaryScalar, BoolScalar, DecimalValue, ExtScalar, ListScalar, Scalar, ScalarValue,
12 StructScalar, Utf8Scalar, match_each_decimal_value, match_each_decimal_value_type,
13};
14
15use crate::arrays::constant::ConstantArray;
16use crate::arrays::primitive::PrimitiveArray;
17use crate::arrays::{
18 BinaryView, BoolArray, ConstantVTable, DecimalArray, ExtensionArray, ListArray, NullArray,
19 StructArray, VarBinViewArray, smallest_storage_type,
20};
21use crate::builders::{ArrayBuilderExt, builder_with_capacity};
22use crate::validity::Validity;
23use crate::vtable::CanonicalVTable;
24use crate::{Canonical, IntoArray};
25
26impl CanonicalVTable<ConstantVTable> for ConstantVTable {
27 fn canonicalize(array: &ConstantArray) -> VortexResult<Canonical> {
28 let scalar = array.scalar();
29
30 let validity = match array.dtype().nullability() {
31 Nullability::NonNullable => Validity::NonNullable,
32 Nullability::Nullable => match scalar.is_null() {
33 true => Validity::AllInvalid,
34 false => Validity::AllValid,
35 },
36 };
37
38 Ok(match array.dtype() {
39 DType::Null => Canonical::Null(NullArray::new(array.len())),
40 DType::Bool(..) => Canonical::Bool(BoolArray::new(
41 if BoolScalar::try_from(scalar)?.value().unwrap_or_default() {
42 BooleanBuffer::new_set(array.len())
43 } else {
44 BooleanBuffer::new_unset(array.len())
45 },
46 validity,
47 )),
48 DType::Primitive(ptype, ..) => {
49 match_each_native_ptype!(ptype, |P| {
50 Canonical::Primitive(PrimitiveArray::new(
51 if scalar.is_valid() {
52 Buffer::full(
53 P::try_from(scalar)
54 .vortex_expect("Couldn't unwrap scalar to primitive"),
55 array.len(),
56 )
57 } else {
58 Buffer::zeroed(array.len())
59 },
60 validity,
61 ))
62 })
63 }
64 DType::Decimal(decimal_type, ..) => {
65 let size = smallest_storage_type(decimal_type);
66 let decimal = scalar.as_decimal();
67 let Some(value) = decimal.decimal_value() else {
68 let all_null = match_each_decimal_value_type!(size, |D| {
69 DecimalArray::new(Buffer::<D>::zeroed(array.len()), *decimal_type, validity)
70 });
71 return Ok(Canonical::Decimal(all_null));
72 };
73
74 let decimal_array = match_each_decimal_value!(value, |value| {
75 DecimalArray::new(Buffer::full(value, array.len()), *decimal_type, validity)
76 });
77 Canonical::Decimal(decimal_array)
78 }
79 DType::Utf8(_) => {
80 let value = Utf8Scalar::try_from(scalar)?.value();
81 let const_value = value.as_ref().map(|v| v.as_bytes());
82 Canonical::VarBinView(canonical_byte_view(
83 const_value,
84 array.dtype(),
85 array.len(),
86 )?)
87 }
88 DType::Binary(_) => {
89 let value = BinaryScalar::try_from(scalar)?.value();
90 let const_value = value.as_ref().map(|v| v.as_slice());
91 Canonical::VarBinView(canonical_byte_view(
92 const_value,
93 array.dtype(),
94 array.len(),
95 )?)
96 }
97 DType::Struct(struct_dtype, _) => {
98 let value = StructScalar::try_from(scalar)?;
99 let fields: Vec<_> = match value.fields() {
100 Some(fields) => fields
101 .into_iter()
102 .map(|s| ConstantArray::new(s, array.len()).into_array())
103 .collect(),
104 None => {
105 assert!(validity.all_invalid());
106 struct_dtype
107 .fields()
108 .map(|dt| {
109 let scalar = Scalar::default_value(dt);
110 ConstantArray::new(scalar, array.len()).into_array()
111 })
112 .collect()
113 }
114 };
115 Canonical::Struct(StructArray::try_new_with_dtype(
116 fields,
117 struct_dtype.clone(),
118 array.len(),
119 validity,
120 )?)
121 }
122 DType::List(..) => {
123 let value = ListScalar::try_from(scalar)?;
124 Canonical::List(canonical_list_array(
125 value.elements(),
126 value.element_dtype(),
127 value.dtype().nullability(),
128 array.len(),
129 )?)
130 }
131 DType::FixedSizeList(..) => {
132 unimplemented!("TODO(connor)[FixedSizeList]")
133 }
134 DType::Extension(ext_dtype) => {
135 let s = ExtScalar::try_from(scalar)?;
136
137 let storage_scalar = s.storage();
138 let storage_self = ConstantArray::new(storage_scalar, array.len()).into_array();
139 Canonical::Extension(ExtensionArray::new(ext_dtype.clone(), storage_self))
140 }
141 })
142 }
143}
144
145fn canonical_byte_view(
146 scalar_bytes: Option<&[u8]>,
147 dtype: &DType,
148 len: usize,
149) -> VortexResult<VarBinViewArray> {
150 match scalar_bytes {
151 None => {
152 let views = buffer![BinaryView::from(0_u128); len];
153
154 unsafe {
156 Ok(VarBinViewArray::new_unchecked(
157 views,
158 Default::default(),
159 dtype.clone(),
160 Validity::AllInvalid,
161 ))
162 }
163 }
164 Some(scalar_bytes) => {
165 let view = BinaryView::make_view(scalar_bytes, 0, 0);
168 let mut buffers = Vec::new();
169 if scalar_bytes.len() >= BinaryView::MAX_INLINED_SIZE {
170 buffers.push(Buffer::copy_from(scalar_bytes));
171 }
172
173 let views = buffer![view; len];
175
176 unsafe {
178 Ok(VarBinViewArray::new_unchecked(
179 views,
180 Arc::from(buffers),
181 dtype.clone(),
182 Validity::from(dtype.nullability()),
183 ))
184 }
185 }
186 }
187}
188
189fn canonical_list_array(
190 values: Option<Vec<Scalar>>,
191 element_dtype: &DType,
192 list_nullability: Nullability,
193 len: usize,
194) -> VortexResult<ListArray> {
195 match values {
196 None => ListArray::try_new(
197 Canonical::empty(element_dtype).into_array(),
198 ConstantArray::new(
199 Scalar::new(
200 DType::Primitive(PType::U64, Nullability::NonNullable),
201 ScalarValue::from(0),
202 ),
203 len + 1,
204 )
205 .into_array(),
206 Validity::AllInvalid,
207 ),
208 Some(vs) => {
209 let mut elements_builder = builder_with_capacity(element_dtype, len * vs.len());
210 for _ in 0..len {
211 for v in &vs {
212 elements_builder.append_scalar(v)?;
213 }
214 }
215 let offsets = if vs.is_empty() {
216 Buffer::zeroed(len + 1)
217 } else {
218 Buffer::from_trusted_len_iter(
219 (0..=len * vs.len()).step_by(vs.len()).map(|i| i as u64),
220 )
221 };
222
223 ListArray::try_new(
224 elements_builder.finish(),
225 offsets.into_array(),
226 Validity::from(list_nullability),
227 )
228 }
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use std::sync::Arc;
235
236 use enum_iterator::all;
237 use itertools::Itertools;
238 use vortex_dtype::half::f16;
239 use vortex_dtype::{DType, Nullability, PType};
240 use vortex_scalar::Scalar;
241
242 use crate::arrays::ConstantArray;
243 use crate::canonical::ToCanonical;
244 use crate::stats::{Stat, StatsProvider};
245 use crate::{Array, IntoArray};
246
247 #[test]
248 fn test_canonicalize_null() {
249 let const_null = ConstantArray::new(Scalar::null(DType::Null), 42);
250 let actual = const_null.to_null().unwrap();
251 assert_eq!(actual.len(), 42);
252 assert_eq!(actual.scalar_at(33), Scalar::null(DType::Null));
253 }
254
255 #[test]
256 fn test_canonicalize_const_str() {
257 let const_array = ConstantArray::new("four".to_string(), 4);
258
259 let canonical = const_array.to_varbinview().unwrap();
261
262 assert_eq!(canonical.len(), 4);
263
264 for i in 0..=3 {
265 assert_eq!(canonical.scalar_at(i), "four".into());
266 }
267 }
268
269 #[test]
270 fn test_canonicalize_propagates_stats() {
271 let scalar = Scalar::bool(true, Nullability::NonNullable);
272 let const_array = ConstantArray::new(scalar, 4).into_array();
273 let stats = const_array
274 .statistics()
275 .compute_all(&all::<Stat>().collect_vec())
276 .unwrap();
277 let canonical = const_array.to_canonical().unwrap();
278 let canonical_stats = canonical.as_ref().statistics();
279
280 let stats_ref = stats.as_typed_ref(canonical.as_ref().dtype());
281
282 for stat in all::<Stat>() {
283 if stat.dtype(canonical.as_ref().dtype()).is_none() {
284 continue;
285 }
286 assert_eq!(
287 canonical_stats.get(stat),
288 stats_ref.get(stat),
289 "stat mismatch {stat}"
290 );
291 }
292 }
293
294 #[test]
295 fn test_canonicalize_scalar_values() {
296 let f16_value = f16::from_f32(5.722046e-6);
297 let f16_scalar = Scalar::primitive(f16_value, Nullability::NonNullable);
298
299 let const_array = ConstantArray::new(f16_scalar.clone(), 1).into_array();
301 let canonical_const = const_array.to_primitive().unwrap();
302
303 assert_eq!(canonical_const.scalar_at(0), f16_scalar);
305 }
306
307 #[test]
308 fn test_canonicalize_lists() {
309 let list_scalar = Scalar::list(
310 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
311 vec![1u64.into(), 2u64.into()],
312 Nullability::NonNullable,
313 );
314 let const_array = ConstantArray::new(list_scalar, 2).into_array();
315 let canonical_const = const_array.to_list().unwrap();
316 assert_eq!(
317 canonical_const
318 .elements()
319 .to_primitive()
320 .unwrap()
321 .as_slice::<u64>(),
322 [1u64, 2, 1, 2]
323 );
324 assert_eq!(
325 canonical_const
326 .offsets()
327 .to_primitive()
328 .unwrap()
329 .as_slice::<u64>(),
330 [0u64, 2, 4]
331 );
332 }
333
334 #[test]
335 fn test_canonicalize_empty_list() {
336 let list_scalar = Scalar::list(
337 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
338 vec![],
339 Nullability::NonNullable,
340 );
341 let const_array = ConstantArray::new(list_scalar, 2).into_array();
342 let canonical_const = const_array.to_list().unwrap();
343 assert!(
344 canonical_const
345 .elements()
346 .to_primitive()
347 .unwrap()
348 .is_empty()
349 );
350 assert_eq!(
351 canonical_const
352 .offsets()
353 .to_primitive()
354 .unwrap()
355 .as_slice::<u64>(),
356 [0u64, 0, 0]
357 );
358 }
359
360 #[test]
361 fn test_canonicalize_null_list() {
362 let list_scalar = Scalar::null(DType::List(
363 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
364 Nullability::Nullable,
365 ));
366 let const_array = ConstantArray::new(list_scalar, 2).into_array();
367 let canonical_const = const_array.to_list().unwrap();
368 assert!(
369 canonical_const
370 .elements()
371 .to_primitive()
372 .unwrap()
373 .is_empty()
374 );
375 assert_eq!(
376 canonical_const
377 .offsets()
378 .to_primitive()
379 .unwrap()
380 .as_slice::<u64>(),
381 [0u64, 0, 0]
382 );
383 }
384
385 #[test]
386 fn test_canonicalize_nullable_struct() {
387 let array = ConstantArray::new(
388 Scalar::null(DType::struct_(
389 [(
390 "non_null_field",
391 DType::Primitive(PType::I8, Nullability::NonNullable),
392 )],
393 Nullability::Nullable,
394 )),
395 3,
396 );
397
398 let struct_array = array.to_struct().unwrap();
399 assert_eq!(struct_array.len(), 3);
400 assert_eq!(struct_array.valid_count(), 0);
401
402 let field = struct_array.field_by_name("non_null_field").unwrap();
403
404 assert_eq!(
405 field.dtype(),
406 &DType::Primitive(PType::I8, Nullability::NonNullable)
407 );
408 }
409}