vortex_compute/arithmetic/
pvector.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::BitAnd;
5
6use vortex_buffer::{Buffer, BufferMut};
7use vortex_dtype::NativePType;
8use vortex_vector::primitive::{PVector, PVectorMut};
9use vortex_vector::{VectorMutOps, VectorOps};
10
11use crate::arithmetic::{Arithmetic, Operator};
12
13/// Implementation that attempts to downcast to a mutable vector and operates in-place.
14impl<Op, T> Arithmetic<Op, &PVector<T>> for PVector<T>
15where
16    T: NativePType,
17    Op: Operator<T>,
18{
19    type Output = PVector<T>;
20
21    fn eval(self, rhs: &PVector<T>) -> Self::Output {
22        match self.try_into_mut() {
23            Ok(lhs) => Arithmetic::<Op, _>::eval(lhs, rhs),
24            Err(lhs) => Arithmetic::<Op, _>::eval(&lhs, rhs),
25        }
26    }
27}
28
29/// Implementation that operates in-place over a mutable vector.
30impl<Op, T> Arithmetic<Op, &PVector<T>> for PVectorMut<T>
31where
32    T: NativePType,
33    Op: Operator<T>,
34    BufferMut<T>: for<'a> Arithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
35{
36    type Output = PVector<T>;
37
38    fn eval(self, other: &PVector<T>) -> Self::Output {
39        assert_eq!(self.len(), other.len());
40
41        let (lhs_buffer, lhs_validity) = self.into_parts();
42
43        // TODO(ngates): based on the true count of the validity, we may wish to short-circuit here
44        //  or choose a different implementation.
45        let validity = lhs_validity.freeze().bitand(other.validity());
46        let elements = Arithmetic::<Op, _>::eval(lhs_buffer, other.elements());
47
48        PVector::new(elements, validity)
49    }
50}
51
52/// Implementation that allocates a new output vector.
53impl<Op, T> Arithmetic<Op, &PVector<T>> for &PVector<T>
54where
55    T: NativePType,
56    Op: Operator<T>,
57    for<'a> &'a Buffer<T>: Arithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
58{
59    type Output = PVector<T>;
60
61    fn eval(self, rhs: &PVector<T>) -> Self::Output {
62        assert_eq!(self.len(), rhs.len());
63
64        // TODO(ngates): based on the true count of the validity, we may wish to short-circuit here
65        //  or choose a different implementation.
66        let validity = self.validity().bitand(rhs.validity());
67
68        let elements = Arithmetic::<Op, _>::eval(self.elements(), rhs.elements());
69        PVector::new(elements, validity)
70    }
71}
72
73/// Implementation that attempts to downcast to a mutable vector and operates in-place against
74/// a scalar RHS value.
75impl<Op, T> Arithmetic<Op, &T> for PVector<T>
76where
77    T: NativePType,
78    Op: Operator<T>,
79    PVectorMut<T>: for<'a> Arithmetic<Op, &'a T, Output = PVector<T>>,
80{
81    type Output = PVector<T>;
82
83    fn eval(self, rhs: &T) -> Self::Output {
84        match self.try_into_mut() {
85            Ok(lhs) => Arithmetic::<Op, _>::eval(lhs, rhs),
86            Err(lhs) => Arithmetic::<Op, _>::eval(&lhs, rhs),
87        }
88    }
89}
90
91/// Implementation that operates in-place over a mutable vector against a scalar RHS value.
92impl<Op, T> Arithmetic<Op, &T> for PVectorMut<T>
93where
94    T: NativePType,
95    Op: Operator<T>,
96    BufferMut<T>: for<'a> Arithmetic<Op, &'a T, Output = Buffer<T>>,
97{
98    type Output = PVector<T>;
99
100    fn eval(self, rhs: &T) -> Self::Output {
101        let (lhs_buffer, lhs_validity) = self.into_parts();
102        let validity = lhs_validity.freeze();
103
104        let elements = Arithmetic::<Op, _>::eval(lhs_buffer, rhs);
105
106        PVector::new(elements, validity)
107    }
108}
109
110/// Implementation that allocates a new output vector against a scalar RHS value.
111impl<Op, T> Arithmetic<Op, &T> for &PVector<T>
112where
113    T: NativePType,
114    Op: Operator<T>,
115    for<'a> &'a Buffer<T>: Arithmetic<Op, &'a T, Output = Buffer<T>>,
116{
117    type Output = PVector<T>;
118
119    fn eval(self, rhs: &T) -> Self::Output {
120        let buffer = Arithmetic::<Op, _>::eval(self.elements(), rhs);
121        PVector::new(buffer, self.validity().clone())
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use vortex_buffer::buffer;
128    use vortex_mask::Mask;
129    use vortex_vector::VectorOps;
130    use vortex_vector::primitive::PVector;
131
132    use crate::arithmetic::{Arithmetic, WrappingAdd, WrappingMul, WrappingSub};
133
134    #[test]
135    fn test_add_pvectors() {
136        let left = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
137        let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
138
139        let result = Arithmetic::<WrappingAdd, _>::eval(left, &right);
140        assert_eq!(result.elements(), &buffer![11u32, 22, 33, 44]);
141    }
142
143    #[test]
144    fn test_add_scalar() {
145        let vec = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
146        let result = Arithmetic::<WrappingAdd, _>::eval(vec, &10);
147        assert_eq!(result.elements(), &buffer![11u32, 12, 13, 14]);
148    }
149
150    #[test]
151    fn test_add_with_nulls() {
152        let left = PVector::new(buffer![1u32, 2, 3], Mask::from_iter([true, false, true]));
153        let right = PVector::new(buffer![10u32, 20, 30], Mask::new_true(3));
154
155        let result = Arithmetic::<WrappingAdd, _>::eval(left, &right);
156        // Validity is AND'd, so if either side is null, result is null
157        assert_eq!(result.validity(), &Mask::from_iter([true, false, true]));
158        assert_eq!(result.elements(), &buffer![11u32, 22, 33]);
159    }
160
161    #[test]
162    fn test_sub_pvectors() {
163        let left = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
164        let right = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
165
166        let result = Arithmetic::<WrappingSub, _>::eval(left, &right);
167        assert_eq!(result.elements(), &buffer![9u32, 18, 27, 36]);
168    }
169
170    #[test]
171    fn test_sub_scalar() {
172        let vec = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
173        let result = Arithmetic::<WrappingSub, _>::eval(vec, &5);
174        assert_eq!(result.elements(), &buffer![5u32, 15, 25, 35]);
175    }
176
177    #[test]
178    fn test_mul_pvectors() {
179        let left = PVector::new(buffer![2u32, 3, 4, 5], Mask::new_true(4));
180        let right = PVector::new(buffer![10u32, 20, 30, 40], Mask::new_true(4));
181
182        let result = Arithmetic::<WrappingMul, _>::eval(left, &right);
183        assert_eq!(result.elements(), &buffer![20u32, 60, 120, 200]);
184    }
185
186    #[test]
187    fn test_mul_scalar() {
188        let vec = PVector::new(buffer![1u32, 2, 3, 4], Mask::new_true(4));
189        let result = Arithmetic::<WrappingMul, _>::eval(vec, &10);
190        assert_eq!(result.elements(), &buffer![10u32, 20, 30, 40]);
191    }
192
193    #[test]
194    fn test_scalar_preserves_validity() {
195        let vec = PVector::new(buffer![1u32, 2, 3], Mask::from_iter([true, false, true]));
196        let result = Arithmetic::<WrappingAdd, _>::eval(vec, &10);
197
198        assert_eq!(result.validity(), &Mask::from_iter([true, false, true]));
199        assert_eq!(result.elements(), &buffer![11u32, 12, 13]);
200    }
201}