1use std::ops::BitAnd;
5
6use vortex_buffer::Buffer;
7use vortex_buffer::BufferMut;
8use vortex_dtype::NativePType;
9use vortex_vector::VectorMutOps;
10use vortex_vector::VectorOps;
11use vortex_vector::primitive::PVector;
12use vortex_vector::primitive::PVectorMut;
13
14use crate::arithmetic::Arithmetic;
15use crate::arithmetic::Operator;
16
17impl<Op, T> Arithmetic<Op, &PVector<T>> for PVector<T>
19where
20 T: NativePType,
21 Op: Operator<T>,
22{
23 type Output = PVector<T>;
24
25 fn eval(self, rhs: &PVector<T>) -> Self::Output {
26 match self.try_into_mut() {
27 Ok(lhs) => Arithmetic::<Op, _>::eval(lhs, rhs),
28 Err(lhs) => Arithmetic::<Op, _>::eval(&lhs, rhs),
29 }
30 }
31}
32
33impl<Op, T> Arithmetic<Op, &PVector<T>> for PVectorMut<T>
35where
36 T: NativePType,
37 Op: Operator<T>,
38 BufferMut<T>: for<'a> Arithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
39{
40 type Output = PVector<T>;
41
42 fn eval(self, other: &PVector<T>) -> Self::Output {
43 assert_eq!(self.len(), other.len());
44
45 let (lhs_buffer, lhs_validity) = self.into_parts();
46
47 let validity = lhs_validity.freeze().bitand(other.validity());
50 let elements = Arithmetic::<Op, _>::eval(lhs_buffer, other.elements());
51
52 PVector::new(elements, validity)
53 }
54}
55
56impl<Op, T> Arithmetic<Op, &PVector<T>> for &PVector<T>
58where
59 T: NativePType,
60 Op: Operator<T>,
61 for<'a> &'a Buffer<T>: Arithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
62{
63 type Output = PVector<T>;
64
65 fn eval(self, rhs: &PVector<T>) -> Self::Output {
66 assert_eq!(self.len(), rhs.len());
67
68 let validity = self.validity().bitand(rhs.validity());
71
72 let elements = Arithmetic::<Op, _>::eval(self.elements(), rhs.elements());
73 PVector::new(elements, validity)
74 }
75}
76
77impl<Op, T> Arithmetic<Op, &T> for PVector<T>
80where
81 T: NativePType,
82 Op: Operator<T>,
83 PVectorMut<T>: for<'a> Arithmetic<Op, &'a T, Output = PVector<T>>,
84{
85 type Output = PVector<T>;
86
87 fn eval(self, rhs: &T) -> Self::Output {
88 match self.try_into_mut() {
89 Ok(lhs) => Arithmetic::<Op, _>::eval(lhs, rhs),
90 Err(lhs) => Arithmetic::<Op, _>::eval(&lhs, rhs),
91 }
92 }
93}
94
95impl<Op, T> Arithmetic<Op, &T> for PVectorMut<T>
97where
98 T: NativePType,
99 Op: Operator<T>,
100 BufferMut<T>: for<'a> Arithmetic<Op, &'a T, Output = Buffer<T>>,
101{
102 type Output = PVector<T>;
103
104 fn eval(self, rhs: &T) -> Self::Output {
105 let (lhs_buffer, lhs_validity) = self.into_parts();
106 let validity = lhs_validity.freeze();
107
108 let elements = Arithmetic::<Op, _>::eval(lhs_buffer, rhs);
109
110 PVector::new(elements, validity)
111 }
112}
113
114impl<Op, T> Arithmetic<Op, &T> for &PVector<T>
116where
117 T: NativePType,
118 Op: Operator<T>,
119 for<'a> &'a Buffer<T>: Arithmetic<Op, &'a T, Output = Buffer<T>>,
120{
121 type Output = PVector<T>;
122
123 fn eval(self, rhs: &T) -> Self::Output {
124 let buffer = Arithmetic::<Op, _>::eval(self.elements(), rhs);
125 PVector::new(buffer, self.validity().clone())
126 }
127}
128
129impl<Op, T> Arithmetic<Op, PVector<T>> for &T
131where
132 T: NativePType,
133 Op: Operator<T>,
134 BufferMut<T>: for<'a> Arithmetic<Op, &'a T, Output = Buffer<T>>,
135 for<'a> &'a Buffer<T>: Arithmetic<Op, &'a T, Output = Buffer<T>>,
136{
137 type Output = PVector<T>;
138
139 fn eval(self, rhs: PVector<T>) -> Self::Output {
140 match rhs.try_into_mut() {
141 Ok(rhs_mut) => {
142 let (rhs_buffer, rhs_validity) = rhs_mut.into_parts();
143 let elements = rhs_buffer
145 .map_each_in_place(|v| Op::apply(self, &v))
146 .freeze();
147 PVector::new(elements, rhs_validity.freeze())
148 }
149 Err(rhs) => {
150 let buffer = Buffer::<T>::from_trusted_len_iter(
152 rhs.as_ref().iter().map(|v| Op::apply(self, v)),
153 );
154 PVector::new(buffer, rhs.validity().clone())
155 }
156 }
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use vortex_buffer::buffer;
163 use vortex_mask::Mask;
164 use vortex_vector::primitive::PVector;
165
166 use crate::arithmetic::Arithmetic;
167 use crate::arithmetic::WrappingAdd;
168 use crate::arithmetic::WrappingMul;
169 use crate::arithmetic::WrappingSub;
170
171 #[test]
172 fn test_add_pvectors() {
173 let left = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
174 let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
175
176 let result = Arithmetic::<WrappingAdd, _>::eval(left, &right);
177 let expected = PVector::from(buffer![11u32, 22, 33, 44]);
178 assert_eq!(result, expected);
179 }
180
181 #[test]
182 fn test_add_scalar() {
183 let vec = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
184 let result = Arithmetic::<WrappingAdd, _>::eval(vec, &10);
185 let expected = PVector::from(buffer![11u32, 12, 13, 14]);
186 assert_eq!(result, expected);
187 }
188
189 #[test]
190 fn test_add_with_nulls() {
191 let left = PVector::new(buffer![1u32, 2, 3], Mask::from_iter([true, false, true]));
192 let right = PVector::new(buffer![10u32, 20, 30], Mask::new_true(3));
193
194 let result = Arithmetic::<WrappingAdd, _>::eval(left, &right);
195 let expected = PVector::new(buffer![11u32, 22, 33], Mask::from_iter([true, false, true]));
197 assert_eq!(result, expected);
198 }
199
200 #[test]
201 fn test_sub_pvectors() {
202 let left = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
203 let right = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
204
205 let result = Arithmetic::<WrappingSub, _>::eval(left, &right);
206 let expected = PVector::from(buffer![9u32, 18, 27, 36]);
207 assert_eq!(result, expected);
208 }
209
210 #[test]
211 fn test_sub_scalar() {
212 let vec = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
213 let result = Arithmetic::<WrappingSub, _>::eval(vec, &5);
214 let expected = PVector::from(buffer![5u32, 15, 25, 35]);
215 assert_eq!(result, expected);
216 }
217
218 #[test]
219 fn test_mul_pvectors() {
220 let left = PVector::new(buffer![2u32, 3, 4, 5], Mask::new_true(4));
221 let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
222
223 let result = Arithmetic::<WrappingMul, _>::eval(left, &right);
224 let expected = PVector::from(buffer![20u32, 60, 120, 200]);
225 assert_eq!(result, expected);
226 }
227
228 #[test]
229 fn test_mul_scalar() {
230 let vec = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
231 let result = Arithmetic::<WrappingMul, _>::eval(vec, &10);
232 let expected = PVector::from(buffer![10u32, 20, 30, 40]);
233 assert_eq!(result, expected);
234 }
235
236 #[test]
237 fn test_scalar_preserves_validity() {
238 let vec = PVector::new(buffer![1u32, 2, 3], Mask::from_iter([true, false, true]));
239 let result = Arithmetic::<WrappingAdd, _>::eval(vec, &10);
240 let expected = PVector::new(buffer![11u32, 12, 13], Mask::from_iter([true, false, true]));
241 assert_eq!(result, expected);
242 }
243}