Skip to main content

vortex_array/arrays/constant/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::hash::Hasher;
7
8use vortex_buffer::ByteBufferMut;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_panic;
13use vortex_session::VortexSession;
14use vortex_session::registry::CachedId;
15
16use crate::ArrayEq;
17use crate::ArrayHash;
18use crate::ArrayRef;
19use crate::ExecutionCtx;
20use crate::ExecutionResult;
21use crate::IntoArray;
22use crate::Precision;
23use crate::array::Array;
24use crate::array::ArrayId;
25use crate::array::ArrayView;
26use crate::array::VTable;
27use crate::arrays::constant::ConstantData;
28use crate::arrays::constant::compute::rules::PARENT_RULES;
29use crate::arrays::constant::vtable::canonical::constant_canonicalize;
30use crate::buffer::BufferHandle;
31use crate::builders::ArrayBuilder;
32use crate::builders::BoolBuilder;
33use crate::builders::DecimalBuilder;
34use crate::builders::NullBuilder;
35use crate::builders::PrimitiveBuilder;
36use crate::builders::VarBinViewBuilder;
37use crate::canonical::Canonical;
38use crate::dtype::DType;
39use crate::match_each_decimal_value;
40use crate::match_each_native_ptype;
41use crate::scalar::DecimalValue;
42use crate::scalar::Scalar;
43use crate::scalar::ScalarValue;
44use crate::serde::ArrayChildren;
45pub(crate) mod canonical;
46mod operations;
47mod validity;
48
49/// A [`Constant`]-encoded Vortex array.
50pub type ConstantArray = Array<Constant>;
51
52#[derive(Clone, Debug)]
53pub struct Constant;
54
55impl ArrayHash for ConstantData {
56    fn array_hash<H: Hasher>(&self, state: &mut H, _precision: Precision) {
57        self.scalar.hash(state);
58    }
59}
60
61impl ArrayEq for ConstantData {
62    fn array_eq(&self, other: &Self, _precision: Precision) -> bool {
63        self.scalar == other.scalar
64    }
65}
66
67impl VTable for Constant {
68    type ArrayData = ConstantData;
69
70    type OperationsVTable = Self;
71    type ValidityVTable = Self;
72
73    fn id(&self) -> ArrayId {
74        static ID: CachedId = CachedId::new("vortex.constant");
75        *ID
76    }
77
78    fn validate(
79        &self,
80        data: &ConstantData,
81        dtype: &DType,
82        _len: usize,
83        _slots: &[Option<ArrayRef>],
84    ) -> VortexResult<()> {
85        vortex_ensure!(
86            data.scalar.dtype() == dtype,
87            "ConstantArray scalar dtype does not match outer dtype"
88        );
89        Ok(())
90    }
91
92    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
93        1
94    }
95
96    fn buffer(array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
97        match idx {
98            0 => BufferHandle::new_host(
99                ScalarValue::to_proto_bytes::<ByteBufferMut>(array.scalar.value()).freeze(),
100            ),
101            _ => vortex_panic!("ConstantArray buffer index {idx} out of bounds"),
102        }
103    }
104
105    fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
106        match idx {
107            0 => Some("scalar".to_string()),
108            _ => None,
109        }
110    }
111
112    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
113        vortex_panic!("ConstantArray slot_name index {idx} out of bounds")
114    }
115
116    fn serialize(
117        _array: ArrayView<'_, Self>,
118        _session: &VortexSession,
119    ) -> VortexResult<Option<Vec<u8>>> {
120        // HACK: Because the scalar is stored in the buffers, we do not need to serialize the
121        // metadata at all.
122        Ok(Some(vec![]))
123    }
124
125    fn deserialize(
126        &self,
127        dtype: &DType,
128        len: usize,
129        _metadata: &[u8],
130
131        buffers: &[BufferHandle],
132        _children: &dyn ArrayChildren,
133        session: &VortexSession,
134    ) -> VortexResult<crate::array::ArrayParts<Self>> {
135        vortex_ensure!(
136            buffers.len() == 1,
137            "Expected 1 buffer, got {}",
138            buffers.len()
139        );
140
141        let buffer = buffers[0].clone().try_to_host_sync()?;
142        let bytes: &[u8] = buffer.as_ref();
143
144        let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype, session)?;
145        let scalar = Scalar::try_new(dtype.clone(), scalar_value)?;
146
147        Ok(crate::array::ArrayParts::new(
148            self.clone(),
149            dtype.clone(),
150            len,
151            ConstantData::new(scalar),
152        ))
153    }
154
155    fn reduce_parent(
156        array: ArrayView<'_, Self>,
157        parent: &ArrayRef,
158        child_idx: usize,
159    ) -> VortexResult<Option<ArrayRef>> {
160        PARENT_RULES.evaluate(array, parent, child_idx)
161    }
162
163    fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
164        Ok(ExecutionResult::done(constant_canonicalize(
165            array.as_view(),
166        )?))
167    }
168
169    fn append_to_builder(
170        array: ArrayView<'_, Self>,
171        builder: &mut dyn ArrayBuilder,
172        ctx: &mut ExecutionCtx,
173    ) -> VortexResult<()> {
174        let n = array.len();
175        let scalar = array.scalar();
176
177        match array.dtype() {
178            DType::Null => append_value_or_nulls::<NullBuilder>(builder, true, n, |_| {}),
179            DType::Bool(_) => {
180                append_value_or_nulls::<BoolBuilder>(builder, scalar.is_null(), n, |b| {
181                    b.append_values(
182                        scalar
183                            .as_bool()
184                            .value()
185                            .vortex_expect("non-null bool scalar must have a value"),
186                        n,
187                    );
188                })
189            }
190            DType::Primitive(ptype, _) => {
191                match_each_native_ptype!(ptype, |P| {
192                    append_value_or_nulls::<PrimitiveBuilder<P>>(
193                        builder,
194                        scalar.is_null(),
195                        n,
196                        |b| {
197                            let value = P::try_from(scalar)
198                                .vortex_expect("Couldn't unwrap constant scalar to primitive");
199                            b.append_n_values(value, n);
200                        },
201                    );
202                });
203            }
204            DType::Decimal(..) => {
205                append_value_or_nulls::<DecimalBuilder>(builder, scalar.is_null(), n, |b| {
206                    let value = scalar
207                        .as_decimal()
208                        .decimal_value()
209                        .vortex_expect("non-null decimal scalar must have a value");
210                    match_each_decimal_value!(value, |v| { b.append_n_values(v, n) });
211                });
212            }
213            DType::Utf8(_) => {
214                append_value_or_nulls::<VarBinViewBuilder>(builder, scalar.is_null(), n, |b| {
215                    let typed = scalar.as_utf8();
216                    let value = typed
217                        .value()
218                        .vortex_expect("non-null utf8 scalar must have a value");
219                    b.append_n_values(value.as_bytes(), n);
220                });
221            }
222            DType::Binary(_) => {
223                append_value_or_nulls::<VarBinViewBuilder>(builder, scalar.is_null(), n, |b| {
224                    let typed = scalar.as_binary();
225                    let value = typed
226                        .value()
227                        .vortex_expect("non-null binary scalar must have a value");
228                    b.append_n_values(value, n);
229                });
230            }
231            // TODO: add fast paths for DType::Struct, DType::List, DType::FixedSizeList, DType::Extension.
232            _ => {
233                let canonical = array
234                    .array()
235                    .clone()
236                    .execute::<Canonical>(ctx)?
237                    .into_array();
238                builder.extend_from_array(&canonical);
239            }
240        }
241
242        Ok(())
243    }
244}
245
246/// Downcasts `builder` to `B`, then either appends `n` nulls or calls `fill` with the typed
247/// builder depending on `is_null`.
248///
249/// `is_null` must only be `true` when the builder is nullable.
250fn append_value_or_nulls<B: ArrayBuilder + 'static>(
251    builder: &mut dyn ArrayBuilder,
252    is_null: bool,
253    n: usize,
254    fill: impl FnOnce(&mut B),
255) {
256    let b = builder
257        .as_any_mut()
258        .downcast_mut::<B>()
259        .vortex_expect("builder dtype must match array dtype");
260    if is_null {
261        // SAFETY: is_null=true only when the scalar (and thus the builder) is nullable.
262        unsafe { b.append_nulls_unchecked(n) };
263    } else {
264        fill(b);
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use rstest::rstest;
271    use vortex_session::VortexSession;
272
273    use crate::ExecutionCtx;
274    use crate::IntoArray;
275    use crate::arrays::ConstantArray;
276    use crate::arrays::constant::vtable::canonical::constant_canonicalize;
277    use crate::assert_arrays_eq;
278    use crate::builders::builder_with_capacity;
279    use crate::dtype::DType;
280    use crate::dtype::Nullability;
281    use crate::dtype::PType;
282    use crate::dtype::StructFields;
283    use crate::scalar::Scalar;
284
285    fn ctx() -> ExecutionCtx {
286        ExecutionCtx::new(VortexSession::empty())
287    }
288
289    /// Appends `array` into a fresh builder and asserts the result matches `constant_canonicalize`.
290    fn assert_append_matches_canonical(array: ConstantArray) -> vortex_error::VortexResult<()> {
291        let expected = constant_canonicalize(array.as_view())?.into_array();
292        let mut builder = builder_with_capacity(array.dtype(), array.len());
293        array
294            .into_array()
295            .append_to_builder(builder.as_mut(), &mut ctx())?;
296        let result = builder.finish();
297        assert_arrays_eq!(&result, &expected);
298        Ok(())
299    }
300
301    #[test]
302    fn test_null_constant_append() -> vortex_error::VortexResult<()> {
303        assert_append_matches_canonical(ConstantArray::new(Scalar::null(DType::Null), 5))
304    }
305
306    #[rstest]
307    #[case::bool_true(true, 5)]
308    #[case::bool_false(false, 3)]
309    fn test_bool_constant_append(
310        #[case] value: bool,
311        #[case] n: usize,
312    ) -> vortex_error::VortexResult<()> {
313        assert_append_matches_canonical(ConstantArray::new(
314            Scalar::bool(value, Nullability::NonNullable),
315            n,
316        ))
317    }
318
319    #[test]
320    fn test_bool_null_constant_append() -> vortex_error::VortexResult<()> {
321        assert_append_matches_canonical(ConstantArray::new(
322            Scalar::null(DType::Bool(Nullability::Nullable)),
323            4,
324        ))
325    }
326
327    #[rstest]
328    #[case::i32(Scalar::primitive(42i32, Nullability::NonNullable), 5)]
329    #[case::u8(Scalar::primitive(7u8, Nullability::NonNullable), 3)]
330    #[case::f64(Scalar::primitive(1.5f64, Nullability::NonNullable), 4)]
331    #[case::i32_null(Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)), 3)]
332    fn test_primitive_constant_append(
333        #[case] scalar: Scalar,
334        #[case] n: usize,
335    ) -> vortex_error::VortexResult<()> {
336        assert_append_matches_canonical(ConstantArray::new(scalar, n))
337    }
338
339    #[rstest]
340    #[case::utf8_inline("hi", 5)] // ≤12 bytes: inlined in BinaryView
341    #[case::utf8_noninline("hello world!!", 5)] // >12 bytes: requires buffer block
342    #[case::utf8_empty("", 3)]
343    #[case::utf8_n_zero("hello world!!", 0)] // n=0 with non-inline: must not write orphaned bytes
344    fn test_utf8_constant_append(
345        #[case] value: &str,
346        #[case] n: usize,
347    ) -> vortex_error::VortexResult<()> {
348        assert_append_matches_canonical(ConstantArray::new(
349            Scalar::utf8(value, Nullability::NonNullable),
350            n,
351        ))
352    }
353
354    #[test]
355    fn test_utf8_null_constant_append() -> vortex_error::VortexResult<()> {
356        assert_append_matches_canonical(ConstantArray::new(
357            Scalar::null(DType::Utf8(Nullability::Nullable)),
358            4,
359        ))
360    }
361
362    #[rstest]
363    #[case::binary_inline(vec![1u8, 2, 3], 5)] // ≤12 bytes: inlined
364    #[case::binary_noninline(vec![0u8; 13], 5)] // >12 bytes: buffer block
365    fn test_binary_constant_append(
366        #[case] value: Vec<u8>,
367        #[case] n: usize,
368    ) -> vortex_error::VortexResult<()> {
369        assert_append_matches_canonical(ConstantArray::new(
370            Scalar::binary(value, Nullability::NonNullable),
371            n,
372        ))
373    }
374
375    #[test]
376    fn test_binary_null_constant_append() -> vortex_error::VortexResult<()> {
377        assert_append_matches_canonical(ConstantArray::new(
378            Scalar::null(DType::Binary(Nullability::Nullable)),
379            4,
380        ))
381    }
382
383    #[test]
384    fn test_struct_constant_append() -> vortex_error::VortexResult<()> {
385        let fields = StructFields::new(
386            ["x", "y"].into(),
387            vec![
388                DType::Primitive(PType::I32, Nullability::NonNullable),
389                DType::Utf8(Nullability::NonNullable),
390            ],
391        );
392        let scalar = Scalar::struct_(
393            DType::Struct(fields, Nullability::NonNullable),
394            [
395                Scalar::primitive(42i32, Nullability::NonNullable),
396                Scalar::utf8("hi", Nullability::NonNullable),
397            ],
398        );
399        assert_append_matches_canonical(ConstantArray::new(scalar, 3))
400    }
401
402    #[test]
403    fn test_null_struct_constant_append() -> vortex_error::VortexResult<()> {
404        let fields = StructFields::new(
405            ["x"].into(),
406            vec![DType::Primitive(PType::I32, Nullability::Nullable)],
407        );
408        let dtype = DType::Struct(fields, Nullability::Nullable);
409        assert_append_matches_canonical(ConstantArray::new(Scalar::null(dtype), 4))
410    }
411}