vortex_array/arrays/constant/vtable/
mod.rs1use std::fmt::Debug;
5use std::hash::Hash;
6use std::hash::Hasher;
7
8use vortex_buffer::ByteBufferMut;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_panic;
13use vortex_session::VortexSession;
14use vortex_session::registry::CachedId;
15
16use crate::ArrayEq;
17use crate::ArrayHash;
18use crate::ArrayParts;
19use crate::ArrayRef;
20use crate::EqMode;
21use crate::ExecutionCtx;
22use crate::ExecutionResult;
23use crate::IntoArray;
24use crate::array::Array;
25use crate::array::ArrayId;
26use crate::array::ArrayView;
27use crate::array::VTable;
28use crate::array::unsupported_buffer_replacement;
29use crate::arrays::constant::ConstantData;
30use crate::arrays::constant::compute::rules::PARENT_RULES;
31use crate::arrays::constant::vtable::canonical::constant_canonicalize;
32use crate::buffer::BufferHandle;
33use crate::builders::ArrayBuilder;
34use crate::builders::BoolBuilder;
35use crate::builders::DecimalBuilder;
36use crate::builders::NullBuilder;
37use crate::builders::PrimitiveBuilder;
38use crate::builders::VarBinViewBuilder;
39use crate::canonical::Canonical;
40use crate::dtype::DType;
41use crate::match_each_decimal_value;
42use crate::match_each_native_ptype;
43use crate::scalar::DecimalValue;
44use crate::scalar::Scalar;
45use crate::scalar::ScalarValue;
46use crate::serde::ArrayChildren;
47pub(crate) mod canonical;
48mod operations;
49mod validity;
50
51pub type ConstantArray = Array<Constant>;
53
54#[derive(Clone, Debug)]
55pub struct Constant;
56
57impl ArrayHash for ConstantData {
58 fn array_hash<H: Hasher>(&self, state: &mut H, _accuracy: EqMode) {
59 self.scalar.hash(state);
60 }
61}
62
63impl ArrayEq for ConstantData {
64 fn array_eq(&self, other: &Self, _accuracy: EqMode) -> bool {
65 self.scalar == other.scalar
66 }
67}
68
69impl VTable for Constant {
70 type TypedArrayData = ConstantData;
71
72 type OperationsVTable = Self;
73 type ValidityVTable = Self;
74
75 fn id(&self) -> ArrayId {
76 static ID: CachedId = CachedId::new("vortex.constant");
77 *ID
78 }
79
80 fn validate(
81 &self,
82 data: &ConstantData,
83 dtype: &DType,
84 _len: usize,
85 _slots: &[Option<ArrayRef>],
86 ) -> VortexResult<()> {
87 vortex_ensure!(
88 data.scalar.dtype() == dtype,
89 "ConstantArray scalar dtype does not match outer dtype"
90 );
91 Ok(())
92 }
93
94 fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
95 1
96 }
97
98 fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
99 match idx {
100 0 => BufferHandle::new_host(
101 ScalarValue::to_proto_bytes::<ByteBufferMut>(array.scalar.value()).freeze(),
102 ),
103 _ => vortex_panic!("ConstantArray buffer index {idx} out of bounds"),
104 }
105 }
106
107 fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
108 match idx {
109 0 => Some("scalar".to_string()),
110 _ => None,
111 }
112 }
113
114 fn with_buffers(
115 &self,
116 array: ArrayView<'_, Self>,
117 buffers: &[BufferHandle],
118 ) -> VortexResult<ArrayParts<Self>> {
119 unsupported_buffer_replacement(array, buffers)
120 }
121
122 fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
123 vortex_panic!("ConstantArray slot_name index {idx} out of bounds")
124 }
125
126 fn serialize(
127 _array: ArrayView<'_, Self>,
128 _session: &VortexSession,
129 ) -> VortexResult<Option<Vec<u8>>> {
130 Ok(Some(vec![]))
133 }
134
135 fn deserialize(
136 &self,
137 dtype: &DType,
138 len: usize,
139 _metadata: &[u8],
140
141 buffers: &[BufferHandle],
142 _children: &dyn ArrayChildren,
143 session: &VortexSession,
144 ) -> VortexResult<ArrayParts<Self>> {
145 vortex_ensure!(
146 buffers.len() == 1,
147 "Expected 1 buffer, got {}",
148 buffers.len()
149 );
150
151 let buffer = buffers[0].clone().try_to_host_sync()?;
152 let bytes: &[u8] = buffer.as_ref();
153
154 let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype, session)?;
155 let scalar = Scalar::try_new(dtype.clone(), scalar_value)?;
156
157 Ok(ArrayParts::new(
158 self.clone(),
159 dtype.clone(),
160 len,
161 ConstantData::new(scalar),
162 ))
163 }
164
165 fn reduce_parent(
166 array: ArrayView<'_, Self>,
167 parent: &ArrayRef,
168 child_idx: usize,
169 ) -> VortexResult<Option<ArrayRef>> {
170 PARENT_RULES.evaluate(array, parent, child_idx)
171 }
172
173 fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
174 Ok(ExecutionResult::done(constant_canonicalize(
175 array.as_view(),
176 ctx,
177 )?))
178 }
179
180 fn append_to_builder(
181 array: ArrayView<'_, Self>,
182 builder: &mut dyn ArrayBuilder,
183 ctx: &mut ExecutionCtx,
184 ) -> VortexResult<()> {
185 let n = array.len();
186 let scalar = array.scalar();
187
188 match array.dtype() {
189 DType::Null => append_value_or_nulls::<NullBuilder>(builder, true, n, |_| {}),
190 DType::Bool(_) => {
191 append_value_or_nulls::<BoolBuilder>(builder, scalar.is_null(), n, |b| {
192 b.append_values(
193 scalar
194 .as_bool()
195 .value()
196 .vortex_expect("non-null bool scalar must have a value"),
197 n,
198 );
199 })
200 }
201 DType::Primitive(ptype, _) => {
202 match_each_native_ptype!(ptype, |P| {
203 append_value_or_nulls::<PrimitiveBuilder<P>>(
204 builder,
205 scalar.is_null(),
206 n,
207 |b| {
208 let value = P::try_from(scalar)
209 .vortex_expect("Couldn't unwrap constant scalar to primitive");
210 b.append_n_values(value, n);
211 },
212 );
213 });
214 }
215 DType::Decimal(..) => {
216 append_value_or_nulls::<DecimalBuilder>(builder, scalar.is_null(), n, |b| {
217 let value = scalar
218 .as_decimal()
219 .decimal_value()
220 .vortex_expect("non-null decimal scalar must have a value");
221 match_each_decimal_value!(value, |v| { b.append_n_values(v, n) });
222 });
223 }
224 DType::Utf8(_) => {
225 append_value_or_nulls::<VarBinViewBuilder>(builder, scalar.is_null(), n, |b| {
226 let typed = scalar.as_utf8();
227 let value = typed
228 .value()
229 .vortex_expect("non-null utf8 scalar must have a value");
230 b.append_n_values(value.as_bytes(), n);
231 });
232 }
233 DType::Binary(_) => {
234 append_value_or_nulls::<VarBinViewBuilder>(builder, scalar.is_null(), n, |b| {
235 let typed = scalar.as_binary();
236 let value = typed
237 .value()
238 .vortex_expect("non-null binary scalar must have a value");
239 b.append_n_values(value, n);
240 });
241 }
242 _ => {
244 let canonical = array
245 .array()
246 .clone()
247 .execute::<Canonical>(ctx)?
248 .into_array();
249 builder.extend_from_array(&canonical);
250 }
251 }
252
253 Ok(())
254 }
255}
256
257fn append_value_or_nulls<B: ArrayBuilder + 'static>(
262 builder: &mut dyn ArrayBuilder,
263 is_null: bool,
264 n: usize,
265 fill: impl FnOnce(&mut B),
266) {
267 let b = builder
268 .as_any_mut()
269 .downcast_mut::<B>()
270 .vortex_expect("builder dtype must match array dtype");
271 if is_null {
272 unsafe { b.append_nulls_unchecked(n) };
274 } else {
275 fill(b);
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use rstest::rstest;
282 use vortex_error::VortexResult;
283
284 use crate::IntoArray;
285 use crate::VortexSessionExecute;
286 use crate::arrays::ConstantArray;
287 use crate::arrays::constant::vtable::canonical::constant_canonicalize;
288 use crate::assert_arrays_eq;
289 use crate::builders::builder_with_capacity;
290 use crate::dtype::DType;
291 use crate::dtype::Nullability;
292 use crate::dtype::PType;
293 use crate::dtype::StructFields;
294 use crate::scalar::Scalar;
295
296 fn assert_append_matches_canonical(array: ConstantArray) -> VortexResult<()> {
298 let mut ctx = crate::array_session().create_execution_ctx();
299
300 let expected = constant_canonicalize(array.as_view(), &mut ctx)?.into_array();
301 let mut builder = builder_with_capacity(array.dtype(), array.len());
302 array
303 .into_array()
304 .append_to_builder(builder.as_mut(), &mut ctx)?;
305 let result = builder.finish();
306 assert_arrays_eq!(&result, &expected, &mut ctx);
307 Ok(())
308 }
309
310 #[test]
311 fn test_null_constant_append() -> VortexResult<()> {
312 assert_append_matches_canonical(ConstantArray::new(Scalar::null(DType::Null), 5))
313 }
314
315 #[test]
316 fn test_with_buffers_rejects_serialized_scalar_buffer() {
317 let array =
318 ConstantArray::new(Scalar::primitive(42i32, Nullability::NonNullable), 3).into_array();
319 let buffers = array.buffer_handles();
320
321 let Err(err) = (unsafe { array.with_buffers(buffers) }) else {
324 panic!("ConstantArray should reject replacing its serialized scalar buffer");
325 };
326 assert!(
327 err.to_string()
328 .contains("does not support in-memory buffer replacement")
329 );
330 }
331
332 #[rstest]
333 #[case::bool_true(true, 5)]
334 #[case::bool_false(false, 3)]
335 fn test_bool_constant_append(#[case] value: bool, #[case] n: usize) -> VortexResult<()> {
336 assert_append_matches_canonical(ConstantArray::new(
337 Scalar::bool(value, Nullability::NonNullable),
338 n,
339 ))
340 }
341
342 #[test]
343 fn test_bool_null_constant_append() -> VortexResult<()> {
344 assert_append_matches_canonical(ConstantArray::new(
345 Scalar::null(DType::Bool(Nullability::Nullable)),
346 4,
347 ))
348 }
349
350 #[rstest]
351 #[case::i32(Scalar::primitive(42i32, Nullability::NonNullable), 5)]
352 #[case::u8(Scalar::primitive(7u8, Nullability::NonNullable), 3)]
353 #[case::f64(Scalar::primitive(1.5f64, Nullability::NonNullable), 4)]
354 #[case::i32_null(Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)), 3)]
355 fn test_primitive_constant_append(
356 #[case] scalar: Scalar,
357 #[case] n: usize,
358 ) -> VortexResult<()> {
359 assert_append_matches_canonical(ConstantArray::new(scalar, n))
360 }
361
362 #[rstest]
363 #[case::utf8_inline("hi", 5)] #[case::utf8_noninline("hello world!!", 5)] #[case::utf8_empty("", 3)]
366 #[case::utf8_n_zero("hello world!!", 0)] fn test_utf8_constant_append(#[case] value: &str, #[case] n: usize) -> VortexResult<()> {
368 assert_append_matches_canonical(ConstantArray::new(
369 Scalar::utf8(value, Nullability::NonNullable),
370 n,
371 ))
372 }
373
374 #[test]
375 fn test_utf8_null_constant_append() -> VortexResult<()> {
376 assert_append_matches_canonical(ConstantArray::new(
377 Scalar::null(DType::Utf8(Nullability::Nullable)),
378 4,
379 ))
380 }
381
382 #[rstest]
383 #[case::binary_inline(vec![1u8, 2, 3], 5)] #[case::binary_noninline(vec![0u8; 13], 5)] fn test_binary_constant_append(#[case] value: Vec<u8>, #[case] n: usize) -> VortexResult<()> {
386 assert_append_matches_canonical(ConstantArray::new(
387 Scalar::binary(value, Nullability::NonNullable),
388 n,
389 ))
390 }
391
392 #[test]
393 fn test_binary_null_constant_append() -> VortexResult<()> {
394 assert_append_matches_canonical(ConstantArray::new(
395 Scalar::null(DType::Binary(Nullability::Nullable)),
396 4,
397 ))
398 }
399
400 #[test]
401 fn test_struct_constant_append() -> VortexResult<()> {
402 let fields = StructFields::new(
403 ["x", "y"].into(),
404 vec![
405 DType::Primitive(PType::I32, Nullability::NonNullable),
406 DType::Utf8(Nullability::NonNullable),
407 ],
408 );
409 let scalar = Scalar::struct_(
410 DType::Struct(fields, Nullability::NonNullable),
411 [
412 Scalar::primitive(42i32, Nullability::NonNullable),
413 Scalar::utf8("hi", Nullability::NonNullable),
414 ],
415 );
416 assert_append_matches_canonical(ConstantArray::new(scalar, 3))
417 }
418
419 #[test]
420 fn test_null_struct_constant_append() -> VortexResult<()> {
421 let fields = StructFields::new(
422 ["x"].into(),
423 vec![DType::Primitive(PType::I32, Nullability::Nullable)],
424 );
425 let dtype = DType::Struct(fields, Nullability::Nullable);
426 assert_append_matches_canonical(ConstantArray::new(Scalar::null(dtype), 4))
427 }
428}