vortex_array/arrays/constant/vtable/
mod.rs1use std::fmt::Debug;
5use std::hash::Hash;
6
7use vortex_buffer::ByteBufferMut;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_ensure;
11use vortex_error::vortex_panic;
12use vortex_session::VortexSession;
13
14use crate::ArrayRef;
15use crate::ExecutionCtx;
16use crate::ExecutionStep;
17use crate::IntoArray;
18use crate::Precision;
19use crate::arrays::ConstantArray;
20use crate::arrays::constant::compute::rules::PARENT_RULES;
21use crate::arrays::constant::vtable::canonical::constant_canonicalize;
22use crate::buffer::BufferHandle;
23use crate::builders::ArrayBuilder;
24use crate::builders::BoolBuilder;
25use crate::builders::DecimalBuilder;
26use crate::builders::NullBuilder;
27use crate::builders::PrimitiveBuilder;
28use crate::builders::VarBinViewBuilder;
29use crate::canonical::Canonical;
30use crate::dtype::DType;
31use crate::match_each_decimal_value;
32use crate::match_each_native_ptype;
33use crate::scalar::DecimalValue;
34use crate::scalar::Scalar;
35use crate::scalar::ScalarValue;
36use crate::serde::ArrayChildren;
37use crate::stats::StatsSetRef;
38use crate::vtable;
39use crate::vtable::ArrayId;
40use crate::vtable::VTable;
41pub(crate) mod canonical;
42mod operations;
43mod validity;
44
45vtable!(Constant);
46
47#[derive(Debug)]
48pub struct ConstantVTable;
49
50impl ConstantVTable {
51 pub const ID: ArrayId = ArrayId::new_ref("vortex.constant");
52}
53
54impl VTable for ConstantVTable {
55 type Array = ConstantArray;
56
57 type Metadata = Scalar;
58 type OperationsVTable = Self;
59 type ValidityVTable = Self;
60
61 fn id(_array: &Self::Array) -> ArrayId {
62 Self::ID
63 }
64
65 fn len(array: &ConstantArray) -> usize {
66 array.len
67 }
68
69 fn dtype(array: &ConstantArray) -> &DType {
70 array.scalar.dtype()
71 }
72
73 fn stats(array: &ConstantArray) -> StatsSetRef<'_> {
74 array.stats_set.to_ref(array.as_ref())
75 }
76
77 fn array_hash<H: std::hash::Hasher>(
78 array: &ConstantArray,
79 state: &mut H,
80 _precision: Precision,
81 ) {
82 array.scalar.hash(state);
83 array.len.hash(state);
84 }
85
86 fn array_eq(array: &ConstantArray, other: &ConstantArray, _precision: Precision) -> bool {
87 array.scalar == other.scalar && array.len == other.len
88 }
89
90 fn nbuffers(_array: &ConstantArray) -> usize {
91 1
92 }
93
94 fn buffer(array: &ConstantArray, idx: usize) -> BufferHandle {
95 match idx {
96 0 => BufferHandle::new_host(
97 ScalarValue::to_proto_bytes::<ByteBufferMut>(array.scalar.value()).freeze(),
98 ),
99 _ => vortex_panic!("ConstantArray buffer index {idx} out of bounds"),
100 }
101 }
102
103 fn buffer_name(_array: &ConstantArray, idx: usize) -> Option<String> {
104 match idx {
105 0 => Some("scalar".to_string()),
106 _ => None,
107 }
108 }
109
110 fn nchildren(_array: &ConstantArray) -> usize {
111 0
112 }
113
114 fn child(_array: &ConstantArray, idx: usize) -> ArrayRef {
115 vortex_panic!("ConstantArray child index {idx} out of bounds")
116 }
117
118 fn child_name(_array: &ConstantArray, idx: usize) -> String {
119 vortex_panic!("ConstantArray child_name index {idx} out of bounds")
120 }
121
122 fn metadata(array: &ConstantArray) -> VortexResult<Self::Metadata> {
123 Ok(array.scalar().clone())
124 }
125
126 fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
127 Ok(Some(vec![]))
130 }
131
132 fn deserialize(
133 _bytes: &[u8],
134 dtype: &DType,
135 _len: usize,
136 buffers: &[BufferHandle],
137 session: &VortexSession,
138 ) -> VortexResult<Self::Metadata> {
139 vortex_ensure!(
140 buffers.len() == 1,
141 "Expected 1 buffer, got {}",
142 buffers.len()
143 );
144
145 let buffer = buffers[0].clone().try_to_host_sync()?;
146 let bytes: &[u8] = buffer.as_ref();
147
148 let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype, session)?;
149 let scalar = Scalar::try_new(dtype.clone(), scalar_value)?;
150
151 Ok(scalar)
152 }
153
154 fn build(
155 _dtype: &DType,
156 len: usize,
157 metadata: &Self::Metadata,
158 _buffers: &[BufferHandle],
159 _children: &dyn ArrayChildren,
160 ) -> VortexResult<ConstantArray> {
161 Ok(ConstantArray::new(metadata.clone(), len))
162 }
163
164 fn with_children(_array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
165 vortex_ensure!(
166 children.is_empty(),
167 "ConstantArray has no children, got {}",
168 children.len()
169 );
170 Ok(())
171 }
172
173 fn reduce_parent(
174 array: &Self::Array,
175 parent: &ArrayRef,
176 child_idx: usize,
177 ) -> VortexResult<Option<ArrayRef>> {
178 PARENT_RULES.evaluate(array, parent, child_idx)
179 }
180
181 fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionStep> {
182 Ok(ExecutionStep::Done(
183 constant_canonicalize(array)?.into_array(),
184 ))
185 }
186
187 fn append_to_builder(
188 array: &ConstantArray,
189 builder: &mut dyn ArrayBuilder,
190 ctx: &mut ExecutionCtx,
191 ) -> VortexResult<()> {
192 let n = array.len();
193 let scalar = array.scalar();
194
195 match array.dtype() {
196 DType::Null => append_value_or_nulls::<NullBuilder>(builder, true, n, |_| {}),
197 DType::Bool(_) => {
198 append_value_or_nulls::<BoolBuilder>(builder, scalar.is_null(), n, |b| {
199 b.append_values(
200 scalar
201 .as_bool()
202 .value()
203 .vortex_expect("non-null bool scalar must have a value"),
204 n,
205 );
206 })
207 }
208 DType::Primitive(ptype, _) => {
209 match_each_native_ptype!(ptype, |P| {
210 append_value_or_nulls::<PrimitiveBuilder<P>>(
211 builder,
212 scalar.is_null(),
213 n,
214 |b| {
215 let value = P::try_from(scalar)
216 .vortex_expect("Couldn't unwrap constant scalar to primitive");
217 b.append_n_values(value, n);
218 },
219 );
220 });
221 }
222 DType::Decimal(..) => {
223 append_value_or_nulls::<DecimalBuilder>(builder, scalar.is_null(), n, |b| {
224 let value = scalar
225 .as_decimal()
226 .decimal_value()
227 .vortex_expect("non-null decimal scalar must have a value");
228 match_each_decimal_value!(value, |v| { b.append_n_values(v, n) });
229 });
230 }
231 DType::Utf8(_) => {
232 append_value_or_nulls::<VarBinViewBuilder>(builder, scalar.is_null(), n, |b| {
233 let typed = scalar.as_utf8();
234 let value = typed
235 .value()
236 .vortex_expect("non-null utf8 scalar must have a value");
237 b.append_n_values(value.as_bytes(), n);
238 });
239 }
240 DType::Binary(_) => {
241 append_value_or_nulls::<VarBinViewBuilder>(builder, scalar.is_null(), n, |b| {
242 let typed = scalar.as_binary();
243 let value = typed
244 .value()
245 .vortex_expect("non-null binary scalar must have a value");
246 b.append_n_values(value, n);
247 });
248 }
249 _ => {
251 let canonical = array
252 .clone()
253 .into_array()
254 .execute::<Canonical>(ctx)?
255 .into_array();
256 builder.extend_from_array(&canonical);
257 }
258 }
259
260 Ok(())
261 }
262}
263
264fn append_value_or_nulls<B: ArrayBuilder + 'static>(
269 builder: &mut dyn ArrayBuilder,
270 is_null: bool,
271 n: usize,
272 fill: impl FnOnce(&mut B),
273) {
274 let b = builder
275 .as_any_mut()
276 .downcast_mut::<B>()
277 .vortex_expect("builder dtype must match array dtype");
278 if is_null {
279 unsafe { b.append_nulls_unchecked(n) };
281 } else {
282 fill(b);
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use rstest::rstest;
289 use vortex_session::VortexSession;
290
291 use crate::ExecutionCtx;
292 use crate::IntoArray;
293 use crate::arrays::ConstantArray;
294 use crate::arrays::constant::vtable::canonical::constant_canonicalize;
295 use crate::assert_arrays_eq;
296 use crate::builders::builder_with_capacity;
297 use crate::dtype::DType;
298 use crate::dtype::Nullability;
299 use crate::dtype::PType;
300 use crate::dtype::StructFields;
301 use crate::scalar::Scalar;
302
303 fn ctx() -> ExecutionCtx {
304 ExecutionCtx::new(VortexSession::empty())
305 }
306
307 fn assert_append_matches_canonical(array: ConstantArray) -> vortex_error::VortexResult<()> {
309 let expected = constant_canonicalize(&array)?.into_array();
310 let mut builder = builder_with_capacity(array.dtype(), array.len());
311 array
312 .into_array()
313 .append_to_builder(builder.as_mut(), &mut ctx())?;
314 let result = builder.finish();
315 assert_arrays_eq!(&result, &expected);
316 Ok(())
317 }
318
319 #[test]
320 fn test_null_constant_append() -> vortex_error::VortexResult<()> {
321 assert_append_matches_canonical(ConstantArray::new(Scalar::null(DType::Null), 5))
322 }
323
324 #[rstest]
325 #[case::bool_true(true, 5)]
326 #[case::bool_false(false, 3)]
327 fn test_bool_constant_append(
328 #[case] value: bool,
329 #[case] n: usize,
330 ) -> vortex_error::VortexResult<()> {
331 assert_append_matches_canonical(ConstantArray::new(
332 Scalar::bool(value, Nullability::NonNullable),
333 n,
334 ))
335 }
336
337 #[test]
338 fn test_bool_null_constant_append() -> vortex_error::VortexResult<()> {
339 assert_append_matches_canonical(ConstantArray::new(
340 Scalar::null(DType::Bool(Nullability::Nullable)),
341 4,
342 ))
343 }
344
345 #[rstest]
346 #[case::i32(Scalar::primitive(42i32, Nullability::NonNullable), 5)]
347 #[case::u8(Scalar::primitive(7u8, Nullability::NonNullable), 3)]
348 #[case::f64(Scalar::primitive(1.5f64, Nullability::NonNullable), 4)]
349 #[case::i32_null(Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)), 3)]
350 fn test_primitive_constant_append(
351 #[case] scalar: Scalar,
352 #[case] n: usize,
353 ) -> vortex_error::VortexResult<()> {
354 assert_append_matches_canonical(ConstantArray::new(scalar, n))
355 }
356
357 #[rstest]
358 #[case::utf8_inline("hi", 5)] #[case::utf8_noninline("hello world!!", 5)] #[case::utf8_empty("", 3)]
361 #[case::utf8_n_zero("hello world!!", 0)] fn test_utf8_constant_append(
363 #[case] value: &str,
364 #[case] n: usize,
365 ) -> vortex_error::VortexResult<()> {
366 assert_append_matches_canonical(ConstantArray::new(
367 Scalar::utf8(value, Nullability::NonNullable),
368 n,
369 ))
370 }
371
372 #[test]
373 fn test_utf8_null_constant_append() -> vortex_error::VortexResult<()> {
374 assert_append_matches_canonical(ConstantArray::new(
375 Scalar::null(DType::Utf8(Nullability::Nullable)),
376 4,
377 ))
378 }
379
380 #[rstest]
381 #[case::binary_inline(vec![1u8, 2, 3], 5)] #[case::binary_noninline(vec![0u8; 13], 5)] fn test_binary_constant_append(
384 #[case] value: Vec<u8>,
385 #[case] n: usize,
386 ) -> vortex_error::VortexResult<()> {
387 assert_append_matches_canonical(ConstantArray::new(
388 Scalar::binary(value, Nullability::NonNullable),
389 n,
390 ))
391 }
392
393 #[test]
394 fn test_binary_null_constant_append() -> vortex_error::VortexResult<()> {
395 assert_append_matches_canonical(ConstantArray::new(
396 Scalar::null(DType::Binary(Nullability::Nullable)),
397 4,
398 ))
399 }
400
401 #[test]
402 fn test_struct_constant_append() -> vortex_error::VortexResult<()> {
403 let fields = StructFields::new(
404 ["x", "y"].into(),
405 vec![
406 DType::Primitive(PType::I32, Nullability::NonNullable),
407 DType::Utf8(Nullability::NonNullable),
408 ],
409 );
410 let scalar = Scalar::struct_(
411 DType::Struct(fields, Nullability::NonNullable),
412 [
413 Scalar::primitive(42i32, Nullability::NonNullable),
414 Scalar::utf8("hi", Nullability::NonNullable),
415 ],
416 );
417 assert_append_matches_canonical(ConstantArray::new(scalar, 3))
418 }
419
420 #[test]
421 fn test_null_struct_constant_append() -> vortex_error::VortexResult<()> {
422 let fields = StructFields::new(
423 ["x"].into(),
424 vec![DType::Primitive(PType::I32, Nullability::Nullable)],
425 );
426 let dtype = DType::Struct(fields, Nullability::Nullable);
427 assert_append_matches_canonical(ConstantArray::new(Scalar::null(dtype), 4))
428 }
429}