1use std::ops::BitAnd;
5
6use vortex_buffer::{Buffer, BufferMut};
7use vortex_dtype::NativePType;
8use vortex_vector::primitive::{PVector, PVectorMut};
9use vortex_vector::{VectorMutOps, VectorOps};
10
11use crate::arithmetic::{Arithmetic, Operator};
12
13impl<Op, T> Arithmetic<Op, &PVector<T>> for PVector<T>
15where
16 T: NativePType,
17 Op: Operator<T>,
18{
19 type Output = PVector<T>;
20
21 fn eval(self, rhs: &PVector<T>) -> Self::Output {
22 match self.try_into_mut() {
23 Ok(lhs) => Arithmetic::<Op, _>::eval(lhs, rhs),
24 Err(lhs) => Arithmetic::<Op, _>::eval(&lhs, rhs),
25 }
26 }
27}
28
29impl<Op, T> Arithmetic<Op, &PVector<T>> for PVectorMut<T>
31where
32 T: NativePType,
33 Op: Operator<T>,
34 BufferMut<T>: for<'a> Arithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
35{
36 type Output = PVector<T>;
37
38 fn eval(self, other: &PVector<T>) -> Self::Output {
39 assert_eq!(self.len(), other.len());
40
41 let (lhs_buffer, lhs_validity) = self.into_parts();
42
43 let validity = lhs_validity.freeze().bitand(other.validity());
46 let elements = Arithmetic::<Op, _>::eval(lhs_buffer, other.elements());
47
48 PVector::new(elements, validity)
49 }
50}
51
52impl<Op, T> Arithmetic<Op, &PVector<T>> for &PVector<T>
54where
55 T: NativePType,
56 Op: Operator<T>,
57 for<'a> &'a Buffer<T>: Arithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
58{
59 type Output = PVector<T>;
60
61 fn eval(self, rhs: &PVector<T>) -> Self::Output {
62 assert_eq!(self.len(), rhs.len());
63
64 let validity = self.validity().bitand(rhs.validity());
67
68 let elements = Arithmetic::<Op, _>::eval(self.elements(), rhs.elements());
69 PVector::new(elements, validity)
70 }
71}
72
73impl<Op, T> Arithmetic<Op, &T> for PVector<T>
76where
77 T: NativePType,
78 Op: Operator<T>,
79 PVectorMut<T>: for<'a> Arithmetic<Op, &'a T, Output = PVector<T>>,
80{
81 type Output = PVector<T>;
82
83 fn eval(self, rhs: &T) -> Self::Output {
84 match self.try_into_mut() {
85 Ok(lhs) => Arithmetic::<Op, _>::eval(lhs, rhs),
86 Err(lhs) => Arithmetic::<Op, _>::eval(&lhs, rhs),
87 }
88 }
89}
90
91impl<Op, T> Arithmetic<Op, &T> for PVectorMut<T>
93where
94 T: NativePType,
95 Op: Operator<T>,
96 BufferMut<T>: for<'a> Arithmetic<Op, &'a T, Output = Buffer<T>>,
97{
98 type Output = PVector<T>;
99
100 fn eval(self, rhs: &T) -> Self::Output {
101 let (lhs_buffer, lhs_validity) = self.into_parts();
102 let validity = lhs_validity.freeze();
103
104 let elements = Arithmetic::<Op, _>::eval(lhs_buffer, rhs);
105
106 PVector::new(elements, validity)
107 }
108}
109
110impl<Op, T> Arithmetic<Op, &T> for &PVector<T>
112where
113 T: NativePType,
114 Op: Operator<T>,
115 for<'a> &'a Buffer<T>: Arithmetic<Op, &'a T, Output = Buffer<T>>,
116{
117 type Output = PVector<T>;
118
119 fn eval(self, rhs: &T) -> Self::Output {
120 let buffer = Arithmetic::<Op, _>::eval(self.elements(), rhs);
121 PVector::new(buffer, self.validity().clone())
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use vortex_buffer::buffer;
128 use vortex_mask::Mask;
129 use vortex_vector::VectorOps;
130 use vortex_vector::primitive::PVector;
131
132 use crate::arithmetic::{Arithmetic, WrappingAdd, WrappingMul, WrappingSub};
133
134 #[test]
135 fn test_add_pvectors() {
136 let left = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
137 let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
138
139 let result = Arithmetic::<WrappingAdd, _>::eval(left, &right);
140 assert_eq!(result.elements(), &buffer![11u32, 22, 33, 44]);
141 }
142
143 #[test]
144 fn test_add_scalar() {
145 let vec = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
146 let result = Arithmetic::<WrappingAdd, _>::eval(vec, &10);
147 assert_eq!(result.elements(), &buffer![11u32, 12, 13, 14]);
148 }
149
150 #[test]
151 fn test_add_with_nulls() {
152 let left = PVector::new(buffer![1u32, 2, 3], Mask::from_iter([true, false, true]));
153 let right = PVector::new(buffer![10u32, 20, 30], Mask::new_true(3));
154
155 let result = Arithmetic::<WrappingAdd, _>::eval(left, &right);
156 assert_eq!(result.validity(), &Mask::from_iter([true, false, true]));
158 assert_eq!(result.elements(), &buffer![11u32, 22, 33]);
159 }
160
161 #[test]
162 fn test_sub_pvectors() {
163 let left = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
164 let right = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
165
166 let result = Arithmetic::<WrappingSub, _>::eval(left, &right);
167 assert_eq!(result.elements(), &buffer![9u32, 18, 27, 36]);
168 }
169
170 #[test]
171 fn test_sub_scalar() {
172 let vec = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
173 let result = Arithmetic::<WrappingSub, _>::eval(vec, &5);
174 assert_eq!(result.elements(), &buffer![5u32, 15, 25, 35]);
175 }
176
177 #[test]
178 fn test_mul_pvectors() {
179 let left = PVector::new(buffer![2u32, 3, 4, 5], Mask::new_true(4));
180 let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
181
182 let result = Arithmetic::<WrappingMul, _>::eval(left, &right);
183 assert_eq!(result.elements(), &buffer![20u32, 60, 120, 200]);
184 }
185
186 #[test]
187 fn test_mul_scalar() {
188 let vec = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
189 let result = Arithmetic::<WrappingMul, _>::eval(vec, &10);
190 assert_eq!(result.elements(), &buffer![10u32, 20, 30, 40]);
191 }
192
193 #[test]
194 fn test_scalar_preserves_validity() {
195 let vec = PVector::new(buffer![1u32, 2, 3], Mask::from_iter([true, false, true]));
196 let result = Arithmetic::<WrappingAdd, _>::eval(vec, &10);
197
198 assert_eq!(result.validity(), &Mask::from_iter([true, false, true]));
199 assert_eq!(result.elements(), &buffer![11u32, 12, 13]);
200 }
201}