1use std::hash::{Hash, Hasher};
5use std::sync::LazyLock;
6
7use enum_map::{Enum, EnumMap, enum_map};
8use vortex_buffer::ByteBuffer;
9use vortex_compute::arithmetic::{
10 Add, Arithmetic, CheckedArithmetic, CheckedOperator, Div, Mul, Operator, Sub,
11};
12use vortex_dtype::{DType, NativePType, PTypeDowncastExt, match_each_native_ptype};
13use vortex_error::{VortexExpect, VortexResult, vortex_err};
14use vortex_scalar::{PValue, Scalar};
15use vortex_vector::primitive::PVector;
16
17use crate::arrays::ConstantArray;
18use crate::execution::{BatchKernelRef, BindCtx, kernel};
19use crate::serde::ArrayChildren;
20use crate::stats::{ArrayStats, StatsSetRef};
21use crate::vtable::{
22 ArrayVTable, NotSupported, OperatorVTable, SerdeVTable, VTable, VisitorVTable,
23};
24use crate::{
25 Array, ArrayBufferVisitor, ArrayChildVisitor, ArrayEq, ArrayHash, ArrayRef,
26 DeserializeMetadata, EmptyMetadata, EncodingId, EncodingRef, IntoArray, Precision, vtable,
27};
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Enum)]
31pub enum ArithmeticOperator {
32 Add,
34 Sub,
36 Mul,
38 Div,
40}
41
42vtable!(Arithmetic);
43
44#[derive(Debug, Clone)]
45pub struct ArithmeticArray {
46 encoding: EncodingRef,
47 lhs: ArrayRef,
48 rhs: ArrayRef,
49 stats: ArrayStats,
50}
51
52impl ArithmeticArray {
53 pub fn new(lhs: ArrayRef, rhs: ArrayRef, operator: ArithmeticOperator) -> Self {
55 assert_eq!(
56 lhs.len(),
57 rhs.len(),
58 "Arithmetic arrays require lhs and rhs to have the same length"
59 );
60
61 assert!(matches!(lhs.dtype(), DType::Primitive(..)));
63 assert_eq!(lhs.dtype(), rhs.dtype());
64
65 Self {
66 encoding: ENCODINGS[operator].clone(),
67 lhs,
68 rhs,
69 stats: ArrayStats::default(),
70 }
71 }
72
73 pub fn operator(&self) -> ArithmeticOperator {
75 self.encoding.as_::<ArithmeticVTable>().operator
76 }
77}
78
79#[derive(Debug, Clone)]
80pub struct ArithmeticEncoding {
81 operator: ArithmeticOperator,
85}
86
87#[allow(clippy::mem_forget)]
88static ENCODINGS: LazyLock<EnumMap<ArithmeticOperator, EncodingRef>> = LazyLock::new(|| {
89 enum_map! {
90 operator => ArithmeticEncoding { operator }.to_encoding(),
91 }
92});
93
94impl VTable for ArithmeticVTable {
95 type Array = ArithmeticArray;
96 type Encoding = ArithmeticEncoding;
97 type ArrayVTable = Self;
98 type CanonicalVTable = NotSupported;
99 type OperationsVTable = NotSupported;
100 type ValidityVTable = NotSupported;
101 type VisitorVTable = Self;
102 type ComputeVTable = NotSupported;
103 type EncodeVTable = NotSupported;
104 type SerdeVTable = Self;
105 type OperatorVTable = Self;
106
107 fn id(encoding: &Self::Encoding) -> EncodingId {
108 match encoding.operator {
109 ArithmeticOperator::Add => EncodingId::from("vortex.add"),
110 ArithmeticOperator::Sub => EncodingId::from("vortex.sub"),
111 ArithmeticOperator::Mul => EncodingId::from("vortex.mul"),
112 ArithmeticOperator::Div => EncodingId::from("vortex.div"),
113 }
114 }
115
116 fn encoding(array: &Self::Array) -> EncodingRef {
117 array.encoding.clone()
118 }
119}
120
121impl ArrayVTable<ArithmeticVTable> for ArithmeticVTable {
122 fn len(array: &ArithmeticArray) -> usize {
123 array.lhs.len()
124 }
125
126 fn dtype(array: &ArithmeticArray) -> &DType {
127 array.lhs.dtype()
128 }
129
130 fn stats(array: &ArithmeticArray) -> StatsSetRef<'_> {
131 array.stats.to_ref(array.as_ref())
132 }
133
134 fn array_hash<H: Hasher>(array: &ArithmeticArray, state: &mut H, precision: Precision) {
135 array.lhs.array_hash(state, precision);
136 array.rhs.array_hash(state, precision);
137 }
138
139 fn array_eq(array: &ArithmeticArray, other: &ArithmeticArray, precision: Precision) -> bool {
140 array.lhs.array_eq(&other.lhs, precision) && array.rhs.array_eq(&other.rhs, precision)
141 }
142}
143
144impl VisitorVTable<ArithmeticVTable> for ArithmeticVTable {
145 fn visit_buffers(_array: &ArithmeticArray, _visitor: &mut dyn ArrayBufferVisitor) {
146 }
148
149 fn visit_children(array: &ArithmeticArray, visitor: &mut dyn ArrayChildVisitor) {
150 visitor.visit_child("lhs", array.lhs.as_ref());
151 visitor.visit_child("rhs", array.rhs.as_ref());
152 }
153}
154
155impl SerdeVTable<ArithmeticVTable> for ArithmeticVTable {
156 type Metadata = EmptyMetadata;
157
158 fn metadata(_array: &ArithmeticArray) -> VortexResult<Option<Self::Metadata>> {
159 Ok(Some(EmptyMetadata))
160 }
161
162 fn build(
163 encoding: &ArithmeticEncoding,
164 dtype: &DType,
165 len: usize,
166 _metadata: &<Self::Metadata as DeserializeMetadata>::Output,
167 buffers: &[ByteBuffer],
168 children: &dyn ArrayChildren,
169 ) -> VortexResult<ArithmeticArray> {
170 assert!(buffers.is_empty());
171
172 Ok(ArithmeticArray::new(
173 children.get(0, dtype, len)?,
174 children.get(1, dtype, len)?,
175 encoding.operator,
176 ))
177 }
178}
179
180impl OperatorVTable<ArithmeticVTable> for ArithmeticVTable {
181 fn reduce_children(array: &ArithmeticArray) -> VortexResult<Option<ArrayRef>> {
182 match (array.lhs.as_constant(), array.rhs.as_constant()) {
183 (Some(lhs), Some(rhs)) => {
185 let op: vortex_scalar::NumericOperator = match array.operator() {
186 ArithmeticOperator::Add => vortex_scalar::NumericOperator::Add,
187 ArithmeticOperator::Sub => vortex_scalar::NumericOperator::Sub,
188 ArithmeticOperator::Mul => vortex_scalar::NumericOperator::Mul,
189 ArithmeticOperator::Div => vortex_scalar::NumericOperator::Div,
190 };
191 let result = lhs
192 .as_primitive()
193 .checked_binary_numeric(&rhs.as_primitive(), op)
194 .ok_or_else(|| {
195 vortex_err!("Constant arithmetic operation resulted in overflow")
196 })?;
197 return Ok(Some(
198 ConstantArray::new(Scalar::from(result), array.len()).into_array(),
199 ));
200 }
201 (Some(lhs), _) if lhs.is_null() => {
203 return Ok(Some(
204 ConstantArray::new(Scalar::null(array.dtype().clone()), array.len())
205 .into_array(),
206 ));
207 }
208 (_, Some(rhs)) if rhs.is_null() => {
209 return Ok(Some(
210 ConstantArray::new(Scalar::null(array.dtype().clone()), array.len())
211 .into_array(),
212 ));
213 }
214 _ => {}
215 }
216
217 Ok(None)
218 }
219
220 fn bind(
221 array: &ArithmeticArray,
222 selection: Option<&ArrayRef>,
223 ctx: &mut dyn BindCtx,
224 ) -> VortexResult<BatchKernelRef> {
225 if let Some(rhs_scalar) = array.rhs.as_constant() {
227 if rhs_scalar.is_null() {
228 return ConstantArray::new(Scalar::null(array.dtype().clone()), array.len())
230 .into_array()
231 .bind(selection, ctx);
232 }
233
234 let lhs = ctx.bind(&array.lhs, selection)?;
235 return match_each_native_ptype!(
236 array.dtype().as_ptype(),
237 integral: |T| {
238 let rhs: T = rhs_scalar
239 .as_primitive()
240 .typed_value::<T>()
241 .vortex_expect("Already checked for null above");
242 Ok(match array.operator() {
243 ArithmeticOperator::Add => checked_arithmetic_scalar_kernel::<Add, T>(lhs, rhs),
244 ArithmeticOperator::Sub => checked_arithmetic_scalar_kernel::<Sub, T>(lhs, rhs),
245 ArithmeticOperator::Mul => checked_arithmetic_scalar_kernel::<Mul, T>(lhs, rhs),
246 ArithmeticOperator::Div => checked_arithmetic_scalar_kernel::<Div, T>(lhs, rhs),
247 })
248 },
249 floating: |T| {
250 let rhs: T = rhs_scalar
251 .as_primitive()
252 .typed_value::<T>()
253 .vortex_expect("Already checked for null above");
254 Ok(match array.operator() {
255 ArithmeticOperator::Add => arithmetic_scalar_kernel::<Add, T>(lhs, rhs),
256 ArithmeticOperator::Sub => arithmetic_scalar_kernel::<Sub, T>(lhs, rhs),
257 ArithmeticOperator::Mul => arithmetic_scalar_kernel::<Mul, T>(lhs, rhs),
258 ArithmeticOperator::Div => arithmetic_scalar_kernel::<Div, T>(lhs, rhs),
259 })
260 }
261 );
262 }
263
264 let lhs = ctx.bind(&array.lhs, selection)?;
265 let rhs = ctx.bind(&array.rhs, selection)?;
266
267 match_each_native_ptype!(
268 array.dtype().as_ptype(),
269 integral: |T| {
270 Ok(match array.operator() {
271 ArithmeticOperator::Add => checked_arithmetic_kernel::<Add, T>(lhs, rhs),
272 ArithmeticOperator::Sub => checked_arithmetic_kernel::<Sub, T>(lhs, rhs),
273 ArithmeticOperator::Mul => checked_arithmetic_kernel::<Mul, T>(lhs, rhs),
274 ArithmeticOperator::Div => checked_arithmetic_kernel::<Div, T>(lhs, rhs),
275 })
276 },
277 floating: |T| {
278 Ok(match array.operator() {
279 ArithmeticOperator::Add => arithmetic_kernel::<Add, T>(lhs, rhs),
280 ArithmeticOperator::Sub => arithmetic_kernel::<Sub, T>(lhs, rhs),
281 ArithmeticOperator::Mul => arithmetic_kernel::<Mul, T>(lhs, rhs),
282 ArithmeticOperator::Div => arithmetic_kernel::<Div, T>(lhs, rhs),
283 })
284 }
285 )
286 }
287}
288
289fn arithmetic_kernel<Op, T>(lhs: BatchKernelRef, rhs: BatchKernelRef) -> BatchKernelRef
290where
291 T: NativePType,
292 Op: Operator<T>,
293{
294 kernel(move || {
295 let lhs = lhs.execute()?.into_primitive().downcast::<T>();
296 let rhs = rhs.execute()?.into_primitive().downcast::<T>();
297 let result = Arithmetic::<Op, _>::eval(lhs, &rhs);
298 Ok(result.into())
299 })
300}
301
302fn arithmetic_scalar_kernel<Op, T>(lhs: BatchKernelRef, rhs: T) -> BatchKernelRef
303where
304 T: NativePType + TryFrom<PValue>,
305 Op: Operator<T>,
306{
307 kernel(move || {
308 let lhs = lhs.execute()?.into_primitive().downcast::<T>();
309 let result = Arithmetic::<Op, _>::eval(lhs, &rhs);
310 Ok(result.into())
311 })
312}
313
314fn checked_arithmetic_kernel<Op, T>(lhs: BatchKernelRef, rhs: BatchKernelRef) -> BatchKernelRef
315where
316 T: NativePType,
317 Op: CheckedOperator<T>,
318 PVector<T>: for<'a> CheckedArithmetic<Op, &'a PVector<T>, Output = PVector<T>>,
319{
320 kernel(move || {
321 let lhs = lhs.execute()?.into_primitive().downcast::<T>();
322 let rhs = rhs.execute()?.into_primitive().downcast::<T>();
323 let result = CheckedArithmetic::<Op, _>::checked_eval(lhs, &rhs)
324 .ok_or_else(|| vortex_err!("Arithmetic operation resulted in overflow"))?;
325 Ok(result.into())
326 })
327}
328
329fn checked_arithmetic_scalar_kernel<Op, T>(lhs: BatchKernelRef, rhs: T) -> BatchKernelRef
330where
331 T: NativePType + TryFrom<PValue>,
332 Op: CheckedOperator<T>,
333 PVector<T>: for<'a> CheckedArithmetic<Op, &'a T, Output = PVector<T>>,
334{
335 kernel(move || {
336 let lhs = lhs.execute()?.into_primitive().downcast::<T>();
337 let result = CheckedArithmetic::<Op, _>::checked_eval(lhs, &rhs)
338 .ok_or_else(|| vortex_err!("Arithmetic operation resulted in overflow"))?;
339 Ok(result.into())
340 })
341}
342
343#[cfg(test)]
344mod tests {
345 use vortex_buffer::{bitbuffer, buffer};
346 use vortex_dtype::PTypeDowncastExt;
347
348 use crate::arrays::PrimitiveArray;
349 use crate::compute::arrays::arithmetic::{ArithmeticArray, ArithmeticOperator};
350 use crate::{ArrayRef, IntoArray};
351
352 fn add(lhs: ArrayRef, rhs: ArrayRef) -> ArrayRef {
353 ArithmeticArray::new(lhs, rhs, ArithmeticOperator::Add).into_array()
354 }
355
356 fn sub(lhs: ArrayRef, rhs: ArrayRef) -> ArrayRef {
357 ArithmeticArray::new(lhs, rhs, ArithmeticOperator::Sub).into_array()
358 }
359
360 fn mul(lhs: ArrayRef, rhs: ArrayRef) -> ArrayRef {
361 ArithmeticArray::new(lhs, rhs, ArithmeticOperator::Mul).into_array()
362 }
363
364 fn div(lhs: ArrayRef, rhs: ArrayRef) -> ArrayRef {
365 ArithmeticArray::new(lhs, rhs, ArithmeticOperator::Div).into_array()
366 }
367
368 #[test]
369 fn test_add() {
370 let lhs = PrimitiveArray::from_iter([1u32, 2, 3]).into_array();
371 let rhs = PrimitiveArray::from_iter([10u32, 20, 30]).into_array();
372 let result = add(lhs, rhs)
373 .execute()
374 .unwrap()
375 .into_primitive()
376 .downcast::<u32>();
377 assert_eq!(result.elements(), &buffer![11u32, 22, 33]);
378 }
379
380 #[test]
381 fn test_sub() {
382 let lhs = PrimitiveArray::from_iter([10u32, 20, 30]).into_array();
383 let rhs = PrimitiveArray::from_iter([1u32, 2, 3]).into_array();
384 let result = sub(lhs, rhs)
385 .execute()
386 .unwrap()
387 .into_primitive()
388 .downcast::<u32>();
389 assert_eq!(result.elements(), &buffer![9u32, 18, 27]);
390 }
391
392 #[test]
393 fn test_mul() {
394 let lhs = PrimitiveArray::from_iter([2u32, 3, 4]).into_array();
395 let rhs = PrimitiveArray::from_iter([10u32, 20, 30]).into_array();
396 let result = mul(lhs, rhs)
397 .execute()
398 .unwrap()
399 .into_primitive()
400 .downcast::<u32>();
401 assert_eq!(result.elements(), &buffer![20u32, 60, 120]);
402 }
403
404 #[test]
405 fn test_div() {
406 let lhs = PrimitiveArray::from_iter([100u32, 200, 300]).into_array();
407 let rhs = PrimitiveArray::from_iter([10u32, 20, 30]).into_array();
408 let result = div(lhs, rhs)
409 .execute()
410 .unwrap()
411 .into_primitive()
412 .downcast::<u32>();
413 assert_eq!(result.elements(), &buffer![10u32, 10, 10]);
414 }
415
416 #[test]
417 fn test_add_with_selection() {
418 let lhs = PrimitiveArray::from_iter([1u32, 2, 3]).into_array();
419 let rhs = PrimitiveArray::from_iter([10u32, 20, 30]).into_array();
420
421 let result = add(lhs, rhs)
422 .execute_with_selection(&bitbuffer![1 0 1].into())
423 .unwrap()
424 .into_primitive()
425 .downcast::<u32>();
426 assert_eq!(result.elements(), &buffer![11u32, 33]);
427 }
428}