vortex_array/arrays/constant/
canonical.rs1use arrow_buffer::BooleanBuffer;
2use vortex_buffer::{Buffer, BufferMut, buffer};
3use vortex_dtype::{DType, Nullability, PType, match_each_native_ptype};
4use vortex_error::{VortexExpect, VortexResult};
5use vortex_scalar::{
6 BinaryScalar, BoolScalar, ExtScalar, ListScalar, Scalar, ScalarValue, StructScalar, Utf8Scalar,
7};
8
9use crate::array::ArrayCanonicalImpl;
10use crate::arrays::constant::ConstantArray;
11use crate::arrays::primitive::PrimitiveArray;
12use crate::arrays::{
13 BinaryView, BoolArray, ExtensionArray, ListArray, NullArray, StructArray, VarBinViewArray,
14};
15use crate::builders::{ArrayBuilderExt, builder_with_capacity};
16use crate::validity::Validity;
17use crate::{Array, Canonical, IntoArray};
18
19impl ArrayCanonicalImpl for ConstantArray {
20 fn _to_canonical(&self) -> VortexResult<Canonical> {
21 let scalar = self.scalar();
22
23 let validity = match self.dtype().nullability() {
24 Nullability::NonNullable => Validity::NonNullable,
25 Nullability::Nullable => match scalar.is_null() {
26 true => Validity::AllInvalid,
27 false => Validity::AllValid,
28 },
29 };
30
31 Ok(match self.dtype() {
32 DType::Null => Canonical::Null(NullArray::new(self.len())),
33 DType::Bool(..) => Canonical::Bool(BoolArray::new(
34 if BoolScalar::try_from(scalar)?.value().unwrap_or_default() {
35 BooleanBuffer::new_set(self.len())
36 } else {
37 BooleanBuffer::new_unset(self.len())
38 },
39 validity,
40 )),
41 DType::Primitive(ptype, ..) => {
42 match_each_native_ptype!(ptype, |$P| {
43 Canonical::Primitive(PrimitiveArray::new(
44 if scalar.is_valid() {
45 Buffer::full(
46 $P::try_from(scalar)
47 .vortex_expect("Couldn't unwrap scalar to primitive"),
48 self.len(),
49 )
50 } else {
51 Buffer::zeroed(self.len())
52 },
53 validity,
54 ))
55 })
56 }
57 DType::Utf8(_) => {
58 let value = Utf8Scalar::try_from(scalar)?.value();
59 let const_value = value.as_ref().map(|v| v.as_bytes());
60 Canonical::VarBinView(canonical_byte_view(const_value, self.dtype(), self.len())?)
61 }
62 DType::Binary(_) => {
63 let value = BinaryScalar::try_from(scalar)?.value();
64 let const_value = value.as_ref().map(|v| v.as_slice());
65 Canonical::VarBinView(canonical_byte_view(const_value, self.dtype(), self.len())?)
66 }
67 DType::Struct(struct_dtype, _) => {
68 let value = StructScalar::try_from(scalar)?;
69 let fields = value.fields().map(|fields| {
70 fields
71 .into_iter()
72 .map(|s| ConstantArray::new(s, self.len()).into_array())
73 .collect::<Vec<_>>()
74 });
75 Canonical::Struct(StructArray::try_new_with_dtype(
76 fields.unwrap_or_default(),
77 struct_dtype.clone(),
78 self.len(),
79 validity,
80 )?)
81 }
82 DType::List(..) => {
83 let value = ListScalar::try_from(scalar)?;
84 Canonical::List(canonical_list_array(
85 value.elements(),
86 value.element_dtype(),
87 value.dtype().nullability(),
88 self.len(),
89 )?)
90 }
91 DType::Extension(ext_dtype) => {
92 let s = ExtScalar::try_from(scalar)?;
93
94 let storage_scalar = s.storage();
95 let storage_self = ConstantArray::new(storage_scalar, self.len()).into_array();
96 Canonical::Extension(ExtensionArray::new(ext_dtype.clone(), storage_self))
97 }
98 })
99 }
100}
101
102fn canonical_byte_view(
103 scalar_bytes: Option<&[u8]>,
104 dtype: &DType,
105 len: usize,
106) -> VortexResult<VarBinViewArray> {
107 match scalar_bytes {
108 None => {
109 let views = buffer![BinaryView::from(0_u128); len];
110
111 VarBinViewArray::try_new(views, Vec::new(), dtype.clone(), Validity::AllInvalid)
112 }
113 Some(scalar_bytes) => {
114 let view = BinaryView::make_view(scalar_bytes, 0, 0);
117 let mut buffers = Vec::new();
118 if scalar_bytes.len() >= BinaryView::MAX_INLINED_SIZE {
119 buffers.push(Buffer::copy_from(scalar_bytes));
120 }
121
122 let mut views = BufferMut::with_capacity_aligned(len, align_of::<u128>().into());
126 for _ in 0..len {
127 views.push(view);
128 }
129
130 VarBinViewArray::try_new(
131 views.freeze(),
132 buffers,
133 dtype.clone(),
134 Validity::from(dtype.nullability()),
135 )
136 }
137 }
138}
139
140fn canonical_list_array(
141 values: Option<Vec<Scalar>>,
142 element_dtype: &DType,
143 list_nullability: Nullability,
144 len: usize,
145) -> VortexResult<ListArray> {
146 match values {
147 None => ListArray::try_new(
148 Canonical::empty(element_dtype).into_array(),
149 ConstantArray::new(
150 Scalar::new(
151 DType::Primitive(PType::U64, Nullability::NonNullable),
152 ScalarValue::from(0),
153 ),
154 len + 1,
155 )
156 .into_array(),
157 Validity::AllInvalid,
158 ),
159 Some(vs) => {
160 let mut elements_builder = builder_with_capacity(element_dtype, len * vs.len());
161 for _ in 0..len {
162 for v in &vs {
163 elements_builder.append_scalar(v)?;
164 }
165 }
166 let offsets = if vs.is_empty() {
167 Buffer::zeroed(len + 1)
168 } else {
169 (0..=len * vs.len())
170 .step_by(vs.len())
171 .map(|i| i as u64)
172 .collect::<Buffer<_>>()
173 };
174
175 ListArray::try_new(
176 elements_builder.finish(),
177 offsets.into_array(),
178 Validity::from(list_nullability),
179 )
180 }
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use std::sync::Arc;
187
188 use enum_iterator::all;
189 use vortex_dtype::half::f16;
190 use vortex_dtype::{DType, Nullability, PType};
191 use vortex_scalar::Scalar;
192
193 use crate::array::Array;
194 use crate::arrays::ConstantArray;
195 use crate::canonical::ToCanonical;
196 use crate::compute::scalar_at;
197 use crate::stats::{Stat, StatsProviderExt, StatsSet};
198
199 #[test]
200 fn test_canonicalize_null() {
201 let const_null = ConstantArray::new(Scalar::null(DType::Null), 42);
202 let actual = const_null.to_null().unwrap();
203 assert_eq!(actual.len(), 42);
204 assert_eq!(scalar_at(&actual, 33).unwrap(), Scalar::null(DType::Null));
205 }
206
207 #[test]
208 fn test_canonicalize_const_str() {
209 let const_array = ConstantArray::new("four".to_string(), 4);
210
211 let canonical = const_array.to_varbinview().unwrap();
213
214 assert_eq!(canonical.len(), 4);
215
216 for i in 0..=3 {
217 assert_eq!(scalar_at(&canonical, i).unwrap(), "four".into());
218 }
219 }
220
221 #[test]
222 fn test_canonicalize_propagates_stats() {
223 let scalar = Scalar::bool(true, Nullability::NonNullable);
224 let const_array = ConstantArray::new(scalar.clone(), 4).into_array();
225 let stats = const_array.statistics().to_owned();
226
227 let canonical = const_array.to_canonical().unwrap();
228 let canonical_stats = canonical.as_ref().statistics().to_owned();
229
230 let reference = StatsSet::constant(scalar, 4);
231 for stat in all::<Stat>() {
232 let canonical_stat =
233 canonical_stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
234 let reference_stat =
235 reference.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
236 let original_stat =
237 stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
238 assert_eq!(canonical_stat, reference_stat);
239 assert_eq!(canonical_stat, original_stat);
240 }
241 }
242
243 #[test]
244 fn test_canonicalize_scalar_values() {
245 let f16_scalar = Scalar::primitive(f16::from_f32(5.722046e-6), Nullability::NonNullable);
246 let scalar = Scalar::new(
247 DType::Primitive(PType::F16, Nullability::NonNullable),
248 Scalar::primitive(96u8, Nullability::NonNullable).into_value(),
249 );
250 let const_array = ConstantArray::new(scalar.clone(), 1).into_array();
251 let canonical_const = const_array.to_primitive().unwrap();
252 assert_eq!(scalar_at(&canonical_const, 0).unwrap(), scalar);
253 assert_eq!(scalar_at(&canonical_const, 0).unwrap(), f16_scalar);
254 }
255
256 #[test]
257 fn test_canonicalize_lists() {
258 let list_scalar = Scalar::list(
259 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
260 vec![1u64.into(), 2u64.into()],
261 Nullability::NonNullable,
262 );
263 let const_array = ConstantArray::new(list_scalar, 2).into_array();
264 let canonical_const = const_array.to_list().unwrap();
265 assert_eq!(
266 canonical_const
267 .elements()
268 .to_primitive()
269 .unwrap()
270 .as_slice::<u64>(),
271 [1u64, 2, 1, 2]
272 );
273 assert_eq!(
274 canonical_const
275 .offsets()
276 .to_primitive()
277 .unwrap()
278 .as_slice::<u64>(),
279 [0u64, 2, 4]
280 );
281 }
282
283 #[test]
284 fn test_canonicalize_empty_list() {
285 let list_scalar = Scalar::list(
286 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
287 vec![],
288 Nullability::NonNullable,
289 );
290 let const_array = ConstantArray::new(list_scalar, 2).into_array();
291 let canonical_const = const_array.to_list().unwrap();
292 assert!(
293 canonical_const
294 .elements()
295 .to_primitive()
296 .unwrap()
297 .is_empty()
298 );
299 assert_eq!(
300 canonical_const
301 .offsets()
302 .to_primitive()
303 .unwrap()
304 .as_slice::<u64>(),
305 [0u64, 0, 0]
306 );
307 }
308
309 #[test]
310 fn test_canonicalize_null_list() {
311 let list_scalar = Scalar::null(DType::List(
312 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
313 Nullability::Nullable,
314 ));
315 let const_array = ConstantArray::new(list_scalar, 2).into_array();
316 let canonical_const = const_array.to_list().unwrap();
317 assert!(
318 canonical_const
319 .elements()
320 .to_primitive()
321 .unwrap()
322 .is_empty()
323 );
324 assert_eq!(
325 canonical_const
326 .offsets()
327 .to_primitive()
328 .unwrap()
329 .as_slice::<u64>(),
330 [0u64, 0, 0]
331 );
332 }
333}