Skip to main content

vortex_array/arrays/constant/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use vortex_buffer::ByteBufferMut;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_panic;
13use vortex_session::VortexSession;
14
15use crate::ArrayRef;
16use crate::ExecutionCtx;
17use crate::ExecutionResult;
18use crate::IntoArray;
19use crate::Precision;
20use crate::arrays::ConstantArray;
21use crate::arrays::constant::compute::rules::PARENT_RULES;
22use crate::arrays::constant::vtable::canonical::constant_canonicalize;
23use crate::buffer::BufferHandle;
24use crate::builders::ArrayBuilder;
25use crate::builders::BoolBuilder;
26use crate::builders::DecimalBuilder;
27use crate::builders::NullBuilder;
28use crate::builders::PrimitiveBuilder;
29use crate::builders::VarBinViewBuilder;
30use crate::canonical::Canonical;
31use crate::dtype::DType;
32use crate::match_each_decimal_value;
33use crate::match_each_native_ptype;
34use crate::scalar::DecimalValue;
35use crate::scalar::Scalar;
36use crate::scalar::ScalarValue;
37use crate::serde::ArrayChildren;
38use crate::stats::StatsSetRef;
39use crate::vtable;
40use crate::vtable::ArrayId;
41use crate::vtable::VTable;
42pub(crate) mod canonical;
43mod operations;
44mod validity;
45
46vtable!(Constant);
47
48#[derive(Clone, Debug)]
49pub struct Constant;
50
51impl Constant {
52    pub const ID: ArrayId = ArrayId::new_ref("vortex.constant");
53}
54
55impl VTable for Constant {
56    type Array = ConstantArray;
57
58    type Metadata = Scalar;
59    type OperationsVTable = Self;
60    type ValidityVTable = Self;
61
62    fn vtable(_array: &Self::Array) -> &Self {
63        &Constant
64    }
65
66    fn id(&self) -> ArrayId {
67        Self::ID
68    }
69
70    fn len(array: &ConstantArray) -> usize {
71        array.len
72    }
73
74    fn dtype(array: &ConstantArray) -> &DType {
75        array.scalar.dtype()
76    }
77
78    fn stats(array: &ConstantArray) -> StatsSetRef<'_> {
79        array.stats_set.to_ref(array.as_ref())
80    }
81
82    fn array_hash<H: std::hash::Hasher>(
83        array: &ConstantArray,
84        state: &mut H,
85        _precision: Precision,
86    ) {
87        array.scalar.hash(state);
88        array.len.hash(state);
89    }
90
91    fn array_eq(array: &ConstantArray, other: &ConstantArray, _precision: Precision) -> bool {
92        array.scalar == other.scalar && array.len == other.len
93    }
94
95    fn nbuffers(_array: &ConstantArray) -> usize {
96        1
97    }
98
99    fn buffer(array: &ConstantArray, idx: usize) -> BufferHandle {
100        match idx {
101            0 => BufferHandle::new_host(
102                ScalarValue::to_proto_bytes::<ByteBufferMut>(array.scalar.value()).freeze(),
103            ),
104            _ => vortex_panic!("ConstantArray buffer index {idx} out of bounds"),
105        }
106    }
107
108    fn buffer_name(_array: &ConstantArray, idx: usize) -> Option<String> {
109        match idx {
110            0 => Some("scalar".to_string()),
111            _ => None,
112        }
113    }
114
115    fn nchildren(_array: &ConstantArray) -> usize {
116        0
117    }
118
119    fn child(_array: &ConstantArray, idx: usize) -> ArrayRef {
120        vortex_panic!("ConstantArray child index {idx} out of bounds")
121    }
122
123    fn child_name(_array: &ConstantArray, idx: usize) -> String {
124        vortex_panic!("ConstantArray child_name index {idx} out of bounds")
125    }
126
127    fn metadata(array: &ConstantArray) -> VortexResult<Self::Metadata> {
128        Ok(array.scalar().clone())
129    }
130
131    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
132        // HACK: Because the scalar is stored in the buffers, we do not need to serialize the
133        // metadata at all.
134        Ok(Some(vec![]))
135    }
136
137    fn deserialize(
138        _bytes: &[u8],
139        dtype: &DType,
140        _len: usize,
141        buffers: &[BufferHandle],
142        session: &VortexSession,
143    ) -> VortexResult<Self::Metadata> {
144        vortex_ensure!(
145            buffers.len() == 1,
146            "Expected 1 buffer, got {}",
147            buffers.len()
148        );
149
150        let buffer = buffers[0].clone().try_to_host_sync()?;
151        let bytes: &[u8] = buffer.as_ref();
152
153        let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype, session)?;
154        let scalar = Scalar::try_new(dtype.clone(), scalar_value)?;
155
156        Ok(scalar)
157    }
158
159    fn build(
160        _dtype: &DType,
161        len: usize,
162        metadata: &Self::Metadata,
163        _buffers: &[BufferHandle],
164        _children: &dyn ArrayChildren,
165    ) -> VortexResult<ConstantArray> {
166        Ok(ConstantArray::new(metadata.clone(), len))
167    }
168
169    fn with_children(_array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
170        vortex_ensure!(
171            children.is_empty(),
172            "ConstantArray has no children, got {}",
173            children.len()
174        );
175        Ok(())
176    }
177
178    fn reduce_parent(
179        array: &Self::Array,
180        parent: &ArrayRef,
181        child_idx: usize,
182    ) -> VortexResult<Option<ArrayRef>> {
183        PARENT_RULES.evaluate(array, parent, child_idx)
184    }
185
186    fn execute(array: Arc<Self::Array>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
187        Ok(ExecutionResult::done(
188            constant_canonicalize(&array)?.into_array(),
189        ))
190    }
191
192    fn append_to_builder(
193        array: &ConstantArray,
194        builder: &mut dyn ArrayBuilder,
195        ctx: &mut ExecutionCtx,
196    ) -> VortexResult<()> {
197        let n = array.len();
198        let scalar = array.scalar();
199
200        match array.dtype() {
201            DType::Null => append_value_or_nulls::<NullBuilder>(builder, true, n, |_| {}),
202            DType::Bool(_) => {
203                append_value_or_nulls::<BoolBuilder>(builder, scalar.is_null(), n, |b| {
204                    b.append_values(
205                        scalar
206                            .as_bool()
207                            .value()
208                            .vortex_expect("non-null bool scalar must have a value"),
209                        n,
210                    );
211                })
212            }
213            DType::Primitive(ptype, _) => {
214                match_each_native_ptype!(ptype, |P| {
215                    append_value_or_nulls::<PrimitiveBuilder<P>>(
216                        builder,
217                        scalar.is_null(),
218                        n,
219                        |b| {
220                            let value = P::try_from(scalar)
221                                .vortex_expect("Couldn't unwrap constant scalar to primitive");
222                            b.append_n_values(value, n);
223                        },
224                    );
225                });
226            }
227            DType::Decimal(..) => {
228                append_value_or_nulls::<DecimalBuilder>(builder, scalar.is_null(), n, |b| {
229                    let value = scalar
230                        .as_decimal()
231                        .decimal_value()
232                        .vortex_expect("non-null decimal scalar must have a value");
233                    match_each_decimal_value!(value, |v| { b.append_n_values(v, n) });
234                });
235            }
236            DType::Utf8(_) => {
237                append_value_or_nulls::<VarBinViewBuilder>(builder, scalar.is_null(), n, |b| {
238                    let typed = scalar.as_utf8();
239                    let value = typed
240                        .value()
241                        .vortex_expect("non-null utf8 scalar must have a value");
242                    b.append_n_values(value.as_bytes(), n);
243                });
244            }
245            DType::Binary(_) => {
246                append_value_or_nulls::<VarBinViewBuilder>(builder, scalar.is_null(), n, |b| {
247                    let typed = scalar.as_binary();
248                    let value = typed
249                        .value()
250                        .vortex_expect("non-null binary scalar must have a value");
251                    b.append_n_values(value, n);
252                });
253            }
254            // TODO: add fast paths for DType::Struct, DType::List, DType::FixedSizeList, DType::Extension.
255            _ => {
256                let canonical = array
257                    .clone()
258                    .into_array()
259                    .execute::<Canonical>(ctx)?
260                    .into_array();
261                builder.extend_from_array(&canonical);
262            }
263        }
264
265        Ok(())
266    }
267}
268
269/// Downcasts `builder` to `B`, then either appends `n` nulls or calls `fill` with the typed
270/// builder depending on `is_null`.
271///
272/// `is_null` must only be `true` when the builder is nullable.
273fn append_value_or_nulls<B: ArrayBuilder + 'static>(
274    builder: &mut dyn ArrayBuilder,
275    is_null: bool,
276    n: usize,
277    fill: impl FnOnce(&mut B),
278) {
279    let b = builder
280        .as_any_mut()
281        .downcast_mut::<B>()
282        .vortex_expect("builder dtype must match array dtype");
283    if is_null {
284        // SAFETY: is_null=true only when the scalar (and thus the builder) is nullable.
285        unsafe { b.append_nulls_unchecked(n) };
286    } else {
287        fill(b);
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use rstest::rstest;
294    use vortex_session::VortexSession;
295
296    use crate::ExecutionCtx;
297    use crate::IntoArray;
298    use crate::arrays::ConstantArray;
299    use crate::arrays::constant::vtable::canonical::constant_canonicalize;
300    use crate::assert_arrays_eq;
301    use crate::builders::builder_with_capacity;
302    use crate::dtype::DType;
303    use crate::dtype::Nullability;
304    use crate::dtype::PType;
305    use crate::dtype::StructFields;
306    use crate::scalar::Scalar;
307
308    fn ctx() -> ExecutionCtx {
309        ExecutionCtx::new(VortexSession::empty())
310    }
311
312    /// Appends `array` into a fresh builder and asserts the result matches `constant_canonicalize`.
313    fn assert_append_matches_canonical(array: ConstantArray) -> vortex_error::VortexResult<()> {
314        let expected = constant_canonicalize(&array)?.into_array();
315        let mut builder = builder_with_capacity(array.dtype(), array.len());
316        array
317            .into_array()
318            .append_to_builder(builder.as_mut(), &mut ctx())?;
319        let result = builder.finish();
320        assert_arrays_eq!(&result, &expected);
321        Ok(())
322    }
323
324    #[test]
325    fn test_null_constant_append() -> vortex_error::VortexResult<()> {
326        assert_append_matches_canonical(ConstantArray::new(Scalar::null(DType::Null), 5))
327    }
328
329    #[rstest]
330    #[case::bool_true(true, 5)]
331    #[case::bool_false(false, 3)]
332    fn test_bool_constant_append(
333        #[case] value: bool,
334        #[case] n: usize,
335    ) -> vortex_error::VortexResult<()> {
336        assert_append_matches_canonical(ConstantArray::new(
337            Scalar::bool(value, Nullability::NonNullable),
338            n,
339        ))
340    }
341
342    #[test]
343    fn test_bool_null_constant_append() -> vortex_error::VortexResult<()> {
344        assert_append_matches_canonical(ConstantArray::new(
345            Scalar::null(DType::Bool(Nullability::Nullable)),
346            4,
347        ))
348    }
349
350    #[rstest]
351    #[case::i32(Scalar::primitive(42i32, Nullability::NonNullable), 5)]
352    #[case::u8(Scalar::primitive(7u8, Nullability::NonNullable), 3)]
353    #[case::f64(Scalar::primitive(1.5f64, Nullability::NonNullable), 4)]
354    #[case::i32_null(Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)), 3)]
355    fn test_primitive_constant_append(
356        #[case] scalar: Scalar,
357        #[case] n: usize,
358    ) -> vortex_error::VortexResult<()> {
359        assert_append_matches_canonical(ConstantArray::new(scalar, n))
360    }
361
362    #[rstest]
363    #[case::utf8_inline("hi", 5)] // ≤12 bytes: inlined in BinaryView
364    #[case::utf8_noninline("hello world!!", 5)] // >12 bytes: requires buffer block
365    #[case::utf8_empty("", 3)]
366    #[case::utf8_n_zero("hello world!!", 0)] // n=0 with non-inline: must not write orphaned bytes
367    fn test_utf8_constant_append(
368        #[case] value: &str,
369        #[case] n: usize,
370    ) -> vortex_error::VortexResult<()> {
371        assert_append_matches_canonical(ConstantArray::new(
372            Scalar::utf8(value, Nullability::NonNullable),
373            n,
374        ))
375    }
376
377    #[test]
378    fn test_utf8_null_constant_append() -> vortex_error::VortexResult<()> {
379        assert_append_matches_canonical(ConstantArray::new(
380            Scalar::null(DType::Utf8(Nullability::Nullable)),
381            4,
382        ))
383    }
384
385    #[rstest]
386    #[case::binary_inline(vec![1u8, 2, 3], 5)] // ≤12 bytes: inlined
387    #[case::binary_noninline(vec![0u8; 13], 5)] // >12 bytes: buffer block
388    fn test_binary_constant_append(
389        #[case] value: Vec<u8>,
390        #[case] n: usize,
391    ) -> vortex_error::VortexResult<()> {
392        assert_append_matches_canonical(ConstantArray::new(
393            Scalar::binary(value, Nullability::NonNullable),
394            n,
395        ))
396    }
397
398    #[test]
399    fn test_binary_null_constant_append() -> vortex_error::VortexResult<()> {
400        assert_append_matches_canonical(ConstantArray::new(
401            Scalar::null(DType::Binary(Nullability::Nullable)),
402            4,
403        ))
404    }
405
406    #[test]
407    fn test_struct_constant_append() -> vortex_error::VortexResult<()> {
408        let fields = StructFields::new(
409            ["x", "y"].into(),
410            vec![
411                DType::Primitive(PType::I32, Nullability::NonNullable),
412                DType::Utf8(Nullability::NonNullable),
413            ],
414        );
415        let scalar = Scalar::struct_(
416            DType::Struct(fields, Nullability::NonNullable),
417            [
418                Scalar::primitive(42i32, Nullability::NonNullable),
419                Scalar::utf8("hi", Nullability::NonNullable),
420            ],
421        );
422        assert_append_matches_canonical(ConstantArray::new(scalar, 3))
423    }
424
425    #[test]
426    fn test_null_struct_constant_append() -> vortex_error::VortexResult<()> {
427        let fields = StructFields::new(
428            ["x"].into(),
429            vec![DType::Primitive(PType::I32, Nullability::Nullable)],
430        );
431        let dtype = DType::Struct(fields, Nullability::Nullable);
432        assert_append_matches_canonical(ConstantArray::new(Scalar::null(dtype), 4))
433    }
434}