vortex_array/arrays/scalar_fn/vtable/
operations.rs1use vortex_error::VortexResult;
5
6use crate::ExecutionCtx;
7use crate::IntoArray;
8use crate::array::ArrayView;
9use crate::array::OperationsVTable;
10use crate::arrays::ConstantArray;
11use crate::arrays::scalar_fn::ScalarFnArrayExt;
12use crate::arrays::scalar_fn::vtable::ScalarFn;
13use crate::columnar::Columnar;
14use crate::scalar::Scalar;
15use crate::scalar_fn::VecExecutionArgs;
16
17impl OperationsVTable<ScalarFn> for ScalarFn {
18 fn scalar_at(
19 array: ArrayView<'_, ScalarFn>,
20 index: usize,
21 ctx: &mut ExecutionCtx,
22 ) -> VortexResult<Scalar> {
23 let inputs: Vec<_> = array
24 .children()
25 .iter()
26 .map(|child| Ok(ConstantArray::new(child.execute_scalar(index, ctx)?, 1).into_array()))
27 .collect::<VortexResult<_>>()?;
28
29 let args = VecExecutionArgs::new(inputs, 1);
30 let result = array.scalar_fn().execute(&args, ctx)?;
31
32 let scalar = match result.execute::<Columnar>(ctx)? {
33 Columnar::Canonical(arr) => {
34 tracing::info!(
35 "Scalar function {} returned non-constant array from execution over all scalar inputs",
36 array.scalar_fn(),
37 );
38 arr.into_array().execute_scalar(0, ctx)?
39 }
40 Columnar::Constant(constant) => constant.scalar().clone(),
41 };
42
43 debug_assert_eq!(
44 scalar.dtype(),
45 array.dtype(),
46 "Scalar function {} returned dtype {:?} but expected {:?}",
47 array.scalar_fn(),
48 scalar.dtype(),
49 array.dtype()
50 );
51
52 Ok(scalar)
53 }
54}
55
56#[cfg(test)]
57mod tests {
58 use vortex_buffer::buffer;
59 use vortex_error::VortexResult;
60
61 use crate::Canonical;
62 use crate::IntoArray;
63 use crate::VortexSessionExecute;
64 use crate::array_session;
65 use crate::arrays::BoolArray;
66 use crate::arrays::PrimitiveArray;
67 use crate::arrays::ScalarFnArray;
68 use crate::arrays::scalar_fn::ScalarFnArrayExt;
69 use crate::assert_arrays_eq;
70 use crate::scalar::Scalar;
71 use crate::scalar_fn::TypedScalarFnInstance;
72 use crate::scalar_fn::fns::binary::Binary;
73 use crate::scalar_fn::fns::literal::Literal;
74 use crate::scalar_fn::fns::operators::Operator;
75 use crate::validity::Validity;
76
77 #[test]
78 fn test_scalar_fn_add() -> VortexResult<()> {
79 let mut ctx = array_session().create_execution_ctx();
80 let lhs = buffer![1i32, 2, 3].into_array();
81 let rhs = buffer![10i32, 20, 30].into_array();
82
83 let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Add).erased();
84 let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?;
85
86 assert_eq!(scalar_fn_array.len(), 3);
87
88 let result = scalar_fn_array
89 .into_array()
90 .execute::<Canonical>(&mut array_session().create_execution_ctx())?
91 .into_array();
92 let expected = buffer![11i32, 22, 33].into_array();
93 assert_arrays_eq!(result, expected, &mut ctx);
94
95 Ok(())
96 }
97
98 #[test]
99 fn test_scalar_fn_inferred_len_rejects_mismatched_children() {
100 let lhs = buffer![1i32, 2, 3].into_array();
101 let rhs = buffer![10i32, 20].into_array();
102
103 let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Add).erased();
104 let err = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])
105 .expect_err("ScalarFnArray::try_new must reject mismatched child lengths");
106
107 assert!(
108 err.to_string()
109 .contains("ScalarFnArray must have children equal to the array length")
110 );
111 }
112
113 #[test]
114 fn test_scalar_fn_without_children_requires_explicit_len() -> VortexResult<()> {
115 let scalar_fn = TypedScalarFnInstance::new(Literal, Scalar::from(1i32)).erased();
116
117 let Err(err) = ScalarFnArray::try_new(scalar_fn.clone(), vec![]) else {
118 panic!("ScalarFnArray::try_new should reject zero children");
119 };
120 assert!(
121 err.to_string()
122 .contains("ScalarFnArray length cannot be inferred without children")
123 );
124
125 let scalar_fn_array = ScalarFnArray::try_new_with_len(scalar_fn, vec![], 3)?;
126 assert_eq!(scalar_fn_array.len(), 3);
127 assert_eq!(scalar_fn_array.child_count(), 0);
128
129 Ok(())
130 }
131
132 #[test]
133 fn test_scalar_fn_mul() -> VortexResult<()> {
134 let mut ctx = array_session().create_execution_ctx();
135 let lhs = buffer![2i32, 3, 4].into_array();
136 let rhs = buffer![5i32, 6, 7].into_array();
137
138 let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Mul).erased();
139 let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?;
140
141 let result = scalar_fn_array
142 .into_array()
143 .execute::<Canonical>(&mut array_session().create_execution_ctx())?
144 .into_array();
145 let expected = buffer![10i32, 18, 28].into_array();
146 assert_arrays_eq!(result, expected, &mut ctx);
147
148 Ok(())
149 }
150
151 #[test]
152 fn test_scalar_fn_with_nullable() -> VortexResult<()> {
153 let mut ctx = array_session().create_execution_ctx();
154 let lhs = PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array();
155 let rhs = PrimitiveArray::new(
156 buffer![10i32, 20, 30],
157 Validity::from_iter([true, false, true]),
158 )
159 .into_array();
160
161 let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Add).erased();
162 let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?;
163
164 let result = scalar_fn_array
165 .into_array()
166 .execute::<Canonical>(&mut array_session().create_execution_ctx())?
167 .into_array();
168 let expected = PrimitiveArray::new(
169 buffer![11i32, 0, 33],
170 Validity::from_iter([true, false, true]),
171 )
172 .into_array();
173 assert_arrays_eq!(result, expected, &mut ctx);
174
175 Ok(())
176 }
177
178 #[test]
179 fn test_scalar_fn_comparison() -> VortexResult<()> {
180 let mut ctx = array_session().create_execution_ctx();
181 let lhs = buffer![1i32, 5, 3].into_array();
182 let rhs = buffer![2i32, 5, 1].into_array();
183
184 let scalar_fn = TypedScalarFnInstance::new(Binary, Operator::Eq).erased();
185 let scalar_fn_array = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs])?;
186
187 let result = scalar_fn_array
188 .into_array()
189 .execute::<Canonical>(&mut array_session().create_execution_ctx())?
190 .into_array();
191 let expected = BoolArray::from_iter([false, true, false]).into_array();
192 assert_arrays_eq!(result, expected, &mut ctx);
193
194 Ok(())
195 }
196}