Skip to main content

vortex_array/arrays/scalar_fn/vtable/
operations.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::Array;
7use crate::IntoArray;
8use crate::LEGACY_SESSION;
9use crate::VortexSessionExecute;
10use crate::arrays::ConstantArray;
11use crate::arrays::scalar_fn::array::ScalarFnArray;
12use crate::arrays::scalar_fn::vtable::ScalarFnVTable;
13use crate::columnar::Columnar;
14use crate::expr::ExecutionArgs;
15use crate::scalar::Scalar;
16use crate::vtable::OperationsVTable;
17
18impl OperationsVTable<ScalarFnVTable> for ScalarFnVTable {
19    fn scalar_at(array: &ScalarFnArray, index: usize) -> VortexResult<Scalar> {
20        let inputs: Vec<_> = array
21            .children
22            .iter()
23            .map(|child| Ok(ConstantArray::new(child.scalar_at(index)?, 1).into_array()))
24            .collect::<VortexResult<_>>()?;
25
26        let mut ctx = LEGACY_SESSION.create_execution_ctx();
27        let args = ExecutionArgs {
28            inputs,
29            row_count: 1,
30            ctx: &mut ctx,
31        };
32        let result = array.scalar_fn.execute(args)?;
33
34        let scalar = match result.execute::<Columnar>(&mut ctx)? {
35            Columnar::Canonical(arr) => {
36                tracing::info!(
37                    "Scalar function {} returned non-constant array from execution over all scalar inputs",
38                    array.scalar_fn,
39                );
40                arr.as_ref().scalar_at(0)?
41            }
42            Columnar::Constant(constant) => constant.scalar().clone(),
43        };
44
45        debug_assert_eq!(
46            scalar.dtype(),
47            &array.dtype,
48            "Scalar function {} returned dtype {:?} but expected {:?}",
49            array.scalar_fn,
50            scalar.dtype(),
51            array.dtype
52        );
53
54        Ok(scalar)
55    }
56}
57
58#[cfg(test)]
59mod tests {
60    use vortex_buffer::buffer;
61    use vortex_error::VortexResult;
62
63    use crate::IntoArray;
64    use crate::arrays::BoolArray;
65    use crate::arrays::PrimitiveArray;
66    use crate::arrays::scalar_fn::array::ScalarFnArray;
67    use crate::assert_arrays_eq;
68    use crate::expr::Operator;
69    use crate::expr::ScalarFn;
70    use crate::expr::binary::Binary;
71    use crate::validity::Validity;
72
73    #[test]
74    fn test_scalar_fn_add() -> VortexResult<()> {
75        let lhs = buffer![1i32, 2, 3].into_array();
76        let rhs = buffer![10i32, 20, 30].into_array();
77
78        let scalar_fn = ScalarFn::new(Binary, Operator::Add);
79        let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?;
80
81        let result = scalar_fn_array.to_canonical()?.into_array();
82        let expected = buffer![11i32, 22, 33].into_array();
83        assert_arrays_eq!(result, expected);
84
85        Ok(())
86    }
87
88    #[test]
89    fn test_scalar_fn_mul() -> VortexResult<()> {
90        let lhs = buffer![2i32, 3, 4].into_array();
91        let rhs = buffer![5i32, 6, 7].into_array();
92
93        let scalar_fn = ScalarFn::new(Binary, Operator::Mul);
94        let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?;
95
96        let result = scalar_fn_array.to_canonical()?.into_array();
97        let expected = buffer![10i32, 18, 28].into_array();
98        assert_arrays_eq!(result, expected);
99
100        Ok(())
101    }
102
103    #[test]
104    fn test_scalar_fn_with_nullable() -> VortexResult<()> {
105        let lhs = PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array();
106        let rhs = PrimitiveArray::new(
107            buffer![10i32, 20, 30],
108            Validity::from_iter([true, false, true]),
109        )
110        .into_array();
111
112        let scalar_fn = ScalarFn::new(Binary, Operator::Add);
113        let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?;
114
115        let result = scalar_fn_array.to_canonical()?.into_array();
116        let expected = PrimitiveArray::new(
117            buffer![11i32, 0, 33],
118            Validity::from_iter([true, false, true]),
119        )
120        .into_array();
121        assert_arrays_eq!(result, expected);
122
123        Ok(())
124    }
125
126    #[test]
127    fn test_scalar_fn_comparison() -> VortexResult<()> {
128        let lhs = buffer![1i32, 5, 3].into_array();
129        let rhs = buffer![2i32, 5, 1].into_array();
130
131        let scalar_fn = ScalarFn::new(Binary, Operator::Eq);
132        let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?;
133
134        let result = scalar_fn_array.to_canonical()?.into_array();
135        let expected = BoolArray::from_iter([false, true, false]).into_array();
136        assert_arrays_eq!(result, expected);
137
138        Ok(())
139    }
140}