1use std::any::Any;
2use std::sync::LazyLock;
3
4use vortex_dtype::DType;
5use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
6use vortex_scalar::{NumericOperator, Scalar};
7
8use crate::arcref::ArcRef;
9use crate::arrays::ConstantArray;
10use crate::arrow::{Datum, from_arrow_array_with_len};
11use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
12use crate::encoding::Encoding;
13use crate::{Array, ArrayRef};
14
15pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
17 numeric(lhs, rhs, NumericOperator::Add)
18}
19
20pub fn add_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
22 numeric(
23 lhs,
24 &ConstantArray::new(rhs, lhs.len()).into_array(),
25 NumericOperator::Add,
26 )
27}
28
29pub fn sub(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
31 numeric(lhs, rhs, NumericOperator::Sub)
32}
33
34pub fn sub_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
36 numeric(
37 lhs,
38 &ConstantArray::new(rhs, lhs.len()).into_array(),
39 NumericOperator::Sub,
40 )
41}
42
43pub fn mul(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
45 numeric(lhs, rhs, NumericOperator::Mul)
46}
47
48pub fn mul_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
50 numeric(
51 lhs,
52 &ConstantArray::new(rhs, lhs.len()).into_array(),
53 NumericOperator::Mul,
54 )
55}
56
57pub fn div(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
59 numeric(lhs, rhs, NumericOperator::Div)
60}
61
62pub fn div_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
64 numeric(
65 lhs,
66 &ConstantArray::new(rhs, lhs.len()).into_array(),
67 NumericOperator::Mul,
68 )
69}
70
71pub fn numeric(lhs: &dyn Array, rhs: &dyn Array, op: NumericOperator) -> VortexResult<ArrayRef> {
73 NUMERIC_FN
74 .invoke(&InvocationArgs {
75 inputs: &[lhs.into(), rhs.into()],
76 options: &op,
77 })?
78 .unwrap_array()
79}
80
81pub struct NumericKernelRef(ArcRef<dyn Kernel>);
82inventory::collect!(NumericKernelRef);
83
84pub trait NumericKernel: Encoding {
85 fn numeric(
86 &self,
87 array: &Self::Array,
88 other: &dyn Array,
89 op: NumericOperator,
90 ) -> VortexResult<Option<ArrayRef>>;
91}
92
93#[derive(Debug)]
94pub struct NumericKernelAdapter<E: Encoding>(pub E);
95
96impl<E: Encoding + NumericKernel> NumericKernelAdapter<E> {
97 pub const fn lift(&'static self) -> NumericKernelRef {
98 NumericKernelRef(ArcRef::new_ref(self))
99 }
100}
101
102impl<E: Encoding + NumericKernel> Kernel for NumericKernelAdapter<E> {
103 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
104 let inputs = NumericArgs::try_from(args)?;
105 let Some(lhs) = inputs.lhs.as_any().downcast_ref::<E::Array>() else {
106 return Ok(None);
107 };
108 Ok(E::numeric(&self.0, lhs, inputs.rhs, inputs.operator)?.map(|array| array.into()))
109 }
110}
111
112pub static NUMERIC_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
113 let compute = ComputeFn::new("numeric".into(), ArcRef::new_ref(&Numeric));
114 for kernel in inventory::iter::<NumericKernelRef> {
115 compute.register_kernel(kernel.0.clone());
116 }
117 compute
118});
119
120struct Numeric;
121
122impl ComputeFnVTable for Numeric {
123 fn invoke(
124 &self,
125 args: &InvocationArgs,
126 kernels: &[ArcRef<dyn Kernel>],
127 ) -> VortexResult<Output> {
128 let NumericArgs { lhs, rhs, operator } = NumericArgs::try_from(args)?;
129
130 for kernel in kernels {
132 if let Some(output) = kernel.invoke(args)? {
133 return Ok(output);
134 }
135 }
136 if let Some(output) = lhs.invoke(&NUMERIC_FN, args)? {
137 return Ok(output);
138 }
139
140 let inverted_args = InvocationArgs {
142 inputs: &[rhs.into(), lhs.into()],
143 options: &operator.swap(),
144 };
145 for kernel in kernels {
146 if let Some(output) = kernel.invoke(&inverted_args)? {
147 return Ok(output);
148 }
149 }
150 if let Some(output) = rhs.invoke(&NUMERIC_FN, &inverted_args)? {
151 return Ok(output);
152 }
153
154 log::debug!(
155 "No numeric implementation found for LHS {}, RHS {}, and operator {:?}",
156 lhs.encoding(),
157 rhs.encoding(),
158 operator,
159 );
160
161 Ok(arrow_numeric(lhs, rhs, operator)?.into())
163 }
164
165 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
166 let NumericArgs { lhs, rhs, .. } = NumericArgs::try_from(args)?;
167 if !matches!(
168 (lhs.dtype(), rhs.dtype()),
169 (DType::Primitive(..), DType::Primitive(..)) | (DType::Decimal(..), DType::Decimal(..))
170 ) || !lhs.dtype().eq_ignore_nullability(rhs.dtype())
171 {
172 vortex_bail!(
173 "Numeric operations are only supported on two arrays sharing the same numeric type: {} {}",
174 lhs.dtype(),
175 rhs.dtype()
176 )
177 }
178 Ok(lhs
179 .dtype()
180 .with_nullability((lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()))
181 }
182
183 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
184 let NumericArgs { lhs, rhs, .. } = NumericArgs::try_from(args)?;
185 if lhs.len() != rhs.len() {
186 vortex_bail!(
187 "Numeric operations aren't supported on arrays of different lengths {} {}",
188 lhs.len(),
189 rhs.len()
190 )
191 }
192 Ok(lhs.len())
193 }
194
195 fn is_elementwise(&self) -> bool {
196 true
197 }
198}
199
200struct NumericArgs<'a> {
201 lhs: &'a dyn Array,
202 rhs: &'a dyn Array,
203 operator: NumericOperator,
204}
205
206impl<'a> TryFrom<&InvocationArgs<'a>> for NumericArgs<'a> {
207 type Error = VortexError;
208
209 fn try_from(args: &InvocationArgs<'a>) -> VortexResult<Self> {
210 if args.inputs.len() != 2 {
211 vortex_bail!("Numeric operations require exactly 2 inputs");
212 }
213 let lhs = args.inputs[0]
214 .array()
215 .ok_or_else(|| vortex_err!("LHS is not an array"))?;
216 let rhs = args.inputs[1]
217 .array()
218 .ok_or_else(|| vortex_err!("RHS is not an array"))?;
219 let operator = *args
220 .options
221 .as_any()
222 .downcast_ref::<NumericOperator>()
223 .ok_or_else(|| vortex_err!("Operator is not a numeric operator"))?;
224 Ok(Self { lhs, rhs, operator })
225 }
226}
227
228impl Options for NumericOperator {
229 fn as_any(&self) -> &dyn Any {
230 self
231 }
232}
233
234fn arrow_numeric(
239 lhs: &dyn Array,
240 rhs: &dyn Array,
241 operator: NumericOperator,
242) -> VortexResult<ArrayRef> {
243 let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
244 let len = lhs.len();
245
246 let left = Datum::try_new(lhs)?;
247 let right = Datum::try_new(rhs)?;
248
249 let array = match operator {
250 NumericOperator::Add => arrow_arith::numeric::add(&left, &right)?,
251 NumericOperator::Sub => arrow_arith::numeric::sub(&left, &right)?,
252 NumericOperator::RSub => arrow_arith::numeric::sub(&right, &left)?,
253 NumericOperator::Mul => arrow_arith::numeric::mul(&left, &right)?,
254 NumericOperator::Div => arrow_arith::numeric::div(&left, &right)?,
255 NumericOperator::RDiv => arrow_arith::numeric::div(&right, &left)?,
256 };
257
258 from_arrow_array_with_len(array, len, nullable)
259}
260
261#[cfg(test)]
262mod test {
263 use vortex_buffer::buffer;
264 use vortex_scalar::Scalar;
265
266 use crate::IntoArray;
267 use crate::array::Array;
268 use crate::arrays::PrimitiveArray;
269 use crate::canonical::ToCanonical;
270 use crate::compute::{scalar_at, sub_scalar};
271
272 #[test]
273 fn test_scalar_subtract_unsigned() {
274 let values = buffer![1u16, 2, 3].into_array();
275 let results = sub_scalar(&values, 1u16.into())
276 .unwrap()
277 .to_primitive()
278 .unwrap()
279 .as_slice::<u16>()
280 .to_vec();
281 assert_eq!(results, &[0u16, 1, 2]);
282 }
283
284 #[test]
285 fn test_scalar_subtract_signed() {
286 let values = buffer![1i64, 2, 3].into_array();
287 let results = sub_scalar(&values, (-1i64).into())
288 .unwrap()
289 .to_primitive()
290 .unwrap()
291 .as_slice::<i64>()
292 .to_vec();
293 assert_eq!(results, &[2i64, 3, 4]);
294 }
295
296 #[test]
297 fn test_scalar_subtract_nullable() {
298 let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
299 let result = sub_scalar(&values, Some(1u16).into())
300 .unwrap()
301 .to_primitive()
302 .unwrap();
303
304 let actual = (0..result.len())
305 .map(|index| scalar_at(&result, index).unwrap())
306 .collect::<Vec<_>>();
307 assert_eq!(
308 actual,
309 vec![
310 Scalar::from(Some(0u16)),
311 Scalar::from(Some(1u16)),
312 Scalar::from(None::<u16>),
313 Scalar::from(Some(2u16))
314 ]
315 );
316 }
317
318 #[test]
319 fn test_scalar_subtract_float() {
320 let values = buffer![1.0f64, 2.0, 3.0].into_array();
321 let to_subtract = -1f64;
322 let results = sub_scalar(&values, to_subtract.into())
323 .unwrap()
324 .to_primitive()
325 .unwrap()
326 .as_slice::<f64>()
327 .to_vec();
328 assert_eq!(results, &[2.0f64, 3.0, 4.0]);
329 }
330
331 #[test]
332 fn test_scalar_subtract_float_underflow_is_ok() {
333 let values = buffer![f32::MIN, 2.0, 3.0].into_array();
334 let _results = sub_scalar(&values, 1.0f32.into()).unwrap();
335 let _results = sub_scalar(&values, f32::MAX.into()).unwrap();
336 }
337
338 #[test]
339 fn test_scalar_subtract_type_mismatch_fails() {
340 let values = buffer![1u64, 2, 3].into_array();
341 let _results =
343 sub_scalar(&values, 1.5f64.into()).expect_err("Expected type mismatch error");
344 }
345}