vortex_array/arrays/constant/vtable/
mod.rs1use std::fmt::Debug;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use vortex_buffer::ByteBufferMut;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_panic;
13use vortex_session::VortexSession;
14
15use crate::ArrayRef;
16use crate::ExecutionCtx;
17use crate::ExecutionResult;
18use crate::IntoArray;
19use crate::Precision;
20use crate::arrays::ConstantArray;
21use crate::arrays::constant::compute::rules::PARENT_RULES;
22use crate::arrays::constant::vtable::canonical::constant_canonicalize;
23use crate::buffer::BufferHandle;
24use crate::builders::ArrayBuilder;
25use crate::builders::BoolBuilder;
26use crate::builders::DecimalBuilder;
27use crate::builders::NullBuilder;
28use crate::builders::PrimitiveBuilder;
29use crate::builders::VarBinViewBuilder;
30use crate::canonical::Canonical;
31use crate::dtype::DType;
32use crate::match_each_decimal_value;
33use crate::match_each_native_ptype;
34use crate::scalar::DecimalValue;
35use crate::scalar::Scalar;
36use crate::scalar::ScalarValue;
37use crate::serde::ArrayChildren;
38use crate::stats::StatsSetRef;
39use crate::vtable;
40use crate::vtable::ArrayId;
41use crate::vtable::VTable;
42pub(crate) mod canonical;
43mod operations;
44mod validity;
45
46vtable!(Constant);
47
48#[derive(Clone, Debug)]
49pub struct Constant;
50
51impl Constant {
52 pub const ID: ArrayId = ArrayId::new_ref("vortex.constant");
53}
54
55impl VTable for Constant {
56 type Array = ConstantArray;
57
58 type Metadata = Scalar;
59 type OperationsVTable = Self;
60 type ValidityVTable = Self;
61
62 fn vtable(_array: &Self::Array) -> &Self {
63 &Constant
64 }
65
66 fn id(&self) -> ArrayId {
67 Self::ID
68 }
69
70 fn len(array: &ConstantArray) -> usize {
71 array.len
72 }
73
74 fn dtype(array: &ConstantArray) -> &DType {
75 array.scalar.dtype()
76 }
77
78 fn stats(array: &ConstantArray) -> StatsSetRef<'_> {
79 array.stats_set.to_ref(array.as_ref())
80 }
81
82 fn array_hash<H: std::hash::Hasher>(
83 array: &ConstantArray,
84 state: &mut H,
85 _precision: Precision,
86 ) {
87 array.scalar.hash(state);
88 array.len.hash(state);
89 }
90
91 fn array_eq(array: &ConstantArray, other: &ConstantArray, _precision: Precision) -> bool {
92 array.scalar == other.scalar && array.len == other.len
93 }
94
95 fn nbuffers(_array: &ConstantArray) -> usize {
96 1
97 }
98
99 fn buffer(array: &ConstantArray, idx: usize) -> BufferHandle {
100 match idx {
101 0 => BufferHandle::new_host(
102 ScalarValue::to_proto_bytes::<ByteBufferMut>(array.scalar.value()).freeze(),
103 ),
104 _ => vortex_panic!("ConstantArray buffer index {idx} out of bounds"),
105 }
106 }
107
108 fn buffer_name(_array: &ConstantArray, idx: usize) -> Option<String> {
109 match idx {
110 0 => Some("scalar".to_string()),
111 _ => None,
112 }
113 }
114
115 fn nchildren(_array: &ConstantArray) -> usize {
116 0
117 }
118
119 fn child(_array: &ConstantArray, idx: usize) -> ArrayRef {
120 vortex_panic!("ConstantArray child index {idx} out of bounds")
121 }
122
123 fn child_name(_array: &ConstantArray, idx: usize) -> String {
124 vortex_panic!("ConstantArray child_name index {idx} out of bounds")
125 }
126
127 fn metadata(array: &ConstantArray) -> VortexResult<Self::Metadata> {
128 Ok(array.scalar().clone())
129 }
130
131 fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
132 Ok(Some(vec![]))
135 }
136
137 fn deserialize(
138 _bytes: &[u8],
139 dtype: &DType,
140 _len: usize,
141 buffers: &[BufferHandle],
142 session: &VortexSession,
143 ) -> VortexResult<Self::Metadata> {
144 vortex_ensure!(
145 buffers.len() == 1,
146 "Expected 1 buffer, got {}",
147 buffers.len()
148 );
149
150 let buffer = buffers[0].clone().try_to_host_sync()?;
151 let bytes: &[u8] = buffer.as_ref();
152
153 let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype, session)?;
154 let scalar = Scalar::try_new(dtype.clone(), scalar_value)?;
155
156 Ok(scalar)
157 }
158
159 fn build(
160 _dtype: &DType,
161 len: usize,
162 metadata: &Self::Metadata,
163 _buffers: &[BufferHandle],
164 _children: &dyn ArrayChildren,
165 ) -> VortexResult<ConstantArray> {
166 Ok(ConstantArray::new(metadata.clone(), len))
167 }
168
169 fn with_children(_array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
170 vortex_ensure!(
171 children.is_empty(),
172 "ConstantArray has no children, got {}",
173 children.len()
174 );
175 Ok(())
176 }
177
178 fn reduce_parent(
179 array: &Self::Array,
180 parent: &ArrayRef,
181 child_idx: usize,
182 ) -> VortexResult<Option<ArrayRef>> {
183 PARENT_RULES.evaluate(array, parent, child_idx)
184 }
185
186 fn execute(array: Arc<Self::Array>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
187 Ok(ExecutionResult::done(
188 constant_canonicalize(&array)?.into_array(),
189 ))
190 }
191
192 fn append_to_builder(
193 array: &ConstantArray,
194 builder: &mut dyn ArrayBuilder,
195 ctx: &mut ExecutionCtx,
196 ) -> VortexResult<()> {
197 let n = array.len();
198 let scalar = array.scalar();
199
200 match array.dtype() {
201 DType::Null => append_value_or_nulls::<NullBuilder>(builder, true, n, |_| {}),
202 DType::Bool(_) => {
203 append_value_or_nulls::<BoolBuilder>(builder, scalar.is_null(), n, |b| {
204 b.append_values(
205 scalar
206 .as_bool()
207 .value()
208 .vortex_expect("non-null bool scalar must have a value"),
209 n,
210 );
211 })
212 }
213 DType::Primitive(ptype, _) => {
214 match_each_native_ptype!(ptype, |P| {
215 append_value_or_nulls::<PrimitiveBuilder<P>>(
216 builder,
217 scalar.is_null(),
218 n,
219 |b| {
220 let value = P::try_from(scalar)
221 .vortex_expect("Couldn't unwrap constant scalar to primitive");
222 b.append_n_values(value, n);
223 },
224 );
225 });
226 }
227 DType::Decimal(..) => {
228 append_value_or_nulls::<DecimalBuilder>(builder, scalar.is_null(), n, |b| {
229 let value = scalar
230 .as_decimal()
231 .decimal_value()
232 .vortex_expect("non-null decimal scalar must have a value");
233 match_each_decimal_value!(value, |v| { b.append_n_values(v, n) });
234 });
235 }
236 DType::Utf8(_) => {
237 append_value_or_nulls::<VarBinViewBuilder>(builder, scalar.is_null(), n, |b| {
238 let typed = scalar.as_utf8();
239 let value = typed
240 .value()
241 .vortex_expect("non-null utf8 scalar must have a value");
242 b.append_n_values(value.as_bytes(), n);
243 });
244 }
245 DType::Binary(_) => {
246 append_value_or_nulls::<VarBinViewBuilder>(builder, scalar.is_null(), n, |b| {
247 let typed = scalar.as_binary();
248 let value = typed
249 .value()
250 .vortex_expect("non-null binary scalar must have a value");
251 b.append_n_values(value, n);
252 });
253 }
254 _ => {
256 let canonical = array
257 .clone()
258 .into_array()
259 .execute::<Canonical>(ctx)?
260 .into_array();
261 builder.extend_from_array(&canonical);
262 }
263 }
264
265 Ok(())
266 }
267}
268
269fn append_value_or_nulls<B: ArrayBuilder + 'static>(
274 builder: &mut dyn ArrayBuilder,
275 is_null: bool,
276 n: usize,
277 fill: impl FnOnce(&mut B),
278) {
279 let b = builder
280 .as_any_mut()
281 .downcast_mut::<B>()
282 .vortex_expect("builder dtype must match array dtype");
283 if is_null {
284 unsafe { b.append_nulls_unchecked(n) };
286 } else {
287 fill(b);
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use rstest::rstest;
294 use vortex_session::VortexSession;
295
296 use crate::ExecutionCtx;
297 use crate::IntoArray;
298 use crate::arrays::ConstantArray;
299 use crate::arrays::constant::vtable::canonical::constant_canonicalize;
300 use crate::assert_arrays_eq;
301 use crate::builders::builder_with_capacity;
302 use crate::dtype::DType;
303 use crate::dtype::Nullability;
304 use crate::dtype::PType;
305 use crate::dtype::StructFields;
306 use crate::scalar::Scalar;
307
308 fn ctx() -> ExecutionCtx {
309 ExecutionCtx::new(VortexSession::empty())
310 }
311
312 fn assert_append_matches_canonical(array: ConstantArray) -> vortex_error::VortexResult<()> {
314 let expected = constant_canonicalize(&array)?.into_array();
315 let mut builder = builder_with_capacity(array.dtype(), array.len());
316 array
317 .into_array()
318 .append_to_builder(builder.as_mut(), &mut ctx())?;
319 let result = builder.finish();
320 assert_arrays_eq!(&result, &expected);
321 Ok(())
322 }
323
324 #[test]
325 fn test_null_constant_append() -> vortex_error::VortexResult<()> {
326 assert_append_matches_canonical(ConstantArray::new(Scalar::null(DType::Null), 5))
327 }
328
329 #[rstest]
330 #[case::bool_true(true, 5)]
331 #[case::bool_false(false, 3)]
332 fn test_bool_constant_append(
333 #[case] value: bool,
334 #[case] n: usize,
335 ) -> vortex_error::VortexResult<()> {
336 assert_append_matches_canonical(ConstantArray::new(
337 Scalar::bool(value, Nullability::NonNullable),
338 n,
339 ))
340 }
341
342 #[test]
343 fn test_bool_null_constant_append() -> vortex_error::VortexResult<()> {
344 assert_append_matches_canonical(ConstantArray::new(
345 Scalar::null(DType::Bool(Nullability::Nullable)),
346 4,
347 ))
348 }
349
350 #[rstest]
351 #[case::i32(Scalar::primitive(42i32, Nullability::NonNullable), 5)]
352 #[case::u8(Scalar::primitive(7u8, Nullability::NonNullable), 3)]
353 #[case::f64(Scalar::primitive(1.5f64, Nullability::NonNullable), 4)]
354 #[case::i32_null(Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)), 3)]
355 fn test_primitive_constant_append(
356 #[case] scalar: Scalar,
357 #[case] n: usize,
358 ) -> vortex_error::VortexResult<()> {
359 assert_append_matches_canonical(ConstantArray::new(scalar, n))
360 }
361
362 #[rstest]
363 #[case::utf8_inline("hi", 5)] #[case::utf8_noninline("hello world!!", 5)] #[case::utf8_empty("", 3)]
366 #[case::utf8_n_zero("hello world!!", 0)] fn test_utf8_constant_append(
368 #[case] value: &str,
369 #[case] n: usize,
370 ) -> vortex_error::VortexResult<()> {
371 assert_append_matches_canonical(ConstantArray::new(
372 Scalar::utf8(value, Nullability::NonNullable),
373 n,
374 ))
375 }
376
377 #[test]
378 fn test_utf8_null_constant_append() -> vortex_error::VortexResult<()> {
379 assert_append_matches_canonical(ConstantArray::new(
380 Scalar::null(DType::Utf8(Nullability::Nullable)),
381 4,
382 ))
383 }
384
385 #[rstest]
386 #[case::binary_inline(vec![1u8, 2, 3], 5)] #[case::binary_noninline(vec![0u8; 13], 5)] fn test_binary_constant_append(
389 #[case] value: Vec<u8>,
390 #[case] n: usize,
391 ) -> vortex_error::VortexResult<()> {
392 assert_append_matches_canonical(ConstantArray::new(
393 Scalar::binary(value, Nullability::NonNullable),
394 n,
395 ))
396 }
397
398 #[test]
399 fn test_binary_null_constant_append() -> vortex_error::VortexResult<()> {
400 assert_append_matches_canonical(ConstantArray::new(
401 Scalar::null(DType::Binary(Nullability::Nullable)),
402 4,
403 ))
404 }
405
406 #[test]
407 fn test_struct_constant_append() -> vortex_error::VortexResult<()> {
408 let fields = StructFields::new(
409 ["x", "y"].into(),
410 vec![
411 DType::Primitive(PType::I32, Nullability::NonNullable),
412 DType::Utf8(Nullability::NonNullable),
413 ],
414 );
415 let scalar = Scalar::struct_(
416 DType::Struct(fields, Nullability::NonNullable),
417 [
418 Scalar::primitive(42i32, Nullability::NonNullable),
419 Scalar::utf8("hi", Nullability::NonNullable),
420 ],
421 );
422 assert_append_matches_canonical(ConstantArray::new(scalar, 3))
423 }
424
425 #[test]
426 fn test_null_struct_constant_append() -> vortex_error::VortexResult<()> {
427 let fields = StructFields::new(
428 ["x"].into(),
429 vec![DType::Primitive(PType::I32, Nullability::Nullable)],
430 );
431 let dtype = DType::Struct(fields, Nullability::Nullable);
432 assert_append_matches_canonical(ConstantArray::new(Scalar::null(dtype), 4))
433 }
434}