vortex_array/arrays/constant/
canonical.rs1use arrow_buffer::BooleanBuffer;
2use vortex_buffer::{Buffer, BufferMut, buffer};
3use vortex_dtype::{DType, Nullability, PType, match_each_native_ptype};
4use vortex_error::{VortexExpect, VortexResult};
5use vortex_scalar::{
6 BinaryScalar, BoolScalar, DecimalValue, ExtScalar, ListScalar, Scalar, ScalarValue,
7 StructScalar, Utf8Scalar,
8};
9
10use crate::array::ArrayCanonicalImpl;
11use crate::arrays::constant::ConstantArray;
12use crate::arrays::primitive::PrimitiveArray;
13use crate::arrays::{
14 BinaryView, BoolArray, DecimalArray, ExtensionArray, ListArray, NullArray, StructArray,
15 VarBinViewArray, precision_to_storage_size,
16};
17use crate::builders::{ArrayBuilderExt, builder_with_capacity};
18use crate::validity::Validity;
19use crate::{Array, Canonical, IntoArray, match_each_decimal_value, match_each_decimal_value_type};
20
21impl ArrayCanonicalImpl for ConstantArray {
22 fn _to_canonical(&self) -> VortexResult<Canonical> {
23 let scalar = self.scalar();
24
25 let validity = match self.dtype().nullability() {
26 Nullability::NonNullable => Validity::NonNullable,
27 Nullability::Nullable => match scalar.is_null() {
28 true => Validity::AllInvalid,
29 false => Validity::AllValid,
30 },
31 };
32
33 Ok(match self.dtype() {
34 DType::Null => Canonical::Null(NullArray::new(self.len())),
35 DType::Bool(..) => Canonical::Bool(BoolArray::new(
36 if BoolScalar::try_from(scalar)?.value().unwrap_or_default() {
37 BooleanBuffer::new_set(self.len())
38 } else {
39 BooleanBuffer::new_unset(self.len())
40 },
41 validity,
42 )),
43 DType::Primitive(ptype, ..) => {
44 match_each_native_ptype!(ptype, |$P| {
45 Canonical::Primitive(PrimitiveArray::new(
46 if scalar.is_valid() {
47 Buffer::full(
48 $P::try_from(scalar)
49 .vortex_expect("Couldn't unwrap scalar to primitive"),
50 self.len(),
51 )
52 } else {
53 Buffer::zeroed(self.len())
54 },
55 validity,
56 ))
57 })
58 }
59 DType::Decimal(decimal_type, ..) => {
60 let size = precision_to_storage_size(decimal_type);
61 let decimal = scalar.as_decimal();
62 let Some(value) = decimal.decimal_value() else {
63 let all_null = match_each_decimal_value_type!(size, |$D| {
64 DecimalArray::new(
65 Buffer::<$D>::zeroed(self.len()),
66 *decimal_type,
67 Validity::AllInvalid,
68 )
69 });
70 return Ok(Canonical::Decimal(all_null));
71 };
72
73 let decimal_array = match_each_decimal_value!(value, |$V| {
74 DecimalArray::new(
75 Buffer::full(*$V, self.len()),
76 *decimal_type,
77 Validity::AllValid,
78 )
79 });
80 Canonical::Decimal(decimal_array)
81 }
82 DType::Utf8(_) => {
83 let value = Utf8Scalar::try_from(scalar)?.value();
84 let const_value = value.as_ref().map(|v| v.as_bytes());
85 Canonical::VarBinView(canonical_byte_view(const_value, self.dtype(), self.len())?)
86 }
87 DType::Binary(_) => {
88 let value = BinaryScalar::try_from(scalar)?.value();
89 let const_value = value.as_ref().map(|v| v.as_slice());
90 Canonical::VarBinView(canonical_byte_view(const_value, self.dtype(), self.len())?)
91 }
92 DType::Struct(struct_dtype, _) => {
93 let value = StructScalar::try_from(scalar)?;
94 let fields = value.fields().map(|fields| {
95 fields
96 .into_iter()
97 .map(|s| ConstantArray::new(s, self.len()).into_array())
98 .collect::<Vec<_>>()
99 });
100 Canonical::Struct(StructArray::try_new_with_dtype(
101 fields.unwrap_or_default(),
102 struct_dtype.clone(),
103 self.len(),
104 validity,
105 )?)
106 }
107 DType::List(..) => {
108 let value = ListScalar::try_from(scalar)?;
109 Canonical::List(canonical_list_array(
110 value.elements(),
111 value.element_dtype(),
112 value.dtype().nullability(),
113 self.len(),
114 )?)
115 }
116 DType::Extension(ext_dtype) => {
117 let s = ExtScalar::try_from(scalar)?;
118
119 let storage_scalar = s.storage();
120 let storage_self = ConstantArray::new(storage_scalar, self.len()).into_array();
121 Canonical::Extension(ExtensionArray::new(ext_dtype.clone(), storage_self))
122 }
123 })
124 }
125}
126
127fn canonical_byte_view(
128 scalar_bytes: Option<&[u8]>,
129 dtype: &DType,
130 len: usize,
131) -> VortexResult<VarBinViewArray> {
132 match scalar_bytes {
133 None => {
134 let views = buffer![BinaryView::from(0_u128); len];
135
136 VarBinViewArray::try_new(views, Vec::new(), dtype.clone(), Validity::AllInvalid)
137 }
138 Some(scalar_bytes) => {
139 let view = BinaryView::make_view(scalar_bytes, 0, 0);
142 let mut buffers = Vec::new();
143 if scalar_bytes.len() >= BinaryView::MAX_INLINED_SIZE {
144 buffers.push(Buffer::copy_from(scalar_bytes));
145 }
146
147 let mut views = BufferMut::with_capacity_aligned(len, align_of::<u128>().into());
151 for _ in 0..len {
152 views.push(view);
153 }
154
155 VarBinViewArray::try_new(
156 views.freeze(),
157 buffers,
158 dtype.clone(),
159 Validity::from(dtype.nullability()),
160 )
161 }
162 }
163}
164
165fn canonical_list_array(
166 values: Option<Vec<Scalar>>,
167 element_dtype: &DType,
168 list_nullability: Nullability,
169 len: usize,
170) -> VortexResult<ListArray> {
171 match values {
172 None => ListArray::try_new(
173 Canonical::empty(element_dtype).into_array(),
174 ConstantArray::new(
175 Scalar::new(
176 DType::Primitive(PType::U64, Nullability::NonNullable),
177 ScalarValue::from(0),
178 ),
179 len + 1,
180 )
181 .into_array(),
182 Validity::AllInvalid,
183 ),
184 Some(vs) => {
185 let mut elements_builder = builder_with_capacity(element_dtype, len * vs.len());
186 for _ in 0..len {
187 for v in &vs {
188 elements_builder.append_scalar(v)?;
189 }
190 }
191 let offsets = if vs.is_empty() {
192 Buffer::zeroed(len + 1)
193 } else {
194 (0..=len * vs.len())
195 .step_by(vs.len())
196 .map(|i| i as u64)
197 .collect::<Buffer<_>>()
198 };
199
200 ListArray::try_new(
201 elements_builder.finish(),
202 offsets.into_array(),
203 Validity::from(list_nullability),
204 )
205 }
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use std::sync::Arc;
212
213 use enum_iterator::all;
214 use vortex_dtype::half::f16;
215 use vortex_dtype::{DType, Nullability, PType};
216 use vortex_scalar::Scalar;
217
218 use crate::array::Array;
219 use crate::arrays::ConstantArray;
220 use crate::canonical::ToCanonical;
221 use crate::compute::scalar_at;
222 use crate::stats::{Stat, StatsProviderExt, StatsSet};
223
224 #[test]
225 fn test_canonicalize_null() {
226 let const_null = ConstantArray::new(Scalar::null(DType::Null), 42);
227 let actual = const_null.to_null().unwrap();
228 assert_eq!(actual.len(), 42);
229 assert_eq!(scalar_at(&actual, 33).unwrap(), Scalar::null(DType::Null));
230 }
231
232 #[test]
233 fn test_canonicalize_const_str() {
234 let const_array = ConstantArray::new("four".to_string(), 4);
235
236 let canonical = const_array.to_varbinview().unwrap();
238
239 assert_eq!(canonical.len(), 4);
240
241 for i in 0..=3 {
242 assert_eq!(scalar_at(&canonical, i).unwrap(), "four".into());
243 }
244 }
245
246 #[test]
247 fn test_canonicalize_propagates_stats() {
248 let scalar = Scalar::bool(true, Nullability::NonNullable);
249 let const_array = ConstantArray::new(scalar.clone(), 4).into_array();
250 let stats = const_array.statistics().to_owned();
251
252 let canonical = const_array.to_canonical().unwrap();
253 let canonical_stats = canonical.as_ref().statistics().to_owned();
254
255 let reference = StatsSet::constant(scalar, 4);
256 for stat in all::<Stat>() {
257 let canonical_stat =
258 canonical_stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
259 let reference_stat =
260 reference.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
261 let original_stat =
262 stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
263 assert_eq!(canonical_stat, reference_stat);
264 assert_eq!(canonical_stat, original_stat);
265 }
266 }
267
268 #[test]
269 fn test_canonicalize_scalar_values() {
270 let f16_scalar = Scalar::primitive(f16::from_f32(5.722046e-6), Nullability::NonNullable);
271 let scalar = Scalar::new(
272 DType::Primitive(PType::F16, Nullability::NonNullable),
273 Scalar::primitive(96u8, Nullability::NonNullable).into_value(),
274 );
275 let const_array = ConstantArray::new(scalar.clone(), 1).into_array();
276 let canonical_const = const_array.to_primitive().unwrap();
277 assert_eq!(scalar_at(&canonical_const, 0).unwrap(), scalar);
278 assert_eq!(scalar_at(&canonical_const, 0).unwrap(), f16_scalar);
279 }
280
281 #[test]
282 fn test_canonicalize_lists() {
283 let list_scalar = Scalar::list(
284 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
285 vec![1u64.into(), 2u64.into()],
286 Nullability::NonNullable,
287 );
288 let const_array = ConstantArray::new(list_scalar, 2).into_array();
289 let canonical_const = const_array.to_list().unwrap();
290 assert_eq!(
291 canonical_const
292 .elements()
293 .to_primitive()
294 .unwrap()
295 .as_slice::<u64>(),
296 [1u64, 2, 1, 2]
297 );
298 assert_eq!(
299 canonical_const
300 .offsets()
301 .to_primitive()
302 .unwrap()
303 .as_slice::<u64>(),
304 [0u64, 2, 4]
305 );
306 }
307
308 #[test]
309 fn test_canonicalize_empty_list() {
310 let list_scalar = Scalar::list(
311 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
312 vec![],
313 Nullability::NonNullable,
314 );
315 let const_array = ConstantArray::new(list_scalar, 2).into_array();
316 let canonical_const = const_array.to_list().unwrap();
317 assert!(
318 canonical_const
319 .elements()
320 .to_primitive()
321 .unwrap()
322 .is_empty()
323 );
324 assert_eq!(
325 canonical_const
326 .offsets()
327 .to_primitive()
328 .unwrap()
329 .as_slice::<u64>(),
330 [0u64, 0, 0]
331 );
332 }
333
334 #[test]
335 fn test_canonicalize_null_list() {
336 let list_scalar = Scalar::null(DType::List(
337 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
338 Nullability::Nullable,
339 ));
340 let const_array = ConstantArray::new(list_scalar, 2).into_array();
341 let canonical_const = const_array.to_list().unwrap();
342 assert!(
343 canonical_const
344 .elements()
345 .to_primitive()
346 .unwrap()
347 .is_empty()
348 );
349 assert_eq!(
350 canonical_const
351 .offsets()
352 .to_primitive()
353 .unwrap()
354 .as_slice::<u64>(),
355 [0u64, 0, 0]
356 );
357 }
358}