vortex_array/arrays/constant/vtable/
mod.rs1use std::fmt::Debug;
5use std::hash::Hash;
6use std::hash::Hasher;
7
8use vortex_buffer::ByteBufferMut;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_panic;
13use vortex_session::VortexSession;
14use vortex_session::registry::CachedId;
15
16use crate::ArrayEq;
17use crate::ArrayHash;
18use crate::ArrayRef;
19use crate::ExecutionCtx;
20use crate::ExecutionResult;
21use crate::IntoArray;
22use crate::Precision;
23use crate::array::Array;
24use crate::array::ArrayId;
25use crate::array::ArrayView;
26use crate::array::VTable;
27use crate::arrays::constant::ConstantData;
28use crate::arrays::constant::compute::rules::PARENT_RULES;
29use crate::arrays::constant::vtable::canonical::constant_canonicalize;
30use crate::buffer::BufferHandle;
31use crate::builders::ArrayBuilder;
32use crate::builders::BoolBuilder;
33use crate::builders::DecimalBuilder;
34use crate::builders::NullBuilder;
35use crate::builders::PrimitiveBuilder;
36use crate::builders::VarBinViewBuilder;
37use crate::canonical::Canonical;
38use crate::dtype::DType;
39use crate::match_each_decimal_value;
40use crate::match_each_native_ptype;
41use crate::scalar::DecimalValue;
42use crate::scalar::Scalar;
43use crate::scalar::ScalarValue;
44use crate::serde::ArrayChildren;
45pub(crate) mod canonical;
46mod operations;
47mod validity;
48
49pub type ConstantArray = Array<Constant>;
51
52#[derive(Clone, Debug)]
53pub struct Constant;
54
55impl ArrayHash for ConstantData {
56 fn array_hash<H: Hasher>(&self, state: &mut H, _precision: Precision) {
57 self.scalar.hash(state);
58 }
59}
60
61impl ArrayEq for ConstantData {
62 fn array_eq(&self, other: &Self, _precision: Precision) -> bool {
63 self.scalar == other.scalar
64 }
65}
66
67impl VTable for Constant {
68 type ArrayData = ConstantData;
69
70 type OperationsVTable = Self;
71 type ValidityVTable = Self;
72
73 fn id(&self) -> ArrayId {
74 static ID: CachedId = CachedId::new("vortex.constant");
75 *ID
76 }
77
78 fn validate(
79 &self,
80 data: &ConstantData,
81 dtype: &DType,
82 _len: usize,
83 _slots: &[Option<ArrayRef>],
84 ) -> VortexResult<()> {
85 vortex_ensure!(
86 data.scalar.dtype() == dtype,
87 "ConstantArray scalar dtype does not match outer dtype"
88 );
89 Ok(())
90 }
91
92 fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
93 1
94 }
95
96 fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
97 match idx {
98 0 => BufferHandle::new_host(
99 ScalarValue::to_proto_bytes::<ByteBufferMut>(array.scalar.value()).freeze(),
100 ),
101 _ => vortex_panic!("ConstantArray buffer index {idx} out of bounds"),
102 }
103 }
104
105 fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
106 match idx {
107 0 => Some("scalar".to_string()),
108 _ => None,
109 }
110 }
111
112 fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
113 vortex_panic!("ConstantArray slot_name index {idx} out of bounds")
114 }
115
116 fn serialize(
117 _array: ArrayView<'_, Self>,
118 _session: &VortexSession,
119 ) -> VortexResult<Option<Vec<u8>>> {
120 Ok(Some(vec![]))
123 }
124
125 fn deserialize(
126 &self,
127 dtype: &DType,
128 len: usize,
129 _metadata: &[u8],
130
131 buffers: &[BufferHandle],
132 _children: &dyn ArrayChildren,
133 session: &VortexSession,
134 ) -> VortexResult<crate::array::ArrayParts<Self>> {
135 vortex_ensure!(
136 buffers.len() == 1,
137 "Expected 1 buffer, got {}",
138 buffers.len()
139 );
140
141 let buffer = buffers[0].clone().try_to_host_sync()?;
142 let bytes: &[u8] = buffer.as_ref();
143
144 let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype, session)?;
145 let scalar = Scalar::try_new(dtype.clone(), scalar_value)?;
146
147 Ok(crate::array::ArrayParts::new(
148 self.clone(),
149 dtype.clone(),
150 len,
151 ConstantData::new(scalar),
152 ))
153 }
154
155 fn reduce_parent(
156 array: ArrayView<'_, Self>,
157 parent: &ArrayRef,
158 child_idx: usize,
159 ) -> VortexResult<Option<ArrayRef>> {
160 PARENT_RULES.evaluate(array, parent, child_idx)
161 }
162
163 fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
164 Ok(ExecutionResult::done(constant_canonicalize(
165 array.as_view(),
166 )?))
167 }
168
169 fn append_to_builder(
170 array: ArrayView<'_, Self>,
171 builder: &mut dyn ArrayBuilder,
172 ctx: &mut ExecutionCtx,
173 ) -> VortexResult<()> {
174 let n = array.len();
175 let scalar = array.scalar();
176
177 match array.dtype() {
178 DType::Null => append_value_or_nulls::<NullBuilder>(builder, true, n, |_| {}),
179 DType::Bool(_) => {
180 append_value_or_nulls::<BoolBuilder>(builder, scalar.is_null(), n, |b| {
181 b.append_values(
182 scalar
183 .as_bool()
184 .value()
185 .vortex_expect("non-null bool scalar must have a value"),
186 n,
187 );
188 })
189 }
190 DType::Primitive(ptype, _) => {
191 match_each_native_ptype!(ptype, |P| {
192 append_value_or_nulls::<PrimitiveBuilder<P>>(
193 builder,
194 scalar.is_null(),
195 n,
196 |b| {
197 let value = P::try_from(scalar)
198 .vortex_expect("Couldn't unwrap constant scalar to primitive");
199 b.append_n_values(value, n);
200 },
201 );
202 });
203 }
204 DType::Decimal(..) => {
205 append_value_or_nulls::<DecimalBuilder>(builder, scalar.is_null(), n, |b| {
206 let value = scalar
207 .as_decimal()
208 .decimal_value()
209 .vortex_expect("non-null decimal scalar must have a value");
210 match_each_decimal_value!(value, |v| { b.append_n_values(v, n) });
211 });
212 }
213 DType::Utf8(_) => {
214 append_value_or_nulls::<VarBinViewBuilder>(builder, scalar.is_null(), n, |b| {
215 let typed = scalar.as_utf8();
216 let value = typed
217 .value()
218 .vortex_expect("non-null utf8 scalar must have a value");
219 b.append_n_values(value.as_bytes(), n);
220 });
221 }
222 DType::Binary(_) => {
223 append_value_or_nulls::<VarBinViewBuilder>(builder, scalar.is_null(), n, |b| {
224 let typed = scalar.as_binary();
225 let value = typed
226 .value()
227 .vortex_expect("non-null binary scalar must have a value");
228 b.append_n_values(value, n);
229 });
230 }
231 _ => {
233 let canonical = array
234 .array()
235 .clone()
236 .execute::<Canonical>(ctx)?
237 .into_array();
238 builder.extend_from_array(&canonical);
239 }
240 }
241
242 Ok(())
243 }
244}
245
246fn append_value_or_nulls<B: ArrayBuilder + 'static>(
251 builder: &mut dyn ArrayBuilder,
252 is_null: bool,
253 n: usize,
254 fill: impl FnOnce(&mut B),
255) {
256 let b = builder
257 .as_any_mut()
258 .downcast_mut::<B>()
259 .vortex_expect("builder dtype must match array dtype");
260 if is_null {
261 unsafe { b.append_nulls_unchecked(n) };
263 } else {
264 fill(b);
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use rstest::rstest;
271 use vortex_session::VortexSession;
272
273 use crate::ExecutionCtx;
274 use crate::IntoArray;
275 use crate::arrays::ConstantArray;
276 use crate::arrays::constant::vtable::canonical::constant_canonicalize;
277 use crate::assert_arrays_eq;
278 use crate::builders::builder_with_capacity;
279 use crate::dtype::DType;
280 use crate::dtype::Nullability;
281 use crate::dtype::PType;
282 use crate::dtype::StructFields;
283 use crate::scalar::Scalar;
284
285 fn ctx() -> ExecutionCtx {
286 ExecutionCtx::new(VortexSession::empty())
287 }
288
289 fn assert_append_matches_canonical(array: ConstantArray) -> vortex_error::VortexResult<()> {
291 let expected = constant_canonicalize(array.as_view())?.into_array();
292 let mut builder = builder_with_capacity(array.dtype(), array.len());
293 array
294 .into_array()
295 .append_to_builder(builder.as_mut(), &mut ctx())?;
296 let result = builder.finish();
297 assert_arrays_eq!(&result, &expected);
298 Ok(())
299 }
300
301 #[test]
302 fn test_null_constant_append() -> vortex_error::VortexResult<()> {
303 assert_append_matches_canonical(ConstantArray::new(Scalar::null(DType::Null), 5))
304 }
305
306 #[rstest]
307 #[case::bool_true(true, 5)]
308 #[case::bool_false(false, 3)]
309 fn test_bool_constant_append(
310 #[case] value: bool,
311 #[case] n: usize,
312 ) -> vortex_error::VortexResult<()> {
313 assert_append_matches_canonical(ConstantArray::new(
314 Scalar::bool(value, Nullability::NonNullable),
315 n,
316 ))
317 }
318
319 #[test]
320 fn test_bool_null_constant_append() -> vortex_error::VortexResult<()> {
321 assert_append_matches_canonical(ConstantArray::new(
322 Scalar::null(DType::Bool(Nullability::Nullable)),
323 4,
324 ))
325 }
326
327 #[rstest]
328 #[case::i32(Scalar::primitive(42i32, Nullability::NonNullable), 5)]
329 #[case::u8(Scalar::primitive(7u8, Nullability::NonNullable), 3)]
330 #[case::f64(Scalar::primitive(1.5f64, Nullability::NonNullable), 4)]
331 #[case::i32_null(Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)), 3)]
332 fn test_primitive_constant_append(
333 #[case] scalar: Scalar,
334 #[case] n: usize,
335 ) -> vortex_error::VortexResult<()> {
336 assert_append_matches_canonical(ConstantArray::new(scalar, n))
337 }
338
339 #[rstest]
340 #[case::utf8_inline("hi", 5)] #[case::utf8_noninline("hello world!!", 5)] #[case::utf8_empty("", 3)]
343 #[case::utf8_n_zero("hello world!!", 0)] fn test_utf8_constant_append(
345 #[case] value: &str,
346 #[case] n: usize,
347 ) -> vortex_error::VortexResult<()> {
348 assert_append_matches_canonical(ConstantArray::new(
349 Scalar::utf8(value, Nullability::NonNullable),
350 n,
351 ))
352 }
353
354 #[test]
355 fn test_utf8_null_constant_append() -> vortex_error::VortexResult<()> {
356 assert_append_matches_canonical(ConstantArray::new(
357 Scalar::null(DType::Utf8(Nullability::Nullable)),
358 4,
359 ))
360 }
361
362 #[rstest]
363 #[case::binary_inline(vec![1u8, 2, 3], 5)] #[case::binary_noninline(vec![0u8; 13], 5)] fn test_binary_constant_append(
366 #[case] value: Vec<u8>,
367 #[case] n: usize,
368 ) -> vortex_error::VortexResult<()> {
369 assert_append_matches_canonical(ConstantArray::new(
370 Scalar::binary(value, Nullability::NonNullable),
371 n,
372 ))
373 }
374
375 #[test]
376 fn test_binary_null_constant_append() -> vortex_error::VortexResult<()> {
377 assert_append_matches_canonical(ConstantArray::new(
378 Scalar::null(DType::Binary(Nullability::Nullable)),
379 4,
380 ))
381 }
382
383 #[test]
384 fn test_struct_constant_append() -> vortex_error::VortexResult<()> {
385 let fields = StructFields::new(
386 ["x", "y"].into(),
387 vec![
388 DType::Primitive(PType::I32, Nullability::NonNullable),
389 DType::Utf8(Nullability::NonNullable),
390 ],
391 );
392 let scalar = Scalar::struct_(
393 DType::Struct(fields, Nullability::NonNullable),
394 [
395 Scalar::primitive(42i32, Nullability::NonNullable),
396 Scalar::utf8("hi", Nullability::NonNullable),
397 ],
398 );
399 assert_append_matches_canonical(ConstantArray::new(scalar, 3))
400 }
401
402 #[test]
403 fn test_null_struct_constant_append() -> vortex_error::VortexResult<()> {
404 let fields = StructFields::new(
405 ["x"].into(),
406 vec![DType::Primitive(PType::I32, Nullability::Nullable)],
407 );
408 let dtype = DType::Struct(fields, Nullability::Nullable);
409 assert_append_matches_canonical(ConstantArray::new(Scalar::null(dtype), 4))
410 }
411}