vortex_array/arrays/constant/vtable/
mod.rs1use std::fmt::Debug;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use vortex_buffer::ByteBufferMut;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_panic;
13use vortex_session::VortexSession;
14
15use crate::ArrayRef;
16use crate::ExecutionCtx;
17use crate::ExecutionResult;
18use crate::IntoArray;
19use crate::Precision;
20use crate::arrays::ConstantArray;
21use crate::arrays::constant::compute::rules::PARENT_RULES;
22use crate::arrays::constant::vtable::canonical::constant_canonicalize;
23use crate::buffer::BufferHandle;
24use crate::builders::ArrayBuilder;
25use crate::builders::BoolBuilder;
26use crate::builders::DecimalBuilder;
27use crate::builders::NullBuilder;
28use crate::builders::PrimitiveBuilder;
29use crate::builders::VarBinViewBuilder;
30use crate::canonical::Canonical;
31use crate::dtype::DType;
32use crate::match_each_decimal_value;
33use crate::match_each_native_ptype;
34use crate::scalar::DecimalValue;
35use crate::scalar::Scalar;
36use crate::scalar::ScalarValue;
37use crate::serde::ArrayChildren;
38use crate::stats::StatsSetRef;
39use crate::vtable;
40use crate::vtable::Array;
41use crate::vtable::ArrayId;
42use crate::vtable::VTable;
43pub(crate) mod canonical;
44mod operations;
45mod validity;
46
47vtable!(Constant);
48
49#[derive(Clone, Debug)]
50pub struct Constant;
51
52impl Constant {
53 pub const ID: ArrayId = ArrayId::new_ref("vortex.constant");
54}
55
56impl VTable for Constant {
57 type Array = ConstantArray;
58
59 type Metadata = Scalar;
60 type OperationsVTable = Self;
61 type ValidityVTable = Self;
62
63 fn vtable(_array: &Self::Array) -> &Self {
64 &Constant
65 }
66
67 fn id(&self) -> ArrayId {
68 Self::ID
69 }
70
71 fn len(array: &ConstantArray) -> usize {
72 array.len
73 }
74
75 fn dtype(array: &ConstantArray) -> &DType {
76 array.scalar.dtype()
77 }
78
79 fn stats(array: &ConstantArray) -> StatsSetRef<'_> {
80 array.stats_set.to_ref(array.as_ref())
81 }
82
83 fn array_hash<H: std::hash::Hasher>(
84 array: &ConstantArray,
85 state: &mut H,
86 _precision: Precision,
87 ) {
88 array.scalar.hash(state);
89 array.len.hash(state);
90 }
91
92 fn array_eq(array: &ConstantArray, other: &ConstantArray, _precision: Precision) -> bool {
93 array.scalar == other.scalar && array.len == other.len
94 }
95
96 fn nbuffers(_array: &ConstantArray) -> usize {
97 1
98 }
99
100 fn buffer(array: &ConstantArray, idx: usize) -> BufferHandle {
101 match idx {
102 0 => BufferHandle::new_host(
103 ScalarValue::to_proto_bytes::<ByteBufferMut>(array.scalar.value()).freeze(),
104 ),
105 _ => vortex_panic!("ConstantArray buffer index {idx} out of bounds"),
106 }
107 }
108
109 fn buffer_name(_array: &ConstantArray, idx: usize) -> Option<String> {
110 match idx {
111 0 => Some("scalar".to_string()),
112 _ => None,
113 }
114 }
115
116 fn nchildren(_array: &ConstantArray) -> usize {
117 0
118 }
119
120 fn child(_array: &ConstantArray, idx: usize) -> ArrayRef {
121 vortex_panic!("ConstantArray child index {idx} out of bounds")
122 }
123
124 fn child_name(_array: &ConstantArray, idx: usize) -> String {
125 vortex_panic!("ConstantArray child_name index {idx} out of bounds")
126 }
127
128 fn metadata(array: &ConstantArray) -> VortexResult<Self::Metadata> {
129 Ok(array.scalar().clone())
130 }
131
132 fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
133 Ok(Some(vec![]))
136 }
137
138 fn deserialize(
139 _bytes: &[u8],
140 dtype: &DType,
141 _len: usize,
142 buffers: &[BufferHandle],
143 session: &VortexSession,
144 ) -> VortexResult<Self::Metadata> {
145 vortex_ensure!(
146 buffers.len() == 1,
147 "Expected 1 buffer, got {}",
148 buffers.len()
149 );
150
151 let buffer = buffers[0].clone().try_to_host_sync()?;
152 let bytes: &[u8] = buffer.as_ref();
153
154 let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype, session)?;
155 let scalar = Scalar::try_new(dtype.clone(), scalar_value)?;
156
157 Ok(scalar)
158 }
159
160 fn build(
161 _dtype: &DType,
162 len: usize,
163 metadata: &Self::Metadata,
164 _buffers: &[BufferHandle],
165 _children: &dyn ArrayChildren,
166 ) -> VortexResult<ConstantArray> {
167 Ok(ConstantArray::new(metadata.clone(), len))
168 }
169
170 fn with_children(_array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
171 vortex_ensure!(
172 children.is_empty(),
173 "ConstantArray has no children, got {}",
174 children.len()
175 );
176 Ok(())
177 }
178
179 fn reduce_parent(
180 array: &Array<Self>,
181 parent: &ArrayRef,
182 child_idx: usize,
183 ) -> VortexResult<Option<ArrayRef>> {
184 PARENT_RULES.evaluate(array, parent, child_idx)
185 }
186
187 fn execute(array: Arc<Array<Self>>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
188 Ok(ExecutionResult::done(
189 constant_canonicalize(&array)?.into_array(),
190 ))
191 }
192
193 fn append_to_builder(
194 array: &ConstantArray,
195 builder: &mut dyn ArrayBuilder,
196 ctx: &mut ExecutionCtx,
197 ) -> VortexResult<()> {
198 let n = array.len();
199 let scalar = array.scalar();
200
201 match array.dtype() {
202 DType::Null => append_value_or_nulls::<NullBuilder>(builder, true, n, |_| {}),
203 DType::Bool(_) => {
204 append_value_or_nulls::<BoolBuilder>(builder, scalar.is_null(), n, |b| {
205 b.append_values(
206 scalar
207 .as_bool()
208 .value()
209 .vortex_expect("non-null bool scalar must have a value"),
210 n,
211 );
212 })
213 }
214 DType::Primitive(ptype, _) => {
215 match_each_native_ptype!(ptype, |P| {
216 append_value_or_nulls::<PrimitiveBuilder<P>>(
217 builder,
218 scalar.is_null(),
219 n,
220 |b| {
221 let value = P::try_from(scalar)
222 .vortex_expect("Couldn't unwrap constant scalar to primitive");
223 b.append_n_values(value, n);
224 },
225 );
226 });
227 }
228 DType::Decimal(..) => {
229 append_value_or_nulls::<DecimalBuilder>(builder, scalar.is_null(), n, |b| {
230 let value = scalar
231 .as_decimal()
232 .decimal_value()
233 .vortex_expect("non-null decimal scalar must have a value");
234 match_each_decimal_value!(value, |v| { b.append_n_values(v, n) });
235 });
236 }
237 DType::Utf8(_) => {
238 append_value_or_nulls::<VarBinViewBuilder>(builder, scalar.is_null(), n, |b| {
239 let typed = scalar.as_utf8();
240 let value = typed
241 .value()
242 .vortex_expect("non-null utf8 scalar must have a value");
243 b.append_n_values(value.as_bytes(), n);
244 });
245 }
246 DType::Binary(_) => {
247 append_value_or_nulls::<VarBinViewBuilder>(builder, scalar.is_null(), n, |b| {
248 let typed = scalar.as_binary();
249 let value = typed
250 .value()
251 .vortex_expect("non-null binary scalar must have a value");
252 b.append_n_values(value, n);
253 });
254 }
255 _ => {
257 let canonical = array
258 .clone()
259 .into_array()
260 .execute::<Canonical>(ctx)?
261 .into_array();
262 builder.extend_from_array(&canonical);
263 }
264 }
265
266 Ok(())
267 }
268}
269
270fn append_value_or_nulls<B: ArrayBuilder + 'static>(
275 builder: &mut dyn ArrayBuilder,
276 is_null: bool,
277 n: usize,
278 fill: impl FnOnce(&mut B),
279) {
280 let b = builder
281 .as_any_mut()
282 .downcast_mut::<B>()
283 .vortex_expect("builder dtype must match array dtype");
284 if is_null {
285 unsafe { b.append_nulls_unchecked(n) };
287 } else {
288 fill(b);
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use rstest::rstest;
295 use vortex_session::VortexSession;
296
297 use crate::ExecutionCtx;
298 use crate::IntoArray;
299 use crate::arrays::ConstantArray;
300 use crate::arrays::constant::vtable::canonical::constant_canonicalize;
301 use crate::assert_arrays_eq;
302 use crate::builders::builder_with_capacity;
303 use crate::dtype::DType;
304 use crate::dtype::Nullability;
305 use crate::dtype::PType;
306 use crate::dtype::StructFields;
307 use crate::scalar::Scalar;
308
309 fn ctx() -> ExecutionCtx {
310 ExecutionCtx::new(VortexSession::empty())
311 }
312
313 fn assert_append_matches_canonical(array: ConstantArray) -> vortex_error::VortexResult<()> {
315 let expected = constant_canonicalize(&array)?.into_array();
316 let mut builder = builder_with_capacity(array.dtype(), array.len());
317 array
318 .into_array()
319 .append_to_builder(builder.as_mut(), &mut ctx())?;
320 let result = builder.finish();
321 assert_arrays_eq!(&result, &expected);
322 Ok(())
323 }
324
325 #[test]
326 fn test_null_constant_append() -> vortex_error::VortexResult<()> {
327 assert_append_matches_canonical(ConstantArray::new(Scalar::null(DType::Null), 5))
328 }
329
330 #[rstest]
331 #[case::bool_true(true, 5)]
332 #[case::bool_false(false, 3)]
333 fn test_bool_constant_append(
334 #[case] value: bool,
335 #[case] n: usize,
336 ) -> vortex_error::VortexResult<()> {
337 assert_append_matches_canonical(ConstantArray::new(
338 Scalar::bool(value, Nullability::NonNullable),
339 n,
340 ))
341 }
342
343 #[test]
344 fn test_bool_null_constant_append() -> vortex_error::VortexResult<()> {
345 assert_append_matches_canonical(ConstantArray::new(
346 Scalar::null(DType::Bool(Nullability::Nullable)),
347 4,
348 ))
349 }
350
351 #[rstest]
352 #[case::i32(Scalar::primitive(42i32, Nullability::NonNullable), 5)]
353 #[case::u8(Scalar::primitive(7u8, Nullability::NonNullable), 3)]
354 #[case::f64(Scalar::primitive(1.5f64, Nullability::NonNullable), 4)]
355 #[case::i32_null(Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)), 3)]
356 fn test_primitive_constant_append(
357 #[case] scalar: Scalar,
358 #[case] n: usize,
359 ) -> vortex_error::VortexResult<()> {
360 assert_append_matches_canonical(ConstantArray::new(scalar, n))
361 }
362
363 #[rstest]
364 #[case::utf8_inline("hi", 5)] #[case::utf8_noninline("hello world!!", 5)] #[case::utf8_empty("", 3)]
367 #[case::utf8_n_zero("hello world!!", 0)] fn test_utf8_constant_append(
369 #[case] value: &str,
370 #[case] n: usize,
371 ) -> vortex_error::VortexResult<()> {
372 assert_append_matches_canonical(ConstantArray::new(
373 Scalar::utf8(value, Nullability::NonNullable),
374 n,
375 ))
376 }
377
378 #[test]
379 fn test_utf8_null_constant_append() -> vortex_error::VortexResult<()> {
380 assert_append_matches_canonical(ConstantArray::new(
381 Scalar::null(DType::Utf8(Nullability::Nullable)),
382 4,
383 ))
384 }
385
386 #[rstest]
387 #[case::binary_inline(vec![1u8, 2, 3], 5)] #[case::binary_noninline(vec![0u8; 13], 5)] fn test_binary_constant_append(
390 #[case] value: Vec<u8>,
391 #[case] n: usize,
392 ) -> vortex_error::VortexResult<()> {
393 assert_append_matches_canonical(ConstantArray::new(
394 Scalar::binary(value, Nullability::NonNullable),
395 n,
396 ))
397 }
398
399 #[test]
400 fn test_binary_null_constant_append() -> vortex_error::VortexResult<()> {
401 assert_append_matches_canonical(ConstantArray::new(
402 Scalar::null(DType::Binary(Nullability::Nullable)),
403 4,
404 ))
405 }
406
407 #[test]
408 fn test_struct_constant_append() -> vortex_error::VortexResult<()> {
409 let fields = StructFields::new(
410 ["x", "y"].into(),
411 vec![
412 DType::Primitive(PType::I32, Nullability::NonNullable),
413 DType::Utf8(Nullability::NonNullable),
414 ],
415 );
416 let scalar = Scalar::struct_(
417 DType::Struct(fields, Nullability::NonNullable),
418 [
419 Scalar::primitive(42i32, Nullability::NonNullable),
420 Scalar::utf8("hi", Nullability::NonNullable),
421 ],
422 );
423 assert_append_matches_canonical(ConstantArray::new(scalar, 3))
424 }
425
426 #[test]
427 fn test_null_struct_constant_append() -> vortex_error::VortexResult<()> {
428 let fields = StructFields::new(
429 ["x"].into(),
430 vec![DType::Primitive(PType::I32, Nullability::Nullable)],
431 );
432 let dtype = DType::Struct(fields, Nullability::Nullable);
433 assert_append_matches_canonical(ConstantArray::new(Scalar::null(dtype), 4))
434 }
435}