1use std::any::Any;
2use std::sync::LazyLock;
3
4use arcref::ArcRef;
5use vortex_dtype::DType;
6use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
7use vortex_scalar::{NumericOperator, Scalar};
8
9use crate::arrays::ConstantArray;
10use crate::arrow::{Datum, from_arrow_array_with_len};
11use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
12use crate::vtable::VTable;
13use crate::{Array, ArrayRef, IntoArray};
14
15pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
21 numeric(lhs, rhs, NumericOperator::Add)
22}
23
24pub fn add_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
26 numeric(
27 lhs,
28 &ConstantArray::new(rhs, lhs.len()).into_array(),
29 NumericOperator::Add,
30 )
31}
32
33pub fn sub(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
35 numeric(lhs, rhs, NumericOperator::Sub)
36}
37
38pub fn sub_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
40 numeric(
41 lhs,
42 &ConstantArray::new(rhs, lhs.len()).into_array(),
43 NumericOperator::Sub,
44 )
45}
46
47pub fn mul(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
49 numeric(lhs, rhs, NumericOperator::Mul)
50}
51
52pub fn mul_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
54 numeric(
55 lhs,
56 &ConstantArray::new(rhs, lhs.len()).into_array(),
57 NumericOperator::Mul,
58 )
59}
60
61pub fn div(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
63 numeric(lhs, rhs, NumericOperator::Div)
64}
65
66pub fn div_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
68 numeric(
69 lhs,
70 &ConstantArray::new(rhs, lhs.len()).into_array(),
71 NumericOperator::Mul,
72 )
73}
74
75pub fn numeric(lhs: &dyn Array, rhs: &dyn Array, op: NumericOperator) -> VortexResult<ArrayRef> {
77 NUMERIC_FN
78 .invoke(&InvocationArgs {
79 inputs: &[lhs.into(), rhs.into()],
80 options: &op,
81 })?
82 .unwrap_array()
83}
84
85pub struct NumericKernelRef(ArcRef<dyn Kernel>);
86inventory::collect!(NumericKernelRef);
87
88pub trait NumericKernel: VTable {
89 fn numeric(
90 &self,
91 array: &Self::Array,
92 other: &dyn Array,
93 op: NumericOperator,
94 ) -> VortexResult<Option<ArrayRef>>;
95}
96
97#[derive(Debug)]
98pub struct NumericKernelAdapter<V: VTable>(pub V);
99
100impl<V: VTable + NumericKernel> NumericKernelAdapter<V> {
101 pub const fn lift(&'static self) -> NumericKernelRef {
102 NumericKernelRef(ArcRef::new_ref(self))
103 }
104}
105
106impl<V: VTable + NumericKernel> Kernel for NumericKernelAdapter<V> {
107 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
108 let inputs = NumericArgs::try_from(args)?;
109 let Some(lhs) = inputs.lhs.as_opt::<V>() else {
110 return Ok(None);
111 };
112 Ok(V::numeric(&self.0, lhs, inputs.rhs, inputs.operator)?.map(|array| array.into()))
113 }
114}
115
116pub static NUMERIC_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
117 let compute = ComputeFn::new("numeric".into(), ArcRef::new_ref(&Numeric));
118 for kernel in inventory::iter::<NumericKernelRef> {
119 compute.register_kernel(kernel.0.clone());
120 }
121 compute
122});
123
124struct Numeric;
125
126impl ComputeFnVTable for Numeric {
127 fn invoke(
128 &self,
129 args: &InvocationArgs,
130 kernels: &[ArcRef<dyn Kernel>],
131 ) -> VortexResult<Output> {
132 let NumericArgs { lhs, rhs, operator } = NumericArgs::try_from(args)?;
133
134 for kernel in kernels {
136 if let Some(output) = kernel.invoke(args)? {
137 return Ok(output);
138 }
139 }
140 if let Some(output) = lhs.invoke(&NUMERIC_FN, args)? {
141 return Ok(output);
142 }
143
144 let inverted_args = InvocationArgs {
146 inputs: &[rhs.into(), lhs.into()],
147 options: &operator.swap(),
148 };
149 for kernel in kernels {
150 if let Some(output) = kernel.invoke(&inverted_args)? {
151 return Ok(output);
152 }
153 }
154 if let Some(output) = rhs.invoke(&NUMERIC_FN, &inverted_args)? {
155 return Ok(output);
156 }
157
158 log::debug!(
159 "No numeric implementation found for LHS {}, RHS {}, and operator {:?}",
160 lhs.encoding_id(),
161 rhs.encoding_id(),
162 operator,
163 );
164
165 Ok(arrow_numeric(lhs, rhs, operator)?.into())
167 }
168
169 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
170 let NumericArgs { lhs, rhs, .. } = NumericArgs::try_from(args)?;
171 if !matches!(
172 (lhs.dtype(), rhs.dtype()),
173 (DType::Primitive(..), DType::Primitive(..)) | (DType::Decimal(..), DType::Decimal(..))
174 ) || !lhs.dtype().eq_ignore_nullability(rhs.dtype())
175 {
176 vortex_bail!(
177 "Numeric operations are only supported on two arrays sharing the same numeric type: {} {}",
178 lhs.dtype(),
179 rhs.dtype()
180 )
181 }
182 Ok(lhs.dtype().union_nullability(rhs.dtype().nullability()))
183 }
184
185 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
186 let NumericArgs { lhs, rhs, .. } = NumericArgs::try_from(args)?;
187 if lhs.len() != rhs.len() {
188 vortex_bail!(
189 "Numeric operations aren't supported on arrays of different lengths {} {}",
190 lhs.len(),
191 rhs.len()
192 )
193 }
194 Ok(lhs.len())
195 }
196
197 fn is_elementwise(&self) -> bool {
198 true
199 }
200}
201
202struct NumericArgs<'a> {
203 lhs: &'a dyn Array,
204 rhs: &'a dyn Array,
205 operator: NumericOperator,
206}
207
208impl<'a> TryFrom<&InvocationArgs<'a>> for NumericArgs<'a> {
209 type Error = VortexError;
210
211 fn try_from(args: &InvocationArgs<'a>) -> VortexResult<Self> {
212 if args.inputs.len() != 2 {
213 vortex_bail!("Numeric operations require exactly 2 inputs");
214 }
215 let lhs = args.inputs[0]
216 .array()
217 .ok_or_else(|| vortex_err!("LHS is not an array"))?;
218 let rhs = args.inputs[1]
219 .array()
220 .ok_or_else(|| vortex_err!("RHS is not an array"))?;
221 let operator = *args
222 .options
223 .as_any()
224 .downcast_ref::<NumericOperator>()
225 .ok_or_else(|| vortex_err!("Operator is not a numeric operator"))?;
226 Ok(Self { lhs, rhs, operator })
227 }
228}
229
230impl Options for NumericOperator {
231 fn as_any(&self) -> &dyn Any {
232 self
233 }
234}
235
236fn arrow_numeric(
241 lhs: &dyn Array,
242 rhs: &dyn Array,
243 operator: NumericOperator,
244) -> VortexResult<ArrayRef> {
245 let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
246 let len = lhs.len();
247
248 let left = Datum::try_new(lhs)?;
249 let right = Datum::try_new(rhs)?;
250
251 let array = match operator {
252 NumericOperator::Add => arrow_arith::numeric::add(&left, &right)?,
253 NumericOperator::Sub => arrow_arith::numeric::sub(&left, &right)?,
254 NumericOperator::RSub => arrow_arith::numeric::sub(&right, &left)?,
255 NumericOperator::Mul => arrow_arith::numeric::mul(&left, &right)?,
256 NumericOperator::Div => arrow_arith::numeric::div(&left, &right)?,
257 NumericOperator::RDiv => arrow_arith::numeric::div(&right, &left)?,
258 };
259
260 from_arrow_array_with_len(array.as_ref(), len, nullable)
261}
262
263#[cfg(test)]
264mod test {
265 use vortex_buffer::buffer;
266 use vortex_scalar::Scalar;
267
268 use crate::IntoArray;
269 use crate::arrays::PrimitiveArray;
270 use crate::canonical::ToCanonical;
271 use crate::compute::sub_scalar;
272
273 #[test]
274 fn test_scalar_subtract_unsigned() {
275 let values = buffer![1u16, 2, 3].into_array();
276 let results = sub_scalar(&values, 1u16.into())
277 .unwrap()
278 .to_primitive()
279 .unwrap()
280 .as_slice::<u16>()
281 .to_vec();
282 assert_eq!(results, &[0u16, 1, 2]);
283 }
284
285 #[test]
286 fn test_scalar_subtract_signed() {
287 let values = buffer![1i64, 2, 3].into_array();
288 let results = sub_scalar(&values, (-1i64).into())
289 .unwrap()
290 .to_primitive()
291 .unwrap()
292 .as_slice::<i64>()
293 .to_vec();
294 assert_eq!(results, &[2i64, 3, 4]);
295 }
296
297 #[test]
298 fn test_scalar_subtract_nullable() {
299 let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
300 let result = sub_scalar(values.as_ref(), Some(1u16).into())
301 .unwrap()
302 .to_primitive()
303 .unwrap();
304
305 let actual = (0..result.len())
306 .map(|index| result.scalar_at(index).unwrap())
307 .collect::<Vec<_>>();
308 assert_eq!(
309 actual,
310 vec![
311 Scalar::from(Some(0u16)),
312 Scalar::from(Some(1u16)),
313 Scalar::from(None::<u16>),
314 Scalar::from(Some(2u16))
315 ]
316 );
317 }
318
319 #[test]
320 fn test_scalar_subtract_float() {
321 let values = buffer![1.0f64, 2.0, 3.0].into_array();
322 let to_subtract = -1f64;
323 let results = sub_scalar(&values, to_subtract.into())
324 .unwrap()
325 .to_primitive()
326 .unwrap()
327 .as_slice::<f64>()
328 .to_vec();
329 assert_eq!(results, &[2.0f64, 3.0, 4.0]);
330 }
331
332 #[test]
333 fn test_scalar_subtract_float_underflow_is_ok() {
334 let values = buffer![f32::MIN, 2.0, 3.0].into_array();
335 let _results = sub_scalar(&values, 1.0f32.into()).unwrap();
336 let _results = sub_scalar(&values, f32::MAX.into()).unwrap();
337 }
338
339 #[test]
340 fn test_scalar_subtract_type_mismatch_fails() {
341 let values = buffer![1u64, 2, 3].into_array();
342 let _results =
344 sub_scalar(&values, 1.5f64.into()).expect_err("Expected type mismatch error");
345 }
346}