Skip to main content

vortex_array/arrays/constant/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::hash::Hasher;
7
8use vortex_buffer::ByteBufferMut;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_panic;
13use vortex_session::VortexSession;
14use vortex_session::registry::CachedId;
15
16use crate::ArrayEq;
17use crate::ArrayHash;
18use crate::ArrayParts;
19use crate::ArrayRef;
20use crate::EqMode;
21use crate::ExecutionCtx;
22use crate::ExecutionResult;
23use crate::IntoArray;
24use crate::array::Array;
25use crate::array::ArrayId;
26use crate::array::ArrayView;
27use crate::array::VTable;
28use crate::array::unsupported_buffer_replacement;
29use crate::arrays::constant::ConstantData;
30use crate::arrays::constant::compute::rules::PARENT_RULES;
31use crate::arrays::constant::vtable::canonical::constant_canonicalize;
32use crate::buffer::BufferHandle;
33use crate::builders::ArrayBuilder;
34use crate::builders::BoolBuilder;
35use crate::builders::DecimalBuilder;
36use crate::builders::NullBuilder;
37use crate::builders::PrimitiveBuilder;
38use crate::builders::VarBinViewBuilder;
39use crate::canonical::Canonical;
40use crate::dtype::DType;
41use crate::match_each_decimal_value;
42use crate::match_each_native_ptype;
43use crate::scalar::DecimalValue;
44use crate::scalar::Scalar;
45use crate::scalar::ScalarValue;
46use crate::serde::ArrayChildren;
47pub(crate) mod canonical;
48mod operations;
49mod validity;
50
51/// A [`Constant`]-encoded Vortex array.
52pub type ConstantArray = Array<Constant>;
53
54#[derive(Clone, Debug)]
55pub struct Constant;
56
57impl ArrayHash for ConstantData {
58    fn array_hash<H: Hasher>(&self, state: &mut H, _accuracy: EqMode) {
59        self.scalar.hash(state);
60    }
61}
62
63impl ArrayEq for ConstantData {
64    fn array_eq(&self, other: &Self, _accuracy: EqMode) -> bool {
65        self.scalar == other.scalar
66    }
67}
68
69impl VTable for Constant {
70    type TypedArrayData = ConstantData;
71
72    type OperationsVTable = Self;
73    type ValidityVTable = Self;
74
75    fn id(&self) -> ArrayId {
76        static ID: CachedId = CachedId::new("vortex.constant");
77        *ID
78    }
79
80    fn validate(
81        &self,
82        data: &ConstantData,
83        dtype: &DType,
84        _len: usize,
85        _slots: &[Option<ArrayRef>],
86    ) -> VortexResult<()> {
87        vortex_ensure!(
88            data.scalar.dtype() == dtype,
89            "ConstantArray scalar dtype does not match outer dtype"
90        );
91        Ok(())
92    }
93
94    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
95        1
96    }
97
98    fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
99        match idx {
100            0 => BufferHandle::new_host(
101                ScalarValue::to_proto_bytes::<ByteBufferMut>(array.scalar.value()).freeze(),
102            ),
103            _ => vortex_panic!("ConstantArray buffer index {idx} out of bounds"),
104        }
105    }
106
107    fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
108        match idx {
109            0 => Some("scalar".to_string()),
110            _ => None,
111        }
112    }
113
114    fn with_buffers(
115        &self,
116        array: ArrayView<'_, Self>,
117        buffers: &[BufferHandle],
118    ) -> VortexResult<ArrayParts<Self>> {
119        unsupported_buffer_replacement(array, buffers)
120    }
121
122    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
123        vortex_panic!("ConstantArray slot_name index {idx} out of bounds")
124    }
125
126    fn serialize(
127        _array: ArrayView<'_, Self>,
128        _session: &VortexSession,
129    ) -> VortexResult<Option<Vec<u8>>> {
130        // HACK: Because the scalar is stored in the buffers, we do not need to serialize the
131        // metadata at all.
132        Ok(Some(vec![]))
133    }
134
135    fn deserialize(
136        &self,
137        dtype: &DType,
138        len: usize,
139        _metadata: &[u8],
140
141        buffers: &[BufferHandle],
142        _children: &dyn ArrayChildren,
143        session: &VortexSession,
144    ) -> VortexResult<ArrayParts<Self>> {
145        vortex_ensure!(
146            buffers.len() == 1,
147            "Expected 1 buffer, got {}",
148            buffers.len()
149        );
150
151        let buffer = buffers[0].clone().try_to_host_sync()?;
152        let bytes: &[u8] = buffer.as_ref();
153
154        let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype, session)?;
155        let scalar = Scalar::try_new(dtype.clone(), scalar_value)?;
156
157        Ok(ArrayParts::new(
158            self.clone(),
159            dtype.clone(),
160            len,
161            ConstantData::new(scalar),
162        ))
163    }
164
165    fn reduce_parent(
166        array: ArrayView<'_, Self>,
167        parent: &ArrayRef,
168        child_idx: usize,
169    ) -> VortexResult<Option<ArrayRef>> {
170        PARENT_RULES.evaluate(array, parent, child_idx)
171    }
172
173    fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
174        Ok(ExecutionResult::done(constant_canonicalize(
175            array.as_view(),
176            ctx,
177        )?))
178    }
179
180    fn append_to_builder(
181        array: ArrayView<'_, Self>,
182        builder: &mut dyn ArrayBuilder,
183        ctx: &mut ExecutionCtx,
184    ) -> VortexResult<()> {
185        let n = array.len();
186        let scalar = array.scalar();
187
188        match array.dtype() {
189            DType::Null => append_value_or_nulls::<NullBuilder>(builder, true, n, |_| {}),
190            DType::Bool(_) => {
191                append_value_or_nulls::<BoolBuilder>(builder, scalar.is_null(), n, |b| {
192                    b.append_values(
193                        scalar
194                            .as_bool()
195                            .value()
196                            .vortex_expect("non-null bool scalar must have a value"),
197                        n,
198                    );
199                })
200            }
201            DType::Primitive(ptype, _) => {
202                match_each_native_ptype!(ptype, |P| {
203                    append_value_or_nulls::<PrimitiveBuilder<P>>(
204                        builder,
205                        scalar.is_null(),
206                        n,
207                        |b| {
208                            let value = P::try_from(scalar)
209                                .vortex_expect("Couldn't unwrap constant scalar to primitive");
210                            b.append_n_values(value, n);
211                        },
212                    );
213                });
214            }
215            DType::Decimal(..) => {
216                append_value_or_nulls::<DecimalBuilder>(builder, scalar.is_null(), n, |b| {
217                    let value = scalar
218                        .as_decimal()
219                        .decimal_value()
220                        .vortex_expect("non-null decimal scalar must have a value");
221                    match_each_decimal_value!(value, |v| { b.append_n_values(v, n) });
222                });
223            }
224            DType::Utf8(_) => {
225                append_value_or_nulls::<VarBinViewBuilder>(builder, scalar.is_null(), n, |b| {
226                    let typed = scalar.as_utf8();
227                    let value = typed
228                        .value()
229                        .vortex_expect("non-null utf8 scalar must have a value");
230                    b.append_n_values(value.as_bytes(), n);
231                });
232            }
233            DType::Binary(_) => {
234                append_value_or_nulls::<VarBinViewBuilder>(builder, scalar.is_null(), n, |b| {
235                    let typed = scalar.as_binary();
236                    let value = typed
237                        .value()
238                        .vortex_expect("non-null binary scalar must have a value");
239                    b.append_n_values(value, n);
240                });
241            }
242            // TODO: add fast paths for DType::Struct, DType::List, DType::FixedSizeList, DType::Extension.
243            _ => {
244                let canonical = array
245                    .array()
246                    .clone()
247                    .execute::<Canonical>(ctx)?
248                    .into_array();
249                builder.extend_from_array(&canonical);
250            }
251        }
252
253        Ok(())
254    }
255}
256
257/// Downcasts `builder` to `B`, then either appends `n` nulls or calls `fill` with the typed
258/// builder depending on `is_null`.
259///
260/// `is_null` must only be `true` when the builder is nullable.
261fn append_value_or_nulls<B: ArrayBuilder + 'static>(
262    builder: &mut dyn ArrayBuilder,
263    is_null: bool,
264    n: usize,
265    fill: impl FnOnce(&mut B),
266) {
267    let b = builder
268        .as_any_mut()
269        .downcast_mut::<B>()
270        .vortex_expect("builder dtype must match array dtype");
271    if is_null {
272        // SAFETY: is_null=true only when the scalar (and thus the builder) is nullable.
273        unsafe { b.append_nulls_unchecked(n) };
274    } else {
275        fill(b);
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use rstest::rstest;
282    use vortex_error::VortexResult;
283
284    use crate::IntoArray;
285    use crate::VortexSessionExecute;
286    use crate::arrays::ConstantArray;
287    use crate::arrays::constant::vtable::canonical::constant_canonicalize;
288    use crate::assert_arrays_eq;
289    use crate::builders::builder_with_capacity;
290    use crate::dtype::DType;
291    use crate::dtype::Nullability;
292    use crate::dtype::PType;
293    use crate::dtype::StructFields;
294    use crate::scalar::Scalar;
295
296    /// Appends `array` into a fresh builder and asserts the result matches `constant_canonicalize`.
297    fn assert_append_matches_canonical(array: ConstantArray) -> VortexResult<()> {
298        let mut ctx = crate::array_session().create_execution_ctx();
299
300        let expected = constant_canonicalize(array.as_view(), &mut ctx)?.into_array();
301        let mut builder = builder_with_capacity(array.dtype(), array.len());
302        array
303            .into_array()
304            .append_to_builder(builder.as_mut(), &mut ctx)?;
305        let result = builder.finish();
306        assert_arrays_eq!(&result, &expected, &mut ctx);
307        Ok(())
308    }
309
310    #[test]
311    fn test_null_constant_append() -> VortexResult<()> {
312        assert_append_matches_canonical(ConstantArray::new(Scalar::null(DType::Null), 5))
313    }
314
315    #[test]
316    fn test_with_buffers_rejects_serialized_scalar_buffer() {
317        let array =
318            ConstantArray::new(Scalar::primitive(42i32, Nullability::NonNullable), 3).into_array();
319        let buffers = array.buffer_handles();
320
321        // SAFETY: the replacement buffers are the array's existing buffers, so the logical values
322        // would be unchanged if the encoding supported buffer replacement.
323        let Err(err) = (unsafe { array.with_buffers(buffers) }) else {
324            panic!("ConstantArray should reject replacing its serialized scalar buffer");
325        };
326        assert!(
327            err.to_string()
328                .contains("does not support in-memory buffer replacement")
329        );
330    }
331
332    #[rstest]
333    #[case::bool_true(true, 5)]
334    #[case::bool_false(false, 3)]
335    fn test_bool_constant_append(#[case] value: bool, #[case] n: usize) -> VortexResult<()> {
336        assert_append_matches_canonical(ConstantArray::new(
337            Scalar::bool(value, Nullability::NonNullable),
338            n,
339        ))
340    }
341
342    #[test]
343    fn test_bool_null_constant_append() -> VortexResult<()> {
344        assert_append_matches_canonical(ConstantArray::new(
345            Scalar::null(DType::Bool(Nullability::Nullable)),
346            4,
347        ))
348    }
349
350    #[rstest]
351    #[case::i32(Scalar::primitive(42i32, Nullability::NonNullable), 5)]
352    #[case::u8(Scalar::primitive(7u8, Nullability::NonNullable), 3)]
353    #[case::f64(Scalar::primitive(1.5f64, Nullability::NonNullable), 4)]
354    #[case::i32_null(Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)), 3)]
355    fn test_primitive_constant_append(
356        #[case] scalar: Scalar,
357        #[case] n: usize,
358    ) -> VortexResult<()> {
359        assert_append_matches_canonical(ConstantArray::new(scalar, n))
360    }
361
362    #[rstest]
363    #[case::utf8_inline("hi", 5)] // ≤12 bytes: inlined in BinaryView
364    #[case::utf8_noninline("hello world!!", 5)] // >12 bytes: requires buffer block
365    #[case::utf8_empty("", 3)]
366    #[case::utf8_n_zero("hello world!!", 0)] // n=0 with non-inline: must not write orphaned bytes
367    fn test_utf8_constant_append(#[case] value: &str, #[case] n: usize) -> VortexResult<()> {
368        assert_append_matches_canonical(ConstantArray::new(
369            Scalar::utf8(value, Nullability::NonNullable),
370            n,
371        ))
372    }
373
374    #[test]
375    fn test_utf8_null_constant_append() -> VortexResult<()> {
376        assert_append_matches_canonical(ConstantArray::new(
377            Scalar::null(DType::Utf8(Nullability::Nullable)),
378            4,
379        ))
380    }
381
382    #[rstest]
383    #[case::binary_inline(vec![1u8, 2, 3], 5)] // ≤12 bytes: inlined
384    #[case::binary_noninline(vec![0u8; 13], 5)] // >12 bytes: buffer block
385    fn test_binary_constant_append(#[case] value: Vec<u8>, #[case] n: usize) -> VortexResult<()> {
386        assert_append_matches_canonical(ConstantArray::new(
387            Scalar::binary(value, Nullability::NonNullable),
388            n,
389        ))
390    }
391
392    #[test]
393    fn test_binary_null_constant_append() -> VortexResult<()> {
394        assert_append_matches_canonical(ConstantArray::new(
395            Scalar::null(DType::Binary(Nullability::Nullable)),
396            4,
397        ))
398    }
399
400    #[test]
401    fn test_struct_constant_append() -> VortexResult<()> {
402        let fields = StructFields::new(
403            ["x", "y"].into(),
404            vec![
405                DType::Primitive(PType::I32, Nullability::NonNullable),
406                DType::Utf8(Nullability::NonNullable),
407            ],
408        );
409        let scalar = Scalar::struct_(
410            DType::Struct(fields, Nullability::NonNullable),
411            [
412                Scalar::primitive(42i32, Nullability::NonNullable),
413                Scalar::utf8("hi", Nullability::NonNullable),
414            ],
415        );
416        assert_append_matches_canonical(ConstantArray::new(scalar, 3))
417    }
418
419    #[test]
420    fn test_null_struct_constant_append() -> VortexResult<()> {
421        let fields = StructFields::new(
422            ["x"].into(),
423            vec![DType::Primitive(PType::I32, Nullability::Nullable)],
424        );
425        let dtype = DType::Struct(fields, Nullability::Nullable);
426        assert_append_matches_canonical(ConstantArray::new(Scalar::null(dtype), 4))
427    }
428}