vortex_array/arrays/constant/vtable/
operator.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::{
5    DType, DecimalType, PrecisionScale, match_each_decimal_value_type, match_each_native_ptype,
6};
7use vortex_error::{VortexExpect, VortexResult};
8use vortex_scalar::{BinaryScalar, BoolScalar, DecimalScalar, PrimitiveScalar, Scalar, Utf8Scalar};
9use vortex_vector::binaryview::{BinaryVectorMut, StringVectorMut};
10use vortex_vector::bool::BoolVectorMut;
11use vortex_vector::decimal::{DVectorMut, DecimalVectorMut};
12use vortex_vector::null::NullVectorMut;
13use vortex_vector::primitive::{PVectorMut, PrimitiveVectorMut};
14use vortex_vector::{VectorMut, VectorMutOps};
15
16use crate::ArrayRef;
17use crate::arrays::{ConstantArray, ConstantVTable};
18use crate::execution::{BatchKernelRef, BindCtx, kernel};
19use crate::vtable::OperatorVTable;
20
21impl OperatorVTable<ConstantVTable> for ConstantVTable {
22    fn bind(
23        array: &ConstantArray,
24        selection: Option<&ArrayRef>,
25        ctx: &mut dyn BindCtx,
26    ) -> VortexResult<BatchKernelRef> {
27        let mask = ctx.bind_selection(array.len, selection)?;
28        let scalar = array.scalar().clone();
29
30        Ok(kernel(move || {
31            // TODO(ngates): would be good to do a sum aggregation, rather than execution.
32            let mask = mask.execute()?;
33            Ok(to_vector(scalar, mask.true_count()).freeze())
34        }))
35    }
36}
37
38fn to_vector(scalar: Scalar, len: usize) -> VectorMut {
39    match scalar.dtype() {
40        DType::Null => NullVectorMut::new(len).into(),
41        DType::Bool(_) => to_vector_bool(scalar.as_bool(), len).into(),
42        DType::Primitive(..) => to_vector_primitive(scalar.as_primitive(), len).into(),
43        DType::Decimal(..) => to_vector_decimal(scalar.as_decimal(), len).into(),
44        DType::Utf8(_) => to_vector_utf8(scalar.as_utf8(), len).into(),
45        DType::Binary(_) => to_vector_binary(scalar.as_binary(), len).into(),
46        DType::List(..) => unimplemented!("List constant vectorization"),
47        DType::FixedSizeList(..) => unimplemented!("FixedSizeList constant vectorization"),
48        DType::Struct(..) => unimplemented!("Struct constant vectorization"),
49        DType::Extension(_) => to_vector(scalar.as_extension().storage(), len),
50    }
51}
52
53fn to_vector_bool(scalar: BoolScalar, len: usize) -> BoolVectorMut {
54    let mut vec = BoolVectorMut::with_capacity(len);
55    match scalar.value() {
56        Some(v) => vec.append_values(v, len),
57        None => vec.append_nulls(len),
58    }
59    vec
60}
61
62fn to_vector_primitive(scalar: PrimitiveScalar, len: usize) -> PrimitiveVectorMut {
63    match_each_native_ptype!(scalar.ptype(), |T| {
64        let mut vec = PVectorMut::<T>::with_capacity(len);
65        match scalar.typed_value::<T>() {
66            Some(v) => vec.append_values(v, len),
67            None => vec.append_nulls(len),
68        }
69        vec.into()
70    })
71}
72
73fn to_vector_decimal(scalar: DecimalScalar, len: usize) -> DecimalVectorMut {
74    let decimal_dtype = scalar
75        .dtype()
76        .as_decimal_opt()
77        .vortex_expect("Decimal scalar must have a decimal type");
78    let decimal_type = DecimalType::smallest_decimal_value_type(decimal_dtype);
79
80    match_each_decimal_value_type!(decimal_type, |D| {
81        let ps = PrecisionScale::<D>::new(decimal_dtype.precision(), decimal_dtype.scale());
82        let mut vec = DVectorMut::<D>::with_capacity(ps, len);
83        match scalar.decimal_value() {
84            Some(v) => vec
85                .try_append_n(v.cast::<D>().vortex_expect("known to fit"), len)
86                .vortex_expect("known to fit"),
87            None => vec.append_nulls(len),
88        }
89        vec.into()
90    })
91}
92
93fn to_vector_utf8(scalar: Utf8Scalar, len: usize) -> StringVectorMut {
94    let mut vec = StringVectorMut::with_capacity(len);
95    match scalar.value() {
96        Some(v) => vec.append_values(v.as_ref(), len),
97        None => vec.append_nulls(len),
98    }
99    vec
100}
101
102fn to_vector_binary(scalar: BinaryScalar, len: usize) -> BinaryVectorMut {
103    let mut vec = BinaryVectorMut::with_capacity(len);
104    match scalar.value() {
105        Some(v) => vec.append_values(v.as_ref(), len),
106        None => vec.append_nulls(len),
107    }
108    vec
109}