vortex_compute/arithmetic/
pvector.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::BitAnd;
5
6use vortex_buffer::Buffer;
7use vortex_buffer::BufferMut;
8use vortex_dtype::NativePType;
9use vortex_vector::VectorMutOps;
10use vortex_vector::VectorOps;
11use vortex_vector::primitive::PVector;
12use vortex_vector::primitive::PVectorMut;
13
14use crate::arithmetic::Arithmetic;
15use crate::arithmetic::Operator;
16
17/// Implementation that attempts to downcast to a mutable vector and operates in-place.
18impl<Op, T> Arithmetic<Op, &PVector<T>> for PVector<T>
19where
20    T: NativePType,
21    Op: Operator<T>,
22{
23    type Output = PVector<T>;
24
25    fn eval(self, rhs: &PVector<T>) -> Self::Output {
26        match self.try_into_mut() {
27            Ok(lhs) => Arithmetic::<Op, _>::eval(lhs, rhs),
28            Err(lhs) => Arithmetic::<Op, _>::eval(&lhs, rhs),
29        }
30    }
31}
32
33/// Implementation that operates in-place over a mutable vector.
34impl<Op, T> Arithmetic<Op, &PVector<T>> for PVectorMut<T>
35where
36    T: NativePType,
37    Op: Operator<T>,
38    BufferMut<T>: for<'a> Arithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
39{
40    type Output = PVector<T>;
41
42    fn eval(self, other: &PVector<T>) -> Self::Output {
43        assert_eq!(self.len(), other.len());
44
45        let (lhs_buffer, lhs_validity) = self.into_parts();
46
47        // TODO(ngates): based on the true count of the validity, we may wish to short-circuit here
48        //  or choose a different implementation.
49        let validity = lhs_validity.freeze().bitand(other.validity());
50        let elements = Arithmetic::<Op, _>::eval(lhs_buffer, other.elements());
51
52        PVector::new(elements, validity)
53    }
54}
55
56/// Implementation that allocates a new output vector.
57impl<Op, T> Arithmetic<Op, &PVector<T>> for &PVector<T>
58where
59    T: NativePType,
60    Op: Operator<T>,
61    for<'a> &'a Buffer<T>: Arithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
62{
63    type Output = PVector<T>;
64
65    fn eval(self, rhs: &PVector<T>) -> Self::Output {
66        assert_eq!(self.len(), rhs.len());
67
68        // TODO(ngates): based on the true count of the validity, we may wish to short-circuit here
69        //  or choose a different implementation.
70        let validity = self.validity().bitand(rhs.validity());
71
72        let elements = Arithmetic::<Op, _>::eval(self.elements(), rhs.elements());
73        PVector::new(elements, validity)
74    }
75}
76
77/// Implementation that attempts to downcast to a mutable vector and operates in-place against
78/// a scalar RHS value.
79impl<Op, T> Arithmetic<Op, &T> for PVector<T>
80where
81    T: NativePType,
82    Op: Operator<T>,
83    PVectorMut<T>: for<'a> Arithmetic<Op, &'a T, Output = PVector<T>>,
84{
85    type Output = PVector<T>;
86
87    fn eval(self, rhs: &T) -> Self::Output {
88        match self.try_into_mut() {
89            Ok(lhs) => Arithmetic::<Op, _>::eval(lhs, rhs),
90            Err(lhs) => Arithmetic::<Op, _>::eval(&lhs, rhs),
91        }
92    }
93}
94
95/// Implementation that operates in-place over a mutable vector against a scalar RHS value.
96impl<Op, T> Arithmetic<Op, &T> for PVectorMut<T>
97where
98    T: NativePType,
99    Op: Operator<T>,
100    BufferMut<T>: for<'a> Arithmetic<Op, &'a T, Output = Buffer<T>>,
101{
102    type Output = PVector<T>;
103
104    fn eval(self, rhs: &T) -> Self::Output {
105        let (lhs_buffer, lhs_validity) = self.into_parts();
106        let validity = lhs_validity.freeze();
107
108        let elements = Arithmetic::<Op, _>::eval(lhs_buffer, rhs);
109
110        PVector::new(elements, validity)
111    }
112}
113
114/// Implementation that allocates a new output vector against a scalar RHS value.
115impl<Op, T> Arithmetic<Op, &T> for &PVector<T>
116where
117    T: NativePType,
118    Op: Operator<T>,
119    for<'a> &'a Buffer<T>: Arithmetic<Op, &'a T, Output = Buffer<T>>,
120{
121    type Output = PVector<T>;
122
123    fn eval(self, rhs: &T) -> Self::Output {
124        let buffer = Arithmetic::<Op, _>::eval(self.elements(), rhs);
125        PVector::new(buffer, self.validity().clone())
126    }
127}
128
129/// Scalar LHS with owned vector RHS - modifies vector in-place.
130impl<Op, T> Arithmetic<Op, PVector<T>> for &T
131where
132    T: NativePType,
133    Op: Operator<T>,
134    BufferMut<T>: for<'a> Arithmetic<Op, &'a T, Output = Buffer<T>>,
135    for<'a> &'a Buffer<T>: Arithmetic<Op, &'a T, Output = Buffer<T>>,
136{
137    type Output = PVector<T>;
138
139    fn eval(self, rhs: PVector<T>) -> Self::Output {
140        match rhs.try_into_mut() {
141            Ok(rhs_mut) => {
142                let (rhs_buffer, rhs_validity) = rhs_mut.into_parts();
143                // Apply scalar op to each element in-place
144                let elements = rhs_buffer
145                    .map_each_in_place(|v| Op::apply(self, &v))
146                    .freeze();
147                PVector::new(elements, rhs_validity.freeze())
148            }
149            Err(rhs) => {
150                // Can't modify in place, allocate new buffer
151                let buffer = Buffer::<T>::from_trusted_len_iter(
152                    rhs.as_ref().iter().map(|v| Op::apply(self, v)),
153                );
154                PVector::new(buffer, rhs.validity().clone())
155            }
156        }
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use vortex_buffer::buffer;
163    use vortex_mask::Mask;
164    use vortex_vector::primitive::PVector;
165
166    use crate::arithmetic::Arithmetic;
167    use crate::arithmetic::WrappingAdd;
168    use crate::arithmetic::WrappingMul;
169    use crate::arithmetic::WrappingSub;
170
171    #[test]
172    fn test_add_pvectors() {
173        let left = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
174        let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
175
176        let result = Arithmetic::<WrappingAdd, _>::eval(left, &right);
177        let expected = PVector::from(buffer![11u32, 22, 33, 44]);
178        assert_eq!(result, expected);
179    }
180
181    #[test]
182    fn test_add_scalar() {
183        let vec = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
184        let result = Arithmetic::<WrappingAdd, _>::eval(vec, &10);
185        let expected = PVector::from(buffer![11u32, 12, 13, 14]);
186        assert_eq!(result, expected);
187    }
188
189    #[test]
190    fn test_add_with_nulls() {
191        let left = PVector::new(buffer![1u32, 2, 3], Mask::from_iter([true, false, true]));
192        let right = PVector::new(buffer![10u32, 20, 30], Mask::new_true(3));
193
194        let result = Arithmetic::<WrappingAdd, _>::eval(left, &right);
195        // Validity is AND'd, so if either side is null, result is null
196        let expected = PVector::new(buffer![11u32, 22, 33], Mask::from_iter([true, false, true]));
197        assert_eq!(result, expected);
198    }
199
200    #[test]
201    fn test_sub_pvectors() {
202        let left = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
203        let right = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
204
205        let result = Arithmetic::<WrappingSub, _>::eval(left, &right);
206        let expected = PVector::from(buffer![9u32, 18, 27, 36]);
207        assert_eq!(result, expected);
208    }
209
210    #[test]
211    fn test_sub_scalar() {
212        let vec = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
213        let result = Arithmetic::<WrappingSub, _>::eval(vec, &5);
214        let expected = PVector::from(buffer![5u32, 15, 25, 35]);
215        assert_eq!(result, expected);
216    }
217
218    #[test]
219    fn test_mul_pvectors() {
220        let left = PVector::new(buffer![2u32, 3, 4, 5], Mask::new_true(4));
221        let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
222
223        let result = Arithmetic::<WrappingMul, _>::eval(left, &right);
224        let expected = PVector::from(buffer![20u32, 60, 120, 200]);
225        assert_eq!(result, expected);
226    }
227
228    #[test]
229    fn test_mul_scalar() {
230        let vec = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
231        let result = Arithmetic::<WrappingMul, _>::eval(vec, &10);
232        let expected = PVector::from(buffer![10u32, 20, 30, 40]);
233        assert_eq!(result, expected);
234    }
235
236    #[test]
237    fn test_scalar_preserves_validity() {
238        let vec = PVector::new(buffer![1u32, 2, 3], Mask::from_iter([true, false, true]));
239        let result = Arithmetic::<WrappingAdd, _>::eval(vec, &10);
240        let expected = PVector::new(buffer![11u32, 12, 13], Mask::from_iter([true, false, true]));
241        assert_eq!(result, expected);
242    }
243}