Skip to main content

vortex_array/arrays/scalar_fn/vtable/
operations.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::ExecutionCtx;
7use crate::IntoArray;
8use crate::array::ArrayView;
9use crate::array::OperationsVTable;
10use crate::arrays::ConstantArray;
11use crate::arrays::scalar_fn::ScalarFnArrayExt;
12use crate::arrays::scalar_fn::vtable::ScalarFn;
13use crate::columnar::Columnar;
14use crate::scalar::Scalar;
15use crate::scalar_fn::VecExecutionArgs;
16
17impl OperationsVTable<ScalarFn> for ScalarFn {
18    fn scalar_at(
19        array: ArrayView<'_, ScalarFn>,
20        index: usize,
21        ctx: &mut ExecutionCtx,
22    ) -> VortexResult<Scalar> {
23        let inputs: Vec<_> = array
24            .children()
25            .iter()
26            .map(|child| Ok(ConstantArray::new(child.execute_scalar(index, ctx)?, 1).into_array()))
27            .collect::<VortexResult<_>>()?;
28
29        let args = VecExecutionArgs::new(inputs, 1);
30        let result = array.scalar_fn().execute(&args, ctx)?;
31
32        let scalar = match result.execute::<Columnar>(ctx)? {
33            Columnar::Canonical(arr) => {
34                tracing::info!(
35                    "Scalar function {} returned non-constant array from execution over all scalar inputs",
36                    array.scalar_fn(),
37                );
38                arr.into_array().execute_scalar(0, ctx)?
39            }
40            Columnar::Constant(constant) => constant.scalar().clone(),
41        };
42
43        debug_assert_eq!(
44            scalar.dtype(),
45            array.dtype(),
46            "Scalar function {} returned dtype {:?} but expected {:?}",
47            array.scalar_fn(),
48            scalar.dtype(),
49            array.dtype()
50        );
51
52        Ok(scalar)
53    }
54}
55
56#[cfg(test)]
57mod tests {
58    use vortex_buffer::buffer;
59    use vortex_error::VortexResult;
60
61    use crate::Canonical;
62    use crate::IntoArray;
63    use crate::VortexSessionExecute;
64    use crate::array_session;
65    use crate::arrays::BoolArray;
66    use crate::arrays::PrimitiveArray;
67    use crate::arrays::ScalarFnArray;
68    use crate::arrays::scalar_fn::ScalarFnArrayExt;
69    use crate::assert_arrays_eq;
70    use crate::scalar::Scalar;
71    use crate::scalar_fn::TypedScalarFnInstance;
72    use crate::scalar_fn::fns::binary::Binary;
73    use crate::scalar_fn::fns::literal::Literal;
74    use crate::scalar_fn::fns::operators::Operator;
75    use crate::validity::Validity;
76
77    #[test]
78    fn test_scalar_fn_add() -> VortexResult<()> {
79        let mut ctx = array_session().create_execution_ctx();
80        let lhs = buffer![1i32, 2, 3].into_array();
81        let rhs = buffer![10i32, 20, 30].into_array();
82
83        let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Add).erased();
84        let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?;
85
86        assert_eq!(scalar_fn_array.len(), 3);
87
88        let result = scalar_fn_array
89            .into_array()
90            .execute::<Canonical>(&mut array_session().create_execution_ctx())?
91            .into_array();
92        let expected = buffer![11i32, 22, 33].into_array();
93        assert_arrays_eq!(result, expected, &mut ctx);
94
95        Ok(())
96    }
97
98    #[test]
99    fn test_scalar_fn_inferred_len_rejects_mismatched_children() {
100        let lhs = buffer![1i32, 2, 3].into_array();
101        let rhs = buffer![10i32, 20].into_array();
102
103        let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Add).erased();
104        let err = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])
105            .expect_err("ScalarFnArray::try_new must reject mismatched child lengths");
106
107        assert!(
108            err.to_string()
109                .contains("ScalarFnArray must have children equal to the array length")
110        );
111    }
112
113    #[test]
114    fn test_scalar_fn_without_children_requires_explicit_len() -> VortexResult<()> {
115        let scalar_fn = TypedScalarFnInstance::new(Literal, Scalar::from(1i32)).erased();
116
117        let Err(err) = ScalarFnArray::try_new(scalar_fn.clone(), vec![]) else {
118            panic!("ScalarFnArray::try_new should reject zero children");
119        };
120        assert!(
121            err.to_string()
122                .contains("ScalarFnArray length cannot be inferred without children")
123        );
124
125        let scalar_fn_array = ScalarFnArray::try_new_with_len(scalar_fn, vec![], 3)?;
126        assert_eq!(scalar_fn_array.len(), 3);
127        assert_eq!(scalar_fn_array.child_count(), 0);
128
129        Ok(())
130    }
131
132    #[test]
133    fn test_scalar_fn_mul() -> VortexResult<()> {
134        let mut ctx = array_session().create_execution_ctx();
135        let lhs = buffer![2i32, 3, 4].into_array();
136        let rhs = buffer![5i32, 6, 7].into_array();
137
138        let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Mul).erased();
139        let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?;
140
141        let result = scalar_fn_array
142            .into_array()
143            .execute::<Canonical>(&mut array_session().create_execution_ctx())?
144            .into_array();
145        let expected = buffer![10i32, 18, 28].into_array();
146        assert_arrays_eq!(result, expected, &mut ctx);
147
148        Ok(())
149    }
150
151    #[test]
152    fn test_scalar_fn_with_nullable() -> VortexResult<()> {
153        let mut ctx = array_session().create_execution_ctx();
154        let lhs = PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array();
155        let rhs = PrimitiveArray::new(
156            buffer![10i32, 20, 30],
157            Validity::from_iter([true, false, true]),
158        )
159        .into_array();
160
161        let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Add).erased();
162        let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?;
163
164        let result = scalar_fn_array
165            .into_array()
166            .execute::<Canonical>(&mut array_session().create_execution_ctx())?
167            .into_array();
168        let expected = PrimitiveArray::new(
169            buffer![11i32, 0, 33],
170            Validity::from_iter([true, false, true]),
171        )
172        .into_array();
173        assert_arrays_eq!(result, expected, &mut ctx);
174
175        Ok(())
176    }
177
178    #[test]
179    fn test_scalar_fn_comparison() -> VortexResult<()> {
180        let mut ctx = array_session().create_execution_ctx();
181        let lhs = buffer![1i32, 5, 3].into_array();
182        let rhs = buffer![2i32, 5, 1].into_array();
183
184        let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Eq).erased();
185        let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?;
186
187        let result = scalar_fn_array
188            .into_array()
189            .execute::<Canonical>(&mut array_session().create_execution_ctx())?
190            .into_array();
191        let expected = BoolArray::from_iter([false, true, false]).into_array();
192        assert_arrays_eq!(result, expected, &mut ctx);
193
194        Ok(())
195    }
196}