vortex_compute/arithmetic/
datum.rs1use vortex_vector::PrimitiveDatum;
5use vortex_vector::ScalarOps;
6use vortex_vector::VectorMutOps;
7use vortex_vector::VectorOps;
8use vortex_vector::primitive::PrimitiveScalar;
9use vortex_vector::primitive::PrimitiveVector;
10
11use crate::arithmetic::Arithmetic;
12use crate::arithmetic::CheckedArithmetic;
13
14impl<Op> CheckedArithmetic<Op> for PrimitiveDatum
15where
16 for<'a> &'a PrimitiveScalar: CheckedArithmetic<Op, Output = PrimitiveScalar>,
17 for<'a> PrimitiveVector: CheckedArithmetic<Op, &'a PrimitiveVector, Output = PrimitiveVector>,
18{
19 type Output = PrimitiveDatum;
20
21 fn checked_eval(self, rhs: PrimitiveDatum) -> Option<Self::Output> {
22 match (self, rhs) {
23 (PrimitiveDatum::Scalar(sc1), PrimitiveDatum::Scalar(sc2)) => {
24 (&sc1).checked_eval(&sc2).map(PrimitiveDatum::Scalar)
25 }
26 (PrimitiveDatum::Vector(vec1), PrimitiveDatum::Vector(vec2)) => {
27 vec1.checked_eval(&vec2).map(PrimitiveDatum::Vector)
28 }
29 (PrimitiveDatum::Vector(vec1), PrimitiveDatum::Scalar(sc2)) => {
30 let len = vec1.len();
31 vec1.checked_eval(&sc2.repeat(len).freeze().into_primitive())
32 .map(PrimitiveDatum::Vector)
33 }
34 (PrimitiveDatum::Scalar(sc1), PrimitiveDatum::Vector(vec2)) => {
35 let len = vec2.len();
36 sc1.repeat(len)
37 .freeze()
38 .into_primitive()
39 .checked_eval(&vec2)
40 .map(PrimitiveDatum::Vector)
41 }
42 }
43 }
44}
45
46impl<Op> Arithmetic<Op> for PrimitiveDatum
47where
48 for<'a> &'a PrimitiveScalar: Arithmetic<Op, &'a PrimitiveScalar, Output = PrimitiveScalar>,
49 for<'a> &'a PrimitiveScalar: Arithmetic<Op, PrimitiveVector, Output = PrimitiveDatum>,
50 for<'a> PrimitiveVector: Arithmetic<Op, &'a PrimitiveVector, Output = PrimitiveVector>,
51 for<'a> PrimitiveVector: Arithmetic<Op, &'a PrimitiveScalar, Output = PrimitiveDatum>,
52{
53 type Output = PrimitiveDatum;
54
55 fn eval(self, rhs: PrimitiveDatum) -> Self::Output {
56 match (self, rhs) {
57 (PrimitiveDatum::Scalar(sc1), PrimitiveDatum::Scalar(sc2)) => {
58 PrimitiveDatum::Scalar((&sc1).eval(&sc2))
59 }
60 (PrimitiveDatum::Vector(vec1), PrimitiveDatum::Vector(vec2)) => {
61 PrimitiveDatum::Vector(vec1.eval(&vec2))
62 }
63 (PrimitiveDatum::Vector(vec1), PrimitiveDatum::Scalar(sc2)) => vec1.eval(&sc2),
64 (PrimitiveDatum::Scalar(sc1), PrimitiveDatum::Vector(vec2)) => (&sc1).eval(vec2),
65 }
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use vortex_buffer::buffer;
72 use vortex_dtype::PTypeDowncast;
73 use vortex_mask::Mask;
74 use vortex_vector::Datum;
75 use vortex_vector::PrimitiveDatum;
76 use vortex_vector::Vector;
77 use vortex_vector::primitive::PVector;
78
79 use crate::arithmetic::Add;
80 use crate::arithmetic::Arithmetic;
81
82 #[test]
83 fn test_datum_arithmetic_in_place() {
84 let left = PVector::new(buffer![1f32, 2.0, 3.0, 4.0], Mask::new_true(4));
85 let right = PVector::new(buffer![10f32, 20.0, 30.0, 40.0], Mask::new_true(4));
86 let left_ptr = left.elements().as_ptr();
87
88 let left_datum = Datum::Vector(Vector::from(left));
89 let right_datum = Datum::Vector(Vector::from(right));
90
91 let result =
92 Arithmetic::<Add, _>::eval(left_datum.into_primitive(), right_datum.into_primitive());
93
94 let result_vec = match result {
95 PrimitiveDatum::Vector(v) => v,
96 _ => panic!("Expected primitive vector result"),
97 };
98
99 let result_pvec: &PVector<f32> = PTypeDowncast::into_f32(&result_vec);
100 let result_ptr = result_pvec.elements().as_ptr();
101
102 assert_eq!(
103 left_ptr, result_ptr,
104 "Buffer should be modified in place when input has unique ownership"
105 );
106 assert_eq!(result_pvec.elements(), &buffer![11f32, 22.0, 33.0, 44.0]);
107 }
108
109 #[test]
110 #[should_panic(expected = "Buffer should be modified in place")]
111 fn test_datum_arithmetic_in_place_fail() {
112 let left = PVector::new(buffer![1f32, 2.0, 3.0, 4.0], Mask::new_true(4));
113 let right = PVector::new(buffer![10f32, 20.0, 30.0, 40.0], Mask::new_true(4));
114 let left_ptr = left.elements().as_ptr();
115
116 let left_datum = Datum::Vector(Vector::from(left));
117 let _left_datum2 = left_datum.clone();
118 let right_datum = Datum::Vector(Vector::from(right));
119
120 let result =
121 Arithmetic::<Add, _>::eval(left_datum.into_primitive(), right_datum.into_primitive());
122
123 let result_vec = match result {
124 PrimitiveDatum::Vector(v) => v,
125 _ => panic!("Expected primitive vector result"),
126 };
127
128 let result_pvec: &PVector<f32> = PTypeDowncast::into_f32(&result_vec);
129 let result_ptr = result_pvec.elements().as_ptr();
130
131 assert_eq!(
132 left_ptr, result_ptr,
133 "Buffer should be modified in place when input has unique ownership"
134 );
135 }
136
137 #[test]
138 fn test_datum_vector_scalar_in_place() {
139 let left = PVector::new(buffer![1f32, 2.0, 3.0, 4.0], Mask::new_true(4));
140 let left_ptr = left.elements().as_ptr();
141
142 let left_datum = PrimitiveDatum::Vector(left.into());
143 let right_datum =
144 PrimitiveDatum::Scalar(vortex_vector::primitive::PScalar::new(Some(10f32)).into());
145
146 let result = Arithmetic::<Add, _>::eval(left_datum, right_datum);
147
148 let result_vec = match result {
149 PrimitiveDatum::Vector(v) => v,
150 _ => panic!("Expected primitive vector result"),
151 };
152
153 let result_pvec: &PVector<f32> = PTypeDowncast::into_f32(&result_vec);
154 let result_ptr = result_pvec.elements().as_ptr();
155
156 assert_eq!(
157 left_ptr, result_ptr,
158 "Buffer should be modified in place for vector-scalar arithmetic"
159 );
160 assert_eq!(result_pvec.elements(), &buffer![11f32, 12.0, 13.0, 14.0]);
161 }
162
163 #[test]
164 fn test_datum_scalar_vector_in_place() {
165 let right = PVector::new(buffer![1f32, 2.0, 3.0, 4.0], Mask::new_true(4));
166 let right_ptr = right.elements().as_ptr();
167
168 let left_datum =
169 PrimitiveDatum::Scalar(vortex_vector::primitive::PScalar::new(Some(10f32)).into());
170 let right_datum = PrimitiveDatum::Vector(right.into());
171
172 let result = Arithmetic::<Add, _>::eval(left_datum, right_datum);
173
174 let result_vec = match result {
175 PrimitiveDatum::Vector(v) => v,
176 _ => panic!("Expected primitive vector result"),
177 };
178
179 let result_pvec = result_vec.into_f32();
180 let result_ptr = result_pvec.elements().as_ptr();
181
182 assert_eq!(
183 right_ptr, result_ptr,
184 "Buffer should be modified in place for scalar-vector arithmetic"
185 );
186 assert_eq!(result_pvec.elements(), &buffer![11f32, 12.0, 13.0, 14.0]);
187 }
188}