Skip to main content

vortex_array/arrays/constant/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use vortex_buffer::ByteBufferMut;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_panic;
13use vortex_session::VortexSession;
14
15use crate::ArrayRef;
16use crate::ExecutionCtx;
17use crate::ExecutionResult;
18use crate::IntoArray;
19use crate::Precision;
20use crate::arrays::ConstantArray;
21use crate::arrays::constant::compute::rules::PARENT_RULES;
22use crate::arrays::constant::vtable::canonical::constant_canonicalize;
23use crate::buffer::BufferHandle;
24use crate::builders::ArrayBuilder;
25use crate::builders::BoolBuilder;
26use crate::builders::DecimalBuilder;
27use crate::builders::NullBuilder;
28use crate::builders::PrimitiveBuilder;
29use crate::builders::VarBinViewBuilder;
30use crate::canonical::Canonical;
31use crate::dtype::DType;
32use crate::match_each_decimal_value;
33use crate::match_each_native_ptype;
34use crate::scalar::DecimalValue;
35use crate::scalar::Scalar;
36use crate::scalar::ScalarValue;
37use crate::serde::ArrayChildren;
38use crate::stats::StatsSetRef;
39use crate::vtable;
40use crate::vtable::Array;
41use crate::vtable::ArrayId;
42use crate::vtable::VTable;
43pub(crate) mod canonical;
44mod operations;
45mod validity;
46
47vtable!(Constant);
48
49#[derive(Clone, Debug)]
50pub struct Constant;
51
52impl Constant {
53    pub const ID: ArrayId = ArrayId::new_ref("vortex.constant");
54}
55
56impl VTable for Constant {
57    type Array = ConstantArray;
58
59    type Metadata = Scalar;
60    type OperationsVTable = Self;
61    type ValidityVTable = Self;
62
63    fn vtable(_array: &Self::Array) -> &Self {
64        &Constant
65    }
66
67    fn id(&self) -> ArrayId {
68        Self::ID
69    }
70
71    fn len(array: &ConstantArray) -> usize {
72        array.len
73    }
74
75    fn dtype(array: &ConstantArray) -> &DType {
76        array.scalar.dtype()
77    }
78
79    fn stats(array: &ConstantArray) -> StatsSetRef<'_> {
80        array.stats_set.to_ref(array.as_ref())
81    }
82
83    fn array_hash<H: std::hash::Hasher>(
84        array: &ConstantArray,
85        state: &mut H,
86        _precision: Precision,
87    ) {
88        array.scalar.hash(state);
89        array.len.hash(state);
90    }
91
92    fn array_eq(array: &ConstantArray, other: &ConstantArray, _precision: Precision) -> bool {
93        array.scalar == other.scalar && array.len == other.len
94    }
95
96    fn nbuffers(_array: &ConstantArray) -> usize {
97        1
98    }
99
100    fn buffer(array: &ConstantArray, idx: usize) -> BufferHandle {
101        match idx {
102            0 => BufferHandle::new_host(
103                ScalarValue::to_proto_bytes::<ByteBufferMut>(array.scalar.value()).freeze(),
104            ),
105            _ => vortex_panic!("ConstantArray buffer index {idx} out of bounds"),
106        }
107    }
108
109    fn buffer_name(_array: &ConstantArray, idx: usize) -> Option<String> {
110        match idx {
111            0 => Some("scalar".to_string()),
112            _ => None,
113        }
114    }
115
116    fn nchildren(_array: &ConstantArray) -> usize {
117        0
118    }
119
120    fn child(_array: &ConstantArray, idx: usize) -> ArrayRef {
121        vortex_panic!("ConstantArray child index {idx} out of bounds")
122    }
123
124    fn child_name(_array: &ConstantArray, idx: usize) -> String {
125        vortex_panic!("ConstantArray child_name index {idx} out of bounds")
126    }
127
128    fn metadata(array: &ConstantArray) -> VortexResult<Self::Metadata> {
129        Ok(array.scalar().clone())
130    }
131
132    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
133        // HACK: Because the scalar is stored in the buffers, we do not need to serialize the
134        // metadata at all.
135        Ok(Some(vec![]))
136    }
137
138    fn deserialize(
139        _bytes: &[u8],
140        dtype: &DType,
141        _len: usize,
142        buffers: &[BufferHandle],
143        session: &VortexSession,
144    ) -> VortexResult<Self::Metadata> {
145        vortex_ensure!(
146            buffers.len() == 1,
147            "Expected 1 buffer, got {}",
148            buffers.len()
149        );
150
151        let buffer = buffers[0].clone().try_to_host_sync()?;
152        let bytes: &[u8] = buffer.as_ref();
153
154        let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype, session)?;
155        let scalar = Scalar::try_new(dtype.clone(), scalar_value)?;
156
157        Ok(scalar)
158    }
159
160    fn build(
161        _dtype: &DType,
162        len: usize,
163        metadata: &Self::Metadata,
164        _buffers: &[BufferHandle],
165        _children: &dyn ArrayChildren,
166    ) -> VortexResult<ConstantArray> {
167        Ok(ConstantArray::new(metadata.clone(), len))
168    }
169
170    fn with_children(_array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
171        vortex_ensure!(
172            children.is_empty(),
173            "ConstantArray has no children, got {}",
174            children.len()
175        );
176        Ok(())
177    }
178
179    fn reduce_parent(
180        array: &Array<Self>,
181        parent: &ArrayRef,
182        child_idx: usize,
183    ) -> VortexResult<Option<ArrayRef>> {
184        PARENT_RULES.evaluate(array, parent, child_idx)
185    }
186
187    fn execute(array: Arc<Array<Self>>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
188        Ok(ExecutionResult::done(
189            constant_canonicalize(&array)?.into_array(),
190        ))
191    }
192
193    fn append_to_builder(
194        array: &ConstantArray,
195        builder: &mut dyn ArrayBuilder,
196        ctx: &mut ExecutionCtx,
197    ) -> VortexResult<()> {
198        let n = array.len();
199        let scalar = array.scalar();
200
201        match array.dtype() {
202            DType::Null => append_value_or_nulls::<NullBuilder>(builder, true, n, |_| {}),
203            DType::Bool(_) => {
204                append_value_or_nulls::<BoolBuilder>(builder, scalar.is_null(), n, |b| {
205                    b.append_values(
206                        scalar
207                            .as_bool()
208                            .value()
209                            .vortex_expect("non-null bool scalar must have a value"),
210                        n,
211                    );
212                })
213            }
214            DType::Primitive(ptype, _) => {
215                match_each_native_ptype!(ptype, |P| {
216                    append_value_or_nulls::<PrimitiveBuilder<P>>(
217                        builder,
218                        scalar.is_null(),
219                        n,
220                        |b| {
221                            let value = P::try_from(scalar)
222                                .vortex_expect("Couldn't unwrap constant scalar to primitive");
223                            b.append_n_values(value, n);
224                        },
225                    );
226                });
227            }
228            DType::Decimal(..) => {
229                append_value_or_nulls::<DecimalBuilder>(builder, scalar.is_null(), n, |b| {
230                    let value = scalar
231                        .as_decimal()
232                        .decimal_value()
233                        .vortex_expect("non-null decimal scalar must have a value");
234                    match_each_decimal_value!(value, |v| { b.append_n_values(v, n) });
235                });
236            }
237            DType::Utf8(_) => {
238                append_value_or_nulls::<VarBinViewBuilder>(builder, scalar.is_null(), n, |b| {
239                    let typed = scalar.as_utf8();
240                    let value = typed
241                        .value()
242                        .vortex_expect("non-null utf8 scalar must have a value");
243                    b.append_n_values(value.as_bytes(), n);
244                });
245            }
246            DType::Binary(_) => {
247                append_value_or_nulls::<VarBinViewBuilder>(builder, scalar.is_null(), n, |b| {
248                    let typed = scalar.as_binary();
249                    let value = typed
250                        .value()
251                        .vortex_expect("non-null binary scalar must have a value");
252                    b.append_n_values(value, n);
253                });
254            }
255            // TODO: add fast paths for DType::Struct, DType::List, DType::FixedSizeList, DType::Extension.
256            _ => {
257                let canonical = array
258                    .clone()
259                    .into_array()
260                    .execute::<Canonical>(ctx)?
261                    .into_array();
262                builder.extend_from_array(&canonical);
263            }
264        }
265
266        Ok(())
267    }
268}
269
270/// Downcasts `builder` to `B`, then either appends `n` nulls or calls `fill` with the typed
271/// builder depending on `is_null`.
272///
273/// `is_null` must only be `true` when the builder is nullable.
274fn append_value_or_nulls<B: ArrayBuilder + 'static>(
275    builder: &mut dyn ArrayBuilder,
276    is_null: bool,
277    n: usize,
278    fill: impl FnOnce(&mut B),
279) {
280    let b = builder
281        .as_any_mut()
282        .downcast_mut::<B>()
283        .vortex_expect("builder dtype must match array dtype");
284    if is_null {
285        // SAFETY: is_null=true only when the scalar (and thus the builder) is nullable.
286        unsafe { b.append_nulls_unchecked(n) };
287    } else {
288        fill(b);
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use rstest::rstest;
295    use vortex_session::VortexSession;
296
297    use crate::ExecutionCtx;
298    use crate::IntoArray;
299    use crate::arrays::ConstantArray;
300    use crate::arrays::constant::vtable::canonical::constant_canonicalize;
301    use crate::assert_arrays_eq;
302    use crate::builders::builder_with_capacity;
303    use crate::dtype::DType;
304    use crate::dtype::Nullability;
305    use crate::dtype::PType;
306    use crate::dtype::StructFields;
307    use crate::scalar::Scalar;
308
309    fn ctx() -> ExecutionCtx {
310        ExecutionCtx::new(VortexSession::empty())
311    }
312
313    /// Appends `array` into a fresh builder and asserts the result matches `constant_canonicalize`.
314    fn assert_append_matches_canonical(array: ConstantArray) -> vortex_error::VortexResult<()> {
315        let expected = constant_canonicalize(&array)?.into_array();
316        let mut builder = builder_with_capacity(array.dtype(), array.len());
317        array
318            .into_array()
319            .append_to_builder(builder.as_mut(), &mut ctx())?;
320        let result = builder.finish();
321        assert_arrays_eq!(&result, &expected);
322        Ok(())
323    }
324
325    #[test]
326    fn test_null_constant_append() -> vortex_error::VortexResult<()> {
327        assert_append_matches_canonical(ConstantArray::new(Scalar::null(DType::Null), 5))
328    }
329
330    #[rstest]
331    #[case::bool_true(true, 5)]
332    #[case::bool_false(false, 3)]
333    fn test_bool_constant_append(
334        #[case] value: bool,
335        #[case] n: usize,
336    ) -> vortex_error::VortexResult<()> {
337        assert_append_matches_canonical(ConstantArray::new(
338            Scalar::bool(value, Nullability::NonNullable),
339            n,
340        ))
341    }
342
343    #[test]
344    fn test_bool_null_constant_append() -> vortex_error::VortexResult<()> {
345        assert_append_matches_canonical(ConstantArray::new(
346            Scalar::null(DType::Bool(Nullability::Nullable)),
347            4,
348        ))
349    }
350
351    #[rstest]
352    #[case::i32(Scalar::primitive(42i32, Nullability::NonNullable), 5)]
353    #[case::u8(Scalar::primitive(7u8, Nullability::NonNullable), 3)]
354    #[case::f64(Scalar::primitive(1.5f64, Nullability::NonNullable), 4)]
355    #[case::i32_null(Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)), 3)]
356    fn test_primitive_constant_append(
357        #[case] scalar: Scalar,
358        #[case] n: usize,
359    ) -> vortex_error::VortexResult<()> {
360        assert_append_matches_canonical(ConstantArray::new(scalar, n))
361    }
362
363    #[rstest]
364    #[case::utf8_inline("hi", 5)] // ≤12 bytes: inlined in BinaryView
365    #[case::utf8_noninline("hello world!!", 5)] // >12 bytes: requires buffer block
366    #[case::utf8_empty("", 3)]
367    #[case::utf8_n_zero("hello world!!", 0)] // n=0 with non-inline: must not write orphaned bytes
368    fn test_utf8_constant_append(
369        #[case] value: &str,
370        #[case] n: usize,
371    ) -> vortex_error::VortexResult<()> {
372        assert_append_matches_canonical(ConstantArray::new(
373            Scalar::utf8(value, Nullability::NonNullable),
374            n,
375        ))
376    }
377
378    #[test]
379    fn test_utf8_null_constant_append() -> vortex_error::VortexResult<()> {
380        assert_append_matches_canonical(ConstantArray::new(
381            Scalar::null(DType::Utf8(Nullability::Nullable)),
382            4,
383        ))
384    }
385
386    #[rstest]
387    #[case::binary_inline(vec![1u8, 2, 3], 5)] // ≤12 bytes: inlined
388    #[case::binary_noninline(vec![0u8; 13], 5)] // >12 bytes: buffer block
389    fn test_binary_constant_append(
390        #[case] value: Vec<u8>,
391        #[case] n: usize,
392    ) -> vortex_error::VortexResult<()> {
393        assert_append_matches_canonical(ConstantArray::new(
394            Scalar::binary(value, Nullability::NonNullable),
395            n,
396        ))
397    }
398
399    #[test]
400    fn test_binary_null_constant_append() -> vortex_error::VortexResult<()> {
401        assert_append_matches_canonical(ConstantArray::new(
402            Scalar::null(DType::Binary(Nullability::Nullable)),
403            4,
404        ))
405    }
406
407    #[test]
408    fn test_struct_constant_append() -> vortex_error::VortexResult<()> {
409        let fields = StructFields::new(
410            ["x", "y"].into(),
411            vec![
412                DType::Primitive(PType::I32, Nullability::NonNullable),
413                DType::Utf8(Nullability::NonNullable),
414            ],
415        );
416        let scalar = Scalar::struct_(
417            DType::Struct(fields, Nullability::NonNullable),
418            [
419                Scalar::primitive(42i32, Nullability::NonNullable),
420                Scalar::utf8("hi", Nullability::NonNullable),
421            ],
422        );
423        assert_append_matches_canonical(ConstantArray::new(scalar, 3))
424    }
425
426    #[test]
427    fn test_null_struct_constant_append() -> vortex_error::VortexResult<()> {
428        let fields = StructFields::new(
429            ["x"].into(),
430            vec![DType::Primitive(PType::I32, Nullability::Nullable)],
431        );
432        let dtype = DType::Struct(fields, Nullability::Nullable);
433        assert_append_matches_canonical(ConstantArray::new(Scalar::null(dtype), 4))
434    }
435}