Skip to main content

vortex_array/arrays/scalar_fn/vtable/
operations.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::DynArray;
7use crate::ExecutionCtx;
8use crate::IntoArray;
9use crate::LEGACY_SESSION;
10use crate::VortexSessionExecute;
11use crate::arrays::ConstantArray;
12use crate::arrays::scalar_fn::array::ScalarFnArray;
13use crate::arrays::scalar_fn::vtable::ScalarFnVTable;
14use crate::columnar::Columnar;
15use crate::scalar::Scalar;
16use crate::scalar_fn::VecExecutionArgs;
17use crate::vtable::OperationsVTable;
18
19impl OperationsVTable<ScalarFnVTable> for ScalarFnVTable {
20    fn scalar_at(
21        array: &ScalarFnArray,
22        index: usize,
23        _ctx: &mut ExecutionCtx,
24    ) -> VortexResult<Scalar> {
25        let inputs: Vec<_> = array
26            .children
27            .iter()
28            .map(|child| Ok(ConstantArray::new(child.scalar_at(index)?, 1).into_array()))
29            .collect::<VortexResult<_>>()?;
30
31        let mut ctx = LEGACY_SESSION.create_execution_ctx();
32        let args = VecExecutionArgs::new(inputs, 1);
33        let result = array.scalar_fn().execute(&args, &mut ctx)?;
34
35        let scalar = match result.execute::<Columnar>(&mut ctx)? {
36            Columnar::Canonical(arr) => {
37                tracing::info!(
38                    "Scalar function {} returned non-constant array from execution over all scalar inputs",
39                    array.scalar_fn(),
40                );
41                arr.as_ref().scalar_at(0)?
42            }
43            Columnar::Constant(constant) => constant.scalar().clone(),
44        };
45
46        debug_assert_eq!(
47            scalar.dtype(),
48            &array.dtype,
49            "Scalar function {} returned dtype {:?} but expected {:?}",
50            array.scalar_fn(),
51            scalar.dtype(),
52            array.dtype
53        );
54
55        Ok(scalar)
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use vortex_buffer::buffer;
62    use vortex_error::VortexResult;
63
64    use crate::IntoArray;
65    use crate::arrays::BoolArray;
66    use crate::arrays::PrimitiveArray;
67    use crate::arrays::scalar_fn::array::ScalarFnArray;
68    use crate::assert_arrays_eq;
69    use crate::scalar_fn::ScalarFn;
70    use crate::scalar_fn::fns::binary::Binary;
71    use crate::scalar_fn::fns::operators::Operator;
72    use crate::validity::Validity;
73
74    #[test]
75    fn test_scalar_fn_add() -> VortexResult<()> {
76        let lhs = buffer![1i32, 2, 3].into_array();
77        let rhs = buffer![10i32, 20, 30].into_array();
78
79        let scalar_fn = ScalarFn::new(Binary, Operator::Add).erased();
80        let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?;
81
82        let result = scalar_fn_array.to_canonical()?.into_array();
83        let expected = buffer![11i32, 22, 33].into_array();
84        assert_arrays_eq!(result, expected);
85
86        Ok(())
87    }
88
89    #[test]
90    fn test_scalar_fn_mul() -> VortexResult<()> {
91        let lhs = buffer![2i32, 3, 4].into_array();
92        let rhs = buffer![5i32, 6, 7].into_array();
93
94        let scalar_fn = ScalarFn::new(Binary, Operator::Mul).erased();
95        let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?;
96
97        let result = scalar_fn_array.to_canonical()?.into_array();
98        let expected = buffer![10i32, 18, 28].into_array();
99        assert_arrays_eq!(result, expected);
100
101        Ok(())
102    }
103
104    #[test]
105    fn test_scalar_fn_with_nullable() -> VortexResult<()> {
106        let lhs = PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array();
107        let rhs = PrimitiveArray::new(
108            buffer![10i32, 20, 30],
109            Validity::from_iter([true, false, true]),
110        )
111        .into_array();
112
113        let scalar_fn = ScalarFn::new(Binary, Operator::Add).erased();
114        let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?;
115
116        let result = scalar_fn_array.to_canonical()?.into_array();
117        let expected = PrimitiveArray::new(
118            buffer![11i32, 0, 33],
119            Validity::from_iter([true, false, true]),
120        )
121        .into_array();
122        assert_arrays_eq!(result, expected);
123
124        Ok(())
125    }
126
127    #[test]
128    fn test_scalar_fn_comparison() -> VortexResult<()> {
129        let lhs = buffer![1i32, 5, 3].into_array();
130        let rhs = buffer![2i32, 5, 1].into_array();
131
132        let scalar_fn = ScalarFn::new(Binary, Operator::Eq).erased();
133        let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?;
134
135        let result = scalar_fn_array.to_canonical()?.into_array();
136        let expected = BoolArray::from_iter([false, true, false]).into_array();
137        assert_arrays_eq!(result, expected);
138
139        Ok(())
140    }
141}