1use std::any::Any;
5use std::sync::LazyLock;
6
7use arcref::ArcRef;
8use vortex_dtype::DType;
9use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
10use vortex_scalar::{NumericOperator, Scalar};
11
12use crate::arrays::ConstantArray;
13use crate::arrow::{Datum, from_arrow_array_with_len};
14use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
15use crate::vtable::VTable;
16use crate::{Array, ArrayRef, IntoArray};
17
18static NUMERIC_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
19 let compute = ComputeFn::new("numeric".into(), ArcRef::new_ref(&Numeric));
20 for kernel in inventory::iter::<NumericKernelRef> {
21 compute.register_kernel(kernel.0.clone());
22 }
23 compute
24});
25
26pub(crate) fn warm_up_vtable() -> usize {
27 NUMERIC_FN.kernels().len()
28}
29
30pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
36 numeric(lhs, rhs, NumericOperator::Add)
37}
38
39pub fn add_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
41 numeric(
42 lhs,
43 &ConstantArray::new(rhs, lhs.len()).into_array(),
44 NumericOperator::Add,
45 )
46}
47
48pub fn sub(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
50 numeric(lhs, rhs, NumericOperator::Sub)
51}
52
53pub fn sub_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
55 numeric(
56 lhs,
57 &ConstantArray::new(rhs, lhs.len()).into_array(),
58 NumericOperator::Sub,
59 )
60}
61
62pub fn mul(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
64 numeric(lhs, rhs, NumericOperator::Mul)
65}
66
67pub fn mul_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
69 numeric(
70 lhs,
71 &ConstantArray::new(rhs, lhs.len()).into_array(),
72 NumericOperator::Mul,
73 )
74}
75
76pub fn div(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
78 numeric(lhs, rhs, NumericOperator::Div)
79}
80
81pub fn div_scalar(lhs: &dyn Array, rhs: Scalar) -> VortexResult<ArrayRef> {
83 numeric(
84 lhs,
85 &ConstantArray::new(rhs, lhs.len()).into_array(),
86 NumericOperator::Mul,
87 )
88}
89
90pub fn numeric(lhs: &dyn Array, rhs: &dyn Array, op: NumericOperator) -> VortexResult<ArrayRef> {
92 NUMERIC_FN
93 .invoke(&InvocationArgs {
94 inputs: &[lhs.into(), rhs.into()],
95 options: &op,
96 })?
97 .unwrap_array()
98}
99
100pub struct NumericKernelRef(ArcRef<dyn Kernel>);
101inventory::collect!(NumericKernelRef);
102
103pub trait NumericKernel: VTable {
104 fn numeric(
105 &self,
106 array: &Self::Array,
107 other: &dyn Array,
108 op: NumericOperator,
109 ) -> VortexResult<Option<ArrayRef>>;
110}
111
112#[derive(Debug)]
113pub struct NumericKernelAdapter<V: VTable>(pub V);
114
115impl<V: VTable + NumericKernel> NumericKernelAdapter<V> {
116 pub const fn lift(&'static self) -> NumericKernelRef {
117 NumericKernelRef(ArcRef::new_ref(self))
118 }
119}
120
121impl<V: VTable + NumericKernel> Kernel for NumericKernelAdapter<V> {
122 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
123 let inputs = NumericArgs::try_from(args)?;
124 let Some(lhs) = inputs.lhs.as_opt::<V>() else {
125 return Ok(None);
126 };
127 Ok(V::numeric(&self.0, lhs, inputs.rhs, inputs.operator)?.map(|array| array.into()))
128 }
129}
130
131struct Numeric;
132
133impl ComputeFnVTable for Numeric {
134 fn invoke(
135 &self,
136 args: &InvocationArgs,
137 kernels: &[ArcRef<dyn Kernel>],
138 ) -> VortexResult<Output> {
139 let NumericArgs { lhs, rhs, operator } = NumericArgs::try_from(args)?;
140
141 for kernel in kernels {
142 if let Some(output) = kernel.invoke(args)? {
143 return Ok(output);
144 }
145 }
146
147 if let Some(output) = lhs.invoke(&NUMERIC_FN, args)? {
149 return Ok(output);
150 }
151
152 let inverted_args = InvocationArgs {
154 inputs: &[rhs.into(), lhs.into()],
155 options: &operator.swap(),
156 };
157 for kernel in kernels {
158 if let Some(output) = kernel.invoke(&inverted_args)? {
159 return Ok(output);
160 }
161 }
162 if let Some(output) = rhs.invoke(&NUMERIC_FN, &inverted_args)? {
163 return Ok(output);
164 }
165
166 log::debug!(
167 "No numeric implementation found for LHS {}, RHS {}, and operator {:?}",
168 lhs.encoding_id(),
169 rhs.encoding_id(),
170 operator,
171 );
172
173 Ok(arrow_numeric(lhs, rhs, operator)?.into())
175 }
176
177 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
178 let NumericArgs { lhs, rhs, .. } = NumericArgs::try_from(args)?;
179 if !matches!(
180 (lhs.dtype(), rhs.dtype()),
181 (DType::Primitive(..), DType::Primitive(..)) | (DType::Decimal(..), DType::Decimal(..))
182 ) || !lhs.dtype().eq_ignore_nullability(rhs.dtype())
183 {
184 vortex_bail!(
185 "Numeric operations are only supported on two arrays sharing the same numeric type: {} {}",
186 lhs.dtype(),
187 rhs.dtype()
188 )
189 }
190 Ok(lhs.dtype().union_nullability(rhs.dtype().nullability()))
191 }
192
193 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
194 let NumericArgs { lhs, rhs, .. } = NumericArgs::try_from(args)?;
195 if lhs.len() != rhs.len() {
196 vortex_bail!(
197 "Numeric operations aren't supported on arrays of different lengths {} {}",
198 lhs.len(),
199 rhs.len()
200 )
201 }
202 Ok(lhs.len())
203 }
204
205 fn is_elementwise(&self) -> bool {
206 true
207 }
208}
209
210struct NumericArgs<'a> {
211 lhs: &'a dyn Array,
212 rhs: &'a dyn Array,
213 operator: NumericOperator,
214}
215
216impl<'a> TryFrom<&InvocationArgs<'a>> for NumericArgs<'a> {
217 type Error = VortexError;
218
219 fn try_from(args: &InvocationArgs<'a>) -> VortexResult<Self> {
220 if args.inputs.len() != 2 {
221 vortex_bail!("Numeric operations require exactly 2 inputs");
222 }
223 let lhs = args.inputs[0]
224 .array()
225 .ok_or_else(|| vortex_err!("LHS is not an array"))?;
226 let rhs = args.inputs[1]
227 .array()
228 .ok_or_else(|| vortex_err!("RHS is not an array"))?;
229 let operator = *args
230 .options
231 .as_any()
232 .downcast_ref::<NumericOperator>()
233 .ok_or_else(|| vortex_err!("Operator is not a numeric operator"))?;
234 Ok(Self { lhs, rhs, operator })
235 }
236}
237
238impl Options for NumericOperator {
239 fn as_any(&self) -> &dyn Any {
240 self
241 }
242}
243
244fn arrow_numeric(
249 lhs: &dyn Array,
250 rhs: &dyn Array,
251 operator: NumericOperator,
252) -> VortexResult<ArrayRef> {
253 let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
254 let len = lhs.len();
255
256 let left = Datum::try_new(lhs)?;
257 let right = Datum::try_new(rhs)?;
258
259 let array = match operator {
260 NumericOperator::Add => arrow_arith::numeric::add(&left, &right)?,
261 NumericOperator::Sub => arrow_arith::numeric::sub(&left, &right)?,
262 NumericOperator::RSub => arrow_arith::numeric::sub(&right, &left)?,
263 NumericOperator::Mul => arrow_arith::numeric::mul(&left, &right)?,
264 NumericOperator::Div => arrow_arith::numeric::div(&left, &right)?,
265 NumericOperator::RDiv => arrow_arith::numeric::div(&right, &left)?,
266 };
267
268 Ok(from_arrow_array_with_len(array.as_ref(), len, nullable))
269}
270
271#[cfg(test)]
272mod test {
273 use vortex_buffer::buffer;
274 use vortex_scalar::Scalar;
275
276 use crate::IntoArray;
277 use crate::arrays::PrimitiveArray;
278 use crate::canonical::ToCanonical;
279 use crate::compute::sub_scalar;
280
281 #[test]
282 fn test_scalar_subtract_unsigned() {
283 let values = buffer![1u16, 2, 3].into_array();
284 let results = sub_scalar(&values, 1u16.into())
285 .unwrap()
286 .to_primitive()
287 .as_slice::<u16>()
288 .to_vec();
289 assert_eq!(results, &[0u16, 1, 2]);
290 }
291
292 #[test]
293 fn test_scalar_subtract_signed() {
294 let values = buffer![1i64, 2, 3].into_array();
295 let results = sub_scalar(&values, (-1i64).into())
296 .unwrap()
297 .to_primitive()
298 .as_slice::<i64>()
299 .to_vec();
300 assert_eq!(results, &[2i64, 3, 4]);
301 }
302
303 #[test]
304 fn test_scalar_subtract_nullable() {
305 let values = PrimitiveArray::from_option_iter([Some(1u16), Some(2), None, Some(3)]);
306 let result = sub_scalar(values.as_ref(), Some(1u16).into())
307 .unwrap()
308 .to_primitive();
309
310 let actual = (0..result.len())
311 .map(|index| result.scalar_at(index))
312 .collect::<Vec<_>>();
313 assert_eq!(
314 actual,
315 vec![
316 Scalar::from(Some(0u16)),
317 Scalar::from(Some(1u16)),
318 Scalar::from(None::<u16>),
319 Scalar::from(Some(2u16))
320 ]
321 );
322 }
323
324 #[test]
325 fn test_scalar_subtract_float() {
326 let values = buffer![1.0f64, 2.0, 3.0].into_array();
327 let to_subtract = -1f64;
328 let results = sub_scalar(&values, to_subtract.into())
329 .unwrap()
330 .to_primitive()
331 .as_slice::<f64>()
332 .to_vec();
333 assert_eq!(results, &[2.0f64, 3.0, 4.0]);
334 }
335
336 #[test]
337 fn test_scalar_subtract_float_underflow_is_ok() {
338 let values = buffer![f32::MIN, 2.0, 3.0].into_array();
339 let _results = sub_scalar(&values, 1.0f32.into()).unwrap();
340 let _results = sub_scalar(&values, f32::MAX.into()).unwrap();
341 }
342
343 #[test]
344 fn test_scalar_subtract_type_mismatch_fails() {
345 let values = buffer![1u64, 2, 3].into_array();
346 let _results =
348 sub_scalar(&values, 1.5f64.into()).expect_err("Expected type mismatch error");
349 }
350}