vortex_array/arrays/constant/vtable/
mod.rs1use std::fmt::Debug;
5use std::hash::Hash;
6use std::hash::Hasher;
7
8use vortex_buffer::ByteBufferMut;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_panic;
13use vortex_session::VortexSession;
14
15use crate::ArrayEq;
16use crate::ArrayHash;
17use crate::ArrayRef;
18use crate::ExecutionCtx;
19use crate::ExecutionResult;
20use crate::IntoArray;
21use crate::Precision;
22use crate::array::Array;
23use crate::array::ArrayId;
24use crate::array::ArrayView;
25use crate::array::VTable;
26use crate::arrays::constant::ConstantData;
27use crate::arrays::constant::compute::rules::PARENT_RULES;
28use crate::arrays::constant::vtable::canonical::constant_canonicalize;
29use crate::buffer::BufferHandle;
30use crate::builders::ArrayBuilder;
31use crate::builders::BoolBuilder;
32use crate::builders::DecimalBuilder;
33use crate::builders::NullBuilder;
34use crate::builders::PrimitiveBuilder;
35use crate::builders::VarBinViewBuilder;
36use crate::canonical::Canonical;
37use crate::dtype::DType;
38use crate::match_each_decimal_value;
39use crate::match_each_native_ptype;
40use crate::scalar::DecimalValue;
41use crate::scalar::Scalar;
42use crate::scalar::ScalarValue;
43use crate::serde::ArrayChildren;
44pub(crate) mod canonical;
45mod operations;
46mod validity;
47
48pub type ConstantArray = Array<Constant>;
50
51#[derive(Clone, Debug)]
52pub struct Constant;
53
54impl Constant {
55 pub const ID: ArrayId = ArrayId::new_ref("vortex.constant");
56}
57
58impl ArrayHash for ConstantData {
59 fn array_hash<H: Hasher>(&self, state: &mut H, _precision: Precision) {
60 self.scalar.hash(state);
61 }
62}
63
64impl ArrayEq for ConstantData {
65 fn array_eq(&self, other: &Self, _precision: Precision) -> bool {
66 self.scalar == other.scalar
67 }
68}
69
70impl VTable for Constant {
71 type ArrayData = ConstantData;
72
73 type OperationsVTable = Self;
74 type ValidityVTable = Self;
75
76 fn id(&self) -> ArrayId {
77 Self::ID
78 }
79
80 fn validate(
81 &self,
82 data: &ConstantData,
83 dtype: &DType,
84 _len: usize,
85 _slots: &[Option<ArrayRef>],
86 ) -> VortexResult<()> {
87 vortex_ensure!(
88 data.scalar.dtype() == dtype,
89 "ConstantArray scalar dtype does not match outer dtype"
90 );
91 Ok(())
92 }
93
94 fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
95 1
96 }
97
98 fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
99 match idx {
100 0 => BufferHandle::new_host(
101 ScalarValue::to_proto_bytes::<ByteBufferMut>(array.scalar.value()).freeze(),
102 ),
103 _ => vortex_panic!("ConstantArray buffer index {idx} out of bounds"),
104 }
105 }
106
107 fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
108 match idx {
109 0 => Some("scalar".to_string()),
110 _ => None,
111 }
112 }
113
114 fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
115 vortex_panic!("ConstantArray slot_name index {idx} out of bounds")
116 }
117
118 fn serialize(
119 _array: ArrayView<'_, Self>,
120 _session: &VortexSession,
121 ) -> VortexResult<Option<Vec<u8>>> {
122 Ok(Some(vec![]))
125 }
126
127 fn deserialize(
128 &self,
129 dtype: &DType,
130 len: usize,
131 _metadata: &[u8],
132
133 buffers: &[BufferHandle],
134 _children: &dyn ArrayChildren,
135 session: &VortexSession,
136 ) -> VortexResult<crate::array::ArrayParts<Self>> {
137 vortex_ensure!(
138 buffers.len() == 1,
139 "Expected 1 buffer, got {}",
140 buffers.len()
141 );
142
143 let buffer = buffers[0].clone().try_to_host_sync()?;
144 let bytes: &[u8] = buffer.as_ref();
145
146 let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype, session)?;
147 let scalar = Scalar::try_new(dtype.clone(), scalar_value)?;
148
149 Ok(crate::array::ArrayParts::new(
150 self.clone(),
151 dtype.clone(),
152 len,
153 ConstantData::new(scalar),
154 ))
155 }
156
157 fn reduce_parent(
158 array: ArrayView<'_, Self>,
159 parent: &ArrayRef,
160 child_idx: usize,
161 ) -> VortexResult<Option<ArrayRef>> {
162 PARENT_RULES.evaluate(array, parent, child_idx)
163 }
164
165 fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
166 Ok(ExecutionResult::done(constant_canonicalize(
167 array.as_view(),
168 )?))
169 }
170
171 fn append_to_builder(
172 array: ArrayView<'_, Self>,
173 builder: &mut dyn ArrayBuilder,
174 ctx: &mut ExecutionCtx,
175 ) -> VortexResult<()> {
176 let n = array.len();
177 let scalar = array.scalar();
178
179 match array.dtype() {
180 DType::Null => append_value_or_nulls::<NullBuilder>(builder, true, n, |_| {}),
181 DType::Bool(_) => {
182 append_value_or_nulls::<BoolBuilder>(builder, scalar.is_null(), n, |b| {
183 b.append_values(
184 scalar
185 .as_bool()
186 .value()
187 .vortex_expect("non-null bool scalar must have a value"),
188 n,
189 );
190 })
191 }
192 DType::Primitive(ptype, _) => {
193 match_each_native_ptype!(ptype, |P| {
194 append_value_or_nulls::<PrimitiveBuilder<P>>(
195 builder,
196 scalar.is_null(),
197 n,
198 |b| {
199 let value = P::try_from(scalar)
200 .vortex_expect("Couldn't unwrap constant scalar to primitive");
201 b.append_n_values(value, n);
202 },
203 );
204 });
205 }
206 DType::Decimal(..) => {
207 append_value_or_nulls::<DecimalBuilder>(builder, scalar.is_null(), n, |b| {
208 let value = scalar
209 .as_decimal()
210 .decimal_value()
211 .vortex_expect("non-null decimal scalar must have a value");
212 match_each_decimal_value!(value, |v| { b.append_n_values(v, n) });
213 });
214 }
215 DType::Utf8(_) => {
216 append_value_or_nulls::<VarBinViewBuilder>(builder, scalar.is_null(), n, |b| {
217 let typed = scalar.as_utf8();
218 let value = typed
219 .value()
220 .vortex_expect("non-null utf8 scalar must have a value");
221 b.append_n_values(value.as_bytes(), n);
222 });
223 }
224 DType::Binary(_) => {
225 append_value_or_nulls::<VarBinViewBuilder>(builder, scalar.is_null(), n, |b| {
226 let typed = scalar.as_binary();
227 let value = typed
228 .value()
229 .vortex_expect("non-null binary scalar must have a value");
230 b.append_n_values(value, n);
231 });
232 }
233 _ => {
235 let canonical = array
236 .array()
237 .clone()
238 .execute::<Canonical>(ctx)?
239 .into_array();
240 builder.extend_from_array(&canonical);
241 }
242 }
243
244 Ok(())
245 }
246}
247
248fn append_value_or_nulls<B: ArrayBuilder + 'static>(
253 builder: &mut dyn ArrayBuilder,
254 is_null: bool,
255 n: usize,
256 fill: impl FnOnce(&mut B),
257) {
258 let b = builder
259 .as_any_mut()
260 .downcast_mut::<B>()
261 .vortex_expect("builder dtype must match array dtype");
262 if is_null {
263 unsafe { b.append_nulls_unchecked(n) };
265 } else {
266 fill(b);
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use rstest::rstest;
273 use vortex_session::VortexSession;
274
275 use crate::ExecutionCtx;
276 use crate::IntoArray;
277 use crate::arrays::ConstantArray;
278 use crate::arrays::constant::vtable::canonical::constant_canonicalize;
279 use crate::assert_arrays_eq;
280 use crate::builders::builder_with_capacity;
281 use crate::dtype::DType;
282 use crate::dtype::Nullability;
283 use crate::dtype::PType;
284 use crate::dtype::StructFields;
285 use crate::scalar::Scalar;
286
287 fn ctx() -> ExecutionCtx {
288 ExecutionCtx::new(VortexSession::empty())
289 }
290
291 fn assert_append_matches_canonical(array: ConstantArray) -> vortex_error::VortexResult<()> {
293 let expected = constant_canonicalize(array.as_view())?.into_array();
294 let mut builder = builder_with_capacity(array.dtype(), array.len());
295 array
296 .into_array()
297 .append_to_builder(builder.as_mut(), &mut ctx())?;
298 let result = builder.finish();
299 assert_arrays_eq!(&result, &expected);
300 Ok(())
301 }
302
303 #[test]
304 fn test_null_constant_append() -> vortex_error::VortexResult<()> {
305 assert_append_matches_canonical(ConstantArray::new(Scalar::null(DType::Null), 5))
306 }
307
308 #[rstest]
309 #[case::bool_true(true, 5)]
310 #[case::bool_false(false, 3)]
311 fn test_bool_constant_append(
312 #[case] value: bool,
313 #[case] n: usize,
314 ) -> vortex_error::VortexResult<()> {
315 assert_append_matches_canonical(ConstantArray::new(
316 Scalar::bool(value, Nullability::NonNullable),
317 n,
318 ))
319 }
320
321 #[test]
322 fn test_bool_null_constant_append() -> vortex_error::VortexResult<()> {
323 assert_append_matches_canonical(ConstantArray::new(
324 Scalar::null(DType::Bool(Nullability::Nullable)),
325 4,
326 ))
327 }
328
329 #[rstest]
330 #[case::i32(Scalar::primitive(42i32, Nullability::NonNullable), 5)]
331 #[case::u8(Scalar::primitive(7u8, Nullability::NonNullable), 3)]
332 #[case::f64(Scalar::primitive(1.5f64, Nullability::NonNullable), 4)]
333 #[case::i32_null(Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)), 3)]
334 fn test_primitive_constant_append(
335 #[case] scalar: Scalar,
336 #[case] n: usize,
337 ) -> vortex_error::VortexResult<()> {
338 assert_append_matches_canonical(ConstantArray::new(scalar, n))
339 }
340
341 #[rstest]
342 #[case::utf8_inline("hi", 5)] #[case::utf8_noninline("hello world!!", 5)] #[case::utf8_empty("", 3)]
345 #[case::utf8_n_zero("hello world!!", 0)] fn test_utf8_constant_append(
347 #[case] value: &str,
348 #[case] n: usize,
349 ) -> vortex_error::VortexResult<()> {
350 assert_append_matches_canonical(ConstantArray::new(
351 Scalar::utf8(value, Nullability::NonNullable),
352 n,
353 ))
354 }
355
356 #[test]
357 fn test_utf8_null_constant_append() -> vortex_error::VortexResult<()> {
358 assert_append_matches_canonical(ConstantArray::new(
359 Scalar::null(DType::Utf8(Nullability::Nullable)),
360 4,
361 ))
362 }
363
364 #[rstest]
365 #[case::binary_inline(vec![1u8, 2, 3], 5)] #[case::binary_noninline(vec![0u8; 13], 5)] fn test_binary_constant_append(
368 #[case] value: Vec<u8>,
369 #[case] n: usize,
370 ) -> vortex_error::VortexResult<()> {
371 assert_append_matches_canonical(ConstantArray::new(
372 Scalar::binary(value, Nullability::NonNullable),
373 n,
374 ))
375 }
376
377 #[test]
378 fn test_binary_null_constant_append() -> vortex_error::VortexResult<()> {
379 assert_append_matches_canonical(ConstantArray::new(
380 Scalar::null(DType::Binary(Nullability::Nullable)),
381 4,
382 ))
383 }
384
385 #[test]
386 fn test_struct_constant_append() -> vortex_error::VortexResult<()> {
387 let fields = StructFields::new(
388 ["x", "y"].into(),
389 vec![
390 DType::Primitive(PType::I32, Nullability::NonNullable),
391 DType::Utf8(Nullability::NonNullable),
392 ],
393 );
394 let scalar = Scalar::struct_(
395 DType::Struct(fields, Nullability::NonNullable),
396 [
397 Scalar::primitive(42i32, Nullability::NonNullable),
398 Scalar::utf8("hi", Nullability::NonNullable),
399 ],
400 );
401 assert_append_matches_canonical(ConstantArray::new(scalar, 3))
402 }
403
404 #[test]
405 fn test_null_struct_constant_append() -> vortex_error::VortexResult<()> {
406 let fields = StructFields::new(
407 ["x"].into(),
408 vec![DType::Primitive(PType::I32, Nullability::Nullable)],
409 );
410 let dtype = DType::Struct(fields, Nullability::Nullable);
411 assert_append_matches_canonical(ConstantArray::new(Scalar::null(dtype), 4))
412 }
413}