vortex_array/arrays/constant/
canonical.rs1use std::sync::Arc;
5
6use arrow_buffer::BooleanBuffer;
7use vortex_buffer::{Buffer, buffer};
8use vortex_dtype::{DType, Nullability, PType, match_each_native_ptype};
9use vortex_error::VortexExpect;
10use vortex_scalar::{
11 BinaryScalar, BoolScalar, DecimalValue, ExtScalar, ListScalar, Scalar, ScalarValue,
12 StructScalar, Utf8Scalar, match_each_decimal_value, match_each_decimal_value_type,
13};
14
15use crate::arrays::constant::ConstantArray;
16use crate::arrays::primitive::PrimitiveArray;
17use crate::arrays::{
18 BinaryView, BoolArray, ConstantVTable, DecimalArray, ExtensionArray, ListArray, NullArray,
19 StructArray, VarBinViewArray, smallest_storage_type,
20};
21use crate::builders::builder_with_capacity;
22use crate::validity::Validity;
23use crate::vtable::CanonicalVTable;
24use crate::{Canonical, IntoArray};
25
26impl CanonicalVTable<ConstantVTable> for ConstantVTable {
27 fn canonicalize(array: &ConstantArray) -> Canonical {
28 let scalar = array.scalar();
29
30 let validity = match array.dtype().nullability() {
31 Nullability::NonNullable => Validity::NonNullable,
32 Nullability::Nullable => match scalar.is_null() {
33 true => Validity::AllInvalid,
34 false => Validity::AllValid,
35 },
36 };
37
38 match array.dtype() {
39 DType::Null => Canonical::Null(NullArray::new(array.len())),
40 DType::Bool(..) => Canonical::Bool(BoolArray::new(
41 if BoolScalar::try_from(scalar)
42 .vortex_expect("must be bool")
43 .value()
44 .unwrap_or_default()
45 {
46 BooleanBuffer::new_set(array.len())
47 } else {
48 BooleanBuffer::new_unset(array.len())
49 },
50 validity,
51 )),
52 DType::Primitive(ptype, ..) => {
53 match_each_native_ptype!(ptype, |P| {
54 Canonical::Primitive(PrimitiveArray::new(
55 if scalar.is_valid() {
56 Buffer::full(
57 P::try_from(scalar)
58 .vortex_expect("Couldn't unwrap scalar to primitive"),
59 array.len(),
60 )
61 } else {
62 Buffer::zeroed(array.len())
63 },
64 validity,
65 ))
66 })
67 }
68 DType::Decimal(decimal_type, ..) => {
69 let size = smallest_storage_type(decimal_type);
70 let decimal = scalar.as_decimal();
71 let Some(value) = decimal.decimal_value() else {
72 let all_null = match_each_decimal_value_type!(size, |D| {
73 DecimalArray::new(Buffer::<D>::zeroed(array.len()), *decimal_type, validity)
74 });
75 return Canonical::Decimal(all_null);
76 };
77
78 let decimal_array = match_each_decimal_value!(value, |value| {
79 DecimalArray::new(Buffer::full(value, array.len()), *decimal_type, validity)
80 });
81 Canonical::Decimal(decimal_array)
82 }
83 DType::Utf8(_) => {
84 let value = Utf8Scalar::try_from(scalar)
85 .vortex_expect("Must be a utf8 scalar")
86 .value();
87 let const_value = value.as_ref().map(|v| v.as_bytes());
88 Canonical::VarBinView(canonical_byte_view(const_value, array.dtype(), array.len()))
89 }
90 DType::Binary(_) => {
91 let value = BinaryScalar::try_from(scalar)
92 .vortex_expect("must be a binary scalar")
93 .value();
94 let const_value = value.as_ref().map(|v| v.as_slice());
95 Canonical::VarBinView(canonical_byte_view(const_value, array.dtype(), array.len()))
96 }
97 DType::Struct(struct_dtype, _) => {
98 let value = StructScalar::try_from(scalar).vortex_expect("must be struct");
99 let fields: Vec<_> = match value.fields() {
100 Some(fields) => fields
101 .into_iter()
102 .map(|s| ConstantArray::new(s, array.len()).into_array())
103 .collect(),
104 None => {
105 assert!(validity.all_invalid());
106 struct_dtype
107 .fields()
108 .map(|dt| {
109 let scalar = Scalar::default_value(dt);
110 ConstantArray::new(scalar, array.len()).into_array()
111 })
112 .collect()
113 }
114 };
115 Canonical::Struct(StructArray::new_unchecked(
116 fields,
117 struct_dtype.clone(),
118 array.len(),
119 validity,
120 ))
121 }
122 DType::List(..) => {
123 let value = ListScalar::try_from(scalar).vortex_expect("must be list");
124 Canonical::List(canonical_list_array(
125 value.elements(),
126 value.element_dtype(),
127 value.dtype().nullability(),
128 array.len(),
129 ))
130 }
131 DType::FixedSizeList(..) => {
132 unimplemented!("TODO(connor)[FixedSizeList]")
133 }
134 DType::Extension(ext_dtype) => {
135 let s = ExtScalar::try_from(scalar).vortex_expect("must be an extension scalar");
136
137 let storage_scalar = s.storage();
138 let storage_self = ConstantArray::new(storage_scalar, array.len()).into_array();
139 Canonical::Extension(ExtensionArray::new(ext_dtype.clone(), storage_self))
140 }
141 }
142 }
143}
144
145fn canonical_byte_view(scalar_bytes: Option<&[u8]>, dtype: &DType, len: usize) -> VarBinViewArray {
146 match scalar_bytes {
147 None => {
148 let views = buffer![BinaryView::from(0_u128); len];
149
150 unsafe {
152 VarBinViewArray::new_unchecked(
153 views,
154 Default::default(),
155 dtype.clone(),
156 Validity::AllInvalid,
157 )
158 }
159 }
160 Some(scalar_bytes) => {
161 let view = BinaryView::make_view(scalar_bytes, 0, 0);
164 let mut buffers = Vec::new();
165 if scalar_bytes.len() >= BinaryView::MAX_INLINED_SIZE {
166 buffers.push(Buffer::copy_from(scalar_bytes));
167 }
168
169 let views = buffer![view; len];
171
172 unsafe {
174 VarBinViewArray::new_unchecked(
175 views,
176 Arc::from(buffers),
177 dtype.clone(),
178 Validity::from(dtype.nullability()),
179 )
180 }
181 }
182 }
183}
184
185fn canonical_list_array(
186 values: Option<Vec<Scalar>>,
187 element_dtype: &DType,
188 list_nullability: Nullability,
189 len: usize,
190) -> ListArray {
191 match values {
192 None => unsafe {
193 ListArray::new_unchecked(
194 Canonical::empty(element_dtype).into_array(),
195 ConstantArray::new(
196 Scalar::new(
197 DType::Primitive(PType::U64, Nullability::NonNullable),
198 ScalarValue::from(0),
199 ),
200 len + 1,
201 )
202 .into_array(),
203 Validity::AllInvalid,
204 )
205 },
206 Some(vs) => {
207 let mut elements_builder = builder_with_capacity(element_dtype, len * vs.len());
208 for _ in 0..len {
209 for v in &vs {
210 elements_builder
211 .append_scalar(v)
212 .vortex_expect("must be a same dtype");
213 }
214 }
215 let offsets = if vs.is_empty() {
216 Buffer::zeroed(len + 1)
217 } else {
218 Buffer::from_trusted_len_iter(
219 (0..=len * vs.len()).step_by(vs.len()).map(|i| i as u64),
220 )
221 };
222
223 unsafe {
224 ListArray::new_unchecked(
225 elements_builder.finish(),
226 offsets.into_array(),
227 Validity::from(list_nullability),
228 )
229 }
230 }
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use std::sync::Arc;
237
238 use enum_iterator::all;
239 use itertools::Itertools;
240 use vortex_dtype::half::f16;
241 use vortex_dtype::{DType, Nullability, PType};
242 use vortex_scalar::Scalar;
243
244 use crate::arrays::ConstantArray;
245 use crate::canonical::ToCanonical;
246 use crate::stats::{Stat, StatsProvider};
247 use crate::{Array, IntoArray};
248
249 #[test]
250 fn test_canonicalize_null() {
251 let const_null = ConstantArray::new(Scalar::null(DType::Null), 42);
252 let actual = const_null.to_null();
253 assert_eq!(actual.len(), 42);
254 assert_eq!(actual.scalar_at(33), Scalar::null(DType::Null));
255 }
256
257 #[test]
258 fn test_canonicalize_const_str() {
259 let const_array = ConstantArray::new("four".to_string(), 4);
260
261 let canonical = const_array.to_varbinview();
263
264 assert_eq!(canonical.len(), 4);
265
266 for i in 0..=3 {
267 assert_eq!(canonical.scalar_at(i), "four".into());
268 }
269 }
270
271 #[test]
272 fn test_canonicalize_propagates_stats() {
273 let scalar = Scalar::bool(true, Nullability::NonNullable);
274 let const_array = ConstantArray::new(scalar, 4).into_array();
275 let stats = const_array
276 .statistics()
277 .compute_all(&all::<Stat>().collect_vec())
278 .unwrap();
279 let canonical = const_array.to_canonical();
280 let canonical_stats = canonical.as_ref().statistics();
281
282 let stats_ref = stats.as_typed_ref(canonical.as_ref().dtype());
283
284 for stat in all::<Stat>() {
285 if stat.dtype(canonical.as_ref().dtype()).is_none() {
286 continue;
287 }
288 assert_eq!(
289 canonical_stats.get(stat),
290 stats_ref.get(stat),
291 "stat mismatch {stat}"
292 );
293 }
294 }
295
296 #[test]
297 fn test_canonicalize_scalar_values() {
298 let f16_value = f16::from_f32(5.722046e-6);
299 let f16_scalar = Scalar::primitive(f16_value, Nullability::NonNullable);
300
301 let const_array = ConstantArray::new(f16_scalar.clone(), 1).into_array();
303 let canonical_const = const_array.to_primitive();
304
305 assert_eq!(canonical_const.scalar_at(0), f16_scalar);
307 }
308
309 #[test]
310 fn test_canonicalize_lists() {
311 let list_scalar = Scalar::list(
312 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
313 vec![1u64.into(), 2u64.into()],
314 Nullability::NonNullable,
315 );
316 let const_array = ConstantArray::new(list_scalar, 2).into_array();
317 let canonical_const = const_array.to_list();
318 assert_eq!(
319 canonical_const.elements().to_primitive().as_slice::<u64>(),
320 [1u64, 2, 1, 2]
321 );
322 assert_eq!(
323 canonical_const.offsets().to_primitive().as_slice::<u64>(),
324 [0u64, 2, 4]
325 );
326 }
327
328 #[test]
329 fn test_canonicalize_empty_list() {
330 let list_scalar = Scalar::list(
331 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
332 vec![],
333 Nullability::NonNullable,
334 );
335 let const_array = ConstantArray::new(list_scalar, 2).into_array();
336 let canonical_const = const_array.to_list();
337 assert!(canonical_const.elements().to_primitive().is_empty());
338 assert_eq!(
339 canonical_const.offsets().to_primitive().as_slice::<u64>(),
340 [0u64, 0, 0]
341 );
342 }
343
344 #[test]
345 fn test_canonicalize_null_list() {
346 let list_scalar = Scalar::null(DType::List(
347 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
348 Nullability::Nullable,
349 ));
350 let const_array = ConstantArray::new(list_scalar, 2).into_array();
351 let canonical_const = const_array.to_list();
352 assert!(canonical_const.elements().to_primitive().is_empty());
353 assert_eq!(
354 canonical_const.offsets().to_primitive().as_slice::<u64>(),
355 [0u64, 0, 0]
356 );
357 }
358
359 #[test]
360 fn test_canonicalize_nullable_struct() {
361 let array = ConstantArray::new(
362 Scalar::null(DType::struct_(
363 [(
364 "non_null_field",
365 DType::Primitive(PType::I8, Nullability::NonNullable),
366 )],
367 Nullability::Nullable,
368 )),
369 3,
370 );
371
372 let struct_array = array.to_struct();
373 assert_eq!(struct_array.len(), 3);
374 assert_eq!(struct_array.valid_count(), 0);
375
376 let field = struct_array.field_by_name("non_null_field").unwrap();
377
378 assert_eq!(
379 field.dtype(),
380 &DType::Primitive(PType::I8, Nullability::NonNullable)
381 );
382 }
383}