vortex_array/arrays/constant/
canonical.rs1use arrow_array::builder::make_view;
2use arrow_buffer::BooleanBuffer;
3use vortex_buffer::{Buffer, BufferMut, buffer};
4use vortex_dtype::{DType, Nullability, PType, match_each_native_ptype};
5use vortex_error::{VortexExpect, VortexResult};
6use vortex_scalar::{
7 BinaryScalar, BoolScalar, ExtScalar, ListScalar, Scalar, ScalarValue, StructScalar, Utf8Scalar,
8};
9
10use crate::array::ArrayCanonicalImpl;
11use crate::arrays::constant::ConstantArray;
12use crate::arrays::primitive::PrimitiveArray;
13use crate::arrays::{
14 BinaryView, BoolArray, ExtensionArray, ListArray, NullArray, StructArray, VarBinViewArray,
15};
16use crate::builders::{ArrayBuilderExt, builder_with_capacity};
17use crate::validity::Validity;
18use crate::{Array, Canonical, IntoArray};
19
20impl ArrayCanonicalImpl for ConstantArray {
21 fn _to_canonical(&self) -> VortexResult<Canonical> {
22 let scalar = self.scalar();
23
24 let validity = match self.dtype().nullability() {
25 Nullability::NonNullable => Validity::NonNullable,
26 Nullability::Nullable => match scalar.is_null() {
27 true => Validity::AllInvalid,
28 false => Validity::AllValid,
29 },
30 };
31
32 Ok(match self.dtype() {
33 DType::Null => Canonical::Null(NullArray::new(self.len())),
34 DType::Bool(..) => Canonical::Bool(BoolArray::new(
35 if BoolScalar::try_from(scalar)?.value().unwrap_or_default() {
36 BooleanBuffer::new_set(self.len())
37 } else {
38 BooleanBuffer::new_unset(self.len())
39 },
40 validity,
41 )),
42 DType::Primitive(ptype, ..) => {
43 match_each_native_ptype!(ptype, |$P| {
44 Canonical::Primitive(PrimitiveArray::new(
45 if scalar.is_valid() {
46 Buffer::full(
47 $P::try_from(scalar)
48 .vortex_expect("Couldn't unwrap scalar to primitive"),
49 self.len(),
50 )
51 } else {
52 Buffer::zeroed(self.len())
53 },
54 validity,
55 ))
56 })
57 }
58 DType::Utf8(_) => {
59 let value = Utf8Scalar::try_from(scalar)?.value();
60 let const_value = value.as_ref().map(|v| v.as_bytes());
61 Canonical::VarBinView(canonical_byte_view(const_value, self.dtype(), self.len())?)
62 }
63 DType::Binary(_) => {
64 let value = BinaryScalar::try_from(scalar)?.value();
65 let const_value = value.as_ref().map(|v| v.as_slice());
66 Canonical::VarBinView(canonical_byte_view(const_value, self.dtype(), self.len())?)
67 }
68 DType::Struct(..) => {
69 let value = StructScalar::try_from(scalar)?;
70 let fields = value.fields().map(|fields| {
71 fields
72 .into_iter()
73 .map(|s| ConstantArray::new(s, self.len()).into_array())
74 .collect::<Vec<_>>()
75 });
76 Canonical::Struct(StructArray::try_new(
77 value.struct_dtype().names().clone(),
78 fields.unwrap_or_default(),
79 self.len(),
80 validity,
81 )?)
82 }
83 DType::List(..) => {
84 let value = ListScalar::try_from(scalar)?;
85 Canonical::List(canonical_list_array(
86 value.elements(),
87 value.element_dtype(),
88 value.dtype().nullability(),
89 self.len(),
90 )?)
91 }
92 DType::Extension(ext_dtype) => {
93 let s = ExtScalar::try_from(scalar)?;
94
95 let storage_scalar = s.storage();
96 let storage_self = ConstantArray::new(storage_scalar, self.len()).into_array();
97 Canonical::Extension(ExtensionArray::new(ext_dtype.clone(), storage_self))
98 }
99 })
100 }
101}
102
103fn canonical_byte_view(
104 scalar_bytes: Option<&[u8]>,
105 dtype: &DType,
106 len: usize,
107) -> VortexResult<VarBinViewArray> {
108 match scalar_bytes {
109 None => {
110 let views = buffer![BinaryView::from(0_u128); len];
111
112 VarBinViewArray::try_new(views, Vec::new(), dtype.clone(), Validity::AllInvalid)
113 }
114 Some(scalar_bytes) => {
115 let view = BinaryView::from(make_view(scalar_bytes, 0, 0));
118 let mut buffers = Vec::new();
119 if scalar_bytes.len() >= BinaryView::MAX_INLINED_SIZE {
120 buffers.push(Buffer::copy_from(scalar_bytes));
121 }
122
123 let mut views = BufferMut::with_capacity_aligned(len, align_of::<u128>().into());
127 for _ in 0..len {
128 views.push(view);
129 }
130
131 VarBinViewArray::try_new(
132 views.freeze(),
133 buffers,
134 dtype.clone(),
135 Validity::from(dtype.nullability()),
136 )
137 }
138 }
139}
140
141fn canonical_list_array(
142 values: Option<Vec<Scalar>>,
143 element_dtype: &DType,
144 list_nullability: Nullability,
145 len: usize,
146) -> VortexResult<ListArray> {
147 match values {
148 None => ListArray::try_new(
149 ConstantArray::new(Scalar::null(element_dtype.clone()), 1).into_array(),
150 ConstantArray::new(
151 Scalar::new(
152 DType::Primitive(PType::U64, Nullability::NonNullable),
153 ScalarValue::from(0),
154 ),
155 len,
156 )
157 .into_array(),
158 Validity::AllInvalid,
159 ),
160 Some(vs) => {
161 let mut elements_builder = builder_with_capacity(element_dtype, len * vs.len());
162 for _ in 0..len {
163 for v in &vs {
164 elements_builder.append_scalar(v)?;
165 }
166 }
167 let offsets = (0..=len * vs.len())
168 .step_by(vs.len())
169 .map(|i| i as u64)
170 .collect::<Buffer<_>>();
171
172 ListArray::try_new(
173 elements_builder.finish(),
174 offsets.into_array(),
175 Validity::from(list_nullability),
176 )
177 }
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use std::sync::Arc;
184
185 use enum_iterator::all;
186 use vortex_dtype::half::f16;
187 use vortex_dtype::{DType, Nullability, PType};
188 use vortex_scalar::Scalar;
189
190 use crate::array::Array;
191 use crate::arrays::ConstantArray;
192 use crate::canonical::ToCanonical;
193 use crate::compute::scalar_at;
194 use crate::stats::{Stat, StatsProviderExt, StatsSet};
195
196 #[test]
197 fn test_canonicalize_null() {
198 let const_null = ConstantArray::new(Scalar::null(DType::Null), 42);
199 let actual = const_null.to_null().unwrap();
200 assert_eq!(actual.len(), 42);
201 assert_eq!(scalar_at(&actual, 33).unwrap(), Scalar::null(DType::Null));
202 }
203
204 #[test]
205 fn test_canonicalize_const_str() {
206 let const_array = ConstantArray::new("four".to_string(), 4);
207
208 let canonical = const_array.to_varbinview().unwrap();
210
211 assert_eq!(canonical.len(), 4);
212
213 for i in 0..=3 {
214 assert_eq!(scalar_at(&canonical, i).unwrap(), "four".into());
215 }
216 }
217
218 #[test]
219 fn test_canonicalize_propagates_stats() {
220 let scalar = Scalar::bool(true, Nullability::NonNullable);
221 let const_array = ConstantArray::new(scalar.clone(), 4).into_array();
222 let stats = const_array.statistics().to_owned();
223
224 let canonical = const_array.to_canonical().unwrap();
225 let canonical_stats = canonical.as_ref().statistics().to_owned();
226
227 let reference = StatsSet::constant(scalar, 4);
228 for stat in all::<Stat>() {
229 let canonical_stat =
230 canonical_stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
231 let reference_stat =
232 reference.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
233 let original_stat =
234 stats.get_scalar(stat, &stat.dtype(canonical.as_ref().dtype()).unwrap());
235 assert_eq!(canonical_stat, reference_stat);
236 assert_eq!(canonical_stat, original_stat);
237 }
238 }
239
240 #[test]
241 fn test_canonicalize_scalar_values() {
242 let f16_scalar = Scalar::primitive(f16::from_f32(5.722046e-6), Nullability::NonNullable);
243 let scalar = Scalar::new(
244 DType::Primitive(PType::F16, Nullability::NonNullable),
245 Scalar::primitive(96u8, Nullability::NonNullable).into_value(),
246 );
247 let const_array = ConstantArray::new(scalar.clone(), 1).into_array();
248 let canonical_const = const_array.to_primitive().unwrap();
249 assert_eq!(scalar_at(&canonical_const, 0).unwrap(), scalar);
250 assert_eq!(scalar_at(&canonical_const, 0).unwrap(), f16_scalar);
251 }
252
253 #[test]
254 fn test_canonicalize_lists() {
255 let list_scalar = Scalar::list(
256 Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)),
257 vec![1u64.into(), 2u64.into()],
258 Nullability::NonNullable,
259 );
260 let const_array = ConstantArray::new(list_scalar, 2).into_array();
261 let canonical_const = const_array.to_list().unwrap();
262 assert_eq!(
263 canonical_const
264 .elements()
265 .to_primitive()
266 .unwrap()
267 .as_slice::<u64>(),
268 [1u64, 2, 1, 2]
269 );
270 assert_eq!(
271 canonical_const
272 .offsets()
273 .to_primitive()
274 .unwrap()
275 .as_slice::<u64>(),
276 [0u64, 2, 4]
277 );
278 }
279}