1use std::sync::Arc;
5
6use arrow_buffer::BooleanBuffer;
7use vortex_buffer::{Buffer, BufferMut, buffer};
8use vortex_dtype::{DType, Nullability, PType, match_each_native_ptype};
9use vortex_error::{VortexExpect, VortexResult};
10use vortex_scalar::{
11 BinaryScalar, BoolScalar, DecimalValue, ExtScalar, ListScalar, Scalar, ScalarValue,
12 StructScalar, Utf8Scalar, match_each_decimal_value, match_each_decimal_value_type,
13};
14
15use crate::arrays::constant::ConstantArray;
16use crate::arrays::primitive::PrimitiveArray;
17use crate::arrays::{
18 BinaryView, BoolArray, ConstantVTable, DecimalArray, ExtensionArray, ListArray, NullArray,
19 StructArray, VarBinViewArray, smallest_storage_type,
20};
21use crate::builders::{ArrayBuilderExt, builder_with_capacity};
22use crate::validity::Validity;
23use crate::vtable::CanonicalVTable;
24use crate::{Canonical, IntoArray};
25
26impl CanonicalVTable<ConstantVTable> for ConstantVTable {
27 fn canonicalize(array: &ConstantArray) -> VortexResult<Canonical> {
28 let scalar = array.scalar();
29
30 let validity = match array.dtype().nullability() {
31 Nullability::NonNullable => Validity::NonNullable,
32 Nullability::Nullable => match scalar.is_null() {
33 true => Validity::AllInvalid,
34 false => Validity::AllValid,
35 },
36 };
37
38 Ok(match array.dtype() {
39 DType::Null => Canonical::Null(NullArray::new(array.len())),
40 DType::Bool(..) => Canonical::Bool(BoolArray::new(
41 if BoolScalar::try_from(scalar)?.value().unwrap_or_default() {
42 BooleanBuffer::new_set(array.len())
43 } else {
44 BooleanBuffer::new_unset(array.len())
45 },
46 validity,
47 )),
48 DType::Primitive(ptype, ..) => {
49 match_each_native_ptype!(ptype, |P| {
50 Canonical::Primitive(PrimitiveArray::new(
51 if scalar.is_valid() {
52 Buffer::full(
53 P::try_from(scalar)
54 .vortex_expect("Couldn't unwrap scalar to primitive"),
55 array.len(),
56 )
57 } else {
58 Buffer::zeroed(array.len())
59 },
60 validity,
61 ))
62 })
63 }
64 DType::Decimal(decimal_type, ..) => {
65 let size = smallest_storage_type(decimal_type);
66 let decimal = scalar.as_decimal();
67 let Some(value) = decimal.decimal_value() else {
68 let all_null = match_each_decimal_value_type!(size, |D| {
69 DecimalArray::new(Buffer::<D>::zeroed(array.len()), *decimal_type, validity)
70 });
71 return Ok(Canonical::Decimal(all_null));
72 };
73
74 let decimal_array = match_each_decimal_value!(value, |value| {
75 DecimalArray::new(Buffer::full(*value, array.len()), *decimal_type, validity)
76 });
77 Canonical::Decimal(decimal_array)
78 }
79 DType::Utf8(_) => {
80 let value = Utf8Scalar::try_from(scalar)?.value();
81 let const_value = value.as_ref().map(|v| v.as_bytes());
82 Canonical::VarBinView(canonical_byte_view(
83 const_value,
84 array.dtype(),
85 array.len(),
86 )?)
87 }
88 DType::Binary(_) => {
89 let value = BinaryScalar::try_from(scalar)?.value();
90 let const_value = value.as_ref().map(|v| v.as_slice());
91 Canonical::VarBinView(canonical_byte_view(
92 const_value,
93 array.dtype(),
94 array.len(),
95 )?)
96 }
97 DType::Struct(struct_dtype, _) => {
98 let value = StructScalar::try_from(scalar)?;
99 let fields: Vec<_> = match value.fields() {
100 Some(fields) => fields
101 .into_iter()
102 .map(|s| ConstantArray::new(s, array.len()).into_array())
103 .collect(),
104 None => {
105 assert!(validity.all_invalid()?);
106 struct_dtype
107 .fields()
108 .map(|dt| {
109 let scalar = Scalar::default_value(dt);
110 ConstantArray::new(scalar, array.len()).into_array()
111 })
112 .collect()
113 }
114 };
115 Canonical::Struct(StructArray::try_new_with_dtype(
116 fields,
117 struct_dtype.clone(),
118 array.len(),
119 validity,
120 )?)
121 }
122 DType::List(..) => {
123 let value = ListScalar::try_from(scalar)?;
124 Canonical::List(canonical_list_array(
125 value.elements(),
126 value.element_dtype(),
127 value.dtype().nullability(),
128 array.len(),
129 )?)
130 }
131 DType::Extension(ext_dtype) => {
132 let s = ExtScalar::try_from(scalar)?;
133
134 let storage_scalar = s.storage();
135 let storage_self = ConstantArray::new(storage_scalar, array.len()).into_array();
136 Canonical::Extension(ExtensionArray::new(ext_dtype.clone(), storage_self))
137 }
138 })
139 }
140}
141
142fn canonical_byte_view(
143 scalar_bytes: Option<&[u8]>,
144 dtype: &DType,
145 len: usize,
146) -> VortexResult<VarBinViewArray> {
147 match scalar_bytes {
148 None => {
149 let views = buffer![BinaryView::from(0_u128); len];
150
151 VarBinViewArray::try_new(
152 views,
153 Default::default(),
154 dtype.clone(),
155 Validity::AllInvalid,
156 )
157 }
158 Some(scalar_bytes) => {
159 let view = BinaryView::make_view(scalar_bytes, 0, 0);
162 let mut buffers = Vec::new();
163 if scalar_bytes.len() >= BinaryView::MAX_INLINED_SIZE {
164 buffers.push(Buffer::copy_from(scalar_bytes));
165 }
166
167 let mut views = BufferMut::with_capacity_aligned(len, align_of::<u128>().into());
171 for _ in 0..len {
172 views.push(view);
173 }
174
175 VarBinViewArray::try_new(
176 views.freeze(),
177 Arc::from(buffers),
178 dtype.clone(),
179 Validity::from(dtype.nullability()),
180 )
181 }
182 }
183}
184
185fn canonical_list_array(
186 values: Option<Vec<Scalar>>,
187 element_dtype: &DType,
188 list_nullability: Nullability,
189 len: usize,
190) -> VortexResult<ListArray> {
191 match values {
192 None => ListArray::try_new(
193 Canonical::empty(element_dtype).into_array(),
194 ConstantArray::new(
195 Scalar::new(
196 DType::Primitive(PType::U64, Nullability::NonNullable),
197 ScalarValue::from(0),
198 ),
199 len + 1,
200 )
201 .into_array(),
202 Validity::AllInvalid,
203 ),
204 Some(vs) => {
205 let mut elements_builder = builder_with_capacity(element_dtype, len * vs.len());
206 for _ in 0..len {
207 for v in &vs {
208 elements_builder.append_scalar(v)?;
209 }
210 }
211 let offsets = if vs.is_empty() {
212 Buffer::zeroed(len + 1)
213 } else {
214 Buffer::from_trusted_len_iter(
215 (0..=len * vs.len()).step_by(vs.len()).map(|i| i as u64),
216 )
217 };
218
219 ListArray::try_new(
220 elements_builder.finish(),
221 offsets.into_array(),
222 Validity::from(list_nullability),
223 )
224 }
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use std::sync::Arc;
231
232 use enum_iterator::all;
233 use itertools::Itertools;
234 use vortex_dtype::half::f16;
235 use vortex_dtype::{DType, Nullability, PType};
236 use vortex_scalar::Scalar;
237
238 use crate::arrays::ConstantArray;
239 use crate::canonical::ToCanonical;
240 use crate::stats::{Stat, StatsProviderExt};
241 use crate::{Array, IntoArray};
242
243 #[test]
244 fn test_canonicalize_null() {
245 let const_null = ConstantArray::new(Scalar::null(DType::Null), 42);
246 let actual = const_null.to_null().unwrap();
247 assert_eq!(actual.len(), 42);
248 assert_eq!(actual.scalar_at(33).unwrap(), Scalar::null(DType::Null));
249 }
250
251 #[test]
252 fn test_canonicalize_const_str() {
253 let const_array = ConstantArray::new("four".to_string(), 4);
254
255 let canonical = const_array.to_varbinview().unwrap();
257
258 assert_eq!(canonical.len(), 4);
259
260 for i in 0..=3 {
261 assert_eq!(canonical.scalar_at(i).unwrap(), "four".into());
262 }
263 }
264
265 #[test]
266 fn test_canonicalize_propagates_stats() {
267 let scalar = Scalar::bool(true, Nullability::NonNullable);
268 let const_array = ConstantArray::new(scalar, 4).into_array();
269 let stats = const_array
270 .statistics()
271 .compute_all(&all::<Stat>().collect_vec())
272 .unwrap();
273 let canonical = const_array.to_canonical().unwrap();
274 let canonical_stats = canonical.as_ref().statistics().to_owned();
275
276 for stat in all::<Stat>() {
277 if stat.dtype(canonical.as_ref().dtype()).is_none() {
278 continue;
279 }
280 let canonical_stat =
281 canonical_stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
282 let original_stat =
283 stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
284 assert_eq!(canonical_stat, original_stat, "stat mismatch {stat}");
285 }
286 }
287
288 #[test]
289 fn test_canonicalize_scalar_values() {
290 let f16_scalar = Scalar::primitive(f16::from_f32(5.722046e-6), Nullability::NonNullable);
291 let scalar = Scalar::new(
292 DType::Primitive(PType::F16, Nullability::NonNullable),
293 Scalar::primitive(96u8, Nullability::NonNullable).into_value(),
294 );
295 let const_array = ConstantArray::new(scalar.clone(), 1).into_array();
296 let canonical_const = const_array.to_primitive().unwrap();
297 assert_eq!(canonical_const.scalar_at(0).unwrap(), scalar);
298 assert_eq!(canonical_const.scalar_at(0).unwrap(), f16_scalar);
299 }
300
301 #[test]
302 fn test_canonicalize_lists() {
303 let list_scalar = Scalar::list(
304 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
305 vec![1u64.into(), 2u64.into()],
306 Nullability::NonNullable,
307 );
308 let const_array = ConstantArray::new(list_scalar, 2).into_array();
309 let canonical_const = const_array.to_list().unwrap();
310 assert_eq!(
311 canonical_const
312 .elements()
313 .to_primitive()
314 .unwrap()
315 .as_slice::<u64>(),
316 [1u64, 2, 1, 2]
317 );
318 assert_eq!(
319 canonical_const
320 .offsets()
321 .to_primitive()
322 .unwrap()
323 .as_slice::<u64>(),
324 [0u64, 2, 4]
325 );
326 }
327
328 #[test]
329 fn test_canonicalize_empty_list() {
330 let list_scalar = Scalar::list(
331 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
332 vec![],
333 Nullability::NonNullable,
334 );
335 let const_array = ConstantArray::new(list_scalar, 2).into_array();
336 let canonical_const = const_array.to_list().unwrap();
337 assert!(
338 canonical_const
339 .elements()
340 .to_primitive()
341 .unwrap()
342 .is_empty()
343 );
344 assert_eq!(
345 canonical_const
346 .offsets()
347 .to_primitive()
348 .unwrap()
349 .as_slice::<u64>(),
350 [0u64, 0, 0]
351 );
352 }
353
354 #[test]
355 fn test_canonicalize_null_list() {
356 let list_scalar = Scalar::null(DType::List(
357 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
358 Nullability::Nullable,
359 ));
360 let const_array = ConstantArray::new(list_scalar, 2).into_array();
361 let canonical_const = const_array.to_list().unwrap();
362 assert!(
363 canonical_const
364 .elements()
365 .to_primitive()
366 .unwrap()
367 .is_empty()
368 );
369 assert_eq!(
370 canonical_const
371 .offsets()
372 .to_primitive()
373 .unwrap()
374 .as_slice::<u64>(),
375 [0u64, 0, 0]
376 );
377 }
378
379 #[test]
380 fn test_canonicalize_nullable_struct() {
381 let array = ConstantArray::new(
382 Scalar::null(DType::struct_(
383 [(
384 "non_null_field",
385 DType::Primitive(PType::I8, Nullability::NonNullable),
386 )],
387 Nullability::Nullable,
388 )),
389 3,
390 );
391
392 let struct_array = array.to_struct().unwrap();
393 assert_eq!(struct_array.len(), 3);
394 assert_eq!(struct_array.valid_count().unwrap(), 0);
395
396 let field = struct_array.field_by_name("non_null_field").unwrap();
397
398 assert_eq!(
399 field.dtype(),
400 &DType::Primitive(PType::I8, Nullability::NonNullable)
401 );
402 }
403}