vortex_compute/arithmetic/
primitive_scalar.rs1use vortex_dtype::half::f16;
7use vortex_error::vortex_panic;
8use vortex_vector::PrimitiveDatum;
9use vortex_vector::match_each_float_pscalar_pair;
10use vortex_vector::match_each_integer_pscalar_pair;
11use vortex_vector::primitive::PScalar;
12use vortex_vector::primitive::PVector;
13use vortex_vector::primitive::PrimitiveScalar;
14use vortex_vector::primitive::PrimitiveVector;
15
16use crate::arithmetic::Arithmetic;
17use crate::arithmetic::CheckedArithmetic;
18
19impl<Op> CheckedArithmetic<Op> for &PrimitiveScalar
20where
21 for<'a> &'a PScalar<i8>: CheckedArithmetic<Op, &'a PScalar<i8>, Output = PScalar<i8>>,
22 for<'a> &'a PScalar<i16>: CheckedArithmetic<Op, &'a PScalar<i16>, Output = PScalar<i16>>,
23 for<'a> &'a PScalar<i32>: CheckedArithmetic<Op, &'a PScalar<i32>, Output = PScalar<i32>>,
24 for<'a> &'a PScalar<i64>: CheckedArithmetic<Op, &'a PScalar<i64>, Output = PScalar<i64>>,
25 for<'a> &'a PScalar<u8>: CheckedArithmetic<Op, &'a PScalar<u8>, Output = PScalar<u8>>,
26 for<'a> &'a PScalar<u16>: CheckedArithmetic<Op, &'a PScalar<u16>, Output = PScalar<u16>>,
27 for<'a> &'a PScalar<u32>: CheckedArithmetic<Op, &'a PScalar<u32>, Output = PScalar<u32>>,
28 for<'a> &'a PScalar<u64>: CheckedArithmetic<Op, &'a PScalar<u64>, Output = PScalar<u64>>,
29{
30 type Output = PrimitiveScalar;
31
32 fn checked_eval(self, rhs: &PrimitiveScalar) -> Option<Self::Output> {
33 match_each_integer_pscalar_pair!(
34 (&self, &rhs),
35 |l, r| { CheckedArithmetic::<Op, _>::checked_eval(l, r).map(Into::into) },
36 { vortex_panic!("cannot compare primitive scalar of different types") }
37 )
38 }
39}
40
41impl<Op> Arithmetic<Op, &PrimitiveScalar> for &PrimitiveScalar
42where
43 for<'a> &'a PScalar<f16>: Arithmetic<Op, &'a PScalar<f16>, Output = PScalar<f16>>,
44 for<'a> &'a PScalar<f32>: Arithmetic<Op, &'a PScalar<f32>, Output = PScalar<f32>>,
45 for<'a> &'a PScalar<f64>: Arithmetic<Op, &'a PScalar<f64>, Output = PScalar<f64>>,
46{
47 type Output = PrimitiveScalar;
48
49 fn eval(self, rhs: &PrimitiveScalar) -> Self::Output {
50 match_each_float_pscalar_pair!(
51 (self, rhs),
52 |l, r| { Arithmetic::<Op, _>::eval(l, r).into() },
53 {
54 vortex_panic!(
55 "Cannot perform arithmetic on PrimitiveScalars of different types: {:?} and {:?}",
56 self,
57 rhs
58 )
59 }
60 )
61 }
62}
63
64impl<Op> Arithmetic<Op, PrimitiveVector> for &PrimitiveScalar
67where
68 for<'a> &'a f16: Arithmetic<Op, PVector<f16>, Output = PVector<f16>>,
69 for<'a> &'a f32: Arithmetic<Op, PVector<f32>, Output = PVector<f32>>,
70 for<'a> &'a f64: Arithmetic<Op, PVector<f64>, Output = PVector<f64>>,
71{
72 type Output = PrimitiveDatum;
73
74 fn eval(self, rhs: PrimitiveVector) -> Self::Output {
75 match (self, rhs) {
76 (PrimitiveScalar::F16(s), PrimitiveVector::F16(v)) => match s.value() {
77 Some(scalar_val) => {
78 PrimitiveDatum::Vector(Arithmetic::<Op, _>::eval(&scalar_val, v).into())
79 }
80 None => PrimitiveDatum::Scalar(s.clone().into()),
81 },
82 (PrimitiveScalar::F32(s), PrimitiveVector::F32(v)) => match s.value() {
83 Some(scalar_val) => {
84 PrimitiveDatum::Vector(Arithmetic::<Op, _>::eval(&scalar_val, v).into())
85 }
86 None => PrimitiveDatum::Scalar(s.clone().into()),
87 },
88 (PrimitiveScalar::F64(s), PrimitiveVector::F64(v)) => match s.value() {
89 Some(scalar_val) => {
90 PrimitiveDatum::Vector(Arithmetic::<Op, _>::eval(&scalar_val, v).into())
91 }
92 None => PrimitiveDatum::Scalar(s.clone().into()),
93 },
94 (s, v) => vortex_panic!(
95 "Cannot perform arithmetic between scalar {:?} and vector {:?}",
96 s,
97 v
98 ),
99 }
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use vortex_vector::primitive::PScalar;
106
107 use super::*;
108 use crate::arithmetic::Add;
109
110 #[test]
111 fn test_checked_add_i32() {
112 let left: PrimitiveScalar = PScalar::new(Some(5i32)).into();
113 let right: PrimitiveScalar = PScalar::new(Some(3i32)).into();
114
115 let result = CheckedArithmetic::<Add, _>::checked_eval(&left, &right).unwrap();
116 if let PrimitiveScalar::I32(v) = result {
117 assert_eq!(v.value(), Some(8));
118 } else {
119 panic!("Expected I32 result");
120 }
121 }
122
123 #[test]
124 fn test_checked_add_overflow() {
125 let left: PrimitiveScalar = PScalar::new(Some(i32::MAX)).into();
126 let right: PrimitiveScalar = PScalar::new(Some(1i32)).into();
127
128 let result = CheckedArithmetic::<Add, _>::checked_eval(&left, &right);
129 assert!(result.is_none());
130 }
131
132 #[test]
133 fn test_float_add() {
134 let left: PrimitiveScalar = PScalar::new(Some(1.5f64)).into();
135 let right: PrimitiveScalar = PScalar::new(Some(2.5f64)).into();
136
137 let result = Arithmetic::<Add, _>::eval(&left, &right);
138 if let PrimitiveScalar::F64(v) = result {
139 assert_eq!(v.value(), Some(4.0));
140 } else {
141 panic!("Expected F64 result");
142 }
143 }
144}