Skip to main content

vortex_array/arrays/scalar_fn/vtable/
operations.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::Array;
7use crate::IntoArray;
8use crate::LEGACY_SESSION;
9use crate::VortexSessionExecute;
10use crate::arrays::ConstantArray;
11use crate::arrays::scalar_fn::array::ScalarFnArray;
12use crate::arrays::scalar_fn::vtable::ScalarFnVTable;
13use crate::columnar::Columnar;
14use crate::scalar::Scalar;
15use crate::scalar_fn::VecExecutionArgs;
16use crate::vtable::OperationsVTable;
17
18impl OperationsVTable<ScalarFnVTable> for ScalarFnVTable {
19    fn scalar_at(array: &ScalarFnArray, index: usize) -> VortexResult<Scalar> {
20        let inputs: Vec<_> = array
21            .children
22            .iter()
23            .map(|child| Ok(ConstantArray::new(child.scalar_at(index)?, 1).into_array()))
24            .collect::<VortexResult<_>>()?;
25
26        let mut ctx = LEGACY_SESSION.create_execution_ctx();
27        let args = VecExecutionArgs::new(inputs, 1);
28        let result = array.scalar_fn.execute(&args, &mut ctx)?;
29
30        let scalar = match result.execute::<Columnar>(&mut ctx)? {
31            Columnar::Canonical(arr) => {
32                tracing::info!(
33                    "Scalar function {} returned non-constant array from execution over all scalar inputs",
34                    array.scalar_fn,
35                );
36                arr.as_ref().scalar_at(0)?
37            }
38            Columnar::Constant(constant) => constant.scalar().clone(),
39        };
40
41        debug_assert_eq!(
42            scalar.dtype(),
43            &array.dtype,
44            "Scalar function {} returned dtype {:?} but expected {:?}",
45            array.scalar_fn,
46            scalar.dtype(),
47            array.dtype
48        );
49
50        Ok(scalar)
51    }
52}
53
54#[cfg(test)]
55mod tests {
56    use vortex_buffer::buffer;
57    use vortex_error::VortexResult;
58
59    use crate::IntoArray;
60    use crate::arrays::BoolArray;
61    use crate::arrays::PrimitiveArray;
62    use crate::arrays::scalar_fn::array::ScalarFnArray;
63    use crate::assert_arrays_eq;
64    use crate::scalar_fn::ScalarFn;
65    use crate::scalar_fn::fns::binary::Binary;
66    use crate::scalar_fn::fns::operators::Operator;
67    use crate::validity::Validity;
68
69    #[test]
70    fn test_scalar_fn_add() -> VortexResult<()> {
71        let lhs = buffer![1i32, 2, 3].into_array();
72        let rhs = buffer![10i32, 20, 30].into_array();
73
74        let scalar_fn = ScalarFn::new(Binary, Operator::Add).erased();
75        let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?;
76
77        let result = scalar_fn_array.to_canonical()?.into_array();
78        let expected = buffer![11i32, 22, 33].into_array();
79        assert_arrays_eq!(result, expected);
80
81        Ok(())
82    }
83
84    #[test]
85    fn test_scalar_fn_mul() -> VortexResult<()> {
86        let lhs = buffer![2i32, 3, 4].into_array();
87        let rhs = buffer![5i32, 6, 7].into_array();
88
89        let scalar_fn = ScalarFn::new(Binary, Operator::Mul).erased();
90        let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?;
91
92        let result = scalar_fn_array.to_canonical()?.into_array();
93        let expected = buffer![10i32, 18, 28].into_array();
94        assert_arrays_eq!(result, expected);
95
96        Ok(())
97    }
98
99    #[test]
100    fn test_scalar_fn_with_nullable() -> VortexResult<()> {
101        let lhs = PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array();
102        let rhs = PrimitiveArray::new(
103            buffer![10i32, 20, 30],
104            Validity::from_iter([true, false, true]),
105        )
106        .into_array();
107
108        let scalar_fn = ScalarFn::new(Binary, Operator::Add).erased();
109        let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?;
110
111        let result = scalar_fn_array.to_canonical()?.into_array();
112        let expected = PrimitiveArray::new(
113            buffer![11i32, 0, 33],
114            Validity::from_iter([true, false, true]),
115        )
116        .into_array();
117        assert_arrays_eq!(result, expected);
118
119        Ok(())
120    }
121
122    #[test]
123    fn test_scalar_fn_comparison() -> VortexResult<()> {
124        let lhs = buffer![1i32, 5, 3].into_array();
125        let rhs = buffer![2i32, 5, 1].into_array();
126
127        let scalar_fn = ScalarFn::new(Binary, Operator::Eq).erased();
128        let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?;
129
130        let result = scalar_fn_array.to_canonical()?.into_array();
131        let expected = BoolArray::from_iter([false, true, false]).into_array();
132        assert_arrays_eq!(result, expected);
133
134        Ok(())
135    }
136}