vortex_compute/arithmetic/
buffer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::Buffer;
5use vortex_buffer::BufferMut;
6
7use crate::arithmetic::Arithmetic;
8use crate::arithmetic::Operator;
9
10/// Implementation that attempts to downcast to a mutable buffer and operates in-place.
11impl<Op, T> Arithmetic<Op, &Buffer<T>> for Buffer<T>
12where
13    T: Copy,
14    BufferMut<T>: for<'a> Arithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
15    for<'a> &'a Buffer<T>: Arithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
16{
17    type Output = Buffer<T>;
18
19    fn eval(self, rhs: &Buffer<T>) -> Self::Output {
20        match self.try_into_mut() {
21            Ok(lhs) => lhs.eval(rhs),
22            Err(lhs) => (&lhs).eval(rhs), // (&lhs) to delegate to borrowed impl
23        }
24    }
25}
26
27/// Implementation that operates in-place over a mutable buffer.
28impl<Op, T> Arithmetic<Op, &Buffer<T>> for BufferMut<T>
29where
30    T: Copy + num_traits::Zero,
31    Op: Operator<T>,
32{
33    type Output = Buffer<T>;
34
35    fn eval(self, rhs: &Buffer<T>) -> Self::Output {
36        assert_eq!(self.len(), rhs.len());
37
38        let mut i = 0;
39        self.map_each_in_place(|a| {
40            // SAFETY: lengths are equal, so index is in bounds
41            let b = unsafe { *rhs.get_unchecked(i) };
42            i += 1;
43
44            Op::apply(&a, &b)
45        })
46        .freeze()
47    }
48}
49
50/// Implementation that allocates a new output buffer.
51impl<Op, T> Arithmetic<Op> for &Buffer<T>
52where
53    Op: Operator<T>,
54{
55    type Output = Buffer<T>;
56
57    fn eval(self, rhs: &Buffer<T>) -> Self::Output {
58        assert_eq!(self.len(), rhs.len());
59        Buffer::<T>::from_trusted_len_iter(
60            self.iter().zip(rhs.iter()).map(|(a, b)| Op::apply(a, b)),
61        )
62    }
63}
64
65/// Implementation that attempts to downcast to a mutable buffer and operates in-place against
66/// a scalar RHS value.
67impl<Op, T> Arithmetic<Op, &T> for Buffer<T>
68where
69    BufferMut<T>: for<'a> Arithmetic<Op, &'a T, Output = Buffer<T>>,
70    for<'a> &'a Buffer<T>: Arithmetic<Op, &'a T, Output = Buffer<T>>,
71{
72    type Output = Buffer<T>;
73
74    fn eval(self, rhs: &T) -> Self::Output {
75        match self.try_into_mut() {
76            Ok(lhs) => lhs.eval(rhs),
77            Err(lhs) => (&lhs).eval(rhs),
78        }
79    }
80}
81
82/// Implementation that operates in-place over a mutable buffer against a scalar RHS value.
83impl<Op, T> Arithmetic<Op, &T> for BufferMut<T>
84where
85    T: Copy,
86    Op: Operator<T>,
87{
88    type Output = Buffer<T>;
89
90    fn eval(self, rhs: &T) -> Self::Output {
91        self.map_each_in_place(|a| Op::apply(&a, rhs)).freeze()
92    }
93}
94
95/// Implementation that allocates a new output buffer operating against a scalar RHS value.
96impl<Op, T> Arithmetic<Op, &T> for &Buffer<T>
97where
98    Op: Operator<T>,
99{
100    type Output = Buffer<T>;
101
102    fn eval(self, rhs: &T) -> Self::Output {
103        Buffer::<T>::from_trusted_len_iter(self.iter().map(|a| Op::apply(a, rhs)))
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use vortex_buffer::buffer;
110
111    use crate::arithmetic::Arithmetic;
112    use crate::arithmetic::WrappingAdd;
113    use crate::arithmetic::WrappingMul;
114    use crate::arithmetic::WrappingSub;
115
116    #[test]
117    fn test_add_buffers() {
118        let left = buffer![1u32, 2, 3, 4];
119        let right = buffer![10u32, 20, 30, 40];
120
121        let result = Arithmetic::<WrappingAdd, _>::eval(left, &right);
122        assert_eq!(result, buffer![11u32, 22, 33, 44]);
123    }
124
125    #[test]
126    fn test_add_scalar() {
127        let buf = buffer![1u32, 2, 3, 4];
128        let result = Arithmetic::<WrappingAdd, _>::eval(buf, &10);
129        assert_eq!(result, buffer![11u32, 12, 13, 14]);
130    }
131
132    #[test]
133    fn test_sub_buffers() {
134        let left = buffer![10u32, 20, 30, 40];
135        let right = buffer![1u32, 2, 3, 4];
136
137        let result = Arithmetic::<WrappingSub, _>::eval(left, &right);
138        assert_eq!(result, buffer![9u32, 18, 27, 36]);
139    }
140
141    #[test]
142    fn test_sub_scalar() {
143        let buf = buffer![10u32, 20, 30, 40];
144        let result = Arithmetic::<WrappingSub, _>::eval(buf, &5);
145        assert_eq!(result, buffer![5u32, 15, 25, 35]);
146    }
147
148    #[test]
149    fn test_mul_buffers() {
150        let left = buffer![2u32, 3, 4, 5];
151        let right = buffer![10u32, 20, 30, 40];
152
153        let result = Arithmetic::<WrappingMul, _>::eval(left, &right);
154        assert_eq!(result, buffer![20u32, 60, 120, 200]);
155    }
156
157    #[test]
158    fn test_mul_scalar() {
159        let buf = buffer![1u32, 2, 3, 4];
160        let result = Arithmetic::<WrappingMul, _>::eval(buf, &10);
161        assert_eq!(result, buffer![10u32, 20, 30, 40]);
162    }
163}