vortex_compute/arithmetic/
buffer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::{Buffer, BufferMut};
5
6use crate::arithmetic::{Arithmetic, Operator};
7
8/// Implementation that attempts to downcast to a mutable buffer and operates in-place.
9impl<Op, T> Arithmetic<Op, &Buffer<T>> for Buffer<T>
10where
11    T: Copy,
12    BufferMut<T>: for<'a> Arithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
13    for<'a> &'a Buffer<T>: Arithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
14{
15    type Output = Buffer<T>;
16
17    fn eval(self, rhs: &Buffer<T>) -> Self::Output {
18        match self.try_into_mut() {
19            Ok(lhs) => lhs.eval(rhs),
20            Err(lhs) => (&lhs).eval(rhs), // (&lhs) to delegate to borrowed impl
21        }
22    }
23}
24
25/// Implementation that operates in-place over a mutable buffer.
26impl<Op, T> Arithmetic<Op, &Buffer<T>> for BufferMut<T>
27where
28    T: Copy + num_traits::Zero,
29    Op: Operator<T>,
30{
31    type Output = Buffer<T>;
32
33    fn eval(self, rhs: &Buffer<T>) -> Self::Output {
34        assert_eq!(self.len(), rhs.len());
35
36        let mut i = 0;
37        self.map_each_in_place(|a| {
38            // SAFETY: lengths are equal, so index is in bounds
39            let b = unsafe { *rhs.get_unchecked(i) };
40            i += 1;
41
42            Op::apply(&a, &b)
43        })
44        .freeze()
45    }
46}
47
48/// Implementation that allocates a new output buffer.
49impl<Op, T> Arithmetic<Op> for &Buffer<T>
50where
51    Op: Operator<T>,
52{
53    type Output = Buffer<T>;
54
55    fn eval(self, rhs: &Buffer<T>) -> Self::Output {
56        assert_eq!(self.len(), rhs.len());
57        Buffer::<T>::from_trusted_len_iter(
58            self.iter().zip(rhs.iter()).map(|(a, b)| Op::apply(a, b)),
59        )
60    }
61}
62
63/// Implementation that attempts to downcast to a mutable buffer and operates in-place against
64/// a scalar RHS value.
65impl<Op, T> Arithmetic<Op, &T> for Buffer<T>
66where
67    BufferMut<T>: for<'a> Arithmetic<Op, &'a T, Output = Buffer<T>>,
68    for<'a> &'a Buffer<T>: Arithmetic<Op, &'a T, Output = Buffer<T>>,
69{
70    type Output = Buffer<T>;
71
72    fn eval(self, rhs: &T) -> Self::Output {
73        match self.try_into_mut() {
74            Ok(lhs) => lhs.eval(rhs),
75            Err(lhs) => (&lhs).eval(rhs),
76        }
77    }
78}
79
80/// Implementation that operates in-place over a mutable buffer against a scalar RHS value.
81impl<Op, T> Arithmetic<Op, &T> for BufferMut<T>
82where
83    T: Copy,
84    Op: Operator<T>,
85{
86    type Output = Buffer<T>;
87
88    fn eval(self, rhs: &T) -> Self::Output {
89        self.map_each_in_place(|a| Op::apply(&a, rhs)).freeze()
90    }
91}
92
93/// Implementation that allocates a new output buffer operating against a scalar RHS value.
94impl<Op, T> Arithmetic<Op, &T> for &Buffer<T>
95where
96    Op: Operator<T>,
97{
98    type Output = Buffer<T>;
99
100    fn eval(self, rhs: &T) -> Self::Output {
101        Buffer::<T>::from_trusted_len_iter(self.iter().map(|a| Op::apply(a, rhs)))
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use vortex_buffer::buffer;
108
109    use crate::arithmetic::{Arithmetic, WrappingAdd, WrappingMul, WrappingSub};
110
111    #[test]
112    fn test_add_buffers() {
113        let left = buffer![1u32, 2, 3, 4];
114        let right = buffer![10u32, 20, 30, 40];
115
116        let result = Arithmetic::<WrappingAdd, _>::eval(left, &right);
117        assert_eq!(result, buffer![11u32, 22, 33, 44]);
118    }
119
120    #[test]
121    fn test_add_scalar() {
122        let buf = buffer![1u32, 2, 3, 4];
123        let result = Arithmetic::<WrappingAdd, _>::eval(buf, &10);
124        assert_eq!(result, buffer![11u32, 12, 13, 14]);
125    }
126
127    #[test]
128    fn test_sub_buffers() {
129        let left = buffer![10u32, 20, 30, 40];
130        let right = buffer![1u32, 2, 3, 4];
131
132        let result = Arithmetic::<WrappingSub, _>::eval(left, &right);
133        assert_eq!(result, buffer![9u32, 18, 27, 36]);
134    }
135
136    #[test]
137    fn test_sub_scalar() {
138        let buf = buffer![10u32, 20, 30, 40];
139        let result = Arithmetic::<WrappingSub, _>::eval(buf, &5);
140        assert_eq!(result, buffer![5u32, 15, 25, 35]);
141    }
142
143    #[test]
144    fn test_mul_buffers() {
145        let left = buffer![2u32, 3, 4, 5];
146        let right = buffer![10u32, 20, 30, 40];
147
148        let result = Arithmetic::<WrappingMul, _>::eval(left, &right);
149        assert_eq!(result, buffer![20u32, 60, 120, 200]);
150    }
151
152    #[test]
153    fn test_mul_scalar() {
154        let buf = buffer![1u32, 2, 3, 4];
155        let result = Arithmetic::<WrappingMul, _>::eval(buf, &10);
156        assert_eq!(result, buffer![10u32, 20, 30, 40]);
157    }
158}