Skip to main content

vortex_array/arrays/constant/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::hash::Hasher;
7
8use vortex_buffer::ByteBufferMut;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_panic;
13use vortex_session::VortexSession;
14
15use crate::ArrayEq;
16use crate::ArrayHash;
17use crate::ArrayRef;
18use crate::ExecutionCtx;
19use crate::ExecutionResult;
20use crate::IntoArray;
21use crate::Precision;
22use crate::array::Array;
23use crate::array::ArrayId;
24use crate::array::ArrayView;
25use crate::array::VTable;
26use crate::arrays::constant::ConstantData;
27use crate::arrays::constant::compute::rules::PARENT_RULES;
28use crate::arrays::constant::vtable::canonical::constant_canonicalize;
29use crate::buffer::BufferHandle;
30use crate::builders::ArrayBuilder;
31use crate::builders::BoolBuilder;
32use crate::builders::DecimalBuilder;
33use crate::builders::NullBuilder;
34use crate::builders::PrimitiveBuilder;
35use crate::builders::VarBinViewBuilder;
36use crate::canonical::Canonical;
37use crate::dtype::DType;
38use crate::match_each_decimal_value;
39use crate::match_each_native_ptype;
40use crate::scalar::DecimalValue;
41use crate::scalar::Scalar;
42use crate::scalar::ScalarValue;
43use crate::serde::ArrayChildren;
44pub(crate) mod canonical;
45mod operations;
46mod validity;
47
48/// A [`Constant`]-encoded Vortex array.
49pub type ConstantArray = Array<Constant>;
50
51#[derive(Clone, Debug)]
52pub struct Constant;
53
54impl Constant {
55    pub const ID: ArrayId = ArrayId::new_ref("vortex.constant");
56}
57
58impl ArrayHash for ConstantData {
59    fn array_hash<H: Hasher>(&self, state: &mut H, _precision: Precision) {
60        self.scalar.hash(state);
61    }
62}
63
64impl ArrayEq for ConstantData {
65    fn array_eq(&self, other: &Self, _precision: Precision) -> bool {
66        self.scalar == other.scalar
67    }
68}
69
70impl VTable for Constant {
71    type ArrayData = ConstantData;
72
73    type OperationsVTable = Self;
74    type ValidityVTable = Self;
75
76    fn id(&self) -> ArrayId {
77        Self::ID
78    }
79
80    fn validate(
81        &self,
82        data: &ConstantData,
83        dtype: &DType,
84        _len: usize,
85        _slots: &[Option<ArrayRef>],
86    ) -> VortexResult<()> {
87        vortex_ensure!(
88            data.scalar.dtype() == dtype,
89            "ConstantArray scalar dtype does not match outer dtype"
90        );
91        Ok(())
92    }
93
94    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
95        1
96    }
97
98    fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
99        match idx {
100            0 => BufferHandle::new_host(
101                ScalarValue::to_proto_bytes::<ByteBufferMut>(array.scalar.value()).freeze(),
102            ),
103            _ => vortex_panic!("ConstantArray buffer index {idx} out of bounds"),
104        }
105    }
106
107    fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
108        match idx {
109            0 => Some("scalar".to_string()),
110            _ => None,
111        }
112    }
113
114    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
115        vortex_panic!("ConstantArray slot_name index {idx} out of bounds")
116    }
117
118    fn serialize(
119        _array: ArrayView<'_, Self>,
120        _session: &VortexSession,
121    ) -> VortexResult<Option<Vec<u8>>> {
122        // HACK: Because the scalar is stored in the buffers, we do not need to serialize the
123        // metadata at all.
124        Ok(Some(vec![]))
125    }
126
127    fn deserialize(
128        &self,
129        dtype: &DType,
130        len: usize,
131        _metadata: &[u8],
132
133        buffers: &[BufferHandle],
134        _children: &dyn ArrayChildren,
135        session: &VortexSession,
136    ) -> VortexResult<crate::array::ArrayParts<Self>> {
137        vortex_ensure!(
138            buffers.len() == 1,
139            "Expected 1 buffer, got {}",
140            buffers.len()
141        );
142
143        let buffer = buffers[0].clone().try_to_host_sync()?;
144        let bytes: &[u8] = buffer.as_ref();
145
146        let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype, session)?;
147        let scalar = Scalar::try_new(dtype.clone(), scalar_value)?;
148
149        Ok(crate::array::ArrayParts::new(
150            self.clone(),
151            dtype.clone(),
152            len,
153            ConstantData::new(scalar),
154        ))
155    }
156
157    fn reduce_parent(
158        array: ArrayView<'_, Self>,
159        parent: &ArrayRef,
160        child_idx: usize,
161    ) -> VortexResult<Option<ArrayRef>> {
162        PARENT_RULES.evaluate(array, parent, child_idx)
163    }
164
165    fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
166        Ok(ExecutionResult::done(constant_canonicalize(
167            array.as_view(),
168        )?))
169    }
170
171    fn append_to_builder(
172        array: ArrayView<'_, Self>,
173        builder: &mut dyn ArrayBuilder,
174        ctx: &mut ExecutionCtx,
175    ) -> VortexResult<()> {
176        let n = array.len();
177        let scalar = array.scalar();
178
179        match array.dtype() {
180            DType::Null => append_value_or_nulls::<NullBuilder>(builder, true, n, |_| {}),
181            DType::Bool(_) => {
182                append_value_or_nulls::<BoolBuilder>(builder, scalar.is_null(), n, |b| {
183                    b.append_values(
184                        scalar
185                            .as_bool()
186                            .value()
187                            .vortex_expect("non-null bool scalar must have a value"),
188                        n,
189                    );
190                })
191            }
192            DType::Primitive(ptype, _) => {
193                match_each_native_ptype!(ptype, |P| {
194                    append_value_or_nulls::<PrimitiveBuilder<P>>(
195                        builder,
196                        scalar.is_null(),
197                        n,
198                        |b| {
199                            let value = P::try_from(scalar)
200                                .vortex_expect("Couldn't unwrap constant scalar to primitive");
201                            b.append_n_values(value, n);
202                        },
203                    );
204                });
205            }
206            DType::Decimal(..) => {
207                append_value_or_nulls::<DecimalBuilder>(builder, scalar.is_null(), n, |b| {
208                    let value = scalar
209                        .as_decimal()
210                        .decimal_value()
211                        .vortex_expect("non-null decimal scalar must have a value");
212                    match_each_decimal_value!(value, |v| { b.append_n_values(v, n) });
213                });
214            }
215            DType::Utf8(_) => {
216                append_value_or_nulls::<VarBinViewBuilder>(builder, scalar.is_null(), n, |b| {
217                    let typed = scalar.as_utf8();
218                    let value = typed
219                        .value()
220                        .vortex_expect("non-null utf8 scalar must have a value");
221                    b.append_n_values(value.as_bytes(), n);
222                });
223            }
224            DType::Binary(_) => {
225                append_value_or_nulls::<VarBinViewBuilder>(builder, scalar.is_null(), n, |b| {
226                    let typed = scalar.as_binary();
227                    let value = typed
228                        .value()
229                        .vortex_expect("non-null binary scalar must have a value");
230                    b.append_n_values(value, n);
231                });
232            }
233            // TODO: add fast paths for DType::Struct, DType::List, DType::FixedSizeList, DType::Extension.
234            _ => {
235                let canonical = array
236                    .array()
237                    .clone()
238                    .execute::<Canonical>(ctx)?
239                    .into_array();
240                builder.extend_from_array(&canonical);
241            }
242        }
243
244        Ok(())
245    }
246}
247
248/// Downcasts `builder` to `B`, then either appends `n` nulls or calls `fill` with the typed
249/// builder depending on `is_null`.
250///
251/// `is_null` must only be `true` when the builder is nullable.
252fn append_value_or_nulls<B: ArrayBuilder + 'static>(
253    builder: &mut dyn ArrayBuilder,
254    is_null: bool,
255    n: usize,
256    fill: impl FnOnce(&mut B),
257) {
258    let b = builder
259        .as_any_mut()
260        .downcast_mut::<B>()
261        .vortex_expect("builder dtype must match array dtype");
262    if is_null {
263        // SAFETY: is_null=true only when the scalar (and thus the builder) is nullable.
264        unsafe { b.append_nulls_unchecked(n) };
265    } else {
266        fill(b);
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use rstest::rstest;
273    use vortex_session::VortexSession;
274
275    use crate::ExecutionCtx;
276    use crate::IntoArray;
277    use crate::arrays::ConstantArray;
278    use crate::arrays::constant::vtable::canonical::constant_canonicalize;
279    use crate::assert_arrays_eq;
280    use crate::builders::builder_with_capacity;
281    use crate::dtype::DType;
282    use crate::dtype::Nullability;
283    use crate::dtype::PType;
284    use crate::dtype::StructFields;
285    use crate::scalar::Scalar;
286
287    fn ctx() -> ExecutionCtx {
288        ExecutionCtx::new(VortexSession::empty())
289    }
290
291    /// Appends `array` into a fresh builder and asserts the result matches `constant_canonicalize`.
292    fn assert_append_matches_canonical(array: ConstantArray) -> vortex_error::VortexResult<()> {
293        let expected = constant_canonicalize(array.as_view())?.into_array();
294        let mut builder = builder_with_capacity(array.dtype(), array.len());
295        array
296            .into_array()
297            .append_to_builder(builder.as_mut(), &mut ctx())?;
298        let result = builder.finish();
299        assert_arrays_eq!(&result, &expected);
300        Ok(())
301    }
302
303    #[test]
304    fn test_null_constant_append() -> vortex_error::VortexResult<()> {
305        assert_append_matches_canonical(ConstantArray::new(Scalar::null(DType::Null), 5))
306    }
307
308    #[rstest]
309    #[case::bool_true(true, 5)]
310    #[case::bool_false(false, 3)]
311    fn test_bool_constant_append(
312        #[case] value: bool,
313        #[case] n: usize,
314    ) -> vortex_error::VortexResult<()> {
315        assert_append_matches_canonical(ConstantArray::new(
316            Scalar::bool(value, Nullability::NonNullable),
317            n,
318        ))
319    }
320
321    #[test]
322    fn test_bool_null_constant_append() -> vortex_error::VortexResult<()> {
323        assert_append_matches_canonical(ConstantArray::new(
324            Scalar::null(DType::Bool(Nullability::Nullable)),
325            4,
326        ))
327    }
328
329    #[rstest]
330    #[case::i32(Scalar::primitive(42i32, Nullability::NonNullable), 5)]
331    #[case::u8(Scalar::primitive(7u8, Nullability::NonNullable), 3)]
332    #[case::f64(Scalar::primitive(1.5f64, Nullability::NonNullable), 4)]
333    #[case::i32_null(Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)), 3)]
334    fn test_primitive_constant_append(
335        #[case] scalar: Scalar,
336        #[case] n: usize,
337    ) -> vortex_error::VortexResult<()> {
338        assert_append_matches_canonical(ConstantArray::new(scalar, n))
339    }
340
341    #[rstest]
342    #[case::utf8_inline("hi", 5)] // ≤12 bytes: inlined in BinaryView
343    #[case::utf8_noninline("hello world!!", 5)] // >12 bytes: requires buffer block
344    #[case::utf8_empty("", 3)]
345    #[case::utf8_n_zero("hello world!!", 0)] // n=0 with non-inline: must not write orphaned bytes
346    fn test_utf8_constant_append(
347        #[case] value: &str,
348        #[case] n: usize,
349    ) -> vortex_error::VortexResult<()> {
350        assert_append_matches_canonical(ConstantArray::new(
351            Scalar::utf8(value, Nullability::NonNullable),
352            n,
353        ))
354    }
355
356    #[test]
357    fn test_utf8_null_constant_append() -> vortex_error::VortexResult<()> {
358        assert_append_matches_canonical(ConstantArray::new(
359            Scalar::null(DType::Utf8(Nullability::Nullable)),
360            4,
361        ))
362    }
363
364    #[rstest]
365    #[case::binary_inline(vec![1u8, 2, 3], 5)] // ≤12 bytes: inlined
366    #[case::binary_noninline(vec![0u8; 13], 5)] // >12 bytes: buffer block
367    fn test_binary_constant_append(
368        #[case] value: Vec<u8>,
369        #[case] n: usize,
370    ) -> vortex_error::VortexResult<()> {
371        assert_append_matches_canonical(ConstantArray::new(
372            Scalar::binary(value, Nullability::NonNullable),
373            n,
374        ))
375    }
376
377    #[test]
378    fn test_binary_null_constant_append() -> vortex_error::VortexResult<()> {
379        assert_append_matches_canonical(ConstantArray::new(
380            Scalar::null(DType::Binary(Nullability::Nullable)),
381            4,
382        ))
383    }
384
385    #[test]
386    fn test_struct_constant_append() -> vortex_error::VortexResult<()> {
387        let fields = StructFields::new(
388            ["x", "y"].into(),
389            vec![
390                DType::Primitive(PType::I32, Nullability::NonNullable),
391                DType::Utf8(Nullability::NonNullable),
392            ],
393        );
394        let scalar = Scalar::struct_(
395            DType::Struct(fields, Nullability::NonNullable),
396            [
397                Scalar::primitive(42i32, Nullability::NonNullable),
398                Scalar::utf8("hi", Nullability::NonNullable),
399            ],
400        );
401        assert_append_matches_canonical(ConstantArray::new(scalar, 3))
402    }
403
404    #[test]
405    fn test_null_struct_constant_append() -> vortex_error::VortexResult<()> {
406        let fields = StructFields::new(
407            ["x"].into(),
408            vec![DType::Primitive(PType::I32, Nullability::Nullable)],
409        );
410        let dtype = DType::Struct(fields, Nullability::Nullable);
411        assert_append_matches_canonical(ConstantArray::new(Scalar::null(dtype), 4))
412    }
413}