1use arrow_buffer::BooleanBuffer;
2use vortex_buffer::{Buffer, BufferMut, buffer};
3use vortex_dtype::{DType, Nullability, PType, match_each_native_ptype};
4use vortex_error::{VortexExpect, VortexResult};
5use vortex_scalar::{
6 BinaryScalar, BoolScalar, DecimalValue, ExtScalar, ListScalar, Scalar, ScalarValue,
7 StructScalar, Utf8Scalar, match_each_decimal_value, match_each_decimal_value_type,
8};
9
10use crate::arrays::constant::ConstantArray;
11use crate::arrays::primitive::PrimitiveArray;
12use crate::arrays::{
13 BinaryView, BoolArray, ConstantVTable, DecimalArray, ExtensionArray, ListArray, NullArray,
14 StructArray, VarBinViewArray, smallest_storage_type,
15};
16use crate::builders::{ArrayBuilderExt, builder_with_capacity};
17use crate::validity::Validity;
18use crate::vtable::CanonicalVTable;
19use crate::{Canonical, IntoArray};
20
21impl CanonicalVTable<ConstantVTable> for ConstantVTable {
22 fn canonicalize(array: &ConstantArray) -> VortexResult<Canonical> {
23 let scalar = array.scalar();
24
25 let validity = match array.dtype().nullability() {
26 Nullability::NonNullable => Validity::NonNullable,
27 Nullability::Nullable => match scalar.is_null() {
28 true => Validity::AllInvalid,
29 false => Validity::AllValid,
30 },
31 };
32
33 Ok(match array.dtype() {
34 DType::Null => Canonical::Null(NullArray::new(array.len())),
35 DType::Bool(..) => Canonical::Bool(BoolArray::new(
36 if BoolScalar::try_from(scalar)?.value().unwrap_or_default() {
37 BooleanBuffer::new_set(array.len())
38 } else {
39 BooleanBuffer::new_unset(array.len())
40 },
41 validity,
42 )),
43 DType::Primitive(ptype, ..) => {
44 match_each_native_ptype!(ptype, |P| {
45 Canonical::Primitive(PrimitiveArray::new(
46 if scalar.is_valid() {
47 Buffer::full(
48 P::try_from(scalar)
49 .vortex_expect("Couldn't unwrap scalar to primitive"),
50 array.len(),
51 )
52 } else {
53 Buffer::zeroed(array.len())
54 },
55 validity,
56 ))
57 })
58 }
59 DType::Decimal(decimal_type, ..) => {
60 let size = smallest_storage_type(decimal_type);
61 let decimal = scalar.as_decimal();
62 let Some(value) = decimal.decimal_value() else {
63 let all_null = match_each_decimal_value_type!(size, |D| {
64 DecimalArray::new(Buffer::<D>::zeroed(array.len()), *decimal_type, validity)
65 });
66 return Ok(Canonical::Decimal(all_null));
67 };
68
69 let decimal_array = match_each_decimal_value!(value, |value| {
70 DecimalArray::new(Buffer::full(*value, array.len()), *decimal_type, validity)
71 });
72 Canonical::Decimal(decimal_array)
73 }
74 DType::Utf8(_) => {
75 let value = Utf8Scalar::try_from(scalar)?.value();
76 let const_value = value.as_ref().map(|v| v.as_bytes());
77 Canonical::VarBinView(canonical_byte_view(
78 const_value,
79 array.dtype(),
80 array.len(),
81 )?)
82 }
83 DType::Binary(_) => {
84 let value = BinaryScalar::try_from(scalar)?.value();
85 let const_value = value.as_ref().map(|v| v.as_slice());
86 Canonical::VarBinView(canonical_byte_view(
87 const_value,
88 array.dtype(),
89 array.len(),
90 )?)
91 }
92 DType::Struct(struct_dtype, _) => {
93 let value = StructScalar::try_from(scalar)?;
94 let fields: Vec<_> = match value.fields() {
95 Some(fields) => fields
96 .into_iter()
97 .map(|s| ConstantArray::new(s, array.len()).into_array())
98 .collect(),
99 None => {
100 assert!(validity.all_invalid()?);
101 struct_dtype
102 .fields()
103 .map(|dt| {
104 let scalar = Scalar::default_value(dt);
105 ConstantArray::new(scalar, array.len()).into_array()
106 })
107 .collect()
108 }
109 };
110 Canonical::Struct(StructArray::try_new_with_dtype(
111 fields,
112 struct_dtype.clone(),
113 array.len(),
114 validity,
115 )?)
116 }
117 DType::List(..) => {
118 let value = ListScalar::try_from(scalar)?;
119 Canonical::List(canonical_list_array(
120 value.elements(),
121 value.element_dtype(),
122 value.dtype().nullability(),
123 array.len(),
124 )?)
125 }
126 DType::Extension(ext_dtype) => {
127 let s = ExtScalar::try_from(scalar)?;
128
129 let storage_scalar = s.storage();
130 let storage_self = ConstantArray::new(storage_scalar, array.len()).into_array();
131 Canonical::Extension(ExtensionArray::new(ext_dtype.clone(), storage_self))
132 }
133 })
134 }
135}
136
137fn canonical_byte_view(
138 scalar_bytes: Option<&[u8]>,
139 dtype: &DType,
140 len: usize,
141) -> VortexResult<VarBinViewArray> {
142 match scalar_bytes {
143 None => {
144 let views = buffer![BinaryView::from(0_u128); len];
145
146 VarBinViewArray::try_new(views, Vec::new(), dtype.clone(), Validity::AllInvalid)
147 }
148 Some(scalar_bytes) => {
149 let view = BinaryView::make_view(scalar_bytes, 0, 0);
152 let mut buffers = Vec::new();
153 if scalar_bytes.len() >= BinaryView::MAX_INLINED_SIZE {
154 buffers.push(Buffer::copy_from(scalar_bytes));
155 }
156
157 let mut views = BufferMut::with_capacity_aligned(len, align_of::<u128>().into());
161 for _ in 0..len {
162 views.push(view);
163 }
164
165 VarBinViewArray::try_new(
166 views.freeze(),
167 buffers,
168 dtype.clone(),
169 Validity::from(dtype.nullability()),
170 )
171 }
172 }
173}
174
175fn canonical_list_array(
176 values: Option<Vec<Scalar>>,
177 element_dtype: &DType,
178 list_nullability: Nullability,
179 len: usize,
180) -> VortexResult<ListArray> {
181 match values {
182 None => ListArray::try_new(
183 Canonical::empty(element_dtype).into_array(),
184 ConstantArray::new(
185 Scalar::new(
186 DType::Primitive(PType::U64, Nullability::NonNullable),
187 ScalarValue::from(0),
188 ),
189 len + 1,
190 )
191 .into_array(),
192 Validity::AllInvalid,
193 ),
194 Some(vs) => {
195 let mut elements_builder = builder_with_capacity(element_dtype, len * vs.len());
196 for _ in 0..len {
197 for v in &vs {
198 elements_builder.append_scalar(v)?;
199 }
200 }
201 let offsets = if vs.is_empty() {
202 Buffer::zeroed(len + 1)
203 } else {
204 (0..=len * vs.len())
205 .step_by(vs.len())
206 .map(|i| i as u64)
207 .collect::<Buffer<_>>()
208 };
209
210 ListArray::try_new(
211 elements_builder.finish(),
212 offsets.into_array(),
213 Validity::from(list_nullability),
214 )
215 }
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use std::sync::Arc;
222
223 use enum_iterator::all;
224 use vortex_dtype::half::f16;
225 use vortex_dtype::{DType, Nullability, PType};
226 use vortex_scalar::Scalar;
227
228 use crate::arrays::ConstantArray;
229 use crate::canonical::ToCanonical;
230 use crate::stats::{Stat, StatsProviderExt, StatsSet};
231 use crate::{Array, IntoArray};
232
233 #[test]
234 fn test_canonicalize_null() {
235 let const_null = ConstantArray::new(Scalar::null(DType::Null), 42);
236 let actual = const_null.to_null().unwrap();
237 assert_eq!(actual.len(), 42);
238 assert_eq!(actual.scalar_at(33).unwrap(), Scalar::null(DType::Null));
239 }
240
241 #[test]
242 fn test_canonicalize_const_str() {
243 let const_array = ConstantArray::new("four".to_string(), 4);
244
245 let canonical = const_array.to_varbinview().unwrap();
247
248 assert_eq!(canonical.len(), 4);
249
250 for i in 0..=3 {
251 assert_eq!(canonical.scalar_at(i).unwrap(), "four".into());
252 }
253 }
254
255 #[test]
256 fn test_canonicalize_propagates_stats() {
257 let scalar = Scalar::bool(true, Nullability::NonNullable);
258 let const_array = ConstantArray::new(scalar.clone(), 4).into_array();
259 let stats = const_array.statistics().to_owned();
260
261 let canonical = const_array.to_canonical().unwrap();
262 let canonical_stats = canonical.as_ref().statistics().to_owned();
263
264 let reference = StatsSet::constant(scalar, 4);
265 for stat in all::<Stat>() {
266 if stat.dtype(canonical.as_ref().dtype()).is_none() {
267 continue;
268 }
269
270 let canonical_stat =
271 canonical_stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
272 let reference_stat =
273 reference.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
274 let original_stat =
275 stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
276 assert_eq!(canonical_stat, reference_stat);
277 assert_eq!(canonical_stat, original_stat);
278 }
279 }
280
281 #[test]
282 fn test_canonicalize_scalar_values() {
283 let f16_scalar = Scalar::primitive(f16::from_f32(5.722046e-6), Nullability::NonNullable);
284 let scalar = Scalar::new(
285 DType::Primitive(PType::F16, Nullability::NonNullable),
286 Scalar::primitive(96u8, Nullability::NonNullable).into_value(),
287 );
288 let const_array = ConstantArray::new(scalar.clone(), 1).into_array();
289 let canonical_const = const_array.to_primitive().unwrap();
290 assert_eq!(canonical_const.scalar_at(0).unwrap(), scalar);
291 assert_eq!(canonical_const.scalar_at(0).unwrap(), f16_scalar);
292 }
293
294 #[test]
295 fn test_canonicalize_lists() {
296 let list_scalar = Scalar::list(
297 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
298 vec![1u64.into(), 2u64.into()],
299 Nullability::NonNullable,
300 );
301 let const_array = ConstantArray::new(list_scalar, 2).into_array();
302 let canonical_const = const_array.to_list().unwrap();
303 assert_eq!(
304 canonical_const
305 .elements()
306 .to_primitive()
307 .unwrap()
308 .as_slice::<u64>(),
309 [1u64, 2, 1, 2]
310 );
311 assert_eq!(
312 canonical_const
313 .offsets()
314 .to_primitive()
315 .unwrap()
316 .as_slice::<u64>(),
317 [0u64, 2, 4]
318 );
319 }
320
321 #[test]
322 fn test_canonicalize_empty_list() {
323 let list_scalar = Scalar::list(
324 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
325 vec![],
326 Nullability::NonNullable,
327 );
328 let const_array = ConstantArray::new(list_scalar, 2).into_array();
329 let canonical_const = const_array.to_list().unwrap();
330 assert!(
331 canonical_const
332 .elements()
333 .to_primitive()
334 .unwrap()
335 .is_empty()
336 );
337 assert_eq!(
338 canonical_const
339 .offsets()
340 .to_primitive()
341 .unwrap()
342 .as_slice::<u64>(),
343 [0u64, 0, 0]
344 );
345 }
346
347 #[test]
348 fn test_canonicalize_null_list() {
349 let list_scalar = Scalar::null(DType::List(
350 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
351 Nullability::Nullable,
352 ));
353 let const_array = ConstantArray::new(list_scalar, 2).into_array();
354 let canonical_const = const_array.to_list().unwrap();
355 assert!(
356 canonical_const
357 .elements()
358 .to_primitive()
359 .unwrap()
360 .is_empty()
361 );
362 assert_eq!(
363 canonical_const
364 .offsets()
365 .to_primitive()
366 .unwrap()
367 .as_slice::<u64>(),
368 [0u64, 0, 0]
369 );
370 }
371
372 #[test]
373 fn test_canonicalize_nullable_struct() {
374 let array = ConstantArray::new(
375 Scalar::null(DType::struct_(
376 [(
377 "non_null_field",
378 DType::Primitive(PType::I8, Nullability::NonNullable),
379 )],
380 Nullability::Nullable,
381 )),
382 3,
383 );
384
385 let struct_array = array.to_struct().unwrap();
386 assert_eq!(struct_array.len(), 3);
387 assert_eq!(struct_array.valid_count().unwrap(), 0);
388
389 let field = struct_array.field_by_name("non_null_field").unwrap();
390
391 assert_eq!(
392 field.dtype(),
393 &DType::Primitive(PType::I8, Nullability::NonNullable)
394 );
395 }
396}