1use std::any::Any;
5use std::sync::LazyLock;
6
7use arcref::ArcRef;
8use vortex_dtype::DType;
9use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
10use vortex_scalar::{NumericOperator, Scalar};
11
12use crate::arrays::ConstantArray;
13use crate::arrow::{Datum, from_arrow_array_with_len};
14use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
15use crate::vtable::VTable;
16use crate::{Array, ArrayRef, IntoArray};
17
18pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
24 numeric(lhs, rhs, NumericOperator::Add)
25}
26
27pub fn add_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
29 numeric(
30 lhs,
31 &ConstantArray::new(rhs, lhs.len()).into_array(),
32 NumericOperator::Add,
33 )
34}
35
36pub fn sub(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
38 numeric(lhs, rhs, NumericOperator::Sub)
39}
40
41pub fn sub_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
43 numeric(
44 lhs,
45 &ConstantArray::new(rhs, lhs.len()).into_array(),
46 NumericOperator::Sub,
47 )
48}
49
50pub fn mul(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
52 numeric(lhs, rhs, NumericOperator::Mul)
53}
54
55pub fn mul_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
57 numeric(
58 lhs,
59 &ConstantArray::new(rhs, lhs.len()).into_array(),
60 NumericOperator::Mul,
61 )
62}
63
64pub fn div(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
66 numeric(lhs, rhs, NumericOperator::Div)
67}
68
69pub fn div_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
71 numeric(
72 lhs,
73 &ConstantArray::new(rhs, lhs.len()).into_array(),
74 NumericOperator::Mul,
75 )
76}
77
78pub fn numeric(lhs: &dyn Array, rhs: &dyn Array, op: NumericOperator) -> VortexResult<ArrayRef> {
80 NUMERIC_FN
81 .invoke(&InvocationArgs {
82 inputs: &[lhs.into(), rhs.into()],
83 options: &op,
84 })?
85 .unwrap_array()
86}
87
88pub struct NumericKernelRef(ArcRef<dyn Kernel>);
89inventory::collect!(NumericKernelRef);
90
91pub trait NumericKernel: VTable {
92 fn numeric(
93 &self,
94 array: &Self::Array,
95 other: &dyn Array,
96 op: NumericOperator,
97 ) -> VortexResult<Option<ArrayRef>>;
98}
99
100#[derive(Debug)]
101pub struct NumericKernelAdapter<V: VTable>(pub V);
102
103impl<V: VTable + NumericKernel> NumericKernelAdapter<V> {
104 pub const fn lift(&'static self) -> NumericKernelRef {
105 NumericKernelRef(ArcRef::new_ref(self))
106 }
107}
108
109impl<V: VTable + NumericKernel> Kernel for NumericKernelAdapter<V> {
110 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
111 let inputs = NumericArgs::try_from(args)?;
112 let Some(lhs) = inputs.lhs.as_opt::<V>() else {
113 return Ok(None);
114 };
115 Ok(V::numeric(&self.0, lhs, inputs.rhs, inputs.operator)?.map(|array| array.into()))
116 }
117}
118
119pub static NUMERIC_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
120 let compute = ComputeFn::new("numeric".into(), ArcRef::new_ref(&Numeric));
121 for kernel in inventory::iter::<NumericKernelRef> {
122 compute.register_kernel(kernel.0.clone());
123 }
124 compute
125});
126
127struct Numeric;
128
129impl ComputeFnVTable for Numeric {
130 fn invoke(
131 &self,
132 args: &InvocationArgs,
133 kernels: &[ArcRef<dyn Kernel>],
134 ) -> VortexResult<Output> {
135 let NumericArgs { lhs, rhs, operator } = NumericArgs::try_from(args)?;
136
137 for kernel in kernels {
139 if let Some(output) = kernel.invoke(args)? {
140 return Ok(output);
141 }
142 }
143 if let Some(output) = lhs.invoke(&NUMERIC_FN, args)? {
144 return Ok(output);
145 }
146
147 let inverted_args = InvocationArgs {
149 inputs: &[rhs.into(), lhs.into()],
150 options: &operator.swap(),
151 };
152 for kernel in kernels {
153 if let Some(output) = kernel.invoke(&inverted_args)? {
154 return Ok(output);
155 }
156 }
157 if let Some(output) = rhs.invoke(&NUMERIC_FN, &inverted_args)? {
158 return Ok(output);
159 }
160
161 log::debug!(
162 "No numeric implementation found for LHS {}, RHS {}, and operator {:?}",
163 lhs.encoding_id(),
164 rhs.encoding_id(),
165 operator,
166 );
167
168 Ok(arrow_numeric(lhs, rhs, operator)?.into())
170 }
171
172 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
173 let NumericArgs { lhs, rhs, .. } = NumericArgs::try_from(args)?;
174 if !matches!(
175 (lhs.dtype(), rhs.dtype()),
176 (DType::Primitive(..), DType::Primitive(..)) | (DType::Decimal(..), DType::Decimal(..))
177 ) || !lhs.dtype().eq_ignore_nullability(rhs.dtype())
178 {
179 vortex_bail!(
180 "Numeric operations are only supported on two arrays sharing the same numeric type: {} {}",
181 lhs.dtype(),
182 rhs.dtype()
183 )
184 }
185 Ok(lhs.dtype().union_nullability(rhs.dtype().nullability()))
186 }
187
188 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
189 let NumericArgs { lhs, rhs, .. } = NumericArgs::try_from(args)?;
190 if lhs.len() != rhs.len() {
191 vortex_bail!(
192 "Numeric operations aren't supported on arrays of different lengths {} {}",
193 lhs.len(),
194 rhs.len()
195 )
196 }
197 Ok(lhs.len())
198 }
199
200 fn is_elementwise(&self) -> bool {
201 true
202 }
203}
204
205struct NumericArgs<'a> {
206 lhs: &'a dyn Array,
207 rhs: &'a dyn Array,
208 operator: NumericOperator,
209}
210
211impl<'a> TryFrom<&InvocationArgs<'a>> for NumericArgs<'a> {
212 type Error = VortexError;
213
214 fn try_from(args: &InvocationArgs<'a>) -> VortexResult<Self> {
215 if args.inputs.len() != 2 {
216 vortex_bail!("Numeric operations require exactly 2 inputs");
217 }
218 let lhs = args.inputs[0]
219 .array()
220 .ok_or_else(|| vortex_err!("LHS is not an array"))?;
221 let rhs = args.inputs[1]
222 .array()
223 .ok_or_else(|| vortex_err!("RHS is not an array"))?;
224 let operator = *args
225 .options
226 .as_any()
227 .downcast_ref::<NumericOperator>()
228 .ok_or_else(|| vortex_err!("Operator is not a numeric operator"))?;
229 Ok(Self { lhs, rhs, operator })
230 }
231}
232
233impl Options for NumericOperator {
234 fn as_any(&self) -> &dyn Any {
235 self
236 }
237}
238
239fn arrow_numeric(
244 lhs: &dyn Array,
245 rhs: &dyn Array,
246 operator: NumericOperator,
247) -> VortexResult<ArrayRef> {
248 let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
249 let len = lhs.len();
250
251 let left = Datum::try_new(lhs)?;
252 let right = Datum::try_new(rhs)?;
253
254 let array = match operator {
255 NumericOperator::Add => arrow_arith::numeric::add(&left, &right)?,
256 NumericOperator::Sub => arrow_arith::numeric::sub(&left, &right)?,
257 NumericOperator::RSub => arrow_arith::numeric::sub(&right, &left)?,
258 NumericOperator::Mul => arrow_arith::numeric::mul(&left, &right)?,
259 NumericOperator::Div => arrow_arith::numeric::div(&left, &right)?,
260 NumericOperator::RDiv => arrow_arith::numeric::div(&right, &left)?,
261 };
262
263 from_arrow_array_with_len(array.as_ref(), len, nullable)
264}
265
266#[cfg(test)]
267mod test {
268 use vortex_buffer::buffer;
269 use vortex_scalar::Scalar;
270
271 use crate::IntoArray;
272 use crate::arrays::PrimitiveArray;
273 use crate::canonical::ToCanonical;
274 use crate::compute::sub_scalar;
275
276 #[test]
277 fn test_scalar_subtract_unsigned() {
278 let values = buffer![1u16, 2, 3].into_array();
279 let results = sub_scalar(&values, 1u16.into())
280 .unwrap()
281 .to_primitive()
282 .unwrap()
283 .as_slice::<u16>()
284 .to_vec();
285 assert_eq!(results, &[0u16, 1, 2]);
286 }
287
288 #[test]
289 fn test_scalar_subtract_signed() {
290 let values = buffer![1i64, 2, 3].into_array();
291 let results = sub_scalar(&values, (-1i64).into())
292 .unwrap()
293 .to_primitive()
294 .unwrap()
295 .as_slice::<i64>()
296 .to_vec();
297 assert_eq!(results, &[2i64, 3, 4]);
298 }
299
300 #[test]
301 fn test_scalar_subtract_nullable() {
302 let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
303 let result = sub_scalar(values.as_ref(), Some(1u16).into())
304 .unwrap()
305 .to_primitive()
306 .unwrap();
307
308 let actual = (0..result.len())
309 .map(|index| result.scalar_at(index).unwrap())
310 .collect::<Vec<_>>();
311 assert_eq!(
312 actual,
313 vec![
314 Scalar::from(Some(0u16)),
315 Scalar::from(Some(1u16)),
316 Scalar::from(None::<u16>),
317 Scalar::from(Some(2u16))
318 ]
319 );
320 }
321
322 #[test]
323 fn test_scalar_subtract_float() {
324 let values = buffer![1.0f64, 2.0, 3.0].into_array();
325 let to_subtract = -1f64;
326 let results = sub_scalar(&values, to_subtract.into())
327 .unwrap()
328 .to_primitive()
329 .unwrap()
330 .as_slice::<f64>()
331 .to_vec();
332 assert_eq!(results, &[2.0f64, 3.0, 4.0]);
333 }
334
335 #[test]
336 fn test_scalar_subtract_float_underflow_is_ok() {
337 let values = buffer![f32::MIN, 2.0, 3.0].into_array();
338 let _results = sub_scalar(&values, 1.0f32.into()).unwrap();
339 let _results = sub_scalar(&values, f32::MAX.into()).unwrap();
340 }
341
342 #[test]
343 fn test_scalar_subtract_type_mismatch_fails() {
344 let values = buffer![1u64, 2, 3].into_array();
345 let _results =
347 sub_scalar(&values, 1.5f64.into()).expect_err("Expected type mismatch error");
348 }
349}