Skip to main content

vortex_array/arrays/constant/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6
7use vortex_buffer::ByteBufferMut;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_ensure;
11use vortex_error::vortex_panic;
12use vortex_session::VortexSession;
13
14use crate::ArrayRef;
15use crate::ExecutionCtx;
16use crate::ExecutionStep;
17use crate::IntoArray;
18use crate::Precision;
19use crate::arrays::ConstantArray;
20use crate::arrays::constant::compute::rules::PARENT_RULES;
21use crate::arrays::constant::vtable::canonical::constant_canonicalize;
22use crate::buffer::BufferHandle;
23use crate::builders::ArrayBuilder;
24use crate::builders::BoolBuilder;
25use crate::builders::DecimalBuilder;
26use crate::builders::NullBuilder;
27use crate::builders::PrimitiveBuilder;
28use crate::builders::VarBinViewBuilder;
29use crate::canonical::Canonical;
30use crate::dtype::DType;
31use crate::match_each_decimal_value;
32use crate::match_each_native_ptype;
33use crate::scalar::DecimalValue;
34use crate::scalar::Scalar;
35use crate::scalar::ScalarValue;
36use crate::serde::ArrayChildren;
37use crate::stats::StatsSetRef;
38use crate::vtable;
39use crate::vtable::ArrayId;
40use crate::vtable::VTable;
41pub(crate) mod canonical;
42mod operations;
43mod validity;
44
45vtable!(Constant);
46
47#[derive(Debug)]
48pub struct ConstantVTable;
49
50impl ConstantVTable {
51    pub const ID: ArrayId = ArrayId::new_ref("vortex.constant");
52}
53
54impl VTable for ConstantVTable {
55    type Array = ConstantArray;
56
57    type Metadata = Scalar;
58    type OperationsVTable = Self;
59    type ValidityVTable = Self;
60
61    fn id(_array: &Self::Array) -> ArrayId {
62        Self::ID
63    }
64
65    fn len(array: &ConstantArray) -> usize {
66        array.len
67    }
68
69    fn dtype(array: &ConstantArray) -> &DType {
70        array.scalar.dtype()
71    }
72
73    fn stats(array: &ConstantArray) -> StatsSetRef<'_> {
74        array.stats_set.to_ref(array.as_ref())
75    }
76
77    fn array_hash<H: std::hash::Hasher>(
78        array: &ConstantArray,
79        state: &mut H,
80        _precision: Precision,
81    ) {
82        array.scalar.hash(state);
83        array.len.hash(state);
84    }
85
86    fn array_eq(array: &ConstantArray, other: &ConstantArray, _precision: Precision) -> bool {
87        array.scalar == other.scalar && array.len == other.len
88    }
89
90    fn nbuffers(_array: &ConstantArray) -> usize {
91        1
92    }
93
94    fn buffer(array: &ConstantArray, idx: usize) -> BufferHandle {
95        match idx {
96            0 => BufferHandle::new_host(
97                ScalarValue::to_proto_bytes::<ByteBufferMut>(array.scalar.value()).freeze(),
98            ),
99            _ => vortex_panic!("ConstantArray buffer index {idx} out of bounds"),
100        }
101    }
102
103    fn buffer_name(_array: &ConstantArray, idx: usize) -> Option<String> {
104        match idx {
105            0 => Some("scalar".to_string()),
106            _ => None,
107        }
108    }
109
110    fn nchildren(_array: &ConstantArray) -> usize {
111        0
112    }
113
114    fn child(_array: &ConstantArray, idx: usize) -> ArrayRef {
115        vortex_panic!("ConstantArray child index {idx} out of bounds")
116    }
117
118    fn child_name(_array: &ConstantArray, idx: usize) -> String {
119        vortex_panic!("ConstantArray child_name index {idx} out of bounds")
120    }
121
122    fn metadata(array: &ConstantArray) -> VortexResult<Self::Metadata> {
123        Ok(array.scalar().clone())
124    }
125
126    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
127        // HACK: Because the scalar is stored in the buffers, we do not need to serialize the
128        // metadata at all.
129        Ok(Some(vec![]))
130    }
131
132    fn deserialize(
133        _bytes: &[u8],
134        dtype: &DType,
135        _len: usize,
136        buffers: &[BufferHandle],
137        session: &VortexSession,
138    ) -> VortexResult<Self::Metadata> {
139        vortex_ensure!(
140            buffers.len() == 1,
141            "Expected 1 buffer, got {}",
142            buffers.len()
143        );
144
145        let buffer = buffers[0].clone().try_to_host_sync()?;
146        let bytes: &[u8] = buffer.as_ref();
147
148        let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype, session)?;
149        let scalar = Scalar::try_new(dtype.clone(), scalar_value)?;
150
151        Ok(scalar)
152    }
153
154    fn build(
155        _dtype: &DType,
156        len: usize,
157        metadata: &Self::Metadata,
158        _buffers: &[BufferHandle],
159        _children: &dyn ArrayChildren,
160    ) -> VortexResult<ConstantArray> {
161        Ok(ConstantArray::new(metadata.clone(), len))
162    }
163
164    fn with_children(_array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
165        vortex_ensure!(
166            children.is_empty(),
167            "ConstantArray has no children, got {}",
168            children.len()
169        );
170        Ok(())
171    }
172
173    fn reduce_parent(
174        array: &Self::Array,
175        parent: &ArrayRef,
176        child_idx: usize,
177    ) -> VortexResult<Option<ArrayRef>> {
178        PARENT_RULES.evaluate(array, parent, child_idx)
179    }
180
181    fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionStep> {
182        Ok(ExecutionStep::Done(
183            constant_canonicalize(array)?.into_array(),
184        ))
185    }
186
187    fn append_to_builder(
188        array: &ConstantArray,
189        builder: &mut dyn ArrayBuilder,
190        ctx: &mut ExecutionCtx,
191    ) -> VortexResult<()> {
192        let n = array.len();
193        let scalar = array.scalar();
194
195        match array.dtype() {
196            DType::Null => append_value_or_nulls::<NullBuilder>(builder, true, n, |_| {}),
197            DType::Bool(_) => {
198                append_value_or_nulls::<BoolBuilder>(builder, scalar.is_null(), n, |b| {
199                    b.append_values(
200                        scalar
201                            .as_bool()
202                            .value()
203                            .vortex_expect("non-null bool scalar must have a value"),
204                        n,
205                    );
206                })
207            }
208            DType::Primitive(ptype, _) => {
209                match_each_native_ptype!(ptype, |P| {
210                    append_value_or_nulls::<PrimitiveBuilder<P>>(
211                        builder,
212                        scalar.is_null(),
213                        n,
214                        |b| {
215                            let value = P::try_from(scalar)
216                                .vortex_expect("Couldn't unwrap constant scalar to primitive");
217                            b.append_n_values(value, n);
218                        },
219                    );
220                });
221            }
222            DType::Decimal(..) => {
223                append_value_or_nulls::<DecimalBuilder>(builder, scalar.is_null(), n, |b| {
224                    let value = scalar
225                        .as_decimal()
226                        .decimal_value()
227                        .vortex_expect("non-null decimal scalar must have a value");
228                    match_each_decimal_value!(value, |v| { b.append_n_values(v, n) });
229                });
230            }
231            DType::Utf8(_) => {
232                append_value_or_nulls::<VarBinViewBuilder>(builder, scalar.is_null(), n, |b| {
233                    let typed = scalar.as_utf8();
234                    let value = typed
235                        .value()
236                        .vortex_expect("non-null utf8 scalar must have a value");
237                    b.append_n_values(value.as_bytes(), n);
238                });
239            }
240            DType::Binary(_) => {
241                append_value_or_nulls::<VarBinViewBuilder>(builder, scalar.is_null(), n, |b| {
242                    let typed = scalar.as_binary();
243                    let value = typed
244                        .value()
245                        .vortex_expect("non-null binary scalar must have a value");
246                    b.append_n_values(value, n);
247                });
248            }
249            // TODO: add fast paths for DType::Struct, DType::List, DType::FixedSizeList, DType::Extension.
250            _ => {
251                let canonical = array
252                    .clone()
253                    .into_array()
254                    .execute::<Canonical>(ctx)?
255                    .into_array();
256                builder.extend_from_array(&canonical);
257            }
258        }
259
260        Ok(())
261    }
262}
263
264/// Downcasts `builder` to `B`, then either appends `n` nulls or calls `fill` with the typed
265/// builder depending on `is_null`.
266///
267/// `is_null` must only be `true` when the builder is nullable.
268fn append_value_or_nulls<B: ArrayBuilder + 'static>(
269    builder: &mut dyn ArrayBuilder,
270    is_null: bool,
271    n: usize,
272    fill: impl FnOnce(&mut B),
273) {
274    let b = builder
275        .as_any_mut()
276        .downcast_mut::<B>()
277        .vortex_expect("builder dtype must match array dtype");
278    if is_null {
279        // SAFETY: is_null=true only when the scalar (and thus the builder) is nullable.
280        unsafe { b.append_nulls_unchecked(n) };
281    } else {
282        fill(b);
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use rstest::rstest;
289    use vortex_session::VortexSession;
290
291    use crate::ExecutionCtx;
292    use crate::IntoArray;
293    use crate::arrays::ConstantArray;
294    use crate::arrays::constant::vtable::canonical::constant_canonicalize;
295    use crate::assert_arrays_eq;
296    use crate::builders::builder_with_capacity;
297    use crate::dtype::DType;
298    use crate::dtype::Nullability;
299    use crate::dtype::PType;
300    use crate::dtype::StructFields;
301    use crate::scalar::Scalar;
302
303    fn ctx() -> ExecutionCtx {
304        ExecutionCtx::new(VortexSession::empty())
305    }
306
307    /// Appends `array` into a fresh builder and asserts the result matches `constant_canonicalize`.
308    fn assert_append_matches_canonical(array: ConstantArray) -> vortex_error::VortexResult<()> {
309        let expected = constant_canonicalize(&array)?.into_array();
310        let mut builder = builder_with_capacity(array.dtype(), array.len());
311        array
312            .into_array()
313            .append_to_builder(builder.as_mut(), &mut ctx())?;
314        let result = builder.finish();
315        assert_arrays_eq!(&result, &expected);
316        Ok(())
317    }
318
319    #[test]
320    fn test_null_constant_append() -> vortex_error::VortexResult<()> {
321        assert_append_matches_canonical(ConstantArray::new(Scalar::null(DType::Null), 5))
322    }
323
324    #[rstest]
325    #[case::bool_true(true, 5)]
326    #[case::bool_false(false, 3)]
327    fn test_bool_constant_append(
328        #[case] value: bool,
329        #[case] n: usize,
330    ) -> vortex_error::VortexResult<()> {
331        assert_append_matches_canonical(ConstantArray::new(
332            Scalar::bool(value, Nullability::NonNullable),
333            n,
334        ))
335    }
336
337    #[test]
338    fn test_bool_null_constant_append() -> vortex_error::VortexResult<()> {
339        assert_append_matches_canonical(ConstantArray::new(
340            Scalar::null(DType::Bool(Nullability::Nullable)),
341            4,
342        ))
343    }
344
345    #[rstest]
346    #[case::i32(Scalar::primitive(42i32, Nullability::NonNullable), 5)]
347    #[case::u8(Scalar::primitive(7u8, Nullability::NonNullable), 3)]
348    #[case::f64(Scalar::primitive(1.5f64, Nullability::NonNullable), 4)]
349    #[case::i32_null(Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)), 3)]
350    fn test_primitive_constant_append(
351        #[case] scalar: Scalar,
352        #[case] n: usize,
353    ) -> vortex_error::VortexResult<()> {
354        assert_append_matches_canonical(ConstantArray::new(scalar, n))
355    }
356
357    #[rstest]
358    #[case::utf8_inline("hi", 5)] // ≤12 bytes: inlined in BinaryView
359    #[case::utf8_noninline("hello world!!", 5)] // >12 bytes: requires buffer block
360    #[case::utf8_empty("", 3)]
361    #[case::utf8_n_zero("hello world!!", 0)] // n=0 with non-inline: must not write orphaned bytes
362    fn test_utf8_constant_append(
363        #[case] value: &str,
364        #[case] n: usize,
365    ) -> vortex_error::VortexResult<()> {
366        assert_append_matches_canonical(ConstantArray::new(
367            Scalar::utf8(value, Nullability::NonNullable),
368            n,
369        ))
370    }
371
372    #[test]
373    fn test_utf8_null_constant_append() -> vortex_error::VortexResult<()> {
374        assert_append_matches_canonical(ConstantArray::new(
375            Scalar::null(DType::Utf8(Nullability::Nullable)),
376            4,
377        ))
378    }
379
380    #[rstest]
381    #[case::binary_inline(vec![1u8, 2, 3], 5)] // ≤12 bytes: inlined
382    #[case::binary_noninline(vec![0u8; 13], 5)] // >12 bytes: buffer block
383    fn test_binary_constant_append(
384        #[case] value: Vec<u8>,
385        #[case] n: usize,
386    ) -> vortex_error::VortexResult<()> {
387        assert_append_matches_canonical(ConstantArray::new(
388            Scalar::binary(value, Nullability::NonNullable),
389            n,
390        ))
391    }
392
393    #[test]
394    fn test_binary_null_constant_append() -> vortex_error::VortexResult<()> {
395        assert_append_matches_canonical(ConstantArray::new(
396            Scalar::null(DType::Binary(Nullability::Nullable)),
397            4,
398        ))
399    }
400
401    #[test]
402    fn test_struct_constant_append() -> vortex_error::VortexResult<()> {
403        let fields = StructFields::new(
404            ["x", "y"].into(),
405            vec![
406                DType::Primitive(PType::I32, Nullability::NonNullable),
407                DType::Utf8(Nullability::NonNullable),
408            ],
409        );
410        let scalar = Scalar::struct_(
411            DType::Struct(fields, Nullability::NonNullable),
412            [
413                Scalar::primitive(42i32, Nullability::NonNullable),
414                Scalar::utf8("hi", Nullability::NonNullable),
415            ],
416        );
417        assert_append_matches_canonical(ConstantArray::new(scalar, 3))
418    }
419
420    #[test]
421    fn test_null_struct_constant_append() -> vortex_error::VortexResult<()> {
422        let fields = StructFields::new(
423            ["x"].into(),
424            vec![DType::Primitive(PType::I32, Nullability::Nullable)],
425        );
426        let dtype = DType::Struct(fields, Nullability::Nullable);
427        assert_append_matches_canonical(ConstantArray::new(Scalar::null(dtype), 4))
428    }
429}